# Shakespeare Language Model

In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import time

import shakespeare_data as sh

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

## Fixed length input

In [2]:
# Data - refer to shakespeare_data.py for details
corpus = sh.read_corpus()
print("First 203 characters...Last 50 characters")
print("{}...{}".format(corpus[:203], corpus[-50:]))
print("Total character count: {}".format(len(corpus)))
chars, charmap = sh.get_charmap(corpus)
charcount = len(chars)
print("Unique character count: {}\n".format(len(chars)))
shakespeare_array = sh.map_corpus(corpus, charmap)
print("shakespeare_array.shape: {}\n".format(shakespeare_array.shape))
small_example = shakespeare_array[:17]
print("First 17 characters as indices", small_example)
print("First 17 characters as characters:", [chars[c] for c in small_example])
print("First 17 character indices as text:\n", sh.to_text(small_example,chars))

First 203 characters...Last 50 characters
1609
 THE SONNETS
 by William Shakespeare
                      1
   From fairest creatures we desire increase,
   That thereby beauty's rose might never die,
   But as the riper should by time decease,
...,
   And new pervert a reconciled maid.'
 THE END

Total character count: 5551930
Unique character count: 84

shakespeare_array.shape: (5551930,)

First 17 characters as indices [12 17 11 20  0  1 45 33 30  1 44 40 39 39 30 45 44]
First 17 characters as characters: ['1', '6', '0', '9', '\n', ' ', 'T', 'H', 'E', ' ', 'S', 'O', 'N', 'N', 'E', 'T', 'S']
First 17 character indices as text:
 1609
 THE SONNETS


In [3]:
# Dataset class. Transform raw text into a set of sequences of fixed length, and extracts inputs and targets
class TextDataset(Dataset):
    def __init__(self,text, seq_len = 200):
        n_seq = len(text) // seq_len
        text = text[:n_seq * seq_len]
        self.data = torch.tensor(text).view(-1,seq_len)
    def __getitem__(self,i):
        txt = self.data[i]
        return txt[:-1],txt[1:]
    def __len__(self):
        return self.data.size(0)

# Collate function. Transform a list of sequences into a batch. Passed as an argument to the DataLoader.
# Returns data on the format seq_len x batch_size
def collate(seq_list):
    inputs = torch.cat([s[0].unsqueeze(1) for s in seq_list],dim=1)
    targets = torch.cat([s[1].unsqueeze(1) for s in seq_list],dim=1)
    return inputs,targets


In [4]:
# Model
class CharLanguageModel(nn.Module):

    def __init__(self,vocab_size,embed_size,hidden_size, nlayers):
        super(CharLanguageModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.nlayers=nlayers
        self.embedding = nn.Embedding(vocab_size,embed_size) # Embedding layer
        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # Recurrent network
        self.scoring = nn.Linear(hidden_size,vocab_size) # Projection layer
        
    def forward(self,seq_batch): #L x N
        # returns 3D logits
        batch_size = seq_batch.size(1)
        embed = self.embedding(seq_batch) #L x N x E
        hidden = None
        output_lstm,hidden = self.rnn(embed,hidden) #L x N x H
        output_lstm_flatten = output_lstm.view(-1,self.hidden_size) #(L*N) x H
        output_flatten = self.scoring(output_lstm_flatten) #(L*N) x V
        return output_flatten.view(-1,batch_size,self.vocab_size)
    
    def generate(self,seq, n_words): # L x V
        # performs greedy search to extract and return words (one sequence).
        generated_words = []
        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E
        hidden = None
        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H
        output = output_lstm[-1] # 1 x H
        scores = self.scoring(output) # 1 x V
        _,current_word = torch.max(scores,dim=1) # 1 x 1
        generated_words.append(current_word)
        if n_words > 1:
            for i in range(n_words-1):
                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E
                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H
                output = output_lstm[0] # 1 x H
                scores = self.scoring(output) # V
                _,current_word = torch.max(scores,dim=1) # 1
                generated_words.append(current_word)
        return torch.cat(generated_words,dim=0)
        
        

In [13]:
def train_epoch(model, optimizer, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(DEVICE)
    before = time.time()
    print("training", len(train_loader), "number of batches")
    for batch_idx, (inputs,targets) in enumerate(train_loader):
        if batch_idx == 0:
            first_time = time.time()
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        outputs = model(inputs) # 3D
        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx == 0:
            print("Time elapsed", time.time() - first_time)
            
        if batch_idx % 100 == 0 and batch_idx != 0:
            after = time.time()
            print("Time: ", after - before)
            print("Loss per word: ", loss.item() / batch_idx)
            print("Perplexity: ", np.exp(loss.item() / batch_idx))
            after = before
    
    val_loss = 0
    batch_id=0
    for inputs,targets in val_loader:
        batch_id+=1
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        outputs = model(inputs)
        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1))
        val_loss+=loss.item()
    val_lpw = val_loss / batch_id
    print("\nValidation loss per word:",val_lpw)
    print("Validation perplexity :",np.exp(val_lpw),"\n")
    return val_lpw
    

In [14]:
model = CharLanguageModel(charcount,256,256,3)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)
split = 5000000
train_dataset = TextDataset(shakespeare_array[:split])
val_dataset = TextDataset(shakespeare_array[split:])
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate, drop_last=True)

In [15]:
for i in range(3):
    train_epoch(model, optimizer, train_loader, val_loader)

training 391 number of batches
Time elapsed 0.14118456840515137
Time:  5.480123996734619
Loss per word:  0.06285614013671875
Perplexity:  1.0648736356360111

Validation loss per word: 1.7341472437215406
Validation perplexity : 5.664095650645451 

training 391 number of batches
Time elapsed 0.10048866271972656
Time:  5.435199022293091
Loss per word:  0.03288536071777344
Perplexity:  1.0334320605406284

Validation loss per word: 1.5287958910298902
Validation perplexity : 4.612619380716595 

training 391 number of batches
Time elapsed 0.10138869285583496
Time:  5.433004140853882
Loss per word:  0.0290824031829834
Perplexity:  1.029509425833496

Validation loss per word: 1.4559050049892692
Validation perplexity : 4.288362699613226 



In [16]:
def generate(model, seed,nwords):
    seq = sh.map_corpus(seed, charmap)
    seq = torch.tensor(seq).to(DEVICE)
    out = model.generate(seq,nwords)
    return sh.to_text(out.cpu().detach().numpy(),chars)

In [17]:
print(generate(model, "To be, or not to be, that is the q",8))

uiet of 


In [18]:
print(generate(model, "Richard ", 1000))

the world to the world to the world
     The world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to the world to the world to the world
     That the world to th

## Packed sequences

In [19]:
stop_character = charmap['\n']
space_character = charmap[" "]
lines = np.split(shakespeare_array, np.where(shakespeare_array == stop_character)[0]+1) # split the data in lines
shakespeare_lines = []
for s in lines:
    s_trimmed = np.trim_zeros(s-space_character)+space_character # remove space-only lines
    if len(s_trimmed)>1:
        shakespeare_lines.append(s)
for i in range(10):
    print(sh.to_text(shakespeare_lines[i],chars))
print(len(shakespeare_lines))

1609

 THE SONNETS

 by William Shakespeare

                      1

   From fairest creatures we desire increase,

   That thereby beauty's rose might never die,

   But as the riper should by time decease,

   His tender heir might bear his memory:

   But thou contracted to thine own bright eyes,

   Feed'st thy light's flame with self-substantial fuel,

114638


In [20]:
class LinesDataset(Dataset):
    def __init__(self,lines):
        self.lines=[torch.tensor(l) for l in lines]
    def __getitem__(self,i):
        line = self.lines[i]
        return line[:-1].to(DEVICE),line[1:].to(DEVICE)
    def __len__(self):
        return len(self.lines)

# collate fn lets you control the return value of each batch
# for packed_seqs, you want to return your data sorted by length
def collate_lines(seq_list):
    inputs,targets = zip(*seq_list)
    lens = [len(seq) for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

In [21]:
# Model that takes packed sequences in training
class PackedLanguageModel(nn.Module):
    
    def __init__(self,vocab_size,embed_size,hidden_size, nlayers, stop):
        super(PackedLanguageModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.nlayers=nlayers
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # 1 layer, batch_size = False
        self.scoring = nn.Linear(hidden_size,vocab_size)
        self.stop = stop # stop line character (\n)
    
    def forward(self,seq_list): # list
        batch_size = len(seq_list)
        lens = [len(s) for s in seq_list] # lens of all lines (already sorted)
        bounds = [0]
        for l in lens:
            bounds.append(bounds[-1]+l) # bounds of all lines in the concatenated sequence
        seq_concat = torch.cat(seq_list) # concatenated sequence
        embed_concat = self.embedding(seq_concat) # concatenated embeddings
        embed_list = [embed_concat[bounds[i]:bounds[i+1]] for i in range(batch_size)] # embeddings per line
        packed_input = rnn.pack_sequence(embed_list) # packed version
        hidden = None
        output_packed,hidden = self.rnn(packed_input,hidden)
        output_padded, _ = rnn.pad_packed_sequence(output_packed) # unpacked output (padded)
        output_flatten = torch.cat([output_padded[:lens[i],i] for i in range(batch_size)]) # concatenated output
        scores_flatten = self.scoring(output_flatten) # concatenated logits
        return scores_flatten # return concatenated logits
    
    def generate(self,seq, n_words): # L x V
        generated_words = []
        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E
        hidden = None
        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H
        output = output_lstm[-1] # 1 x H
        scores = self.scoring(output) # 1 x V
        _,current_word = torch.max(scores,dim=1) # 1 x 1
        generated_words.append(current_word)
        if n_words > 1:
            for i in range(n_words-1):
                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E
                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H
                output = output_lstm[0] # 1 x H
                scores = self.scoring(output) # V
                _,current_word = torch.max(scores,dim=1) # 1
                generated_words.append(current_word)
                if current_word[0].item()==self.stop: # If end of line
                    break
        return torch.cat(generated_words,dim=0)

In [43]:
def train_epoch_packed(model, optimizer, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss(reduction="sum") # sum instead of averaging, to take into account the different lengths
    criterion = criterion.to(DEVICE)
    batch_id=0
    before = time.time()
    print("Training", len(train_loader), "number of batches")
    for inputs,targets in train_loader: # lists, presorted, preloaded on GPU
        batch_id+=1
        outputs = model(inputs)
        loss = criterion(outputs,torch.cat(targets)) # criterion of the concatenated output
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_id % 100 == 0:
            after = time.time()
            nwords = np.sum(np.array([len(l) for l in inputs]))
            lpw = loss.item() / nwords
            print("Time elapsed: ", after - before)
            print("At batch",batch_id)
            print("Training loss per word:",lpw)
            print("Training perplexity :",np.exp(lpw))
            before = after
    
    val_loss = 0
    batch_id=0
    nwords = 0
    for inputs,targets in val_loader:
        nwords += np.sum(np.array([len(l) for l in inputs]))
        batch_id+=1
        outputs = model(inputs)
        loss = criterion(outputs,torch.cat(targets))
        val_loss+=loss.item()
    val_lpw = val_loss / nwords
    print("\nValidation loss per word:",val_lpw)
    print("Validation perplexity :",np.exp(val_lpw),"\n")
    return val_lpw

In [44]:
model = PackedLanguageModel(charcount,256,256,3, stop=stop_character)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)
split = 100000
train_dataset = LinesDataset(shakespeare_lines[:split])
val_dataset = LinesDataset(shakespeare_lines[split:])
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate_lines)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines, drop_last=True)

In [45]:
for i in range(20):
    train_epoch_packed(model, optimizer, train_loader, val_loader)

Training 1563 number of batches
Time elapsed:  8.641294002532959
At batch 100
Training loss per word: 2.742883137914904
Training perplexity : 15.531700639823379
Time elapsed:  8.498454809188843
At batch 200
Training loss per word: 2.3086599491908766
Training perplexity : 10.060933455568435
Time elapsed:  8.486337661743164
At batch 300
Training loss per word: 2.0011692090510236
Training perplexity : 7.39770050277287
Time elapsed:  8.501046419143677
At batch 400
Training loss per word: 1.8189667917986516
Training perplexity : 6.1654849282699615
Time elapsed:  8.46365237236023
At batch 500
Training loss per word: 1.8059468757545878
Training perplexity : 6.085731152505224
Time elapsed:  8.46724796295166
At batch 600
Training loss per word: 1.769131321431426
Training perplexity : 5.86575569132725
Time elapsed:  8.495564222335815
At batch 700
Training loss per word: 1.7194178916612746
Training perplexity : 5.581278609664651
Time elapsed:  8.486192464828491
At batch 800
Training loss per word

Time elapsed:  8.70168399810791
At batch 100
Training loss per word: 1.2418934268228337
Training perplexity : 3.4621626141972115
Time elapsed:  8.626674890518188
At batch 200
Training loss per word: 1.1778679438481208
Training perplexity : 3.2474430857409047
Time elapsed:  8.595263719558716
At batch 300
Training loss per word: 1.2199829023362427
Training perplexity : 3.3871298211194683
Time elapsed:  8.714223146438599
At batch 400
Training loss per word: 1.1951881521072014
Training perplexity : 3.3041794003361007
Time elapsed:  8.538036823272705
At batch 500
Training loss per word: 1.2012900288384414
Training perplexity : 3.324402733129015
Time elapsed:  8.50441312789917
At batch 600
Training loss per word: 1.229068253391473
Training perplexity : 3.418043301884305
Time elapsed:  8.59624719619751
At batch 700
Training loss per word: 1.2383571552498367
Training perplexity : 3.449941089015638
Time elapsed:  8.592359781265259
At batch 800
Training loss per word: 1.226766673653238
Training 

Time elapsed:  8.480568885803223
At batch 100
Training loss per word: 1.192062441325104
Training perplexity : 3.29386761538276
Time elapsed:  8.735455751419067
At batch 200
Training loss per word: 1.2190057841399802
Training perplexity : 3.38382181135961
Time elapsed:  8.494828462600708
At batch 300
Training loss per word: 1.1886550210449385
Training perplexity : 3.282663124116132
Time elapsed:  8.451841354370117
At batch 400
Training loss per word: 1.129733033466058
Training perplexity : 3.094830173744251
Time elapsed:  8.485805988311768
At batch 500
Training loss per word: 1.1904845342010437
Training perplexity : 3.2886742965716342
Time elapsed:  8.504512786865234
At batch 600
Training loss per word: 1.1841058020113673
Training perplexity : 3.2677634871033923
Time elapsed:  8.495721340179443
At batch 700
Training loss per word: 1.166028470087904
Training perplexity : 3.2092217751073555
Time elapsed:  8.496970415115356
At batch 800
Training loss per word: 1.1948274427925882
Training p

Time elapsed:  8.602280616760254
At batch 100
Training loss per word: 1.1490344318115462
Training perplexity : 3.1551449308356894
Time elapsed:  8.57916784286499
At batch 200
Training loss per word: 1.1811499477844756
Training perplexity : 3.258118715878707
Time elapsed:  8.544205904006958
At batch 300
Training loss per word: 1.108135528908009
Training perplexity : 3.0287061917481424
Time elapsed:  8.637030601501465
At batch 400
Training loss per word: 1.1722897607514646
Training perplexity : 3.2293786837248213
Time elapsed:  8.527188062667847
At batch 500
Training loss per word: 1.0946651855164904
Training perplexity : 2.988182029263203
Time elapsed:  8.53776741027832
At batch 600
Training loss per word: 1.1453837968438745
Training perplexity : 3.143647647384427
Time elapsed:  8.647263765335083
At batch 700
Training loss per word: 1.1851636104313987
Training perplexity : 3.2712219837256526
Time elapsed:  8.71172308921814
At batch 800
Training loss per word: 1.072200887678157
Training 

KeyboardInterrupt: 

In [47]:
torch.save(model, "trained_model.pt")

  "type " + obj.__name__ + ". It won't be checked "


In [55]:
print(generate(model, "To be, or not to be, that is the q",20))

uarrel



In [71]:
print(generate(model, "Richard ", 1000))

    The sea of the sea of the strength of the streets



In [84]:
print(generate(model, "Hello", 1000))

res.

