In [2]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

(seq_len, batch_size, hidden_size)

In [3]:
out_pre = torch.ones(3, 5, 2, dtype=torch.float64)
out_pre[1,:,:] *= 2
out_pre[2,:,:] *= 3
hidden_post = torch.randint(-10, -1, size=(5,2), dtype=torch.float64)

In [4]:
print(hidden_post)
print(out_pre)

tensor([[-4., -8.],
        [-4., -6.],
        [-3., -7.],
        [-7., -5.],
        [-5., -7.]], dtype=torch.float64)
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.]],

        [[3., 3.],
         [3., 3.],
         [3., 3.],
         [3., 3.],
         [3., 3.]]], dtype=torch.float64)


## Dot product

In [4]:
out_pre * hidden_post

tensor([[[ -4.,  -3.],
         [ -7., -10.],
         [ -9.,  -9.],
         [ -8.,  -3.],
         [ -3.,  -5.]],

        [[ -8.,  -6.],
         [-14., -20.],
         [-18., -18.],
         [-16.,  -6.],
         [ -6., -10.]],

        [[-12.,  -9.],
         [-21., -30.],
         [-27., -27.],
         [-24.,  -9.],
         [ -9., -15.]]], dtype=torch.float64)

In [5]:
pre_soft = torch.sum((out_pre * hidden_post), 2)

In [6]:
pre_soft

tensor([[ -7., -17., -18., -11.,  -8.],
        [-14., -34., -36., -22., -16.],
        [-21., -51., -54., -33., -24.]], dtype=torch.float64)

## Feedforward NN

In [7]:
hidden_post

tensor([[ -4.,  -3.],
        [ -7., -10.],
        [ -9.,  -9.],
        [ -8.,  -3.],
        [ -3.,  -5.]], dtype=torch.float64)

In [8]:
w_a, u_a, v_a = torch.randn(2,2, dtype=torch.float64), torch.randn(2,2, dtype=torch.float64), torch.randn(2,1, dtype=torch.float64)

In [9]:
w_a

tensor([[ 0.6756, -1.3255],
        [ 0.8864,  0.1925]], dtype=torch.float64)

In [10]:
out_pre

tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.]],

        [[3., 3.],
         [3., 3.],
         [3., 3.],
         [3., 3.],
         [3., 3.]]], dtype=torch.float64)

In [11]:
torch.matmul(out_pre, w_a)

tensor([[[ 1.5619, -1.1331],
         [ 1.5619, -1.1331],
         [ 1.5619, -1.1331],
         [ 1.5619, -1.1331],
         [ 1.5619, -1.1331]],

        [[ 3.1238, -2.2662],
         [ 3.1238, -2.2662],
         [ 3.1238, -2.2662],
         [ 3.1238, -2.2662],
         [ 3.1238, -2.2662]],

        [[ 4.6857, -3.3992],
         [ 4.6857, -3.3992],
         [ 4.6857, -3.3992],
         [ 4.6857, -3.3992],
         [ 4.6857, -3.3992]]], dtype=torch.float64)

In [12]:
out_pre.shape

torch.Size([3, 5, 2])

In [13]:
torch.matmul(hidden_post, u_a)

tensor([[ 6.0929,  1.9116],
        [17.2361,  7.2713],
        [16.8229,  6.1607],
        [ 8.0342,  1.3435],
        [ 8.3754,  3.7067]], dtype=torch.float64)

In [14]:
pre_soft = torch.matmul(out_pre, w_a) + torch.matmul(hidden_post, u_a)
pre_soft = torch.tanh(pre_soft).matmul(v_a).squeeze()

In [15]:
F.softmax(pre_soft, dim=0)

tensor([[0.2542, 0.3333, 0.3330, 0.2674, 0.3070],
        [0.3417, 0.3333, 0.3331, 0.3532, 0.3158],
        [0.4041, 0.3334, 0.3338, 0.3794, 0.3772]], dtype=torch.float64)

## Softmax

In [16]:
post_soft = F.softmax(pre_soft, dim=0)

In [17]:
post_soft

tensor([[0.2542, 0.3333, 0.3330, 0.2674, 0.3070],
        [0.3417, 0.3333, 0.3331, 0.3532, 0.3158],
        [0.4041, 0.3334, 0.3338, 0.3794, 0.3772]], dtype=torch.float64)

In [18]:
out_pre.permute(2,0,1)

tensor([[[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]],

        [[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]]], dtype=torch.float64)

In [19]:
attention = torch.sum(out_pre.permute(2, 0, 1) * post_soft, 1)

In [20]:
attention.t()

tensor([[2.1498, 2.1498],
        [2.0001, 2.0001],
        [2.0008, 2.0008],
        [2.1121, 2.1121],
        [2.0701, 2.0701]], dtype=torch.float64)

# Testing architectures

In [21]:
from amphibian.networks.LSTM import LSTMModel
from amphibian.networks.attention import AttentionModel

In [22]:
batch_size = 256
seq_len = 30
input_size = 120
my_batch = torch.randn(batch_size, seq_len, input_size)

## LSTM

In [23]:
lstm = LSTMModel(batch_size=batch_size, seq_len=seq_len, input_size=input_size, hidden_size=10, n_outputs=3,
                 num_layers=2, dropout=0.1)

In [24]:
lstm_out = lstm(my_batch.permute(1,0,2))

## Attention

We test by checking if all gradients change during a backward pass. We define a dummy target tensor and use simple L2 loss.

In [25]:
y = torch.randn(batch_size, 3)
loss_fun = nn.MSELoss()

#### RNN dotprod

In [26]:
attn_rnn = AttentionModel(batch_size=batch_size, seq_len=seq_len, input_size=input_size, hidden_size=10, n_outputs=3,
                          num_layers=2, dropout=0.1)

optimizer = optim.SGD(attn_rnn.parameters(), lr=1)
attn_rnn_out = attn_rnn(my_batch.permute(1, 0, 2))

In [27]:
list(attn_rnn.parameters())[-1]

Parameter containing:
tensor([0.2920, 0.2192, 0.2202], requires_grad=True)

In [28]:
pre_opt = copy.deepcopy(attn_rnn.state_dict())

In [29]:
for name, ten in attn_rnn.state_dict().items():
    print(name, ten.shape)

recurrent_pre.weight_ih_l0 torch.Size([10, 120])
recurrent_pre.weight_hh_l0 torch.Size([10, 10])
recurrent_pre.bias_ih_l0 torch.Size([10])
recurrent_pre.bias_hh_l0 torch.Size([10])
recurrent_pre.weight_ih_l1 torch.Size([10, 10])
recurrent_pre.weight_hh_l1 torch.Size([10, 10])
recurrent_pre.bias_ih_l1 torch.Size([10])
recurrent_pre.bias_hh_l1 torch.Size([10])
recurrent_cell_post.weight_ih torch.Size([10, 10])
recurrent_cell_post.weight_hh torch.Size([10, 10])
recurrent_cell_post.bias_ih torch.Size([10])
recurrent_cell_post.bias_hh torch.Size([10])
fc.weight torch.Size([3, 10])
fc.bias torch.Size([3])


In [30]:
loss = loss_fun(attn_rnn_out, y)
attn_rnn.zero_grad()
loss.backward()
optimizer.step()
# Print differences in gradients
for pre, post in zip(pre_opt.items(), attn_rnn.state_dict().items()):
    print(pre[0], torch.sum(torch.abs(pre[1] - post[1])))

recurrent_pre.weight_ih_l0 tensor(0.8174)
recurrent_pre.weight_hh_l0 tensor(0.1195)
recurrent_pre.bias_ih_l0 tensor(0.1143)
recurrent_pre.bias_hh_l0 tensor(0.1143)
recurrent_pre.weight_ih_l1 tensor(0.9262)
recurrent_pre.weight_hh_l1 tensor(1.4349)
recurrent_pre.bias_ih_l1 tensor(0.5261)
recurrent_pre.bias_hh_l1 tensor(0.5261)
recurrent_cell_post.weight_ih tensor(2.8797)
recurrent_cell_post.weight_hh tensor(2.3095)
recurrent_cell_post.bias_ih tensor(0.8826)
recurrent_cell_post.bias_hh tensor(0.8826)
fc.weight tensor(1.7004)
fc.bias tensor(0.6380)


In [31]:
list(attn_rnn.parameters())[-1]

Parameter containing:
tensor([ 0.2020, -0.0618, -0.0469], requires_grad=True)

#### RNN FFNN

In [32]:
attn_rnn_ffnn = AttentionModel(batch_size=batch_size, seq_len=seq_len, input_size=input_size, hidden_size=10, n_outputs=3,
                               num_layers=2, dropout=0.1, alignment='ffnn')

optimizer = optim.SGD(attn_rnn_ffnn.parameters(), lr=1)
attn_rnn_ffnn_out = attn_rnn_ffnn(my_batch.permute(1, 0, 2))

In [33]:
list(attn_rnn_ffnn.parameters())[-1]

Parameter containing:
tensor([-2.6149e-01,  1.4842e-05,  5.5488e-02], requires_grad=True)

In [34]:
pre_opt = copy.deepcopy(attn_rnn_ffnn.state_dict())

In [35]:
for name, ten in attn_rnn_ffnn.state_dict().items():
    print(name, ten.shape)

w_a torch.Size([10, 10])
u_a torch.Size([10, 10])
v_a torch.Size([10, 1])
recurrent_pre.weight_ih_l0 torch.Size([10, 120])
recurrent_pre.weight_hh_l0 torch.Size([10, 10])
recurrent_pre.bias_ih_l0 torch.Size([10])
recurrent_pre.bias_hh_l0 torch.Size([10])
recurrent_pre.weight_ih_l1 torch.Size([10, 10])
recurrent_pre.weight_hh_l1 torch.Size([10, 10])
recurrent_pre.bias_ih_l1 torch.Size([10])
recurrent_pre.bias_hh_l1 torch.Size([10])
recurrent_cell_post.weight_ih torch.Size([10, 10])
recurrent_cell_post.weight_hh torch.Size([10, 10])
recurrent_cell_post.bias_ih torch.Size([10])
recurrent_cell_post.bias_hh torch.Size([10])
fc.weight torch.Size([3, 10])
fc.bias torch.Size([3])


In [36]:
loss = loss_fun(attn_rnn_ffnn_out, y)
attn_rnn_ffnn.zero_grad()
loss.backward()
optimizer.step()
# Print differences in gradients
for pre, post in zip(pre_opt.items(), attn_rnn_ffnn.state_dict().items()):
    print(pre[0], torch.sum(torch.abs(pre[1] - post[1])))

w_a tensor(0.1741)
u_a tensor(0.0188)
v_a tensor(0.0480)
recurrent_pre.weight_ih_l0 tensor(0.4393)
recurrent_pre.weight_hh_l0 tensor(0.0607)
recurrent_pre.bias_ih_l0 tensor(0.0657)
recurrent_pre.bias_hh_l0 tensor(0.0657)
recurrent_pre.weight_ih_l1 tensor(0.3176)
recurrent_pre.weight_hh_l1 tensor(0.6208)
recurrent_pre.bias_ih_l1 tensor(0.2742)
recurrent_pre.bias_hh_l1 tensor(0.2742)
recurrent_cell_post.weight_ih tensor(1.4279)
recurrent_cell_post.weight_hh tensor(2.1454)
recurrent_cell_post.bias_ih tensor(0.5945)
recurrent_cell_post.bias_hh tensor(0.5945)
fc.weight tensor(1.8058)
fc.bias tensor(0.4990)


In [37]:
list(attn_rnn_ffnn.parameters())[-1]

Parameter containing:
tensor([-0.0390, -0.0107, -0.2104], requires_grad=True)

#### LSTM FFNN

In [47]:
attn_lstm = AttentionModel(batch_size=batch_size, seq_len=seq_len, input_size=input_size, hidden_size=10, n_outputs=3,
                           num_layers=2, dropout=0.1, alignment='ffnn', recurrent_type='lstm')

optimizer = optim.SGD(attn_lstm.parameters(), lr=1)
attn_lstm_out = attn_lstm(my_batch.permute(1, 0, 2))

In [48]:
list(attn_lstm.parameters())[-1]

Parameter containing:
tensor([ 0.0163, -0.1729,  0.0629], requires_grad=True)

In [49]:
pre_opt = copy.deepcopy(attn_lstm.state_dict())

In [50]:
for name, ten in attn_lstm.state_dict().items():
    print(name, ten.shape)

w_a torch.Size([10, 10])
u_a torch.Size([10, 10])
v_a torch.Size([10, 1])
recurrent_pre.weight_ih_l0 torch.Size([40, 120])
recurrent_pre.weight_hh_l0 torch.Size([40, 10])
recurrent_pre.bias_ih_l0 torch.Size([40])
recurrent_pre.bias_hh_l0 torch.Size([40])
recurrent_pre.weight_ih_l1 torch.Size([40, 10])
recurrent_pre.weight_hh_l1 torch.Size([40, 10])
recurrent_pre.bias_ih_l1 torch.Size([40])
recurrent_pre.bias_hh_l1 torch.Size([40])
recurrent_cell_post.weight_ih torch.Size([40, 10])
recurrent_cell_post.weight_hh torch.Size([40, 10])
recurrent_cell_post.bias_ih torch.Size([40])
recurrent_cell_post.bias_hh torch.Size([40])
fc.weight torch.Size([3, 10])
fc.bias torch.Size([3])


In [51]:
loss = loss_fun(attn_lstm_out, y)
attn_lstm.zero_grad()
loss.backward()
optimizer.step()
# Print differences in gradients
for pre, post in zip(pre_opt.items(), attn_lstm.state_dict().items()):
    print(pre[0], torch.sum(torch.abs(pre[1] - post[1])))

w_a tensor(0.0002)
u_a tensor(2.7465e-06)
v_a tensor(3.3846e-05)
recurrent_pre.weight_ih_l0 tensor(0.0968)
recurrent_pre.weight_hh_l0 tensor(0.0025)
recurrent_pre.bias_ih_l0 tensor(0.0009)
recurrent_pre.bias_hh_l0 tensor(0.0009)
recurrent_pre.weight_ih_l1 tensor(0.0127)
recurrent_pre.weight_hh_l1 tensor(0.0086)
recurrent_pre.bias_ih_l1 tensor(0.0110)
recurrent_pre.bias_hh_l1 tensor(0.0110)
recurrent_cell_post.weight_ih tensor(0.0518)
recurrent_cell_post.weight_hh tensor(0.0533)
recurrent_cell_post.bias_ih tensor(0.0636)
recurrent_cell_post.bias_hh tensor(0.0636)
fc.weight tensor(0.1401)
fc.bias tensor(0.1688)


In [52]:
list(attn_rnn_ffnn.parameters())[-1]

Parameter containing:
tensor([-0.0390, -0.0107, -0.2104], requires_grad=True)