# Recurrent Neural Networks and Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext, datasets, math
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cuda:2


## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task.  This time, we will use the `datasets` library from HuggingFace to load.

In [5]:
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

#there are raw and preprocessed version; we used the raw one and preprocessed ourselves for fun
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')
print(dataset)

Found cached dataset wikitext (/root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 224.22it/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})





In [7]:
print(dataset['train'][333]['text'])

'''
If you try to change the index you might notice that sometimes there is no paragraph 
and rather an empty string so we will have to care of that later.
'''

 During the same time frame as the Hitchcock rumors , goaltender Curtis Sanford returned from his groin injury on November 13 . He made his first start of the season against the Boston Bruins , losing 2 – 1 in a shootout . Sanford continued his strong play , posting a 3 – 1 – 2 record , 1 @.@ 38 goals against average and .947 save percentage over his next six games . Sanford started 12 consecutive games before Steve Mason made his next start . The number of starts might not have been as numerous , but prior to the November 23 game , Mason was hit in the head by a shot from Rick Nash during pre @-@ game warm @-@ ups and suffered a concussion . Mason returned from his concussion after two games , making a start against the Vancouver Canucks . Mason allowed only one goal in the game despite suffering from cramping in the third period , temporarily being replaced by Sanford for just over three minutes . Columbus won the game 2 – 1 in a shootout , breaking a nine @-@ game losing streak to t

'\nIf you try to change the index you might notice that sometimes there is no paragraph \nand rather an empty string so we will have to care of that later.\n'

## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [8]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

#function to tokenize
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}  

#map the function to each example
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
print(tokenized_dataset['train'][333]['tokens'])

100%|██████████| 4358/4358 [00:00<00:00, 5889.72ex/s]
100%|██████████| 36718/36718 [00:05<00:00, 6396.75ex/s]
100%|██████████| 3760/3760 [00:00<00:00, 4830.63ex/s]

['during', 'the', 'same', 'time', 'frame', 'as', 'the', 'hitchcock', 'rumors', ',', 'goaltender', 'curtis', 'sanford', 'returned', 'from', 'his', 'groin', 'injury', 'on', 'november', '13', '.', 'he', 'made', 'his', 'first', 'start', 'of', 'the', 'season', 'against', 'the', 'boston', 'bruins', ',', 'losing', '2', '–', '1', 'in', 'a', 'shootout', '.', 'sanford', 'continued', 'his', 'strong', 'play', ',', 'posting', 'a', '3', '–', '1', '–', '2', 'record', ',', '1', '@', '.', '@', '38', 'goals', 'against', 'average', 'and', '.', '947', 'save', 'percentage', 'over', 'his', 'next', 'six', 'games', '.', 'sanford', 'started', '12', 'consecutive', 'games', 'before', 'steve', 'mason', 'made', 'his', 'next', 'start', '.', 'the', 'number', 'of', 'starts', 'might', 'not', 'have', 'been', 'as', 'numerous', ',', 'but', 'prior', 'to', 'the', 'november', '23', 'game', ',', 'mason', 'was', 'hit', 'in', 'the', 'head', 'by', 'a', 'shot', 'from', 'rick', 'nash', 'during', 'pre', '@-@', 'game', 'warm', '@-@




### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.  Also we shall make sure to add `unk` and `eos`.

In [9]:
## numericalizing
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], 
min_freq=3) 
vocab.insert_token('<unk>', 0)           
vocab.insert_token('<eos>', 1)            
vocab.set_default_index(vocab['<unk>'])   
print(len(vocab))                         
print(vocab.get_itos()[:10])       

29473
['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [10]:
def get_data(dataset, vocab, batch_size):
    data = []                                                   
    for example in dataset:
        if example['tokens']:         
            #appends eos so we know it ends....so model learn how to end...                             
            tokens = example['tokens'].append('<eos>') #end of sentence
            #numericalize          
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                       
    data = data.view(batch_size, num_batches)        
    return data


In [11]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'], vocab, batch_size)

In [12]:
train_data.shape #[batch_size, all the next length]

torch.Size([128, 16214])

## 4. Modeling 

In [13]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
                
        super().__init__()
        self.hid_dim = hid_dim
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size,emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, 
                                        dropout = dropout_rate, batch_first = True)
        self.dropout = nn.Dropout(dropout_rate)
        #when you do LM, you look forward, so it does not make sense to do bidirectional
        self.fc = nn.Linear(hid_dim,vocab_size)

    def init_hidden(self, batch_size, device):
        #this function gonna be run in the beginning of the epoch
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)

        return hidden, cell #return as tuple

    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #removing this hidden from gradients graph
        cell =  cell.detach() #removing this hidden from gradients graph
        return hidden, cell

    def forward(self, src, hidden):
        #src: [batch_size, seq_len]

        #embed 
        embedded = self.embedding(src)
        #embed : [batch_size, seq_len, emb_dim]

        #send this to the lstm
        #we want to put hidden here... because we want to reset hidden .....
        output, hidden = self.lstm(embedded, hidden)
        #output : [batch_size, seq_len, hid_dim] ==> all hidden states
        #hidden : [batch_size, seq_len, hid_dim] ==> last hidden states from each layer

        output = self.dropout(output)
        prediction = self.fc(output)
        #prediction: [batch size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [14]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [15]:
model = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 77,183,777 trainable parameters


In [16]:
def get_batch(data, seq_len, idx):
    #this data is from get_data()
    #train_data.shape # [batch_size, number of batches....]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [17]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip) #prevent gradient explosion - clip is basically 
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [18]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [19]:
n_epochs = 50
seq_len  = 50
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                           

	Train Perplexity: 585.331
	Valid Perplexity: 287.178


                                                           

	Train Perplexity: 299.731
	Valid Perplexity: 218.282


                                                           

	Train Perplexity: 223.134
	Valid Perplexity: 184.422


                                                           

	Train Perplexity: 180.191
	Valid Perplexity: 166.492


                                                           

	Train Perplexity: 151.608
	Valid Perplexity: 152.975


                                                           

	Train Perplexity: 131.643
	Valid Perplexity: 144.683


                                                           

	Train Perplexity: 116.931
	Valid Perplexity: 141.777


                                                           

	Train Perplexity: 105.754
	Valid Perplexity: 139.468


                                                           

	Train Perplexity: 96.521
	Valid Perplexity: 134.585


                                                           

	Train Perplexity: 88.916
	Valid Perplexity: 133.029


                                                           

	Train Perplexity: 82.543
	Valid Perplexity: 135.352


                                                           

	Train Perplexity: 73.955
	Valid Perplexity: 132.622


                                                           

	Train Perplexity: 69.369
	Valid Perplexity: 132.661


                                                           

	Train Perplexity: 65.103
	Valid Perplexity: 129.927


                                                           

	Train Perplexity: 62.816
	Valid Perplexity: 130.531


                                                           

	Train Perplexity: 61.166
	Valid Perplexity: 131.216


                                                           

	Train Perplexity: 60.566
	Valid Perplexity: 130.808


                                                           

	Train Perplexity: 60.991
	Valid Perplexity: 129.733


                                                           

	Train Perplexity: 60.475
	Valid Perplexity: 129.597


                                                           

	Train Perplexity: 60.154
	Valid Perplexity: 129.981


                                                           

	Train Perplexity: 60.495
	Valid Perplexity: 129.214


                                                           

	Train Perplexity: 60.486
	Valid Perplexity: 129.112


                                                           

	Train Perplexity: 59.921
	Valid Perplexity: 129.132


                                                           

	Train Perplexity: 60.068
	Valid Perplexity: 128.956


                                                           

	Train Perplexity: 60.018
	Valid Perplexity: 128.892


                                                           

	Train Perplexity: 59.981
	Valid Perplexity: 128.922


                                                           

	Train Perplexity: 60.195
	Valid Perplexity: 128.929


                                                           

	Train Perplexity: 60.402
	Valid Perplexity: 128.896


                                                           

	Train Perplexity: 60.467
	Valid Perplexity: 128.889


                                                           

	Train Perplexity: 60.504
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.537
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.444
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.527
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.535
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.506
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.521
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.479
	Valid Perplexity: 128.878


                                                           

	Train Perplexity: 60.448
	Valid Perplexity: 128.879


                                                           

	Train Perplexity: 60.433
	Valid Perplexity: 128.879


                                                           

	Train Perplexity: 60.482
	Valid Perplexity: 128.879


                                                           

	Train Perplexity: 60.470
	Valid Perplexity: 128.879


                                                           

	Train Perplexity: 60.416
	Valid Perplexity: 128.879


                                                           

KeyboardInterrupt: 

## 6. Testing

In [20]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 121.783


## 7. Real-world inference

Here we take the prompt, tokenize, encode and feed it into the model to get the predictions.  We then apply softmax while specifying that we want the output due to the last word in the sequence which represents the prediction for the next word.  We divide the logits by a temperature value to alter the model’s confidence by adjusting the softmax probability distribution.

Once we have the Softmax distribution, we randomly sample it to make our prediction on the next word. If we get <unk> then we give that another try.  Once we get <eos> we stop predicting.
    
We decode the prediction back to strings last lines.

In [21]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [22]:
prompt = 'Harry Potter is '
max_seq_len = 30
seed = 0
            #superdiverse   more diverse
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0] 
#sample from this distribution higher probability will get more change
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
harry potter is a member of the church of st . louis . the cathedral is the first of the cathedral ' s first part of the 19th century .

0.7
harry potter is a reference to the church of the holy church . the main cathedral is the st . louis rail , which is the site of the cathedral . the cathedral

0.75
harry potter is a reference to the church of the holy church . the main cathedral is the st . louis rail , which is the site of the cathedral . the cathedral

0.8
harry potter is said to be a bit stronger .

1.0
harry potter is said to be hell .

