In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import List, Tuple
from collections import Counter

## RNN Basic
### Token
For language model, we need to tokenize our words (or characters) into number. We can tokenize our inputs based on each input's frequency. We tokenize our input based on input's frequency because this allow computer to cache more frequently used words and improve training efficiency:

In [2]:
import string
import unicodedata


# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicode_to_ascii(s):
    allowed_characters = string.ascii_letters
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in allowed_characters
    )

class Tokenizer:

    def __init__(self, tokens:List[str], reserved_tokens:List[str]=[], min_freq:int=0) -> None:
        self.unk = '<unk>'
        self.reserved_tokens = reserved_tokens

        counter = Counter(tokens)
        tokens_freq = sorted(counter.items(), key=lambda token:token[1], reverse=True)
        sorted_tokens = list([self.unk] + self.reserved_tokens + [
            token for token, freq in tokens_freq if freq > min_freq
        ])

        self.token_to_idx = {
            token:index for index, token in enumerate(sorted_tokens)
        }

        self.idx_to_token = {
            index:token for index, token, in enumerate(sorted_tokens)
        }

    def to_idx(self, token:str) -> int:
        if token not in self.token_to_idx:
            return self.token_to_idx[self.unk]

        return self.token_to_idx[token]
    
    def to_token(self, idx:int) -> str:
        if idx not in self.idx_to_token:
            return self.unk

        return self.idx_to_token[idx]
    
    def get_most_frequent(self, n: int) -> List[str]:
        res = []
        start_idx = len(self.unk) + len(self.reserved_tokens)
        for idx in range(n):
            res.append(self.to_token(start_idx + idx))
        return res
    
    def __getitem__(self, index):
        if isinstance(index, (list, tuple, slice)):
            return [self.__getitem__(i) for i in index]
        return self.to_idx(index)

    def __len__(self) -> int:
        return len(self.token_to_idx)

### Construct Dataset
We devided the input text into segments with the length of num_inputs.

In [3]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
class TextDataset(Dataset):

    def __init__(self, text_file:Path, num_inputs:int) -> None:
        super().__init__()
        
        with text_file.open('r') as f:
            text = f.read()
            splitted_text = text.split()
            splitted_text = [self.preprocess_text(text).lower() for text in splitted_text]

        self.tokenizer = Tokenizer(splitted_text)
        tokenized = [self.tokenizer.to_idx(text) for text in splitted_text]

        array = torch.tensor([tokenized[i:i+num_inputs+1]
                        for i in range(len(tokenized)-num_inputs)])
        self.X, self.Y = array[:,:-1], array[:,-1:]

    def preprocess_text(self, text) -> List[str]:
        return unicode_to_ascii(text)
    
    def __getitem__(self, index):
        if isinstance(index, (list, tuple)):
            return [self.__getitem__(idx) for idx in index]
        elif isinstance(index, slice):
            return [self.__getitem__(idx) for idx in range(*index.indices(len(self)))]
        return (self.X[index], self.Y[index])
    
    def __len__(self) -> int:
        return len(self.X)

dataset = TextDataset(Path(r'shakespeare.txt'), 16)
train_num = int(len(dataset) * 0.8)
train_dataset = dataset[:train_num]
valid_dataset = dataset[train_num:]

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=64,
                          shuffle=True,
                          num_workers=2)

test_loader = DataLoader(dataset=valid_dataset,
                         batch_size=64,
                         shuffle=False,
                         num_workers=True)

len(train_dataset), len(valid_dataset)




(14058, 3515)

### Language Model
Let each input in time t denoted as $x_t$. Our goal is to predict $x_t$ given $x_0, x_1, ..., x_{t-1}$.
The probability of a sequence of words with length t will then be:
$$
P(x_1, x_2, ..., x_t) = P(x_1) * \prod_{t=2}^T P(x_t  \mid  x_1, \ldots, x_{t-1})
$$

The probability of $x_1$ is $P(x_1)$.

The probability of $x_1$ and $x_2$ is $P(x_1, x_2) = P(x_2 | x_1) * P(x_1)$.

That is, the joint probability of $x_1$ and $x_2$ is just the probability of $x_1$ times the probability of $x_2$ given $x_1$.

We can treat the output of the model at each stage as the probability of $x_t$ given $x_1, ..., x_{t-1}$

### Perplexity
We can measure the cross-entropy loss averaged over all the tokens of a sequence with perplexity:
$$\frac{1}{n} \sum_{t=1}^n -\log P(x_t \mid x_{t-1}, \ldots, x_1),$$
where $P$ is given by a language model and $x_t$ is the actual token observed at time step $t$ from the sequence.
This makes the performance on documents of different lengths comparable. For historical reasons, scientists in natural language processing prefer to use a quantity called *perplexity*:

$$\exp\left(-\frac{1}{n} \sum_{t=1}^n \log P(x_t \mid x_{t-1}, \ldots, x_1)\right)$$

### RNN Model
RNN model is similar to MLP. The key difference is that instead of using hidden layer, RNN typically uses hidden states to store all the features from the previous samples. At each batch, the hidden state from the last batch will be multiplied with the current input to compute the current hiddent state. The current hidden state will then be multiplied by a weight to compute the output of the current stage.


![screenshot](resources/rnn_with_hidden_state.png)

The calculation of the hidden layer output of the current time step is determined by the input of the current time step together with the hidden layer output of the previous time step:

$$\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{\textrm{xh}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hh}}  + \mathbf{b}_\textrm{h}).$$

For time step $t$,
the output of the output layer is similar to the computation in the MLP:

$$\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{\textrm{hq}} + \mathbf{b}_\textrm{q}.$$


In [None]:
class RNN_Scratch(nn.Module):

    def __init__(self, num_inputs:int, num_hiddens:int):
        super().__init__()

        self.num_inputs = num_inputs
        self.num_hiddens = num_hiddens

        self.w_xh = nn.Parameter(
            torch.randn(num_inputs, num_hiddens)
        )

        self.w_hh = nn.Parameter(
            torch.randn((num_hiddens, num_hiddens))
        )

        self.b_h = nn.Parameter(
            torch.randn(num_hiddens)
        )
    
    def forward(self, X, state=None):
        if state is None:
            state = nn.Parameter(
                torch.randn((self.num_hiddens, self.num_hiddens))
            )
        else:
            state, = state

        # Input size will be (steps, batchs, inputs)
        outputs = []
        for step in X:
            cur_hidden = torch.matmul(step, self.w_xh) + self.b_hidden
            cur_output = torch.matmul()
            outputs.append(cur_hidden)

        return outputs, state

class RNN_LM_Scratch(nn.Module):

    def __init__(self, rnn: RNN_Scratch, num_vocabs:int):
        super().__init__()

        self.rnn = rnn
        self.num_vocabs = num_vocabs
        
        self.w_hidden_out = nn.Parameter(
            torch.randn((rnn.num_hiddens, self.num_vocabs))
        )

        self.b_out = nn.Parameter(
            torch.randn(self.num_vocabs)
        )

    def forward(self, X, state=None):

        # The shape of the embedding will be (step, batch, input)
        emb = F.one_hot(X.T, self.num_outputs)



SyntaxError: expected ':' (3858558754.py, line 14)

In [None]:
# Shape: batch, num_steps
test_tensor = torch.tensor([
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
])

# Shape: batch, num_steps, num_inputs
emb = F.one_hot(test_tensor.T, 6)

(torch.Size([5, 2, 6]),
 tensor([[[0, 1, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0]],
 
         [[0, 0, 1, 0, 0, 0],
          [0, 0, 1, 0, 0, 0]],
 
         [[0, 0, 0, 1, 0, 0],
          [0, 0, 0, 1, 0, 0]],
 
         [[0, 0, 0, 0, 1, 0],
          [0, 0, 0, 0, 1, 0]],
 
         [[0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 1]]]))