In [326]:
import torch
import os
import numpy as np
import torchvision as tv
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
%matplotlib inline

In [327]:
seq_count = 1000
seq_len = 20
X = np.random.randint(10, size=(seq_count, seq_len), dtype=int)
y = np.zeros((seq_count, seq_len), dtype=int)
X[:10]

array([[6, 5, 6, 9, 7, 1, 0, 3, 5, 8, 5, 5, 9, 4, 5, 6, 4, 0, 1, 1],
       [9, 9, 0, 0, 9, 0, 0, 2, 3, 0, 5, 9, 9, 7, 8, 8, 9, 5, 7, 1],
       [8, 7, 2, 2, 5, 2, 6, 0, 1, 4, 0, 4, 7, 7, 3, 3, 1, 2, 3, 7],
       [2, 4, 1, 9, 4, 3, 7, 2, 4, 2, 5, 7, 6, 2, 6, 7, 8, 2, 2, 1],
       [0, 7, 9, 3, 9, 9, 6, 8, 9, 3, 4, 2, 7, 1, 7, 8, 7, 5, 1, 3],
       [3, 4, 1, 0, 2, 7, 7, 3, 7, 7, 8, 3, 7, 2, 0, 9, 8, 7, 3, 4],
       [2, 1, 4, 0, 5, 5, 5, 8, 1, 0, 0, 4, 6, 7, 6, 2, 2, 4, 2, 8],
       [9, 1, 8, 1, 3, 6, 2, 3, 1, 2, 1, 6, 4, 1, 6, 1, 9, 0, 2, 4],
       [2, 5, 3, 5, 2, 6, 1, 1, 1, 9, 6, 8, 5, 2, 5, 5, 5, 4, 1, 2],
       [4, 9, 3, 4, 5, 9, 8, 4, 0, 9, 5, 7, 7, 5, 2, 1, 9, 7, 7, 2]])

In [328]:
for i in range(seq_count):
  y[i][0] = X[i][0]
  for j in range(1, seq_len):
    num = X[i][j] + X[i][0]
    y[i][j] = num - 10 if num >= 10 else num

X[0:1], y[:1]

(array([[6, 5, 6, 9, 7, 1, 0, 3, 5, 8, 5, 5, 9, 4, 5, 6, 4, 0, 1, 1]]),
 array([[6, 1, 2, 5, 3, 7, 6, 9, 1, 4, 1, 1, 5, 0, 1, 2, 0, 6, 7, 7]]))

In [329]:
y = y[:,-1]
y.shape

(1000,)

In [330]:
X, y = torch.from_numpy(X), torch.from_numpy(y)

In [331]:
batch_size = 100

dataset = torch.utils.data.TensorDataset(X, y)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)

In [332]:
class Network(torch.nn.Module):
    def __init__(self, network_type, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.rnn = network_type(embed_dim, hidden_dim, batch_first=True)
        self.linear = torch.nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, inp):
      embedding = self.embed(inp)
      _, state = self.rnn(embedding)
      out = self.linear(state[0])
      if isinstance(self.rnn, torch.nn.LSTM):
        out = out.squeeze(0)
      return out


In [333]:
def train_model(model, loader, loss_fn, optimizer, epochs=10):
  train_losses = []
  for epoch in range(epochs):
    train_loss = 0.0

    model.train()
    for X_batch, y_batch in loader:
        optimizer.zero_grad()
        y_pred = model.forward(X_batch)  
        y_batch = y_batch
        loss = loss_fn(y_pred, y_batch)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= len(loader)
    train_losses.append(train_loss)
    print(f'Epoch: {epoch}, loss: {train_loss:.3f}')


In [337]:
vocab_size = 10
embed_dim = 64
hidden_dim = 128

model = Network(torch.nn.RNN, vocab_size, embed_dim, hidden_dim)
loss_fn1 = torch.nn.CrossEntropyLoss()
optimizer1 = torch.optim.Adam(list(model.parameters()), lr=0.001)
train_model(model, dataset_loader, loss_fn1, optimizer1, 100)

Epoch: 0, loss: 2.340
Epoch: 1, loss: 2.260
Epoch: 2, loss: 2.227
Epoch: 3, loss: 2.205
Epoch: 4, loss: 2.185
Epoch: 5, loss: 2.165
Epoch: 6, loss: 2.145
Epoch: 7, loss: 2.127
Epoch: 8, loss: 2.103
Epoch: 9, loss: 2.084
Epoch: 10, loss: 2.059
Epoch: 11, loss: 2.038
Epoch: 12, loss: 2.010
Epoch: 13, loss: 1.986
Epoch: 14, loss: 1.954
Epoch: 15, loss: 1.925
Epoch: 16, loss: 1.894
Epoch: 17, loss: 1.858
Epoch: 18, loss: 1.820
Epoch: 19, loss: 1.782
Epoch: 20, loss: 1.735
Epoch: 21, loss: 1.690
Epoch: 22, loss: 1.639
Epoch: 23, loss: 1.584
Epoch: 24, loss: 1.518
Epoch: 25, loss: 1.461
Epoch: 26, loss: 1.392
Epoch: 27, loss: 1.320
Epoch: 28, loss: 1.251
Epoch: 29, loss: 1.168
Epoch: 30, loss: 1.097
Epoch: 31, loss: 1.016
Epoch: 32, loss: 0.933
Epoch: 33, loss: 0.870
Epoch: 34, loss: 0.792
Epoch: 35, loss: 0.719
Epoch: 36, loss: 0.650
Epoch: 37, loss: 0.600
Epoch: 38, loss: 0.541
Epoch: 39, loss: 0.500
Epoch: 40, loss: 0.451
Epoch: 41, loss: 0.412
Epoch: 42, loss: 0.360
Epoch: 43, loss: 0.31

In [338]:
gru = Network(torch.nn.GRU, vocab_size, embed_dim, hidden_dim)
loss_fn2 = torch.nn.CrossEntropyLoss()
optimizer2 = torch.optim.Adam(list(gru.parameters()), lr=0.01)
train_model(gru, dataset_loader, loss_fn2, optimizer2, 100)

Epoch: 0, loss: 2.378
Epoch: 1, loss: 2.292
Epoch: 2, loss: 2.225
Epoch: 3, loss: 2.115
Epoch: 4, loss: 1.891
Epoch: 5, loss: 1.540
Epoch: 6, loss: 1.035
Epoch: 7, loss: 0.560
Epoch: 8, loss: 0.249
Epoch: 9, loss: 0.104
Epoch: 10, loss: 0.045
Epoch: 11, loss: 0.021
Epoch: 12, loss: 0.012
Epoch: 13, loss: 0.008
Epoch: 14, loss: 0.006
Epoch: 15, loss: 0.005
Epoch: 16, loss: 0.004
Epoch: 17, loss: 0.004
Epoch: 18, loss: 0.003
Epoch: 19, loss: 0.003
Epoch: 20, loss: 0.003
Epoch: 21, loss: 0.003
Epoch: 22, loss: 0.002
Epoch: 23, loss: 0.002
Epoch: 24, loss: 0.002
Epoch: 25, loss: 0.002
Epoch: 26, loss: 0.002
Epoch: 27, loss: 0.002
Epoch: 28, loss: 0.002
Epoch: 29, loss: 0.002
Epoch: 30, loss: 0.001
Epoch: 31, loss: 0.001
Epoch: 32, loss: 0.001
Epoch: 33, loss: 0.001
Epoch: 34, loss: 0.001
Epoch: 35, loss: 0.001
Epoch: 36, loss: 0.001
Epoch: 37, loss: 0.001
Epoch: 38, loss: 0.001
Epoch: 39, loss: 0.001
Epoch: 40, loss: 0.001
Epoch: 41, loss: 0.001
Epoch: 42, loss: 0.001
Epoch: 43, loss: 0.00

In [340]:
lstm = Network(torch.nn.LSTM, vocab_size, embed_dim, hidden_dim)
loss_fn3 = torch.nn.CrossEntropyLoss()
optimizer3 = torch.optim.Adam(list(lstm.parameters()), lr=0.01)
train_model(lstm, dataset_loader, loss_fn3, optimizer3, 100)

Epoch: 0, loss: 2.336
Epoch: 1, loss: 2.273
Epoch: 2, loss: 2.236
Epoch: 3, loss: 2.157
Epoch: 4, loss: 2.001
Epoch: 5, loss: 1.736
Epoch: 6, loss: 1.386
Epoch: 7, loss: 0.967
Epoch: 8, loss: 0.582
Epoch: 9, loss: 0.322
Epoch: 10, loss: 0.151
Epoch: 11, loss: 0.074
Epoch: 12, loss: 0.037
Epoch: 13, loss: 0.023
Epoch: 14, loss: 0.016
Epoch: 15, loss: 0.011
Epoch: 16, loss: 0.009
Epoch: 17, loss: 0.007
Epoch: 18, loss: 0.006
Epoch: 19, loss: 0.005
Epoch: 20, loss: 0.005
Epoch: 21, loss: 0.004
Epoch: 22, loss: 0.004
Epoch: 23, loss: 0.004
Epoch: 24, loss: 0.003
Epoch: 25, loss: 0.003
Epoch: 26, loss: 0.003
Epoch: 27, loss: 0.003
Epoch: 28, loss: 0.002
Epoch: 29, loss: 0.002
Epoch: 30, loss: 0.002
Epoch: 31, loss: 0.002
Epoch: 32, loss: 0.002
Epoch: 33, loss: 0.002
Epoch: 34, loss: 0.002
Epoch: 35, loss: 0.002
Epoch: 36, loss: 0.002
Epoch: 37, loss: 0.002
Epoch: 38, loss: 0.001
Epoch: 39, loss: 0.001
Epoch: 40, loss: 0.001
Epoch: 41, loss: 0.001
Epoch: 42, loss: 0.001
Epoch: 43, loss: 0.00