In [None]:
import os

download_name = "helper.py"
if not os.path.exists(download_name):
    import requests
    response = requests.get(f"https://raw.githubusercontent.com/bzitko/nlp_repo/main/demos/{download_name}")
    with open(download_name, "wb") as fp:
        fp.write(response.content)
    response.close()

In [3]:
import torch
import numpy as np
from helper import look

# RNN cell single forward

In [5]:
torch.manual_seed(42)

rnn_cell = torch.nn.RNNCell(input_size=5, hidden_size=3)

look("## RNN Cell (5x3)")
w_ih = rnn_cell.weight_ih.data
b_ih = rnn_cell.bias_ih.data
look("$W_{ih}=$", w_ih, "$b_{ih}=$", b_ih)

w_hh = rnn_cell.weight_hh.data
b_hh = rnn_cell.bias_hh.data
look("$W_{hh}=$", w_hh, "$b_{hh}=$", b_hh)

look("## Data")
h0 = torch.rand(3)
look("$h_0=$", h0)
x1 = torch.rand(5)
look("$x_1=$", x1)

hx1 = torch.sum(w_ih * x1, dim=1) + b_ih  
hh1 = torch.sum(w_hh * h0, dim=1) + b_hh

look("$W_{ih} x_1 + b_{ih} =$", hx1)
look("$W_{hh} x_1 + b_{hh} =$", hh1)

man_h1 = torch.tanh(hx1 + hh1)
fw_h1 = rnn_cell.forward(x1, h0)

look("$h_1 = tanh(W_{ih} x_1 + b_{ih} + W_{ih} x_1 + b_{ih})=$", man_h1)

look("Manual $h_1=$", man_h1)
look("Framework $h_1=$", fw_h1)

## RNN Cell (5x3)

$W_{ih}=$ $\begin{bmatrix} 0.441 & 0.479 & -0.135 & 0.53 & -0.126 \\ 0.117 & -0.281 & 0.339 & 0.509 & -0.424 \\ 0.502 & 0.108 & 0.427 & 0.0782 & 0.278\end{bmatrix}$ $b_{ih}=$ $\begin{bmatrix} -0.456 & -0.266 & -0.163\end{bmatrix}$

$W_{hh}=$ $\begin{bmatrix} -0.0815 & 0.445 & 0.0853 \\ -0.27 & 0.147 & -0.266 \\ -0.0677 & -0.234 & 0.383\end{bmatrix}$ $b_{hh}=$ $\begin{bmatrix} -0.347 & 0.0545 & -0.57\end{bmatrix}$

## Data

$h_0=$ $\begin{bmatrix} 0.952 & 0.0753 & 0.886\end{bmatrix}$

$x_1=$ $\begin{bmatrix} 0.583 & 0.338 & 0.809 & 0.578 & 0.904\end{bmatrix}$

$W_{ih} x_1 + b_{ih} =$ $\begin{bmatrix} 0.0462 & -0.108 & 0.808\end{bmatrix}$

$W_{hh} x_1 + b_{hh} =$ $\begin{bmatrix} -0.316 & -0.427 & -0.313\end{bmatrix}$

$h_1 = tanh(W_{ih} x_1 + b_{ih} + W_{ih} x_1 + b_{ih})=$ $\begin{bmatrix} -0.263 & -0.489 & 0.458\end{bmatrix}$

Manual $h_1=$ $\begin{bmatrix} -0.263 & -0.489 & 0.458\end{bmatrix}$

Framework $h_1=$ $\begin{bmatrix} -0.263 & -0.489 & 0.458\end{bmatrix}$

# RNN cell multiple forward propagations

In [9]:
N = 4
X = torch.rand(N, 5)
look(r"$X=$", np.array([["x_1", "x_2", "x_3"]]).T, "$=$", X)
#look(r"$X=\begin{bmatrix}x_1 \\ x_2 \\ x_3\end{bmatrix}=$", X)

h0 = torch.rand(3)
look("$h_0=$", h0)

fw_h = h0
man_h = h0
for i, x in enumerate(X, 1):
    
    fw_h_next = rnn_cell(x, fw_h)
    look(f"$h_{i}=RNN(x_{i}, h_{i-1})=RNN($", x, "$,$", fw_h,"$)=$", fw_h_next)
    
    man_h = torch.tanh(torch.sum(w_ih * x, dim=1) + b_ih + torch.sum(w_hh * man_h, dim=1) + b_hh)
    fw_h = fw_h_next

look(f"Manual $h_{N}=$", man_h)
look(f"Framework $h_{N}=$", fw_h)


$X=$ $\begin{bmatrix} x_1 \\ x_2 \\ x_3\end{bmatrix}$ $=$ $\begin{bmatrix} 0.225 & 0.0624 & 0.182 & 1.0 & 0.594 \\ 0.654 & 0.0337 & 0.172 & 0.334 & 0.578 \\ 0.06 & 0.285 & 0.201 & 0.501 & 0.314 \\ 0.465 & 0.161 & 0.157 & 0.208 & 0.329\end{bmatrix}$

$h_0=$ $\begin{bmatrix} 0.105 & 0.919 & 0.401\end{bmatrix}$

$h_1=RNN(x_1, h_0)=RNN($ $\begin{bmatrix} 0.225 & 0.0624 & 0.182 & 1.0 & 0.594\end{bmatrix}$ $,$ $\begin{bmatrix} 0.105 & 0.919 & 0.401\end{bmatrix}$ $)=$ $\begin{bmatrix} 0.189 & 0.115 & -0.347\end{bmatrix}$

$h_2=RNN(x_2, h_1)=RNN($ $\begin{bmatrix} 0.654 & 0.0337 & 0.172 & 0.334 & 0.578\end{bmatrix}$ $,$ $\begin{bmatrix} 0.189 & 0.115 & -0.347\end{bmatrix}$ $)=$ $\begin{bmatrix} -0.389 & -0.103 & -0.304\end{bmatrix}$

$h_3=RNN(x_3, h_2)=RNN($ $\begin{bmatrix} 0.06 & 0.285 & 0.201 & 0.501 & 0.314\end{bmatrix}$ $,$ $\begin{bmatrix} -0.389 & -0.103 & -0.304\end{bmatrix}$ $)=$ $\begin{bmatrix} -0.447 & 0.0761 & -0.482\end{bmatrix}$

$h_4=RNN(x_4, h_3)=RNN($ $\begin{bmatrix} 0.465 & 0.161 & 0.157 & 0.208 & 0.329\end{bmatrix}$ $,$ $\begin{bmatrix} -0.447 & 0.0761 & -0.482\end{bmatrix}$ $)=$ $\begin{bmatrix} -0.416 & 0.077 & -0.446\end{bmatrix}$

Manual $h_4=$ $\begin{bmatrix} -0.416 & 0.077 & -0.446\end{bmatrix}$

Framework $h_4=$ $\begin{bmatrix} -0.416 & 0.077 & -0.446\end{bmatrix}$

# LSTM cell 

In [12]:
torch.manual_seed(42)

I_SIZE=3
H_SIZE=2
print(I_SIZE, H_SIZE)


look("## Weights")
lstmcell = torch.nn.LSTMCell(input_size=I_SIZE, hidden_size=H_SIZE)

w_ih = lstmcell.weight_ih.data
b_ih = lstmcell.bias_ih.data
#w_ii, w_if, w_ig, w_io = torch.chunk(w_ih, 4)
#b_ii, b_if, b_ig, b_io = torch.chunk(b_ih, 4)
assert w_ih.shape == (4 * H_SIZE, I_SIZE)
assert b_ih.shape == (4 * H_SIZE,)
look("$W_{ih}=$", w_ih)
look("$b_{ih}=$", b_ih)

w_hh = lstmcell.weight_hh.data
b_hh = lstmcell.bias_hh.data
#w_hi, w_hf, w_hg, w_ho = torch.chunk(w_hh, 4)
#b_hi, b_hf, b_hg, b_ho = torch.chunk(b_hh, 4)
assert w_hh.shape == (4 * H_SIZE, H_SIZE)
assert b_hh.shape == (4 * H_SIZE,)
look("$W_{hh}=$", w_hh)
look("$b_{hh}=$", b_hh)


look("## Input")
x = torch.rand(I_SIZE)
look("$x=$", x)
h = torch.rand(H_SIZE)
look("$h_0=$", h)
c = torch.rand(H_SIZE)
look("$c_0=$", c)

look("## Iterim (applying weights)")
wxb = torch.sum(w_ih * x, dim=1) + b_ih
whb = torch.sum(w_hh * h, dim=1) + b_hh

i_t, f_t, g_t, o_t = torch.chunk(wxb + whb, 4)

i_t = torch.sigmoid(i_t)
f_t = torch.sigmoid(f_t)
g_t = torch.tanh(g_t)
o_t = torch.sigmoid(o_t)

look("$i_t=$", i_t)
look("$f_t=$", f_t)
look("$o_t=$", o_t)
look("$g_t=$", g_t)

look("## Output (combining gates)")
c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)
look("$c_t=$", c_t)
h_t = torch.mul(o_t, torch.tanh(c_t))
look("$h_t=$", h_t)

# framework
fw_h_t, fw_c_t = lstmcell(x, (h, c))

look("framework $h_t=$", fw_h_t)
look("framework $c_t=$", fw_c_t)

3 2


## Weights

$W_{ih}=$ $\begin{bmatrix} 0.541 & 0.587 & -0.166 \\ 0.65 & -0.155 & 0.143 \\ -0.344 & 0.415 & 0.623 \\ -0.519 & 0.615 & 0.132 \\ 0.522 & 0.0958 & 0.341 \\ -0.0998 & 0.545 & 0.105 \\ -0.33 & 0.18 & -0.326 \\ -0.0829 & -0.287 & 0.469\end{bmatrix}$

$b_{ih}=$ $\begin{bmatrix} 0.19 & -0.192 & 0.298 & 0.631 & 0.409 & -0.309 & 0.408 & 0.127\end{bmatrix}$

$W_{hh}=$ $\begin{bmatrix} -0.558 & -0.326 \\ -0.2 & -0.425 \\ 0.0667 & -0.698 \\ 0.639 & -0.601 \\ 0.546 & 0.118 \\ -0.23 & 0.437 \\ 0.11 & 0.571 \\ 0.0773 & -0.223\end{bmatrix}$

$b_{hh}=$ $\begin{bmatrix} 0.359 & -0.431 & -0.7 & -0.273 & -0.542 & 0.58 & 0.204 & 0.293\end{bmatrix}$

## Input

$x=$ $\begin{bmatrix} 0.658 & 0.491 & 0.891\end{bmatrix}$

$h_0=$ $\begin{bmatrix} 0.145 & 0.531\end{bmatrix}$

$c_0=$ $\begin{bmatrix} 0.159 & 0.654\end{bmatrix}$

## Iterim (applying weights)

$i_t=$ $\begin{bmatrix} 0.688 & 0.402\end{bmatrix}$

$f_t=$ $\begin{bmatrix} 0.443 & 0.552\end{bmatrix}$

$o_t=$ $\begin{bmatrix} 0.625 & 0.631\end{bmatrix}$

$g_t=$ $\begin{bmatrix} 0.606 & 0.644\end{bmatrix}$

## Output (combining gates)

$c_t=$ $\begin{bmatrix} 0.487 & 0.62\end{bmatrix}$

$h_t=$ $\begin{bmatrix} 0.283 & 0.347\end{bmatrix}$

framework $h_t=$ $\begin{bmatrix} 0.283 & 0.347\end{bmatrix}$

framework $c_t=$ $\begin{bmatrix} 0.487 & 0.62\end{bmatrix}$

# GRU Cell

In [13]:
torch.manual_seed(42)

I_SIZE=3
H_SIZE=2

look("## Weights")
gru_cell = torch.nn.GRUCell(input_size=I_SIZE, hidden_size=H_SIZE)

w_ih = gru_cell.weight_ih.data
b_ih = gru_cell.bias_ih.data
assert w_ih.shape == (3 * H_SIZE, I_SIZE)
assert b_ih.shape == (3 * H_SIZE,)
look("$W_{ih}=$", w_ih)
look("$b_{ih}=$", b_ih)

w_hh = gru_cell.weight_hh.data
b_hh = gru_cell.bias_hh.data
assert w_hh.shape == (3 * H_SIZE, H_SIZE)
assert b_hh.shape == (3 * H_SIZE,)
look("$W_{hh}=$", w_hh)
look("$b_{hh}=$", b_hh)

w_ih_r, w_ih_z, w_ih_n = torch.chunk(w_ih, 3)
b_ih_r, b_ih_z, b_ih_n = torch.chunk(b_ih, 3)
w_hh_r, w_hh_z, w_hh_n = torch.chunk(w_hh, 3)
b_hh_r, b_hh_z, b_hh_n = torch.chunk(b_hh, 3)

look("## Input")
x = torch.rand(I_SIZE)
look("$x=$", x)
h = torch.rand(H_SIZE)
look("$h_0=$", h)

def apply_wb(w, b, x):
    return torch.sum(w * x, dim=1) + b

look("## Iterim (applying weights)")
r_t = torch.sigmoid(torch.sum(w_ih_r * x, dim=1) + b_ih_r + torch.sum(w_hh_r * h, dim=1) + b_hh_r)
z_t = torch.sigmoid(torch.sum(w_ih_z * x, dim=1) + b_ih_z + torch.sum(w_hh_z * h, dim=1) + b_hh_z)
n_t = torch.tanh(torch.sum(w_ih_n * x, dim=1) + b_ih_n + r_t * (torch.sum(w_hh_n * h, dim=1) + b_hh_n))
h_t = (1 - z_t) * n_t + z_t * h
look("$h_t$", h_t)

# framework
fw_h_t = gru_cell(x, h)

look("framework $h_t=$", fw_h_t)

## Weights

$W_{ih}=$ $\begin{bmatrix} 0.541 & 0.587 & -0.166 \\ 0.65 & -0.155 & 0.143 \\ -0.344 & 0.415 & 0.623 \\ -0.519 & 0.615 & 0.132 \\ 0.522 & 0.0958 & 0.341 \\ -0.0998 & 0.545 & 0.105\end{bmatrix}$

$b_{ih}=$ $\begin{bmatrix} 0.639 & -0.601 & 0.546 & 0.118 & -0.23 & 0.437\end{bmatrix}$

$W_{hh}=$ $\begin{bmatrix} -0.33 & 0.18 \\ -0.326 & -0.0829 \\ -0.287 & 0.469 \\ -0.558 & -0.326 \\ -0.2 & -0.425 \\ 0.0667 & -0.698\end{bmatrix}$

$b_{hh}=$ $\begin{bmatrix} 0.11 & 0.571 & 0.0773 & -0.223 & 0.19 & -0.192\end{bmatrix}$

## Input

$x=$ $\begin{bmatrix} 0.71 & 0.946 & 0.789\end{bmatrix}$

$h_0=$ $\begin{bmatrix} 0.281 & 0.789\end{bmatrix}$

## Iterim (applying weights)

$h_t$ $\begin{bmatrix} 0.288 & 0.634\end{bmatrix}$

framework $h_t=$ $\begin{bmatrix} 0.288 & 0.634\end{bmatrix}$

## RNN and RNNCell (one layer)

In [14]:
torch.manual_seed(42)

rnn = torch.nn.RNN(input_size=5, hidden_size=3, num_layers=1, batch_first=False)
rnn_cell = torch.nn.RNNCell(input_size=5, hidden_size=3)

rnn_cell.weight_hh.data = rnn.weight_hh_l0.data
rnn_cell.bias_hh.data = rnn.bias_hh_l0.data
rnn_cell.weight_ih.data = rnn.weight_ih_l0.data
rnn_cell.bias_ih.data = rnn.bias_ih_l0.data

x = torch.rand(5)
h = torch.rand(3)
look("$x=$", x)
look("$h=$", h)

y, h_t = rnn(x.view(1,-1), h.view(1, -1))
rnn_y = y.squeeze()
rnn_h_t = h_t.squeeze()

rnn_cell_y = rnn_cell(x, h)

look("rnn $y=$", rnn_y)
look("rnn cell $y=$", rnn_cell_y)


$x=$ $\begin{bmatrix} 0.531 & 0.159 & 0.654 & 0.328 & 0.653\end{bmatrix}$

$h=$ $\begin{bmatrix} 0.396 & 0.915 & 0.204\end{bmatrix}$

rnn $y=$ $\begin{bmatrix} -0.097 & -0.108 & -0.126\end{bmatrix}$

rnn cell $y=$ $\begin{bmatrix} -0.097 & -0.108 & -0.126\end{bmatrix}$

## BRNN and bidirectional RNN cells (one layer)

In [15]:
torch.manual_seed(42)

birnn = torch.nn.RNN(input_size=5, hidden_size=3, bidirectional=True)

birnn_cell_left = torch.nn.RNNCell(input_size=5, hidden_size=3)
birnn_cell_right = torch.nn.RNNCell(input_size=5, hidden_size=3)

birnn_cell_left.weight_hh.data = birnn.weight_hh_l0.data
birnn_cell_left.bias_hh.data = birnn.bias_hh_l0.data
birnn_cell_left.weight_ih.data = birnn.weight_ih_l0.data
birnn_cell_left.bias_ih.data = birnn.bias_ih_l0.data

birnn_cell_right.weight_hh.data = birnn.weight_hh_l0_reverse.data
birnn_cell_right.bias_hh.data = birnn.bias_hh_l0_reverse.data
birnn_cell_right.weight_ih.data = birnn.weight_ih_l0_reverse.data
birnn_cell_right.bias_ih.data = birnn.bias_ih_l0_reverse.data

x = torch.rand(5)
h_l = torch.rand(3)
h_r = torch.rand(3)

y_c, y_s = birnn(x.view(1, -1), torch.stack([h_l, h_r]))

cell_y_l = birnn_cell_left(x, h_l)
cell_y_r = birnn_cell_right(x, h_r)

look("RNN concatenated $y_L=$", y_c)
look("RNN stacked $y_R=$", y_s)

look("RNN cell left $y_L=$", cell_y_l, round=2)
look("RNN cell right $y_R=$", cell_y_r, round=2)

RNN concatenated $y_L=$ $\begin{bmatrix} -0.282 & -0.103 & -0.239 & 0.426 & 0.6 & 0.107\end{bmatrix}$

RNN stacked $y_R=$ $\begin{bmatrix} -0.282 & -0.103 & -0.239 \\ 0.426 & 0.6 & 0.107\end{bmatrix}$

RNN cell left $y_L=$ $\begin{bmatrix} -0.28 & -0.1 & -0.24\end{bmatrix}$

RNN cell right $y_R=$ $\begin{bmatrix} 0.43 & 0.6 & 0.11\end{bmatrix}$

# RNN and RNN cell (N > 1 layers)

In [16]:
torch.manual_seed(42)

INPUT_SIZE = 5
HIDDEN_SIZE = 3
NUM_LAYERS = 4

rnn = torch.nn.RNN(input_size=INPUT_SIZE, 
                   hidden_size=HIDDEN_SIZE, 
                   num_layers=NUM_LAYERS)
                   
rnn_cells = [torch.nn.RNNCell(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE)] 
rnn_cells += [torch.nn.RNNCell(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) 
             for _ in range(NUM_LAYERS - 1)] 

for layer, rnn_cell in enumerate(rnn_cells):
    rnn_cell.weight_ih.data = getattr(rnn, f"weight_ih_l{layer}").data
    rnn_cell.bias_ih.data = getattr(rnn, f"bias_ih_l{layer}").data
    rnn_cell.weight_hh.data = getattr(rnn, f"weight_hh_l{layer}").data
    rnn_cell.bias_hh.data = getattr(rnn, f"bias_hh_l{layer}").data

x = torch.rand(INPUT_SIZE)
hiddens = [torch.rand(HIDDEN_SIZE)
           for _ in range(NUM_LAYERS)]

# layered approach
look("### RNN")
y, h_t = rnn(x.view(1,-1), torch.stack(hiddens))
look("out =", y)
look("hiddens=", h_t)

# manual cell by cell
look("### RNN Cell")
vec = x
for layer, (rnn_cell, h) in enumerate(zip(rnn_cells, hiddens)):
    vec = rnn_cell(vec, h)
    look(f"layer {layer} out = ", vec)


### RNN

out = $\begin{bmatrix} -0.623 & -0.482 & 0.011\end{bmatrix}$

hiddens= $\begin{bmatrix} -0.165 & -0.475 & -0.245 \\ 0.0621 & -0.502 & -0.249 \\ -0.188 & -0.687 & -0.152 \\ -0.623 & -0.482 & 0.011\end{bmatrix}$

### RNN Cell

layer 0 out =  $\begin{bmatrix} -0.165 & -0.475 & -0.245\end{bmatrix}$

layer 1 out =  $\begin{bmatrix} 0.0621 & -0.502 & -0.249\end{bmatrix}$

layer 2 out =  $\begin{bmatrix} -0.188 & -0.687 & -0.152\end{bmatrix}$

layer 3 out =  $\begin{bmatrix} -0.623 & -0.482 & 0.011\end{bmatrix}$

## BRNN and bidirectional RNN cells (N>1 layers)

In [17]:
torch.manual_seed(42)

NUM_LAYERS = 4

INPUT_SIZE = 5
HIDDEN_SIZE = 3

# creating networks
brnn = torch.nn.RNN(input_size=INPUT_SIZE, 
                    hidden_size=HIDDEN_SIZE, 
                    num_layers=NUM_LAYERS, 
                    bidirectional=True)

brnn_cells = [(torch.nn.RNNCell(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE),
               torch.nn.RNNCell(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE))]
brnn_cells += [(torch.nn.RNNCell(input_size=HIDDEN_SIZE * 2, hidden_size=HIDDEN_SIZE),
                torch.nn.RNNCell(input_size=HIDDEN_SIZE * 2, hidden_size=HIDDEN_SIZE))
               for _ in range(NUM_LAYERS - 1)]

def set_param(brnn_cell, brnn, name, layer, reverse):
    target_weight = getattr(brnn_cell, name).data
    if not reverse:
        source_weight = getattr(brnn, f"{name}_l{layer}").data
    else:
        source_weight = getattr(brnn, f"{name}_l{layer}_reverse").data
    assert target_weight.shape == source_weight.shape

    setattr(brnn_cell, name, torch.nn.parameter.Parameter(source_weight))
    assert torch.allclose(getattr(brnn_cell, name), source_weight)
    
# equalizing parameters
for layer, (brnn_cell_left, brnn_cell_right) in enumerate(brnn_cells):
    set_param(brnn_cell_left, brnn, "weight_ih", layer, reverse=False)
    set_param(brnn_cell_left, brnn, "bias_ih", layer, reverse=False)
    set_param(brnn_cell_left, brnn, "weight_hh", layer, reverse=False)
    set_param(brnn_cell_left, brnn, "bias_hh", layer, reverse=False)

    set_param(brnn_cell_right, brnn, "weight_ih", layer, reverse=True)
    set_param(brnn_cell_right, brnn, "bias_ih", layer, reverse=True)
    set_param(brnn_cell_right, brnn, "weight_hh", layer, reverse=True)
    set_param(brnn_cell_right, brnn, "bias_hh", layer, reverse=True)

# initial input and hidden vector
x = torch.rand(INPUT_SIZE)
hiddens = [(torch.rand(HIDDEN_SIZE), torch.rand(HIDDEN_SIZE))
           for _ in range(NUM_LAYERS)]

look("### BRNN")
brnn_hiddens = torch.stack([torch.stack(h) for h in hiddens]).flatten(0, 1)
y, h_t = brnn(x.view(1, -1), brnn_hiddens)

look("out =", y)
look("hiddens=", h_t)

# manual cell by cell
look("### BRNN Cell")
vec = x
for layer, ((brnn_cell_left, brnn_cell_right), (h_l, h_r)) in enumerate(zip(brnn_cells, hiddens)):
    h = torch.concat([h_l, h_r])
    vec_l = brnn_cell_left(vec, h_l)
    vec_r = brnn_cell_right(vec, h_r)
    vec = torch.cat([vec_l, vec_r])
    look(f"layer {layer} out = ", vec_l, vec_r)


### BRNN

out = $\begin{bmatrix} 0.0126 & -0.164 & -0.494 & -0.911 & 0.196 & -0.895\end{bmatrix}$

hiddens= $\begin{bmatrix} -0.0901 & -0.122 & 0.401 \\ 0.697 & 0.84 & 0.558 \\ -0.0241 & -0.427 & 0.102 \\ -0.482 & -0.805 & -0.858 \\ -0.223 & 0.976 & -0.0125 \\ 0.373 & 0.526 & 0.389 \\ 0.0126 & -0.164 & -0.494 \\ -0.911 & 0.196 & -0.895\end{bmatrix}$

### BRNN Cell

layer 0 out =  $\begin{bmatrix} -0.0901 & -0.122 & 0.401\end{bmatrix}$ $\begin{bmatrix} 0.697 & 0.84 & 0.558\end{bmatrix}$

layer 1 out =  $\begin{bmatrix} -0.0241 & -0.427 & 0.102\end{bmatrix}$ $\begin{bmatrix} -0.482 & -0.805 & -0.858\end{bmatrix}$

layer 2 out =  $\begin{bmatrix} -0.223 & 0.976 & -0.0125\end{bmatrix}$ $\begin{bmatrix} 0.373 & 0.526 & 0.389\end{bmatrix}$

layer 3 out =  $\begin{bmatrix} 0.0126 & -0.164 & -0.494\end{bmatrix}$ $\begin{bmatrix} -0.911 & 0.196 & -0.895\end{bmatrix}$

In [36]:
cell = torch.nn.RNNCell(5, 3)
x = torch.rand(5)
h = torch.rand(3)

for i in range(10):
    h = cell(x, h)
    l = torch.mean(h)
    print(i, l, x.grad)
    l.backward()



0 tensor(0.2082, grad_fn=<MeanBackward0>) None
1 tensor(0.0963, grad_fn=<MeanBackward0>) None


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.