In [None]:
import numpy as np
import torch
import torch.nn.functional as torchFunctional
import sys

In [None]:
input_string = [2, 45, 30, 55, 10]
output_string = [45, 30, 55, 10, 1]

In [None]:
num_features = 100 #number of features of our embeddings
vocab_size = 80

In [None]:
embeddings = list()
for i in range(len(input_string)):
  x = np.random.randn(num_features, 1)
  embeddings.append(x)

In [None]:
embeddings[0].shape

(100, 1)

In [None]:
len(embeddings)

5

In [None]:
def get_one_hot(idx):
  one_hot = np.zeros((vocab_size, 1))
  one_hot[idx] = 1
  return one_hot

In [None]:
num_units = 50
h0 = torch.tensor(np.zeros((num_units, 1)))
Wh = torch.tensor(np.random.uniform(0, 1, (num_units, num_units)), requires_grad=True)
Wx = torch.tensor(np.random.uniform(0, 1, (num_units, num_features)), requires_grad=True)
Wy = torch.tensor(np.random.uniform(0, 1, (vocab_size, num_units)), requires_grad=True)

In [None]:
print(Wh.shape, Wx.shape, Wh.shape, h0.shape)

torch.Size([50, 50]) torch.Size([50, 100]) torch.Size([50, 50]) torch.Size([50, 1])


In [None]:
def stepForward(xt, Wx, Wh, Wy, prev_memory):
  x_frwrd = torch.matmul(Wx, torch.from_numpy(xt))
  h_frwrd = torch.matmul(Wh, prev_memory)
  ht = torch.tanh(x_frwrd + h_frwrd)
  yt_hat = torchFunctional.softmax(torch.matmul(Wy, ht), dim = 0)
  return ht, yt_hat

Testing for just one embedding input

In [None]:
ht, yt_hat = stepForward(embeddings[0], Wx, Wh, Wy, h0)

In [None]:
ht.shape

torch.Size([50, 1])

In [None]:
yt_hat.shape

torch.Size([80, 1])

In [None]:
yt_hat.sum()

tensor(1., dtype=torch.float64, grad_fn=<SumBackward0>)

In [None]:
def full_forward_RNN(X, Wx, Wh, Wy, prev_memory):
  y_hat = []
  for t in range(len(X)): #X is the list of embeddings
    ht, yt_hat = stepForward(X[t], Wx, Wh, Wy, prev_memory)
    prev_memory = ht
    y_hat.append(yt_hat)
  return y_hat

In [None]:
y_hat = full_forward_RNN(embeddings, Wx, Wh, Wy, h0)

In [None]:
len(y_hat)

5

In [None]:
y_hat[0].shape

torch.Size([80, 1])

In [None]:
def compute_loss(y, y_hat):
  loss = 0
  for yi, yi_hat in zip(y, y_hat):
    Li = -torch.log2(yi_hat[yi == 1])
    loss += Li
  return loss / len(y)

In [None]:
y = list()
for idx in output_string:
  y.append(get_one_hot(idx))

In [None]:
print(compute_loss(y, y_hat))

tensor([9.4092], dtype=torch.float64, grad_fn=<DivBackward0>)


In [None]:
def update_params(Wx, Wh, Wy, dWx, dWh, dWy, lr):
  with torch.no_grad():
    Wx -= lr * dWx
    Wh -= lr * dWh
    Wy -= lr * dWy

  return Wx, Wh, Wy

In [1]:
def train_RNN(X, y, Wx, Wh, Wy, prev_memory, lr, n_epochs):
  losses = []
  for epoch in range(n_epochs):
    y_hat = full_forward_RNN(X, Wx, Wh, Wy, prev_memory)
    loss = compute_loss(y, y_hat)
    loss.backward()
    losses.append(loss)
    print("Loss after epoch=%d: %f" %(epoch, loss))
    sys.stdout.flush()
    dWx = Wx.grad.data
    dWh = Wh.grad.data
    dWy = Wy.grad.data
    Wx, Wh, Wy = update_params(Wx, Wh, Wy, dWx, dWh, dWy, lr)
    Wx.grad.data.zero_()
    Wh.grad.data.zero_()
    Wy.grad.data.zero_()

  return Wx, Wh, Wy, losses

In [None]:
Wx, Wh, Wy, losses = train_RNN(embeddings, y, Wx, Wh, Wy, h0, 0.001, 100)

Loss after epoch=0: 9.409180
Loss after epoch=1: 9.378634
Loss after epoch=2: 9.348218
Loss after epoch=3: 9.317933
Loss after epoch=4: 9.287780
Loss after epoch=5: 9.257760
Loss after epoch=6: 9.227871
Loss after epoch=7: 9.198115
Loss after epoch=8: 9.168489
Loss after epoch=9: 9.138993
Loss after epoch=10: 9.109625
Loss after epoch=11: 9.080384
Loss after epoch=12: 9.051267
Loss after epoch=13: 9.022272
Loss after epoch=14: 8.993398
Loss after epoch=15: 8.964641
Loss after epoch=16: 8.936000
Loss after epoch=17: 8.907472
Loss after epoch=18: 8.879055
Loss after epoch=19: 8.850747
Loss after epoch=20: 8.822544
Loss after epoch=21: 8.794444
Loss after epoch=22: 8.766446
Loss after epoch=23: 8.738547
Loss after epoch=24: 8.710744
Loss after epoch=25: 8.683036
Loss after epoch=26: 8.655419
Loss after epoch=27: 8.627893
Loss after epoch=28: 8.600455
Loss after epoch=29: 8.573102
Loss after epoch=30: 8.545833
Loss after epoch=31: 8.518647
Loss after epoch=32: 8.491540
Loss after epoch=33: