In [9]:
## Implementing RNN using RNNCell
## Initializing inputs and models. Insert manual_seeds carefully. We will be comparing outputs
## to see if our implementation is correct

import torch.nn as nn
import torch

input_size = 32
hidden_size = 16
num_layers = 1
bidirectional = True
cell_type = 'GRU'

torch.manual_seed(3407)
rnncell = []
if bidirectional:
  rnncell_rev = []
for layer in range(num_layers):
  ip_arg = input_size if layer == 0 else 2*hidden_size if bidirectional else hidden_size
  if cell_type == 'Vanilla_RNN':
    rnncell.append(nn.RNNCell(input_size=ip_arg, hidden_size=hidden_size))
  elif cell_type == 'GRU':
    rnncell.append(nn.GRUCell(input_size=ip_arg, hidden_size=hidden_size))
  
  if bidirectional:
    if cell_type == 'Vanilla_RNN':
      rnncell_rev.append(nn.RNNCell(input_size=ip_arg, hidden_size=hidden_size))
    elif cell_type == 'GRU':
      rnncell_rev.append(nn.GRUCell(input_size=ip_arg, hidden_size=hidden_size))
  


torch.manual_seed(3407)
if cell_type == 'Vanilla_RNN':
  rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
elif cell_type == 'GRU':
  rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)

## total rnn params
# rnn_total_params = 0
# for param in rnn.parameters():
#   rnn_total_params += param.numel()

## total rnncell params
# rnncell_total_params = 0
# for cell in (rnncell + rnncell_rev):
#   for param in cell.parameters():
#     rnncell_total_params += param.numel()


## Comparing initialized model weights to check if we have same initialization
# param_list = ['weight_ih', 'weight_hh']
# for param in param_list:
#   w1 = getattr(rnncell[0], param)
#   w2 = getattr(rnncell_rev[0], param)
#   w3 = getattr(rnncell[1], param)
#   w4 = getattr(rnncell_rev[1], param)

#   w5 = getattr(rnn, f'{param}_l0')
#   w6 = getattr(rnn, f'{param}_l0_reverse')
#   w7 = getattr(rnn, f'{param}_l1')
#   w8 = getattr(rnn, f'{param}_l1_reverse')
  
  
#   print(((w1-w5)**2).mean(), ((w2-w6)**2).mean(), ((w3-w7)**2).mean(), ((w4-w8)**2).mean()) # should be zero
#   print(((w1-w2)**2).mean(), ((w1-w6)**2).mean())

torch.manual_seed(3407)
ip = torch.randn((5,input_size))
rnn_op,_ = rnn(ip)

In [6]:
## Single Layer Unidirectional RNN using RNNCell
h = []
h.append(torch.zeros((hidden_size))) ## initial hidden state
for t in range(len(ip)):
  h.append(rnncell[0](ip[t], h[t]))
h = torch.stack(h[1:])

## Compare outputs. Should be 0 (or something like e^{-15} type super small value ~ effectively zero)
print(((rnn_op - h)**2).mean())

tensor(1.2998e-15, grad_fn=<MeanBackward0>)


In [8]:
## MultiLayer Unidirectional RNN using RNNCell. (Left to Right Unfolding)
h = []
h_prev = [torch.zeros((hidden_size))]*num_layers ## initial hidden state
for t in range(len(ip)):
  h_t = []
  for layer in range(num_layers):
    if layer == 0:
      h_t.append(rnncell[layer](ip[t], h_prev[layer]))
    else:
      h_t.append(rnncell[layer](h_t[layer-1], h_prev[layer]))
  h.append(h_t[num_layers-1])
  h_prev = h_t

h = torch.stack(h)
print(((rnn_op-h)**2).mean())

tensor(3.1135e-16, grad_fn=<MeanBackward0>)


In [10]:
# Single Layer Bi-Directional RNN using RNNCell
h, h_rev = [],[]
h.append(torch.zeros((hidden_size))); h_rev.append(torch.zeros((hidden_size)))
for t in range(len(ip)):
  h.append(rnncell[0](ip[t],h[t]))
  h_rev.append(rnncell_rev[0](ip[len(ip)-1-t],h_rev[t]))
h = h[1:]; h_rev = h_rev[1:]
h = torch.stack([torch.cat([x,y]) for x,y in zip(h,reversed(h_rev))])
print(((rnn_op-h)**2).mean())

tensor(1.3092e-15, grad_fn=<MeanBackward0>)


In [4]:
# Multi Layer Bi-Directional RNN using RNNCell (Down to Up unfolding)
h = []
# h_prev = [torch.zeros((32))]*num_layers ## initial hidden state
for layer in range(num_layers):
  h_l = []; h_l.append(torch.zeros((hidden_size)))
  h_l_rev = []; h_l_rev.append(torch.zeros((hidden_size)))
  for t in range(len(ip)):
    if layer == 0:
      h_l.append(rnncell[layer](ip[t], h_l[t]))
      h_l_rev.append(rnncell_rev[layer](ip[len(ip)-1-t], h_l_rev[t]))
    else:
      h_l_input = torch.cat([h_below[t], h_below_rev[len(ip)-1-t]]) # torch.cat([h_below_rev[len(ip)-1-t], h_below[t]])
      h_l_rev_input = torch.cat([h_below[len(ip)-1-t], h_below_rev[t]]) # torch.cat([h_below_rev[t], h_below[len(ip)-1-t]])
      h_l.append(rnncell[layer](h_l_input, h_l[t]))
      h_l_rev.append(rnncell_rev[layer](h_l_rev_input, h_l_rev[t]))
    
  h_below = h_l[1:] ## below = one layer below; prev = one time step previous
  h_below_rev = h_l_rev[1:]

h_l = h_l[1:]; h_l_rev = h_l_rev[1:]
h = torch.stack([torch.cat([x,y]) for x,y in zip(h_l,reversed(h_l_rev))])

print(((rnn_op-h)**2).mean())


tensor(8.4273e-16, grad_fn=<MeanBackward0>)
