# RNN and Long Short-Term Memory (LSTM) recurrent neural network cells

We will implement RNN and LSTM cells from scratch and use them to build a language model. We will then train the model on a dataset of Shakespeare's writing.

Like it says, practice makes perfect. So, let's get started!

In [1]:
# import packages that are not related to torch
import os
import math
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


# torch import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tu_data
from torchvision.datasets import FashionMNIST


### --------- environment setup --------- ###
# set up the data path
DATA_PATH = "../GPT-2/data"

# function for setting seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
# set up seed globally and deterministically
set_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
print("Device: ", device)

Device:  cuda


In [3]:
# read the data
with open(os.path.join(DATA_PATH, "input.txt"), "r") as f:
    text = f.read()

In [5]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [7]:
chars = sorted(list(set(text)))
print(chars)
print("The number of unique characters: {}".format(len(chars)))
print("".join(chars))

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
The number of unique characters: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [8]:
# create index
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}
print(char2idx)
print(idx2char)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i',

In [9]:
# encdoe and decode functions
encode = lambda text: [char2idx[ch] for ch in text]
decode = lambda tnsr: "".join([idx2char[i] for i in tnsr])

# test the encode and decode functions
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [27]:
# data loader for each batch
# return a generator
def get_batch(encoded_text, seq_len, batch_size):
    """
    encoded_text: the encoded text
    seq_len: the length of the sequence
    batch_size: the size of the batch
    """
    # calculate the number of batches
    n_batches = encoded_text.shape[0] // (seq_len * batch_size)
    # reshape the encoded text
    encoded_text = encoded_text[:n_batches * batch_size * seq_len]
    encoded_text = encoded_text.reshape((batch_size, -1))
    # loop through the encoded text
    for i in range(0, encoded_text.shape[1], seq_len):
        # get the input and the target
        X = encoded_text[:, i:i+seq_len]
        Y = torch.zeros_like(X)
        # this part is tricky, we need to shift the input and the target
        # by one character and avoid the loop (meaning the last character
        # of the target should not be the first character of the input)
        Y[:, :-1], Y[:, -1] = X[:, 1:], encoded_text[:, i+seq_len]
        yield X, Y

In [28]:
# test it
encoded_text = torch.tensor(encode(text), dtype=torch.long, device=device)
foo_x, foo_y = next(get_batch(encoded_text, 10, 32))
print(foo_x.shape, foo_y.shape)
print(foo_x[0])
print(foo_y[0])

torch.Size([32, 10]) torch.Size([32, 10])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47], device='cuda:0')
tensor([47, 56, 57, 58,  1, 15, 47, 58, 47, 64], device='cuda:0')


In [75]:
F.one_hot(foo_x[0], num_classes=len(chars)).shape

torch.Size([10, 65])

In [12]:
# for this kind of model, we will not split the data into train and test
# as we are not doing classification, we are doing generation
# so we will use all the data for training

In [79]:
# we will wrint RNN again but this time we will use PyTorch's framework
# it is important to do inheritance here from nn.Module
class RNN(nn.Module):

    def __init__(self, vocab_size, hidden_size, dropout=0.5, bias=True):
        super().__init__()

        # we will not use embedding layer here
        # instead, we will use one-hot encoding
        # so the input size will be the vocab size
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.dropout = dropout

        # initialize the weights by following the rule in the paper
        # or you can use nn.init.xavier_uniform_ or kaime_uniform_
        self.Wxh = nn.init.kaiming_uniform_(torch.empty(vocab_size, hidden_size, device=device))
        self.Whh = nn.init.kaiming_uniform_(torch.empty(hidden_size, hidden_size, device=device))
        self.bh = torch.zeros(hidden_size, device=device)

        if bias:
            self.bh = torch.zeros(hidden_size, device=device)
            # no need to do by as we will use a linear layer
        else:
            self.bh = None


        # define a dropout layer
        self.dropout_layer = nn.Dropout(dropout)
        # define a linear layer
        self.linear_layer = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h_t):
        """
        x: the input, shape = (batch_size, seq_len, embedding_size)
            if no embedding layer, shape = (batch_size, seq_len, vocab_size)
        h_t: the hidden state at time t, shape = (batch_size, hidden_size)
        """
        batch_size, seq_len, _ = x.shape
        hidden_sequences = []

        # loop through the sequence
        for t in range(seq_len):
            # get the current input
            x_t = x[:, t, :]
            # calculate the hidden state
            h_t = torch.tanh(x_t @ self.Wxh + h_t @ self.Whh + self.bh)
            # append the hidden state, shape = (batch_size, hidden_size)
            hidden_sequences.append(h_t) 
        
        # stack the hidden states, shape = (seq_len, batch_size, hidden_size)
        hidden_sequences = torch.stack(hidden_sequences)
        # transpose the hidden states, shape = (batch_size, seq_len, hidden_size)
        hidden_sequences = hidden_sequences.view(batch_size, seq_len, -1)

        # stack the hidden states, shape = (batch_size * seq_len, hidden_size)
        # we need to do this because we will use a linear layer
        hidden_sequences = hidden_sequences.view(-1, self.hidden_size)
        # hidden_sequences.shape = (batch_size * seq_len, hidden_size)
        # apply dropout
        hidden_sequences = self.dropout_layer(hidden_sequences)
        # apply the linear layer
        logits = self.linear_layer(hidden_sequences)
        # logits.shape = (batch_size * seq_len, vocab_size)

        # we need to return the logits and the last hidden state
        return logits, h_t
    
    
    def init_hidden(self, batch_size):
        """
        initialize the hidden state
        """
        h0 = torch.zeros(batch_size, self.hidden_size)
        # push to device
        h0 = h0.to(device)
        return h0
    
    def parameters(self):
        """
        return all the parameters
        """
        params = [self.Wxh, self.Whh]
        if self.bh is not None:
            params.append(self.bh)
        # add the parameters of the linear layer and dropout layer
        params += list(self.linear_layer.parameters())
        params += list(self.dropout_layer.parameters())
        return params


In [83]:
# test the RNN
max_steps = 10000
vocab_size = len(chars)
hidden_size = 512
batch_size = 128
seq_len = 100
learning_rate = 1e-3

# encode the text
encoded_text = torch.tensor(encode(text), dtype=torch.long, device=device)

# initialize the model
rnn = RNN(vocab_size, hidden_size, dropout=0.5, bias=True)
rrnn = rnn.to(device)

# initialize the optimizer
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

# initialize the hidden state
h_t = rnn.init_hidden(batch_size)

# begin training
for i in range(max_steps):
    # get the input and the target
    X, Y = next(get_batch(encoded_text, seq_len, batch_size))
    # turn X, Y into one-hot encoding
    X = F.one_hot(X, vocab_size).float()
    Y = F.one_hot(Y, vocab_size).float()
    # X.shape = (batch_size, seq_len, vocab_size)
    # Y.shape = (batch_size, seq_len, vocab_size)
    # put X, Y on the device
    X = X.to(device)
    Y = Y.to(device)
    # get the logits and the last hidden state
    logits, h_t = rnn(X, h_t)
    # logits.shape = (batch_size * seq_len, vocab_size)
    # h_t.shape = (batch_size, hidden_size)
    # y.shape = (batch_size * seq_len, vocab_size)
    # reshape Y
    Y = Y.view(-1, vocab_size)
    # calculate the loss
    loss = F.cross_entropy(logits, Y)
    # zero the gradients
    optimizer.zero_grad()
    # backprop
    loss.backward()
    # clip the gradients
    torch.nn.utils.clip_grad_norm_(rnn.parameters(), 5)
    # update the parameters
    optimizer.step()

    if i % 2000 == 0:
        print(f"step: {i}, loss: {loss.item():.4f}")


step: 0, loss: 4.2684
step: 2000, loss: 3.3287
step: 4000, loss: 3.2865
step: 6000, loss: 3.2795
step: 8000, loss: 3.2803


In [84]:
def predict(model, char, device, h=None, top_k=5):
        ''' Given a character & hidden state, predict the next character.
            Returns the predicted character and the hidden state.
        '''
        
        # tensor inputs
        x = np.array([[char2int[char]]])
        x = F.one_hot(torch.from_numpy(x), len(model.chars)).float().to(device)
        inputs = torch.from_numpy(x).to(device)
        
        with torch.no_grad():
            # get the output of the model
            out, h = model(inputs, h)

            # get the character probabilities
            # move to cpu for further processing with numpy etc. 
            p = F.softmax(out, dim=1).data.cpu()

            # get the top characters with highest likelihood
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()

            # select the likely next character with some element of randomness
            # for more variability
            p = p.numpy().squeeze()
            char = np.random.choice(top_ch, p=p/p.sum())
        
        # return the encoded value of the predicted char and the hidden state
        return int2char[char], h

In [85]:
def sample(model, size, device, prime='A', top_k=None):
    # method to generate new text based on a "prime"/initial sequence. 
    # Basically, the outer loop convenience function that calls the above
    # defined predict method. 
    model.eval() # eval mode
    
    # Calculate model for the initial prime characters
    chars = [ch for ch in prime]
    with torch.no_grad():
        # initialize hidden with 0 in the beginning. Set our batch size to 1 
        # as we wish to generate one sequence only. 
        h = model.init_hidden(batch_size=1)
        for ch in prime:
            char, h = predict(model, ch, device, h=h, top_k=top_k)

        # append the characters to the sequence
        chars.append(char)

        # Now pass in the previous/last character and get a new one
        # Repeat this process for the desired length of the sequence to be 
        # generated
        for ii in range(size):
            char, h = predict(model, chars[-1], device, h=h, top_k=top_k)
            chars.append(char)

    return ''.join(chars)