# Recurent layer

This page explains the concept of a recurrent layer.

The key idea is to create a mechanism where each input affects the processing and outcome of subsequent inputs.

![](recurent_layer_files/recurent_schema.svg)

An RNN is essentially a single layer. At each step, it uses $h_{t-1}$, a special state vector from the previous step.

Strictly speaking, the deduction is as follows:  

$$h_t = f(x_t W^T_1 + b_1 + h_{t-1} W^T_2 + b_2)$$  

Where:  
- $x_t$: input at the $t$-th step.  
- $h_t$: vector that describes hidden state at the $t$-th step.  
- $W_1$: weights associated with the input.  
- $W_2$: weights associated with the state.  
- $b_1$: bias associated with the input.  
- $b_2$: bias associated with the state.  
- $f$: activation function, typically a hyperbolic tangent.  

## Realization on python

In this section, we will step by step implement the computations performed by a recurrent layer and compare them with `torch.nn.RNN` as the reference.

In [1]:
import torch

In [2]:
samples_size = 10
element_size = 5
sequence_size = 15
state_size = 3
activation = torch.nn.Tanh()

input_data = torch.rand(samples_size, sequence_size, element_size)

In [3]:
test_linear = torch.nn.Linear(
    in_features=element_size,
    out_features=state_size
)

In [4]:
W_1 = torch.rand(state_size, element_size)
b_1 = torch.rand(state_size)
W_2 = torch.rand(state_size, state_size)
b_2 = torch.rand(state_size)

state = torch.zeros(samples_size, state_size)

In [5]:
(input_data[:, 0, :] @ W_1.T) + b_1 + (state @ W_2.T) + b_2

tensor([[4.1659, 3.3416, 2.9369],
        [3.4290, 2.6908, 2.7988],
        [2.7855, 2.2476, 2.2464],
        [3.3770, 2.6954, 2.2828],
        [3.6830, 2.9414, 2.7285],
        [3.4822, 2.7814, 2.6768],
        [4.2522, 3.3817, 3.1177],
        [3.1929, 2.5440, 2.3744],
        [3.6737, 2.9679, 3.2228],
        [2.9061, 2.3930, 2.4780]])

In [6]:
states = [state]

for i in range(input_data.shape[1]):
    res = activation( 
        (input_data[:, i, :] @ W_1.T) + b_1
        + (states[-1] @ W_2.T) + b_2
    )
    states.append(res)

my_ans = torch.stack(states[1:]).permute((1, 0, 2))

In [9]:
rnn = torch.nn.RNN(element_size, state_size, batch_first=True)

with torch.no_grad():
    rnn.weight_ih_l0.copy_(W_1)
    rnn.bias_ih_l0.copy_(b_1)
    rnn.weight_hh_l0.copy_(W_2)
    rnn.bias_hh_l0.copy_(b_2)

    torch_ans = rnn(input_data)

In [10]:
torch.testing.assert_close(torch_ans[0], my_ans)