## Inner Workings of RNNs

Now that we have a basic understanding and a bird's-eye view of how RNNs (Recurrent Neural Networks) work, let's explore some basic computations that the RNN cells have to do to produce the hidden states and outputs.

In the first step, a hidden state will usually be seeded as a matrix of zeros, so that it can be fed into the RNN cell together with the first input in the sequence. In the simplest RNNs, the hidden state and the input data will be multiplied with weight matrices initialized via a scheme such as Xavier or Kaiming (you can read more on this topic [here](link-to-more-info)). The result of these multiplications will then be passed through an activation function (such as a tanh function) to introduce non-linearity.

$$
hidden_t = \tanh(weight_{hidden} \cdot hidden_{t-1} + weight_{input} \cdot input_t)
$$

Additionally, if we require an output at the end of each time step, we can pass the hidden state that we just produced through a linear layer or just multiply it by another weight matrix to obtain the desired shape of the result.

$$
output_t = weight_{output} \cdot hidden_t
$$

The hidden state that we just produced will then be fed back into the RNN cell together with the next input, and this process continues until we run out of input or the model is programmed to stop producing outputs.

As mentioned earlier, these computations presented above are just simple representations of how RNN cells do their calculations. For the more advanced RNN structures such as LSTMs, GRUs, etc., the computations are generally much more complicated.


In the Markdown (MDX) format, I've used backticks to represent the mathematical equations inline. Please replace `link-to-more-info` with the appropriate link to read more about Xavier and Kaiming initialization if you have one.

In [179]:
import numpy as np
class DummyRNN():

  def __init__(self, input_features, output_features):
    self.Whi = np.random.random((output_features, output_features))
    self.Wxi = np.random.random((output_features, input_features))
    self.Who = np.random.random((output_features, output_features))
    self.b1 = np.random.random((output_features,))
    self.b2 = np.random.random((output_features,))
    self.init_hidden = np.zeros(output_features)

  def forward(self, input_data):
    hidden_in = self.init_hidden
    memory = []  
    for input_t in input_data:
      hidden_activation_t =  hidden_in @ self.Whi+ self.Wxi @ input_t + self.b1
      print(hidden_activation_t.shape)
      output = np.tanh(self.Who @ hidden_activation_t + self.b2)
      memory.append(output)
      hidden_in = hidden_activation_t
    return np.array(memory)

time_steps = 10
input_features = 1
output_features = 10

model = DummyRNN(input_features, output_features)
input_data = np.random.random((time_steps, input_features))
print(f"Input Shape: {input_data.shape}")
result = model.forward(input_data)
print(f"Output Shape: {result.shape}")
# Output shape: (output_features, time_steps)
result

Input Shape: (10, 1)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
(10,)
Output Shape: (10, 10)


array([[0.99627527, 0.99943913, 0.99878452, 0.99980872, 0.99975101,
        0.99669862, 0.99972958, 0.99994519, 0.99989894, 0.99967911],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.       

In [3]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split, \
TensorDataset
from torch.nn.utils import rnn as rnn_utils


In [252]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Your dataset
# Create the sequence data
X = np.arange(1, 46).reshape(-1, 3,1)
Y = np.array([y.sum() for y in X])

# Convert data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
Y_tensor = torch.tensor(Y, dtype=torch.float32)
X_tensor

tensor([[[ 1.],
         [ 2.],
         [ 3.]],

        [[ 4.],
         [ 5.],
         [ 6.]],

        [[ 7.],
         [ 8.],
         [ 9.]],

        [[10.],
         [11.],
         [12.]],

        [[13.],
         [14.],
         [15.]],

        [[16.],
         [17.],
         [18.]],

        [[19.],
         [20.],
         [21.]],

        [[22.],
         [23.],
         [24.]],

        [[25.],
         [26.],
         [27.]],

        [[28.],
         [29.],
         [30.]],

        [[31.],
         [32.],
         [33.]],

        [[34.],
         [35.],
         [36.]],

        [[37.],
         [38.],
         [39.]],

        [[40.],
         [41.],
         [42.]],

        [[43.],
         [44.],
         [45.]]])

In [23]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm 
import numpy as np 
import pandas as pd 
import random
import matplotlib.pyplot as plt 
import seaborn as sns
import numpy as np
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, seq_len=5, max_len=1000):
        # Initialize the base class (Dataset)
        super(CustomDataset, self).__init__()
        
        # Create a list of numbers from 0 to max_len
        self.datalist = np.arange(0, max_len)
        
        # Split the data into sequences of length seq_len and their corresponding targets
        self.data, self.targets = self.timeseries(self.datalist, seq_len)
        
    def __len__(self):
        # Return the total number of samples in the dataset
        return len(self.data)
    
    def timeseries(self, data, window):
        # Create sequences and corresponding targets from the data
        temp = []
        targ = data[window:]
        for i in range(len(data) - window):
            temp.append(data[i:i + window])

        return np.array(temp), targ
    
    def __getitem__(self, index):
        # Get a single sample from the dataset
        x = torch.tensor(self.data[index]).type(torch.Tensor)
        y = torch.tensor(self.targets[index]).type(torch.Tensor)
        return x, y

# Create an instance of the CustomDataset with specified parameters
dataset = CustomDataset(seq_len=5, max_len=1000)
for x,y in dataset:
    print(x,y)
    break   
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


tensor([0., 1., 2., 3., 4.]) tensor(5.)


In [32]:
np.random.randn(2,3)

array([[-1.48960531,  1.17588816,  0.64131081],
       [-0.19888345,  0.79987576, -1.24990254]])

In [54]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm 
import numpy as np 
import pandas as pd 
import random
import matplotlib.pyplot as plt 
import seaborn as sns
class CustomDataset(Dataset):
    def __init__(self,input_seq = 5,sample  = 100):
        super(CustomDataset).__init__()
        self.data = np.arange(1, (input_seq*sample)+1).reshape(-1, input_seq, 1)
        self.targets = self.sum(self.data)
        
    def __len__(self):
        return len(self.data)
    
    def sum(self, data):
        Y = np.array([y.sum() for y in data])
        return  Y 
    
    def __getitem__(self, index):
        x = torch.tensor(self.data[index]).type(torch.Tensor)
        y = torch.tensor(self.targets[index]).type(torch.Tensor)
        return x,y
    
dataset = CustomDataset(input_seq = 3,sample  = 100)
for x,y in dataset:
    print(x,y)
    break
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


tensor([[1.],
        [2.],
        [3.]]) tensor(6.)


In [55]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # hidden states not defnined hence the value of h0,c0 == (0,0)
        out, (hn, cn) = self.rnn(x)
        # OUT: B,L,F
        
        # as the diagram suggest to take the last output in many to one 

        # all batch, last seq, all hidden values
        out = out[:, -1, :]
        # OUT: B,L[last],F ->final output of last time step
        out = self.fc(out)
        
        return out
    
    
model = RNN(input_size=1, hidden_size=256, num_layers=2)
t = torch.tensor([11,12,13]).type(torch.Tensor).view(1,-1,1)
print(f"Input: {t.shape}")
model(t)

Input: torch.Size([1, 3, 1])


tensor([[0.0176]], grad_fn=<AddmmBackward0>)

In [56]:
loss_function = nn.MSELoss()
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for e in tqdm(range(1000)):
    i = 0
    avg_loss = []
    for x,y in dataloader:
        optimizer.zero_grad()
        # forward
        predictions = model(x)
        
        # loss
        loss = loss_function(predictions, y)
        
        # backward
        loss.backward()

        # optimization
        optimizer.step()
        avg_loss.append(loss.detach().numpy())

        i+=1
    if e%5==0:
        avg_loss = np.array(avg_loss)
        print(avg_loss.mean())

  0%|          | 1/1000 [00:00<14:06,  1.18it/s]

265941.56


  1%|          | 6/1000 [00:03<10:00,  1.66it/s]

234120.02


  1%|          | 11/1000 [00:06<09:31,  1.73it/s]

210579.56


  2%|▏         | 16/1000 [00:10<10:16,  1.60it/s]

189915.81


  2%|▏         | 21/1000 [00:13<09:58,  1.64it/s]

171707.31


  3%|▎         | 26/1000 [00:16<09:39,  1.68it/s]

155732.53


  3%|▎         | 31/1000 [00:19<09:40,  1.67it/s]

141674.22


  4%|▎         | 36/1000 [00:22<09:48,  1.64it/s]

129429.32


  4%|▍         | 41/1000 [00:25<09:44,  1.64it/s]

118883.38


  5%|▍         | 46/1000 [00:28<10:03,  1.58it/s]

109857.03


  5%|▌         | 51/1000 [00:31<10:51,  1.46it/s]

101994.92


  6%|▌         | 56/1000 [00:35<10:02,  1.57it/s]

95469.65


  6%|▌         | 61/1000 [00:38<09:36,  1.63it/s]

89889.6


  7%|▋         | 66/1000 [00:41<09:25,  1.65it/s]

85266.95


  7%|▋         | 71/1000 [00:44<10:07,  1.53it/s]

81477.75


  8%|▊         | 76/1000 [00:47<09:38,  1.60it/s]

78368.625


  8%|▊         | 81/1000 [00:50<09:22,  1.64it/s]

75843.29


  9%|▊         | 86/1000 [00:53<09:15,  1.65it/s]

73848.555


  9%|▉         | 91/1000 [00:57<09:11,  1.65it/s]

72264.95


 10%|▉         | 96/1000 [01:00<09:17,  1.62it/s]

71082.75


 10%|█         | 101/1000 [01:03<09:12,  1.63it/s]

70102.11


 11%|█         | 106/1000 [01:06<09:04,  1.64it/s]

69411.52


 11%|█         | 111/1000 [01:09<08:57,  1.65it/s]

68851.87


 12%|█▏        | 116/1000 [01:13<12:12,  1.21it/s]

68477.47


 12%|█▏        | 121/1000 [01:16<09:34,  1.53it/s]

68146.734


 13%|█▎        | 126/1000 [01:20<09:13,  1.58it/s]

69238.52


 13%|█▎        | 131/1000 [01:23<09:32,  1.52it/s]

68522.44


 14%|█▎        | 136/1000 [01:26<08:54,  1.62it/s]

68079.73


 14%|█▍        | 141/1000 [01:29<09:14,  1.55it/s]

67352.92


 15%|█▍        | 146/1000 [01:32<08:44,  1.63it/s]

66776.51


 15%|█▌        | 151/1000 [01:36<08:31,  1.66it/s]

65892.51


 16%|█▌        | 156/1000 [01:39<09:09,  1.54it/s]

65251.684


 16%|█▌        | 161/1000 [01:42<08:40,  1.61it/s]

65992.02


 17%|█▋        | 166/1000 [01:45<08:29,  1.64it/s]

65407.32


 17%|█▋        | 171/1000 [01:48<08:19,  1.66it/s]

64554.254


 18%|█▊        | 176/1000 [01:51<08:18,  1.65it/s]

65124.816


 18%|█▊        | 181/1000 [01:54<08:06,  1.68it/s]

64367.547


 19%|█▊        | 186/1000 [01:57<08:01,  1.69it/s]

64132.227


 19%|█▉        | 191/1000 [02:00<08:03,  1.67it/s]

65469.023


 20%|█▉        | 196/1000 [02:03<08:03,  1.66it/s]

65193.98


 20%|██        | 201/1000 [02:06<07:53,  1.69it/s]

65715.65


 21%|██        | 206/1000 [02:10<07:51,  1.68it/s]

63635.11


 21%|██        | 211/1000 [02:13<08:14,  1.60it/s]

64254.215


 22%|██▏       | 216/1000 [02:16<08:03,  1.62it/s]

63897.48


 22%|██▏       | 221/1000 [02:19<08:14,  1.58it/s]

62858.06


 23%|██▎       | 226/1000 [02:22<08:00,  1.61it/s]

62511.99


 23%|██▎       | 231/1000 [02:25<07:46,  1.65it/s]

66610.164


 24%|██▎       | 236/1000 [02:28<07:38,  1.67it/s]

64387.754


 24%|██▍       | 241/1000 [02:32<08:02,  1.57it/s]

63608.58


 25%|██▍       | 246/1000 [02:35<07:44,  1.62it/s]

64035.984


 25%|██▌       | 251/1000 [02:38<07:37,  1.64it/s]

63954.645


 26%|██▌       | 256/1000 [02:41<07:30,  1.65it/s]

64363.31


 26%|██▌       | 261/1000 [02:44<07:40,  1.61it/s]

63682.23


 27%|██▋       | 266/1000 [02:47<07:26,  1.64it/s]

65670.23


 27%|██▋       | 271/1000 [02:50<07:13,  1.68it/s]

62853.203


 28%|██▊       | 276/1000 [02:53<07:09,  1.68it/s]

62522.395


 28%|██▊       | 281/1000 [02:56<07:19,  1.64it/s]

64951.31


 29%|██▊       | 286/1000 [02:59<07:35,  1.57it/s]

63932.695


 29%|██▉       | 291/1000 [03:02<07:18,  1.62it/s]

60148.26


 30%|██▉       | 296/1000 [03:05<06:57,  1.69it/s]

62155.047


 30%|███       | 301/1000 [03:08<06:52,  1.69it/s]

61174.703


 31%|███       | 306/1000 [03:12<08:29,  1.36it/s]

66140.46


 31%|███       | 311/1000 [03:15<07:28,  1.54it/s]

64908.6


 32%|███▏      | 316/1000 [03:19<07:02,  1.62it/s]

62766.88


 32%|███▏      | 321/1000 [03:22<07:12,  1.57it/s]

62412.547


 33%|███▎      | 326/1000 [03:25<06:27,  1.74it/s]

61900.71


 33%|███▎      | 331/1000 [03:28<06:15,  1.78it/s]

64077.855


 34%|███▎      | 336/1000 [03:30<06:07,  1.81it/s]

65687.99


 34%|███▍      | 341/1000 [03:33<06:40,  1.65it/s]

65647.7


 35%|███▍      | 346/1000 [03:37<06:57,  1.57it/s]

62320.234


 35%|███▌      | 351/1000 [03:40<06:12,  1.74it/s]

65341.023


 36%|███▌      | 356/1000 [03:42<06:00,  1.79it/s]

64325.54


 36%|███▌      | 361/1000 [03:45<06:09,  1.73it/s]

64103.14


 37%|███▋      | 366/1000 [03:48<05:53,  1.79it/s]

62919.17


 37%|███▋      | 371/1000 [03:51<05:48,  1.80it/s]

73777.97


 38%|███▊      | 376/1000 [03:54<05:52,  1.77it/s]

64180.613


 38%|███▊      | 381/1000 [03:57<05:45,  1.79it/s]

60874.227


 39%|███▊      | 386/1000 [04:00<05:43,  1.79it/s]

63744.04


 39%|███▉      | 391/1000 [04:02<05:38,  1.80it/s]

63638.363


 40%|███▉      | 396/1000 [04:05<05:43,  1.76it/s]

63413.19


 40%|████      | 401/1000 [04:08<05:44,  1.74it/s]

64252.69


 41%|████      | 406/1000 [04:11<05:38,  1.75it/s]

65612.44


 41%|████      | 411/1000 [04:14<05:40,  1.73it/s]

64412.33


 42%|████▏     | 416/1000 [04:17<05:44,  1.69it/s]

63549.7


 42%|████▏     | 421/1000 [04:20<05:33,  1.74it/s]

64965.977


 43%|████▎     | 426/1000 [04:23<05:22,  1.78it/s]

62504.31


 43%|████▎     | 431/1000 [04:26<05:14,  1.81it/s]

62551.164


 44%|████▎     | 436/1000 [04:28<05:14,  1.79it/s]

63535.605


 44%|████▍     | 441/1000 [04:31<05:19,  1.75it/s]

65260.68


 45%|████▍     | 446/1000 [04:34<05:11,  1.78it/s]

64086.953


 45%|████▌     | 451/1000 [04:37<05:09,  1.77it/s]

64709.88


 46%|████▌     | 456/1000 [04:40<05:33,  1.63it/s]

64942.7


 46%|████▌     | 461/1000 [04:43<05:19,  1.69it/s]

65531.66


 47%|████▋     | 466/1000 [04:46<05:08,  1.73it/s]

65325.01


 47%|████▋     | 471/1000 [04:49<05:01,  1.76it/s]

64356.83


 48%|████▊     | 476/1000 [04:52<05:55,  1.47it/s]

63207.977


 48%|████▊     | 481/1000 [04:55<05:13,  1.66it/s]

64936.086


 49%|████▊     | 486/1000 [04:58<04:51,  1.76it/s]

64896.086


 49%|████▉     | 491/1000 [05:01<04:46,  1.78it/s]

60189.637


 50%|████▉     | 496/1000 [05:04<04:45,  1.77it/s]

64124.2


 50%|█████     | 501/1000 [05:07<04:41,  1.77it/s]

64043.83


 51%|█████     | 506/1000 [05:10<04:38,  1.77it/s]

63760.605


 51%|█████     | 511/1000 [05:13<04:40,  1.74it/s]

59581.24


 52%|█████▏    | 516/1000 [05:16<04:32,  1.78it/s]

67036.734


 52%|█████▏    | 521/1000 [05:18<04:34,  1.75it/s]

63528.28


 53%|█████▎    | 526/1000 [05:21<04:29,  1.76it/s]

62470.05


 53%|█████▎    | 531/1000 [05:24<04:27,  1.75it/s]

62131.08


 54%|█████▎    | 536/1000 [05:27<04:18,  1.80it/s]

63024.215


 54%|█████▍    | 541/1000 [05:30<05:21,  1.43it/s]

65086.58


 55%|█████▍    | 546/1000 [05:33<04:12,  1.80it/s]

61399.805


 55%|█████▌    | 551/1000 [05:36<04:01,  1.86it/s]

63682.145


 56%|█████▌    | 556/1000 [05:39<03:57,  1.87it/s]

64984.895


 56%|█████▌    | 561/1000 [05:42<04:40,  1.57it/s]

60773.76


 57%|█████▋    | 566/1000 [05:45<04:29,  1.61it/s]

62611.605


 57%|█████▋    | 571/1000 [05:48<04:37,  1.55it/s]

65985.77


 58%|█████▊    | 576/1000 [05:51<03:50,  1.84it/s]

65835.48


 58%|█████▊    | 581/1000 [05:54<04:09,  1.68it/s]

61576.87


 59%|█████▊    | 586/1000 [05:57<03:48,  1.81it/s]

61686.29


 59%|█████▉    | 591/1000 [06:00<03:38,  1.87it/s]

64641.01


 60%|█████▉    | 596/1000 [06:02<03:31,  1.91it/s]

64337.44


 60%|██████    | 601/1000 [06:06<04:14,  1.57it/s]

66290.49


 61%|██████    | 606/1000 [06:10<04:30,  1.46it/s]

62439.24


 61%|██████    | 611/1000 [06:13<04:24,  1.47it/s]

61952.15


 62%|██████▏   | 616/1000 [06:16<03:51,  1.66it/s]

61803.266


 62%|██████▏   | 621/1000 [06:19<04:13,  1.50it/s]

66592.42


 63%|██████▎   | 626/1000 [06:23<04:09,  1.50it/s]

64513.984


 63%|██████▎   | 631/1000 [06:26<04:19,  1.42it/s]

64122.54


 64%|██████▎   | 636/1000 [06:29<03:21,  1.80it/s]

63904.09


 64%|██████▍   | 641/1000 [06:32<03:22,  1.77it/s]

64468.36


 65%|██████▍   | 646/1000 [06:34<03:12,  1.84it/s]

65212.37


 65%|██████▌   | 651/1000 [06:37<03:21,  1.73it/s]

62922.1


 66%|██████▌   | 656/1000 [06:41<03:23,  1.69it/s]

63427.613


 66%|██████▌   | 661/1000 [06:44<03:19,  1.70it/s]

64652.875


 67%|██████▋   | 666/1000 [06:48<04:06,  1.35it/s]

64381.79


 67%|██████▋   | 671/1000 [06:51<03:30,  1.56it/s]

66711.74


 68%|██████▊   | 676/1000 [06:54<03:04,  1.76it/s]

60394.23


 68%|██████▊   | 681/1000 [06:57<02:50,  1.87it/s]

65804.664


 69%|██████▊   | 686/1000 [07:00<03:14,  1.62it/s]

66017.02


 69%|██████▉   | 691/1000 [07:03<02:50,  1.81it/s]

64411.36


 70%|██████▉   | 696/1000 [07:06<02:56,  1.72it/s]

65264.516


 70%|███████   | 701/1000 [07:09<02:54,  1.72it/s]

62679.05


 71%|███████   | 706/1000 [07:12<02:50,  1.73it/s]

62807.67


 71%|███████   | 711/1000 [07:15<02:48,  1.72it/s]

63609.38


 72%|███████▏  | 716/1000 [07:18<02:47,  1.70it/s]

64257.797


 72%|███████▏  | 721/1000 [07:21<02:43,  1.70it/s]

64786.625


 73%|███████▎  | 726/1000 [07:24<02:43,  1.67it/s]

65219.035


 73%|███████▎  | 731/1000 [07:27<02:35,  1.73it/s]

61076.55


 74%|███████▎  | 736/1000 [07:30<02:49,  1.56it/s]

66359.38


 74%|███████▍  | 741/1000 [07:33<02:34,  1.68it/s]

62432.39


 75%|███████▍  | 746/1000 [07:36<02:22,  1.78it/s]

63418.51


 75%|███████▌  | 751/1000 [07:39<02:17,  1.81it/s]

62930.523


 76%|███████▌  | 756/1000 [07:42<02:15,  1.80it/s]

63269.56


 76%|███████▌  | 761/1000 [07:45<02:22,  1.67it/s]

62323.8


 77%|███████▋  | 766/1000 [07:47<02:13,  1.76it/s]

64827.64


 77%|███████▋  | 771/1000 [07:50<02:10,  1.76it/s]

64434.19


 78%|███████▊  | 776/1000 [07:53<02:07,  1.76it/s]

65270.87


 78%|███████▊  | 781/1000 [07:56<02:02,  1.78it/s]

67501.88


 79%|███████▊  | 786/1000 [07:59<01:59,  1.78it/s]

63347.05


 79%|███████▉  | 791/1000 [08:02<01:57,  1.78it/s]

61100.02


 80%|███████▉  | 796/1000 [08:05<01:55,  1.76it/s]

59776.816


 80%|████████  | 801/1000 [08:08<01:53,  1.75it/s]

64716.875


 81%|████████  | 806/1000 [08:10<01:49,  1.77it/s]

64695.27


 81%|████████  | 811/1000 [08:13<01:48,  1.74it/s]

61472.67


 82%|████████▏ | 816/1000 [08:16<01:48,  1.70it/s]

63556.645


 82%|████████▏ | 821/1000 [08:19<01:41,  1.76it/s]

65344.35


 83%|████████▎ | 826/1000 [08:23<01:54,  1.52it/s]

63728.086


 83%|████████▎ | 831/1000 [08:26<01:38,  1.72it/s]

63867.9


 84%|████████▎ | 836/1000 [08:28<01:34,  1.73it/s]

64591.64


 84%|████████▍ | 841/1000 [08:31<01:31,  1.74it/s]

65557.22


 85%|████████▍ | 846/1000 [08:34<01:29,  1.72it/s]

66049.875


 85%|████████▌ | 851/1000 [08:37<01:24,  1.77it/s]

64536.766


 86%|████████▌ | 856/1000 [08:40<01:21,  1.77it/s]

65301.93


 86%|████████▌ | 861/1000 [08:43<01:22,  1.69it/s]

65675.73


 87%|████████▋ | 866/1000 [08:46<01:15,  1.77it/s]

65856.84


 87%|████████▋ | 871/1000 [08:49<01:12,  1.78it/s]

61237.094


 88%|████████▊ | 876/1000 [08:52<01:09,  1.78it/s]

66146.81


 88%|████████▊ | 881/1000 [08:55<01:15,  1.58it/s]

64320.613


 89%|████████▊ | 886/1000 [08:58<01:05,  1.75it/s]

64113.066


 89%|████████▉ | 891/1000 [09:01<01:01,  1.77it/s]

64915.324


 90%|████████▉ | 896/1000 [09:03<00:58,  1.79it/s]

66032.52


 90%|█████████ | 901/1000 [09:06<00:56,  1.76it/s]

66681.56


 91%|█████████ | 906/1000 [09:09<00:54,  1.74it/s]

62384.414


 91%|█████████ | 911/1000 [09:12<00:51,  1.73it/s]

65153.68


 92%|█████████▏| 916/1000 [09:15<00:49,  1.69it/s]

63505.63


 92%|█████████▏| 921/1000 [09:18<00:44,  1.77it/s]

62366.047


 93%|█████████▎| 926/1000 [09:21<00:41,  1.78it/s]

65855.19


 93%|█████████▎| 931/1000 [09:24<00:39,  1.76it/s]

63661.41


 94%|█████████▎| 936/1000 [09:27<00:36,  1.77it/s]

64053.766


 94%|█████████▍| 941/1000 [09:30<00:33,  1.79it/s]

62934.766


 95%|█████████▍| 946/1000 [09:33<00:31,  1.73it/s]

63927.3


 95%|█████████▌| 951/1000 [09:35<00:27,  1.76it/s]

66731.52


 96%|█████████▌| 956/1000 [09:38<00:25,  1.74it/s]

64168.703


 96%|█████████▌| 961/1000 [09:41<00:22,  1.73it/s]

64915.4


 97%|█████████▋| 966/1000 [09:44<00:20,  1.65it/s]

63335.94


 97%|█████████▋| 971/1000 [09:47<00:17,  1.69it/s]

63383.395


 98%|█████████▊| 976/1000 [09:50<00:13,  1.76it/s]

64025.906


 98%|█████████▊| 981/1000 [09:53<00:10,  1.77it/s]

62746.53


 99%|█████████▊| 986/1000 [09:56<00:07,  1.81it/s]

67263.3


 99%|█████████▉| 991/1000 [09:59<00:05,  1.79it/s]

63222.99


100%|█████████▉| 996/1000 [10:02<00:02,  1.76it/s]

64033.81


100%|██████████| 1000/1000 [10:04<00:00,  1.65it/s]


In [59]:
input_tensor = torch.tensor([1,11,12]).type(torch.Tensor).view(1,-1,1)
model(input_tensor)

tensor([[338.7949]], grad_fn=<AddmmBackward0>)

In [3]:
import torch
import torch.nn as nn

N = 100
D = 3
hidden_unit = 50

X = torch.randn(N,D)
W = torch.randn(D,hidden_unit)
WX = torch.matmul(X,W)
print(WX.shape)
linear_layer = nn.Linear(in_features=D,out_features=hidden_unit)
wx = linear_layer(X)
print(wx.shape)

H = torch.randn(N,hidden_unit)
U = torch.randn(hidden_unit,hidden_unit)
HU = torch.matmul(H,U)
print(HU.shape)
linear_layer = nn.Linear(in_features=hidden_unit,out_features=hidden_unit)
hu = linear_layer(H)
print(hu.shape)

print((WX+HU).shape)

torch.Size([100, 50])
torch.Size([100, 50])
torch.Size([100, 50])
torch.Size([100, 50])
torch.Size([100, 50])


In [4]:
WX+HU + torch.randn(1,hidden_unit)

tensor([[-7.7264, -0.2913, -1.9945,  ...,  6.0025, -2.9757,  9.1364],
        [-2.2795, -4.2909,  6.3641,  ...,  7.9033, -3.2580, 16.0137],
        [12.5918,  0.8077,  2.4117,  ...,  0.7654, -8.3708,  1.3911],
        ...,
        [ 4.6413, -4.3244,  0.4369,  ...,  5.2540, -7.5529, -3.1038],
        [10.0056,  0.2125, -5.2950,  ..., -2.8770,  3.0147, -8.1516],
        [ 9.3609,  2.3222,  7.0069,  ...,  1.0851,  1.7609,  0.2755]])

In [519]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNCell(nn.Module):
    def __init__(self, in_features, hidden_unit, output_unit):
        super(RNNCell, self).__init__()
        self.hidden_unit = hidden_unit

        # Define the parameters as PyTorch Parameters
        self.Wax = nn.Parameter(torch.randn(in_features, hidden_unit))
        self.Waa = nn.Parameter(torch.randn(hidden_unit, hidden_unit))
        self.Wya = nn.Parameter(torch.randn(hidden_unit, output_unit))
        self.ba = nn.Parameter(torch.zeros(1, hidden_unit))  
        self.by = nn.Parameter(torch.zeros(1, output_unit))  

    def forward(self, X, H):
        # Compute next activation state
        a_next = torch.tanh(torch.matmul(X, self.Wax) + torch.matmul(H, self.Waa) + self.ba)
        # print(a_next.shape)
        # Compute output of the current cell
        yt_pred = F.softmax(torch.matmul(a_next, self.Wya) + self.by, dim=1)  # Corrected dim=0 to dim=1

        return a_next, yt_pred


# Create an instance of the RNNCell module
N = 100
D = 1
hidden_unit = 200
rnn_cell = RNNCell(in_features=D , hidden_unit=hidden_unit, output_unit=2)

# Define the input data and previous hidden state as PyTorch tensors
Xt = torch.randn(N, D)
a_prev = torch.randn(N, hidden_unit)

# Perform forward pass through the RNNCell
a_next, yt_pred = rnn_cell(Xt, a_prev)

print("Next hidden state (a_next):")
print(a_next.shape)
print("Predicted output (yt_pred):")
print(yt_pred.shape)
print(Xt.shape)

Next hidden state (a_next):
torch.Size([100, 200])
Predicted output (yt_pred):
torch.Size([100, 2])
torch.Size([100, 1])


In [638]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
torch.random.manual_seed(123)
class CustomDataset(Dataset):
    def __init__(self, input_seq=5, sample=100):
        super(CustomDataset, self).__init__()
        self.data = np.arange(1, (input_seq * sample) + 1).reshape(-1, input_seq, 1)
        self.targets = self.sum(self.data)

    def __len__(self):
        return len(self.data)

    def sum(self, data):
        Y = np.array([y.sum() for y in data])
        return Y

    def __getitem__(self, index):
        x = torch.tensor(self.data[index]).type(torch.Tensor)
        y = torch.tensor(self.targets[index]).type(torch.Tensor)
        return x, y

# Create the CustomDataset and DataLoader
dataset = CustomDataset(input_seq=3, sample=10000)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define the RNNCell module
class RNNCell(nn.Module):
    def __init__(self, in_features, hidden_unit):
        super(RNNCell, self).__init__()
        self.hidden_unit = hidden_unit

        # Define the parameters as PyTorch Parameters
        self.Wax = nn.Parameter(torch.randn(in_features, hidden_unit))
        self.Waa = nn.Parameter(torch.randn(hidden_unit, hidden_unit))
        self.ba = nn.Parameter(torch.zeros(1, hidden_unit))
        self.fc = nn.Linear(hidden_unit, 1)  # Fully connected layer for Many-to-One

    def forward(self, X):
        # Get the batch size and sequence length from the input tensor
        batch_size, seq_length, _ = X.size()

        # Initialize the hidden state as zeros
        a_prev = torch.zeros(batch_size, self.hidden_unit)

        # List to store the output at each time step
        outputs = []

        # Loop through the sequence
        for t in range(seq_length):
            # Extract the current time step's input
            Xt = X[:, t, :]

            # Compute the next hidden state and predicted output
            a_next = torch.tanh(torch.matmul(Xt, self.Wax) + torch.matmul(a_prev, self.Waa) + self.ba)
            # Append the output to the list
            outputs.append(a_next)
            a_prev =  a_next

        last_hidden_state = outputs[-1]  # Take the final hidden state
        # print(last_hidden_state.shape) # 1,HU
        # Pass the final hidden state through the fully connected layer
        out = self.fc(last_hidden_state)
        return out 

model = RNNCell(in_features=1, hidden_unit=300)
input_tensor = torch.tensor([1,2,3,4,5]).type(torch.Tensor).view(1,-1,1)
model(input_tensor)


tensor([[-0.0731]], grad_fn=<AddmmBackward0>)

In [639]:
loss_function = nn.MSELoss()
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for e in tqdm(range(100)):
    for x,y in dataloader:
        optimizer.zero_grad()


        predictions = model(x)

        loss = loss_function(predictions.view(-1), y) # [N,1] -> [N]
        
        # backward
        loss.backward()

        # optimization
        optimizer.step()


100%|██████████| 100/100 [14:46<00:00,  8.86s/it]


In [640]:
input_tensor = torch.tensor([1,2,3,4,5]).type(torch.Tensor).view(1,-1,1)
model(input_tensor)

tensor([[257.0251]], grad_fn=<AddmmBackward0>)