### RNN

```
class torch.nn.RNN(*args, **kwargs)
```
Applies a multi-layer Elman RNN with $tanh$ or $ReLU$ non-linearity to an input sequence.

For each element in the input sequence, each layer computes the following function:

 $$ h_t = tanh(W_{ih}x_t + b_{ih} + W_{hh}h_{t-1} + b_{hh}) $$
 
 where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, and $h_{(t-1)}$ is the hidden state of the previous layer at time t-1 or the initial hidden state at time 0. If nonlinearity is 'relu', then ReLU\text{ReLU}ReLU is used instead of tanh\tanhtanh .

In [3]:
from torch import nn, optim

In [4]:
rnn = nn.RNN(10, 20, 2)

In [8]:
rnn

RNN(10, 20, num_layers=2)

In [5]:
input = torch.randn(5, 3, 10)

In [9]:
input

tensor([[[ 1.3091, -0.1126,  0.5788,  0.1221,  0.7953, -1.3723, -0.1174,
           0.6686, -0.5618, -0.5063],
         [ 1.0827, -0.1737,  1.1196, -0.6322, -0.1385,  1.6695,  1.2338,
          -1.0422, -0.2826, -1.1955],
         [ 0.1112, -0.1399,  0.2403, -1.7350, -0.1036, -0.8991,  0.6302,
          -1.1892,  2.0100,  0.1442]],

        [[ 0.0570, -1.5936,  0.1694, -1.5635, -1.0278, -1.7197,  0.4143,
          -0.8215,  0.9182,  0.6281],
         [ 0.1824, -0.2326,  0.5547, -0.8105,  0.5354, -0.1058,  0.3179,
          -1.1965, -0.0733, -0.8007],
         [-0.1541,  0.3355, -1.0153,  0.4665, -0.6644, -0.2509, -0.9390,
          -0.8885, -0.2003,  1.1564]],

        [[-1.0642,  0.3859,  2.0870, -0.6164, -0.4606,  1.2554, -1.2888,
           0.1322, -0.4710,  0.1168],
         [-0.9570,  1.3968,  0.9398, -0.7282,  1.1125,  0.5747, -0.0849,
          -0.8303, -0.4675, -1.3660],
         [-0.7828, -0.4571, -0.9021,  0.2786, -1.1754, -0.3934, -0.5557,
           1.7946,  0.2576, -1.5560

In [6]:
h0 = torch.randn(2, 3, 20)

In [10]:
h0

tensor([[[ 1.7005,  2.0622,  0.2596,  1.1244, -0.0030, -0.8713,  1.3854,
          -0.3751, -0.4079, -0.8769, -0.5297, -1.0469,  0.7056, -0.5951,
          -1.5194, -0.6982,  1.7719,  0.8174, -0.0060, -1.8900],
         [-0.7312,  0.7778, -0.5811, -0.4257, -0.8034, -0.8660,  1.5157,
          -0.7403,  1.1677,  0.5809,  0.4522, -0.4620, -0.9990,  0.0904,
           0.8183, -0.2081,  0.7247, -0.7790,  0.9893, -0.8280],
         [-1.0413, -0.3022, -1.2148, -1.7313,  0.2203, -0.5146,  0.6175,
          -0.6839, -0.0307,  1.6719,  0.4437, -0.2275,  1.3126, -0.7540,
          -0.0356, -0.8927,  2.0019,  0.2936,  0.8306, -0.5191]],

        [[ 0.8618,  2.8280, -1.4003,  0.4045, -0.8874,  1.1740,  2.6002,
           0.8753,  1.3151, -1.4312,  2.0962,  0.8373,  0.3204, -0.2359,
           0.7882, -1.2130,  1.5117, -2.1493, -2.0501,  1.1786],
         [ 0.6501, -1.2137, -0.4433,  0.6859, -0.3681, -1.8112,  1.0084,
           0.2967,  1.4405, -0.7480, -0.3950, -1.4791,  0.0060,  0.4411,
        

In [11]:
output, hn = rnn(input, h0)

In [13]:
hn

tensor([[[ 1.0671e-01,  6.3203e-01, -1.4205e-01, -4.5679e-01,  4.9430e-01,
           4.1664e-04, -2.2096e-01,  1.6742e-01, -1.5266e-01, -2.8774e-01,
           5.0801e-01, -3.2709e-01,  1.5137e-01, -1.9688e-01,  6.6191e-01,
           3.7101e-01, -1.6490e-01, -1.5713e-01,  3.3369e-02,  5.4381e-03],
         [-4.0392e-01,  7.4687e-01, -3.6658e-01, -6.0945e-01,  1.7619e-01,
          -5.1696e-01, -1.3466e-01,  4.9498e-01, -1.8346e-01, -2.3167e-01,
          -1.2358e-02,  1.6410e-01, -2.3573e-02, -1.1894e-01,  4.8816e-01,
           7.9178e-01,  3.0626e-01,  1.4387e-02,  8.6952e-02, -6.5329e-02],
         [-4.9836e-01,  4.7920e-02, -4.2825e-01, -2.1469e-01,  2.6080e-01,
          -1.2441e-01, -9.6232e-03,  4.6136e-01, -5.9921e-01, -3.0484e-01,
          -2.2815e-01, -3.4020e-02, -9.3360e-01,  4.2591e-01,  2.8626e-01,
           5.1743e-01,  1.1075e-01,  8.6932e-01, -1.5866e-02,  2.9688e-01]],

        [[ 4.9944e-01, -3.9954e-02, -1.2648e-01,  9.3358e-03,  3.0234e-02,
          -2.1879e-0

In [14]:
output

tensor([[[-0.5825,  0.7571, -0.4633, -0.8861, -0.4963,  0.6599,  0.8740,
           0.5617, -0.6960, -0.0864,  0.3847,  0.0574,  0.0355,  0.5923,
           0.9324, -0.6560, -0.9243, -0.6949, -0.6292,  0.4132],
         [-0.4413,  0.1195, -0.6355, -0.0180, -0.6298,  0.0630,  0.7187,
           0.1938, -0.7252,  0.5678, -0.8940, -0.2355,  0.5314,  0.4671,
          -0.3629, -0.9565, -0.8289,  0.6665, -0.4162,  0.6056],
         [ 0.4981, -0.0386,  0.4446,  0.6393,  0.2247, -0.3710,  0.7947,
           0.8055,  0.2164,  0.8039,  0.1271, -0.2050, -0.6548, -0.8605,
          -0.4928,  0.4536,  0.6726,  0.0681,  0.2140, -0.4608]],

        [[ 0.6615,  0.5527, -0.5354, -0.6154,  0.1578, -0.1347,  0.1343,
           0.5307, -0.1963, -0.1883,  0.6901, -0.1787, -0.0157, -0.6794,
           0.3095,  0.4369,  0.5052,  0.1298, -0.5970, -0.0251],
         [ 0.7091,  0.5332, -0.1371, -0.0753,  0.3985, -0.5981, -0.0085,
           0.7076,  0.0055, -0.3812,  0.2159, -0.1415, -0.5214, -0.2694,
        