In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

def generate_dataset(seq_length, num_sample, vocab_size):
  inputs = torch.randint(1, vocab_size, (num_sample, seq_length))
  outputs = inputs.clone()
  return TensorDataset(inputs, outputs)

class Encoder(nn.Module):
  def __init__(self, input_size, hidden_size):
    super(Encoder, self).__init__()
    self.hidden_size = hidden_size
    self.linear = nn.Linear(input_size, hidden_size)
    self.activation = nn.Tanh()
  
  def forward(self, input_seq):
    batch_size, seq_length = input_seq.size()
    hidden = torch.zeros(batch_size, self.hidden_size).to(input_seq.device)

    for char_idx in range(seq_length):
      x_t = nn.functional.one_hot(input_seq[:, char_idx], num_classes = self.linear.in_features).float().to(input_seq.device)
      hidden = self.activation(self.linear(x_t) + hidden)
    return hidden
class Decoder(nn.Module):
  def __init__(self,input_size, hidden_size, output_size):
    super(Decoder, self).__init__()
    self.hidden_size = hidden_size
    self.input_size = input_size
    self.output_size = output_size

    #self.i2h = nnLinear(input_size, hidden_size) ####  RNN
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.activation = nn.Tanh()
    self.linear2 = nn.Linear(hidden_size, output_size)
  def forward(self, target_seq, hidden):
    batch_size, seq_len = target_seq.size()
    outputs = torch.zeros(batch_size, seq_len, self.output_size).to(target_seq.device)

    for char_idx in range(seq_len):
      if char_idx == 0:
        previous_y = torch.zeros(batch_size, self.input_size).to(target_seq.device)
      else : 
        y_prev = target_seq[:, char_idx - 1]
        previous_y = nn.functional.one_hot(y_prev, num_classes=self.input_size).float().to(target_seq.device)

      hidden = self.activation(self.linear1(previous_y) + hidden)
      output = self.linear2(hidden)
      outputs[:, char_idx, :] = output
    return outputs
class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder):
    super(Seq2Seq, self).__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, input_seq, target_seq):
    encoder_hidden = self.encoder(input_seq)
    decoder_hidden = self.decoder(target_seq, encoder_hidden)
    return decoder_hidden


def train_model(model, dataloader, criterion, optimizer, num_epochs, device):
  model.to(device)
  for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0

    for inputs, targets in dataloader:
      # inputs.shage - batch_size, sequence_length
      inputs, targets = inputs.to(device), targets.to(device)
      optimizer.zero_grad()

      outputs = model(inputs, targets)
      outputs = outputs.view(-1, outputs.size(-1))
      targets = targets.view(-1)

      loss = criterion(outputs, targets)
      loss.backward()
      optimizer.step()

      epoch_loss += loss.item()
  
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch : {epoch}, Loss: {avg_loss:.4f}")


def evaluate_model(model, dataloader, device):
  model.eval()

  correct = 0
  total = 0
  with torch.no_grad():
    for inputs, targets in dataloader:
      inputs, targets = inputs.to(device), targets.to(device)

      outputs = model(inputs, targets)
      predicted = torch.argmax(outputs, dim=2)
      correct += (predicted == targets).sum().item()
      total += targets.numel()
    acc = correct / total
    return acc


seq_length = 10
num_samples = 1000
vocab_size = 5
input_size = vocab_size
hidden_size = 130
output_size = vocab_size
learning_rate = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
dataset = generate_dataset(seq_length, num_samples, vocab_size)
dataloader = DataLoader(dataset, batch_size = 64, shuffle=True)


encoder = Encoder(input_size, hidden_size)
decoder = Decoder(input_size, hidden_size, output_size)
model = Seq2Seq(encoder, decoder).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), learning_rate)


train_model(model, dataloader, criterion, optimizer, num_epochs=155, device=device)

acc = evaluate_model(model, dataloader, device)
print(f"Traning Accuracy: {acc*100:.2f}%\n")
with torch.no_grad():
  test_input, test_target = dataset[0]
  test_input = test_input.unsqueeze(0).to(device)
  test_target = test_target.unsqueeze(0).to(device)

  output = model(test_input, test_target)

  predicted = torch.argmax(output, dim = 2)
  print("Sample Input Sequence : ", test_input.squeeze().tolist())
  print("Sample Target Sequence:", test_target.squeeze().tolist())
  print("Predicted Sequence    : ", predicted.squeeze().tolist())