In [2]:
import torch 
import torch.nn as nn 
import torch.functional as F 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device:{device}")

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size,hidden_size)
        
        self.gru = nn.GRU(hidden_size,hidden_size)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1,1,-1)
        output = embedded
        output, hidden = self.gru(output,hidden)
        
        return output, hidden 
    
    def initHidden(self):
        return torch.zeros(1,1, self.hidden_size, device=device)

device:cuda


In [3]:
class DecoderRNN(nn.Module):
    def __init__(self,hidden_size, output_size):
        super(DecoderRNN,self).__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(output_size,hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim = 1)
        
    def forward(self, input, hidden):
        output = self.embedding(input).view(1,1,-1)
        output = F.relu(output)
        output, hidden = self.gru(output,hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden 
    
    def initHidden(self):
        return torch.zeros(1,1, self.hidden_size, device=device)
    