In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pandas as pd

In [None]:
def generate_sequences(n=128, variable_len=False, seed=13):
    basic_corners = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]])
    np.random.seed(seed)
    bases = np.random.randint(4, size=n)
    if variable_len:
        lengths = np.random.randint(3, size=n) + 2
    else:
        lengths = [4] * n
    directions = np.random.randint(2, size=n)
    points = [basic_corners[[(b + i) % 4 for i in range(4)]][slice(None, None, d*2-1)][:l] + np.random.randn(l, 2) * 0.1 for b, d, l in zip(bases, directions, lengths)]
    return points, directions

In [None]:
x = torch.randn((1,4,2)).float()
x

tensor([[[ 2.1046,  0.9141],
         [-0.5512,  1.3308],
         [-0.9752, -0.9401],
         [ 1.0307,  0.1510]]])

In [None]:
rnn_stacked = nn.RNN(input_size=2,hidden_size=2, num_layers =2,batch_first =True)

In [None]:
state = rnn_stacked.state_dict()
state

OrderedDict([('weight_ih_l0',
              tensor([[-0.0368,  0.6444],
                      [-0.1122,  0.2986]])),
             ('weight_hh_l0',
              tensor([[ 0.0425,  0.0104],
                      [ 0.0133, -0.1654]])),
             ('bias_ih_l0', tensor([0.4213, 0.4087])),
             ('bias_hh_l0', tensor([-0.0919, -0.2846])),
             ('weight_ih_l1',
              tensor([[ 0.5321, -0.4178],
                      [-0.5924,  0.5688]])),
             ('weight_hh_l1',
              tensor([[-0.3473,  0.1479],
                      [-0.4552,  0.2986]])),
             ('bias_ih_l1', tensor([-0.0691, -0.4289])),
             ('bias_hh_l1', tensor([-0.6838,  0.2752]))])

In [None]:
rnn_layer0 = nn.RNN(input_size=2,hidden_size=2,batch_first = True)
rnn_layer1 = nn.RNN(input_size=2,hidden_size=2,batch_first = True)

rnn_layer0.load_state_dict(dict(list(state.items())[:4]))
rnn_layer1.load_state_dict(dict([(k[:-1]+'0',v) for k,v in list(state.items())[4:]]))

<All keys matched successfully>

In [None]:
out0 , hidden0 = rnn_layer0(x.float())

In [None]:
print(out0.shape,hidden0.shape)

torch.Size([1, 4, 2]) torch.Size([1, 1, 2])


In [None]:
out1 , hidden1 = rnn_layer1(out0)
print(out0.shape,hidden0.shape)

torch.Size([1, 4, 2]) torch.Size([1, 1, 2])


In [None]:
out1, torch.cat([hidden0,hidden1])

(tensor([[[-0.4255, -0.4378],
          [-0.4090, -0.2910],
          [-0.6097, -0.0062],
          [-0.3619, -0.0522]]], grad_fn=<TransposeBackward1>),
 tensor([[[ 0.3620,  0.0707]],
 
         [[-0.3619, -0.0522]]], grad_fn=<CatBackward0>))

In [None]:
torch.cat([hidden0,hidden1]).shape

torch.Size([2, 1, 2])

In [None]:
out,h = rnn_stacked(x)
print(out,h)
print(out.shape,h.shape)

tensor([[[-0.4255, -0.4378],
         [-0.4090, -0.2910],
         [-0.6097, -0.0062],
         [-0.3619, -0.0522]]], grad_fn=<TransposeBackward1>) tensor([[[ 0.3620,  0.0707]],

        [[-0.3619, -0.0522]]], grad_fn=<StackBackward0>)
torch.Size([1, 4, 2]) torch.Size([2, 1, 2])


**BI-Directional RNN**

In [None]:
rnn_bid = nn.RNN(input_size=2,hidden_size=2,batch_first =True,bidirectional =True)
state  = rnn_bid.state_dict()
state

OrderedDict([('weight_ih_l0',
              tensor([[ 0.6687, -0.6328],
                      [-0.1181, -0.0797]])),
             ('weight_hh_l0',
              tensor([[-0.1407,  0.2852],
                      [ 0.3560,  0.0532]])),
             ('bias_ih_l0', tensor([-0.0754,  0.6367])),
             ('bias_hh_l0', tensor([ 0.1014, -0.2713])),
             ('weight_ih_l0_reverse',
              tensor([[ 0.6612, -0.3805],
                      [ 0.6779, -0.6743]])),
             ('weight_hh_l0_reverse',
              tensor([[-0.4944, -0.6651],
                      [-0.0349,  0.5139]])),
             ('bias_ih_l0_reverse', tensor([ 0.1482, -0.6304])),
             ('bias_hh_l0_reverse', tensor([ 0.5152, -0.4710]))])

In [None]:
rnn_forward = nn.RNN(input_size=2,hidden_size=2,batch_first = True)
rnn_reverse = nn.RNN(input_size=2,hidden_size=2,batch_first = True)


rnn_forward.load_state_dict(dict(list(state.items())[:4]))
rnn_reverse.load_state_dict( dict([(k[:-8],v)  for k, v in list(state.items())[4:]]))

<All keys matched successfully>

In [None]:
x_rev = torch.flip(x,dims = [1])       #N , L , H

print(x)
print(x_rev)

tensor([[[ 2.1046,  0.9141],
         [-0.5512,  1.3308],
         [-0.9752, -0.9401],
         [ 1.0307,  0.1510]]])
tensor([[[ 1.0307,  0.1510],
         [-0.9752, -0.9401],
         [-0.5512,  1.3308],
         [ 2.1046,  0.9141]]])


In [None]:
out, h =rnn_forward(x)
out_rev,h_rev = rnn_reverse(x_rev)

In [None]:
print(out.shape,h.shape)
print(out_rev.shape,h_rev.shape)

torch.Size([1, 4, 2]) torch.Size([1, 1, 2])
torch.Size([1, 4, 2]) torch.Size([1, 1, 2])


In [None]:
print(out)
print(out_rev)

tensor([[[ 0.6937,  0.0438],
         [-0.8537,  0.5180],
         [ 0.2324,  0.2721],
         [ 0.5814,  0.3174]]], grad_fn=<TransposeBackward1>)
tensor([[[ 0.8584, -0.4658],
         [ 0.2558, -0.8849],
         [ 0.2493, -0.9931],
         [ 0.9778, -0.6697]]], grad_fn=<TransposeBackward1>)


In [None]:
out_rev_back = torch.flip(out_rev,dims=[1])
out_rev_back

tensor([[[ 0.9778, -0.6697],
         [ 0.2493, -0.9931],
         [ 0.2558, -0.8849],
         [ 0.8584, -0.4658]]], grad_fn=<FlipBackward0>)

In [None]:
torch.cat([out,out_rev_back],dim=2).shape

torch.Size([1, 4, 4])

In [None]:
torch.cat([out,out_rev_back],dim=2) , torch.cat([h,h_rev])

(tensor([[[ 0.6937,  0.0438,  0.9778, -0.6697],
          [-0.8537,  0.5180,  0.2493, -0.9931],
          [ 0.2324,  0.2721,  0.2558, -0.8849],
          [ 0.5814,  0.3174,  0.8584, -0.4658]]], grad_fn=<CatBackward0>),
 tensor([[[ 0.5814,  0.3174]],
 
         [[ 0.9778, -0.6697]]], grad_fn=<CatBackward0>))

**LSTM**

In [None]:
b