<a href="https://colab.research.google.com/github/inderpreetsingh01/PyTorch/blob/main/D2l_6_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import collections
import torch
from torch import nn
from torch.nn import functional as F

In [38]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        # The list of unique tokens
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    @property
    def unk(self):  # Index for the unknown token
        return self.token_to_idx['<unk>']

In [34]:
text = 'The vast ocean stretched endlessly under the bright blue sky shimmering in the sunlight Waves crashed softly along the shore creating a rhythm that felt timeless Seagulls soared high their cries echoing above while the wind carried a salty scent across the land On the horizon ships sailed gracefully their white sails billowing in the gentle breeze The beauty of nature reminded everyone there of the serenity and wonder that existed in the world if one paused to see and appreciate it'

In [35]:
text = text.split()


In [56]:
text1 = [ch.lower() for word in text for ch in word] + [' ']

In [48]:
text1

['t',
 'h',
 'e',
 'v',
 'a',
 's',
 't',
 'o',
 'c',
 'e',
 'a',
 'n',
 's',
 't',
 'r',
 'e',
 't',
 'c',
 'h',
 'e',
 'd',
 'e',
 'n',
 'd',
 'l',
 'e',
 's',
 's',
 'l',
 'y',
 'u',
 'n',
 'd',
 'e',
 'r',
 't',
 'h',
 'e',
 'b',
 'r',
 'i',
 'g',
 'h',
 't',
 'b',
 'l',
 'u',
 'e',
 's',
 'k',
 'y',
 's',
 'h',
 'i',
 'm',
 'm',
 'e',
 'r',
 'i',
 'n',
 'g',
 'i',
 'n',
 't',
 'h',
 'e',
 's',
 'u',
 'n',
 'l',
 'i',
 'g',
 'h',
 't',
 'w',
 'a',
 'v',
 'e',
 's',
 'c',
 'r',
 'a',
 's',
 'h',
 'e',
 'd',
 's',
 'o',
 'f',
 't',
 'l',
 'y',
 'a',
 'l',
 'o',
 'n',
 'g',
 't',
 'h',
 'e',
 's',
 'h',
 'o',
 'r',
 'e',
 'c',
 'r',
 'e',
 'a',
 't',
 'i',
 'n',
 'g',
 'a',
 'r',
 'h',
 'y',
 't',
 'h',
 'm',
 't',
 'h',
 'a',
 't',
 'f',
 'e',
 'l',
 't',
 't',
 'i',
 'm',
 'e',
 'l',
 'e',
 's',
 's',
 's',
 'e',
 'a',
 'g',
 'u',
 'l',
 'l',
 's',
 's',
 'o',
 'a',
 'r',
 'e',
 'd',
 'h',
 'i',
 'g',
 'h',
 't',
 'h',
 'e',
 'i',
 'r',
 'c',
 'r',
 'i',
 'e',
 's',
 'e',
 'c',
 'h'

In [57]:
vocab = Vocab(text1)

In [58]:
vocab.token_freqs[:10]

[('e', 64),
 ('t', 39),
 ('h', 31),
 ('s', 30),
 ('i', 30),
 ('a', 27),
 ('n', 26),
 ('r', 26),
 ('l', 21),
 ('o', 19)]

In [59]:
len(vocab)

26

In [None]:
# RNN

In [53]:
class RNNScratch(nn.Module):
    """The RNN model implemented from scratch."""
    def __init__(self, num_inputs, num_hiddens, vocab_size, sigma=0.01):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_hiddens = num_hiddens
        self.W_xh = nn.Parameter(
            torch.randn(num_inputs, num_hiddens) * sigma)
        self.W_hh = nn.Parameter(
            torch.randn(num_hiddens, num_hiddens) * sigma)
        self.b_h = nn.Parameter(torch.zeros(num_hiddens))
        self.W_hq = nn.Parameter(
            torch.randn(num_hiddens, vocab_size) * sigma)
        self.b_q = nn.Parameter(torch.zeros(vocab_size))


    def forward(self, inputs, state=None):
      inputs = self.one_hot(inputs)

      if state is None:
          # Initial state with shape: (batch_size, num_hiddens)
          state = torch.zeros((inputs.shape[1], self.num_hiddens),
                            device=inputs.device)
      else:
          state, = state
      outputs = []
      # iterating over time steps (num_steps)
      for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs)
          state = torch.tanh(torch.matmul(X, self.W_xh) +
                          torch.matmul(state, self.W_hh) + self.b_h)
          outputs.append(state)

      outputs = [torch.matmul(H, self.W_hq) + self.b_q for H in outputs]
      return torch.stack(outputs, 1), state


    def one_hot(self, X):
      # Output shape: (num_steps, batch_size, vocab_size)
      return F.one_hot(X.T, self.vocab_size).type(torch.float32)

    def clip_gradients(self, grad_clip_val, model):
      params = [p for p in model.parameters() if p.requires_grad]
      norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
      if norm > grad_clip_val:
          for param in params:
              param.grad[:] *= grad_clip_val / norm


    def predict(self, prefix, num_preds, vocab, device=None):
      state, outputs = None, [vocab[prefix[0]]]
      for i in range(len(prefix) + num_preds - 1):
          X = torch.tensor([[outputs[-1]]], device=device)
          Y, state = self.forward(X, state)
          if i < len(prefix) - 1:  # Warm-up period
              outputs.append(vocab[prefix[i + 1]])
          else:  # Predict num_preds steps
              # Y = self.output_layer(rnn_outputs)
              outputs.append(int(Y.argmax(axis=2).reshape(1)))
      return ''.join([vocab.idx_to_token[i] for i in outputs])

In [62]:
batch_size, num_inputs, num_hiddens, num_steps = 2, 26, 32, 10
vocab_size = 26
rnn = RNNScratch(num_inputs, num_hiddens, vocab_size)
# X = torch.ones((num_steps, batch_size, num_inputs))
out = rnn.predict('who the fuck are you ', 10, vocab)
# X = torch.ones((batch_size, num_steps), dtype=torch.int64)
# outputs, state = rnn(X)

In [63]:
out

'who the fuck are you a<unk>zekonnnn'

In [52]:
vocab['a']

1

In [32]:
outputs.shape

torch.Size([2, 10, 30])

In [33]:
state.shape

torch.Size([2, 32])

In [3]:
# LSTM

class LSTMScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()

        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))
        self.W_xi, self.W_hi, self.b_i = triple()  # Input gate
        self.W_xf, self.W_hf, self.b_f = triple()  # Forget gate
        self.W_xo, self.W_ho, self.b_o = triple()  # Output gate
        self.W_xc, self.W_hc, self.b_c = triple()  # Input node


    def forward(self, inputs, H_C=None):
        if H_C is None:
            # Initial state with shape: (batch_size, num_hiddens)
            H = torch.zeros((inputs.shape[1], self.num_hiddens),
                          device=inputs.device)
            C = torch.zeros((inputs.shape[1], self.num_hiddens),
                          device=inputs.device)
        else:
            H, C = H_C
        outputs = []
        for X in inputs:
            I = torch.sigmoid(torch.matmul(X, self.W_xi) +
                            torch.matmul(H, self.W_hi) + self.b_i)
            F = torch.sigmoid(torch.matmul(X, self.W_xf) +
                            torch.matmul(H, self.W_hf) + self.b_f)
            O = torch.sigmoid(torch.matmul(X, self.W_xo) +
                            torch.matmul(H, self.W_ho) + self.b_o)
            C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +
                              torch.matmul(H, self.W_hc) + self.b_c)
            C = F * C + I * C_tilde
            H = O * torch.tanh(C)
            outputs.append(H)
        return outputs, (H, C)

In [None]:
# GRU

class GRUScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()

        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))
        self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state

    def forward(self, inputs, H=None):
      if H is None:
          # Initial state with shape: (batch_size, num_hiddens)
          H = torch.zeros((inputs.shape[1], self.num_hiddens),
                        device=inputs.device)
      outputs = []
      for X in inputs:
          Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
                          torch.matmul(H, self.W_hz) + self.b_z)
          R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                          torch.matmul(H, self.W_hr) + self.b_r)
          H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +
                            torch.matmul(R * H, self.W_hh) + self.b_h)
          H = Z * H + (1 - Z) * H_tilde
          outputs.append(H)
      return outputs, H

In [1]:
# Bidirectional Neural Network

class BiRNNScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.f_rnn = RNNScratch(num_inputs, num_hiddens, sigma)
        self.b_rnn = RNNScratch(num_inputs, num_hiddens, sigma)
        self.num_hiddens *= 2

    def forward(self, inputs, Hs=None):
      f_H, b_H = Hs if Hs is not None else (None, None)
      f_outputs, f_H = self.f_rnn(inputs, f_H)
      b_outputs, b_H = self.b_rnn(reversed(inputs), b_H)
      outputs = [torch.cat((f, b), -1) for f, b in zip(
          f_outputs, reversed(b_outputs))]
      return outputs, (f_H, b_H)

NameError: name 'nn' is not defined