In [None]:
import torch
import walnut
import numpy as np

In [None]:
batches = 10
sequence = 8
in_channels = 3
hidden_channels = 5
num_layers = 2

X = walnut.randn((batches, sequence, in_channels))

W_in1 = walnut.randn((in_channels, hidden_channels)) * in_channels**-0.5
B_in1 = walnut.randn((hidden_channels,))
W_hidden1 = walnut.randn((hidden_channels, hidden_channels)) * hidden_channels**-0.5
B_hidden1 = walnut.randn((hidden_channels,))

W_in2 = walnut.randn((hidden_channels, hidden_channels)) * in_channels**-0.5
B_in2 = walnut.randn((hidden_channels,))
W_hidden2 = walnut.randn((hidden_channels, hidden_channels)) * hidden_channels**-0.5
B_hidden2 = walnut.randn((hidden_channels,))

t_x = torch.nn.Parameter(torch.from_numpy(X.data).float(), requires_grad=True)

t_w_in1 = torch.nn.Parameter(torch.from_numpy(W_in1.T).float(), requires_grad=True)
t_b_in1 = torch.nn.Parameter(torch.from_numpy(B_in1.data).float(), requires_grad=True)
t_w_hidden1 = torch.nn.Parameter(torch.from_numpy(W_hidden1.T).float(), requires_grad=True)
t_b_hidden1 = torch.nn.Parameter(torch.from_numpy(B_hidden1.data).float(), requires_grad=True)

t_w_in2 = torch.nn.Parameter(torch.from_numpy(W_in2.T).float(), requires_grad=True)
t_b_in2 = torch.nn.Parameter(torch.from_numpy(B_in2.data).float(), requires_grad=True)
t_w_hidden2 = torch.nn.Parameter(torch.from_numpy(W_hidden2.T).float(), requires_grad=True)
t_b_hidden2 = torch.nn.Parameter(torch.from_numpy(B_hidden2.data).float(), requires_grad=True)

### Forward

In [None]:
import walnut.nn as nn
rnn = nn.RNN(in_channels, hidden_channels, num_layers=num_layers)
rnn_t = torch.nn.RNN(in_channels, hidden_channels, batch_first = True, num_layers=num_layers)

In [None]:
rnn.training_mode()

rnn.layers[0].w = W_in1
rnn.layers[0].b = B_in1
rnn.layers[1].w = W_hidden1
rnn.layers[1].b = B_hidden1

rnn.layers[2].w = W_in2
rnn.layers[2].b = B_in2
rnn.layers[3].w = W_hidden2
rnn.layers[3].b = B_hidden2

out = rnn(X)
out[0]

In [None]:
rnn_t.weight_ih_l0 = t_w_in1
rnn_t.bias_ih_l0 = t_b_in1
rnn_t.weight_hh_l0 = t_w_hidden1
rnn_t.bias_hh_l0 = t_b_hidden1

rnn_t.weight_ih_l1 = t_w_in2
rnn_t.bias_ih_l1 = t_b_in2
rnn_t.weight_hh_l1 = t_w_hidden2
rnn_t.bias_hh_l1 = t_b_hidden2

t_out = rnn_t(t_x)[0]
t_out[0]

### Backward

In [None]:
dy = walnut.ones(out.shape).data
t_dy = torch.nn.Parameter(torch.from_numpy(dy).float())

In [None]:
rnn.reset_grads()
x_grad = rnn.backward(dy)

In [None]:
t_out.backward(t_dy)

X

Problem: cannot backward multiple times or different inputs, because backward is defined during call and therefore data of the last forward pass is used for x, y, etc.

In [None]:
x_grad[0]

In [None]:
t_x.grad[0]

W Hidden

In [None]:
rnn.layers[1].w.grad

In [None]:
rnn_t.weight_hh_l0.grad.T

In [None]:
rnn.layers[3].w.grad

In [None]:
rnn_t.weight_hh_l1.grad.T

B Hidden

In [None]:
rnn.layers[1].b.grad

In [None]:
rnn_t.bias_hh_l0.grad

In [None]:
rnn.layers[3].b.grad

In [None]:
rnn_t.bias_hh_l1.grad

W Input

In [None]:
rnn.layers[0].w.grad

In [None]:
rnn_t.weight_ih_l0.grad.T

B Input

In [None]:
rnn.layers[0].b.grad

In [None]:
rnn_t.bias_ih_l0.grad