In [3]:
import torch
import numpy as np
import torch.nn.functional as F
import sys

In [4]:
inputString = [2,45,30,55,10]
outputString = [45,30,55,10,1]

In [5]:
numFeatures= 100
vocabSize=80

In [6]:
embeddings = []
for i in range(len(inputString)):
  x=np.random.randn(numFeatures,1)
  embeddings.append(x)

In [7]:
embeddings[0]

array([[ 0.20772516],
       [-0.81063148],
       [ 1.50362721],
       [-0.57405679],
       [-1.96460283],
       [-1.99446029],
       [ 0.13580091],
       [ 0.64470894],
       [-1.74638023],
       [-0.46962407],
       [ 0.19451241],
       [ 0.02340528],
       [ 0.86344434],
       [-1.56646334],
       [ 0.88270315],
       [-1.10215462],
       [ 0.80593327],
       [-0.36996518],
       [-0.66560058],
       [-0.88165356],
       [-0.13666683],
       [ 0.82212044],
       [ 0.22932528],
       [ 0.09085516],
       [-1.00262998],
       [-0.45574193],
       [ 0.62028393],
       [-0.53315786],
       [ 0.01270585],
       [ 0.03115569],
       [-0.08032797],
       [-0.55654269],
       [ 1.20972335],
       [-1.75759275],
       [ 1.99320867],
       [-0.88539996],
       [ 0.16901367],
       [ 1.12533927],
       [-1.64170586],
       [ 1.97624444],
       [ 0.08666987],
       [ 0.35096204],
       [ 1.18536575],
       [ 1.86179144],
       [ 1.30178689],
       [ 0

In [8]:
len(embeddings)

5

In [9]:
def getOneHot(idx):
  one_hot=np.zeros((vocabSize,1))
  one_hot[idx]=1
  return one_hot


In [10]:
print(getOneHot(2))

[[0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]


In [11]:
numUnits=50
h0=torch.tensor(np.zeros((numUnits,1)))
Wh=torch.tensor(np.random.uniform(0,1,(numUnits,numUnits)),requires_grad=True)
Wx=torch.tensor(np.random.uniform(0,1,(numUnits,numFeatures)),requires_grad=True)
Wy=torch.tensor(np.random.uniform(0,1,(vocabSize,numUnits)),requires_grad=True)

In [12]:
print(Wh.shape,Wx.shape,Wy.shape)

torch.Size([50, 50]) torch.Size([50, 100]) torch.Size([80, 50])


In [13]:
def stepForward(xt,Wx,Wy,Wh,prevMem):
  x_frd = torch.matmul(Wx,torch.from_numpy(xt))
  h_frd = torch.matmul(Wh,prevMem)
  ht = torch.tanh(x_frd+h_frd)
  yt_hat = F.softmax(torch.matmul(Wy,ht),dim=0)
  return ht,yt_hat

In [14]:
ht,yt_hat = stepForward(embeddings[0],Wx,Wy,Wh,h0)

In [15]:
ht.shape

torch.Size([50, 1])

In [16]:
def fullForwardRNN(X,Wx,Wh,Wy,prevMem):
  y_hat = []
  for t in range(len(X)):
    ht,yt_hat = stepForward(X[t],Wx,Wy,Wh,prevMem)
    prevMem=ht
    y_hat.append(yt_hat)
  return y_hat


In [17]:
y_hat = fullForwardRNN(embeddings,Wx,Wh,Wy,h0)

In [18]:
len(y_hat)

5

In [19]:
def computeLoss(y,y_hat):
  loss=0
  for yi,yi_hat in zip(y,y_hat):
    Li=-torch.log2(yi_hat[yi==1])
    loss += Li
  return loss/len(y)

In [20]:
y = []
for idx in outputString:
  y.append(getOneHot(idx))

In [21]:
print(computeLoss(y,y_hat))

tensor([8.9742], dtype=torch.float64, grad_fn=<DivBackward0>)


In [22]:
def updateParams(Wx,Wh,Wy,dWx,dWh,dWy,lr):
  with torch.no_grad():
    Wx-= lr*dWx
    Wh-= lr*dWh
    Wy-= lr*dWy
  return Wx,Wh,Wy


In [27]:
def trainRNN(X,y,Wx,Wh,Wy,prevMem,lr,nepoch):
  losses = []
  for epoch in range(nepoch):
    y_hat = fullForwardRNN(X,Wx,Wh,Wy,prevMem)
    loss = computeLoss(y,y_hat)
    loss.backward()
    losses.append(loss)
    print("loss after epoch %d is %f" %(epoch,loss))
    sys.stdout.flush()
    dWx=Wx.grad.data
    dWh=Wh.grad.data
    dWy=Wy.grad.data
    Wx,Wh,Wy=updateParams(Wx,Wh,Wy,dWx,dWh,dWy,lr)
    Wx.grad.data.zero_()
    Wh.grad.data.zero_()
    Wy.grad.data.zero_()
  return Wx,Wh,Wy,losses

In [28]:
Wx,Wh,Wy,losses=trainRNN(embeddings,y,Wx,Wh,Wy,h0,0.001,100)

loss after epoch 0 is 8.974202
loss after epoch 1 is 8.899103
loss after epoch 2 is 8.824780
loss after epoch 3 is 8.751247
loss after epoch 4 is 8.678520
loss after epoch 5 is 8.606611
loss after epoch 6 is 8.535533
loss after epoch 7 is 8.465298
loss after epoch 8 is 8.395916
loss after epoch 9 is 8.327397
loss after epoch 10 is 8.259750
loss after epoch 11 is 8.192983
loss after epoch 12 is 8.127102
loss after epoch 13 is 8.062113
loss after epoch 14 is 7.998021
loss after epoch 15 is 7.934829
loss after epoch 16 is 7.872537
loss after epoch 17 is 7.811148
loss after epoch 18 is 7.750661
loss after epoch 19 is 7.691074
loss after epoch 20 is 7.632384
loss after epoch 21 is 7.574588
loss after epoch 22 is 7.517680
loss after epoch 23 is 7.461654
loss after epoch 24 is 7.406504
loss after epoch 25 is 7.352220
loss after epoch 26 is 7.298795
loss after epoch 27 is 7.246217
loss after epoch 28 is 7.194478
loss after epoch 29 is 7.143564
loss after epoch 30 is 7.093465
loss after epoch 3