In [None]:
try:
  import google.colab
  !nvidia-smi -L
  google.colab.drive.mount('/content/drive/')               # For saving the model.
  #%cd ../content/drive/SaveLocation
except:
  IN_COLAB = False

GPU 0: Tesla T4 (UUID: GPU-5dfdc142-7fe8-17fe-b6f6-0b59ab820e47)
Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from nltk import sent_tokenize
import numpy as np

import math

from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

<torch._C.Generator at 0x7fc308049150>

In [None]:
import nltk
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
with open("Ulysses - James Joyce.txt", "r") as f:
    corpus = f.read()


print(corpus[:1000])

[ 1 ]

Stately, plump Buck Mulligan came from the stairhead, bearing a bowl of
lather on which a mirror and a razor lay crossed. A yellow
dressinggown, ungirdled, was sustained gently behind him on the mild
morning air. He held the bowl aloft and intoned:

—_Introibo ad altare Dei_.

Halted, he peered down the dark winding stairs and called out coarsely:

—Come up, Kinch! Come up, you fearful jesuit!

Solemnly he came forward and mounted the round gunrest. He faced about
and blessed gravely thrice the tower, the surrounding land and the
awaking mountains. Then, catching sight of Stephen Dedalus, he bent
towards him and made rapid crosses in the air, gurgling in his throat
and shaking his head. Stephen Dedalus, displeased and sleepy, leaned
his arms on the top of the staircase and looked coldly at the shaking
gurgling face that blessed him, equine in its length, and at the light
untonsured hair, grained and hued like pale oak.

Buck Mulligan peeped an instant under the mirror and then c

In [None]:
# Tokenizer function returns clean data
def clean_text(text):
    import re
    text = re.sub(
        r'(Mr\.|Mrs\.|Ms\.)[a-zA-Z]*', '<TITLE>', text)
    text = re.sub(r'https?:\/\/\S+\b(?!\.)?', '<URL>', text)
    text = re.sub(r'@\w+', '<MENTION>', text)
    text = re.sub(r'#\w+', '<HASHTAG>', text)
    text = re.sub(r'\S*[\w\~\-]\@[\w\~\-]\S*', r'<EMAIL>', text)

    text = re.sub(r'([a-zA-Z]+)n[\'’]t', r'\1 not', text)
    text = re.sub(r'([iI])[\'’]m', r'\1 am', text)
    text = re.sub(r'([a-zA-Z]+)[\'’]s', r'\1 is', text)

    text = re.sub(r'\*{2,}.*?\*{2,}', '', text, flags=re.DOTALL)

    text = re.sub(r"_(.*?)_", r"\1", text)
    text = text.split()
    text = " ".join(text)
#     text = "<s> " + text + " <e>"

    text = re.sub(r'[^\w\s<>.]', ' ', text)
    text = text.lower()

    return text

In [None]:
# corpus = sent_tokenize(corpus)
corpus = clean_text(corpus)
# corpus[:1000]

In [None]:
text = []
corpus = corpus.split(". ")
for sent in corpus:
    split = sent.split(" ")
    split = [s for s in split if len(s) > 0]
    if len(split) > 0:
      sent = " ".join(split)
      text.append(sent)

text_len = len(text)
text[:100]

['1 stately plump buck mulligan came from the stairhead bearing a bowl of lather on which a mirror and a razor lay crossed',
 'a yellow dressinggown ungirdled was sustained gently behind him on the mild morning air',
 'he held the bowl aloft and intoned introibo ad altare dei',
 'halted he peered down the dark winding stairs and called out coarsely come up kinch come up you fearful jesuit solemnly he came forward and mounted the round gunrest',
 'he faced about and blessed gravely thrice the tower the surrounding land and the awaking mountains',
 'then catching sight of stephen dedalus he bent towards him and made rapid crosses in the air gurgling in his throat and shaking his head',
 'stephen dedalus displeased and sleepy leaned his arms on the top of the staircase and looked coldly at the shaking gurgling face that blessed him equine in its length and at the light untonsured hair grained and hued like pale oak',
 'buck mulligan peeped an instant under the mirror and then covered the 

In [None]:
train_split = int(0.7*text_len)
dev_split = int(0.15*text_len)

train_data = text[:train_split]
dev_data = text[train_split:train_split+dev_split]
test_data = text[train_split+dev_split:]


In [None]:
vocab = set()
for sent in train_data:
    vocab.update(sent.split())
len(vocab)

18691

In [None]:
vocab.add("<unk>")
vocab.add("<eos>")
len(vocab)
vocab = list(vocab)
vocab_dict = {}
for i, word in enumerate(vocab):
    vocab_dict[word] = i
    

In [None]:
word_count = {}
for sent in text:
    for word in sent.split():
        if word_count.get(word) is None:
            word_count[word] = 1
        else:
            word_count[word]+=1
           

In [None]:
word_count


{'1': 90,
 'stately': 3,
 'plump': 19,
 'buck': 116,
 'mulligan': 155,
 'came': 194,
 'from': 1080,
 'the': 14877,
 'stairhead': 3,
 'bearing': 25,
 'a': 6425,
 'bowl': 12,
 'of': 8115,
 'lather': 4,
 'on': 2036,
 'which': 511,
 'mirror': 36,
 'and': 7197,
 'razor': 5,
 'lay': 66,
 'crossed.': 1,
 'yellow': 47,
 'dressinggown': 1,
 'ungirdled': 2,
 'was': 2112,
 'sustained': 11,
 'gently': 27,
 'behind': 129,
 'him': 1312,
 'mild': 11,
 'morning': 88,
 'air.': 30,
 'he': 4090,
 'held': 60,
 'aloft': 4,
 'intoned': 1,
 'introibo': 2,
 'ad': 29,
 'altare': 2,
 'dei.': 1,
 'halted': 34,
 'peered': 11,
 'down': 412,
 'dark': 114,
 'winding': 11,
 'stairs': 14,
 'called': 83,
 'out': 824,
 'coarsely': 1,
 'come': 230,
 'up': 728,
 'kinch': 22,
 'you': 1873,
 'fearful': 2,
 'jesuit': 9,
 'solemnly': 6,
 'forward': 62,
 'mounted': 4,
 'round': 231,
 'gunrest.': 1,
 'faced': 4,
 'about': 509,
 'blessed': 46,
 'gravely': 12,
 'thrice': 4,
 'tower': 23,
 'surrounding': 2,
 'land': 72,
 'awaking'

In [None]:
vocab[5234]

'goggles'

In [None]:
def get_data(corpus, vocab, batch_size):
    data = []                                                   
    for sent in corpus:                                      
        tokens = sent.split()
        tokens.append("<eos>")
        temptokens = []
        for word in tokens:
          try:
            if word != "<eos>" and word_count[word] >= 3:
                    temptokens.append(vocab[word])
            else:
                temptokens.append(vocab["<unk>"]) 
          except KeyError:
              temptokens.append(vocab["<unk>"])
        data.extend(temptokens)   
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                       
    data = data.view(batch_size, num_batches)          
    return data


In [None]:
train_data = get_data(train_data, vocab_dict, 128)
dev_data = get_data(dev_data, vocab_dict, 128)
test_data = get_data(test_data, vocab_dict, 128)


In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, 
                tie_weights):
                
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        if tie_weights:
            assert embedding_dim == hidden_dim, 'cannot tie, check dims'
            self.embedding.weight = self.fc.weight
        self.init_weights()

    def forward(self, src, hidden):
        embedding = self.dropout(self.embedding(src))
        output, hidden = self.lstm(embedding, hidden)          
        output = self.dropout(output) 
        prediction = self.fc(output)
        return prediction, hidden
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hidden_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.embedding_dim,
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hidden_dim, 
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell
    
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell


In [None]:
vocab_size = len(vocab_dict)
embedding_dim = 1150             # 400 in the paper
hidden_dim = 1150                # 1150 in the paper
num_layers = 2                   # 3 in the paper
dropout_rate = 0.5              
tie_weights = True                  
lr = 1e-3  

In [None]:
model = LSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 42,694,043 trainable parameters


In [None]:
def get_batch(data, seq_len, num_batches, idx):
    src = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]             
    return src, target


In [None]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):  # The last batch can't be a src
        optimizer.zero_grad()
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, num_batches, idx)
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        prediction = prediction.reshape(batch_size * seq_len, -1)   
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches


In [None]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, num_batches, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
n_epochs = 50
seq_len = 50
clip = 0.25
saved = False

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

if saved:
    model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
    test_loss = evaluate(model, test_data, criterion, 128, seq_len, device)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')
else:
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):
        train_loss = train(model, train_data, optimizer, criterion, 
                    128, seq_len, clip, device)
        valid_loss = evaluate(model, dev_data, criterion, 128, 
                    seq_len, device)
        
        lr_scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

        print(f'\t{epoch} Train Perplexity: {math.exp(train_loss):.3f}')
        print(f'\t{epoch} Valid Perplexity: {math.exp(valid_loss):.3f}')



	0 Train Perplexity: 868.115
	0 Valid Perplexity: 377.119




	1 Train Perplexity: 451.417
	1 Valid Perplexity: 326.972




	2 Train Perplexity: 409.413
	2 Valid Perplexity: 309.833




	3 Train Perplexity: 374.239
	3 Valid Perplexity: 295.125




	4 Train Perplexity: 333.837
	4 Valid Perplexity: 269.115




	5 Train Perplexity: 293.663
	5 Valid Perplexity: 257.679




	6 Train Perplexity: 261.657
	6 Valid Perplexity: 246.698




	7 Train Perplexity: 232.933
	7 Valid Perplexity: 235.730




	8 Train Perplexity: 211.661
	8 Valid Perplexity: 229.571




	9 Train Perplexity: 193.443
	9 Valid Perplexity: 227.291




	10 Train Perplexity: 178.252
	10 Valid Perplexity: 223.472




	11 Train Perplexity: 165.532
	11 Valid Perplexity: 224.953




	12 Train Perplexity: 152.322
	12 Valid Perplexity: 223.156




	13 Train Perplexity: 144.399
	13 Valid Perplexity: 222.999




	14 Train Perplexity: 137.772
	14 Valid Perplexity: 224.640




	15 Train Perplexity: 131.011
	15 Valid Perplexity: 226.603




	16 Train Perplexity: 128.335
	16 Valid Perplexity: 225.374




	17 Train Perplexity: 125.184
	17 Valid Perplexity: 226.404




	18 Train Perplexity: 124.086
	18 Valid Perplexity: 226.154




	19 Train Perplexity: 123.808
	19 Valid Perplexity: 226.318




	20 Train Perplexity: 123.272
	20 Valid Perplexity: 226.449




	21 Train Perplexity: 123.390
	21 Valid Perplexity: 226.491




	22 Train Perplexity: 123.191
	22 Valid Perplexity: 226.502




	23 Train Perplexity: 123.101
	23 Valid Perplexity: 226.506




	24 Train Perplexity: 123.200
	24 Valid Perplexity: 226.509




	25 Train Perplexity: 123.140
	25 Valid Perplexity: 226.509




	26 Train Perplexity: 122.739
	26 Valid Perplexity: 226.510




	27 Train Perplexity: 123.175
	27 Valid Perplexity: 226.510




	28 Train Perplexity: 123.036
	28 Valid Perplexity: 226.510




	29 Train Perplexity: 122.759
	29 Valid Perplexity: 226.511




	30 Train Perplexity: 123.147
	30 Valid Perplexity: 226.511




	31 Train Perplexity: 123.191
	31 Valid Perplexity: 226.511




	32 Train Perplexity: 123.037
	32 Valid Perplexity: 226.511




	33 Train Perplexity: 122.976
	33 Valid Perplexity: 226.511




	34 Train Perplexity: 123.399
	34 Valid Perplexity: 226.511




	35 Train Perplexity: 123.223
	35 Valid Perplexity: 226.511




	36 Train Perplexity: 123.040
	36 Valid Perplexity: 226.511




	37 Train Perplexity: 123.103
	37 Valid Perplexity: 226.511




	38 Train Perplexity: 122.857
	38 Valid Perplexity: 226.511




	39 Train Perplexity: 122.861
	39 Valid Perplexity: 226.511




	40 Train Perplexity: 123.147
	40 Valid Perplexity: 226.511




	41 Train Perplexity: 123.187
	41 Valid Perplexity: 226.511




	42 Train Perplexity: 123.188
	42 Valid Perplexity: 226.511




	43 Train Perplexity: 123.258
	43 Valid Perplexity: 226.511




	44 Train Perplexity: 123.092
	44 Valid Perplexity: 226.511




	45 Train Perplexity: 123.147
	45 Valid Perplexity: 226.511




	46 Train Perplexity: 123.017
	46 Valid Perplexity: 226.511




	47 Train Perplexity: 123.018
	47 Valid Perplexity: 226.511




	48 Train Perplexity: 123.129
	48 Valid Perplexity: 226.511




	49 Train Perplexity: 123.288
	49 Valid Perplexity: 226.511


In [None]:
n_epochs = 50
seq_len = 50
clip = 0.25
saved = True

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

if saved:
    model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
    test_loss = evaluate(model, test_data, criterion, 128, seq_len, device)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')
    import json

    with open(f"ptvocab{math.exp(test_loss):.3f}.json", "w") as f:
        json.dump(vocab_dict, f)

    with open(f"word_count{math.exp(test_loss):.3f}.json", "w") as f:
        json.dump(word_count, f)


else:
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):
        train_loss = train(model, train_data, optimizer, criterion, 
                    128, seq_len, clip, device)
        valid_loss = evaluate(model, dev_data, criterion, 128, 
                    seq_len, device)
        
        lr_scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

        print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
        print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')
        


Test Perplexity: 277.126


In [None]:
def calculate_perp(data, vocab_dict):
    data = clean_text(data)
    data = data.split(". ")
    data = get_data(data, vocab_dict, 16)
    print(data)
    model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
    test_loss = evaluate(model, data, criterion, 16, 1, device)
    print(test_loss)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')

In [None]:
with open("ptvocab277.126.json", "r") as f:
  import json
  vocab_dict = json.load(f)

with open("word_count277.126.json", "r") as f:
  import json
  word_count = json.load(f)

calculate_perp("I am is Bhanuj. Bhanuj is, good. one new good word. We're careful about orange ping pong balls because people might think they're fruit. She cried diamonds. Getting up at dawn is for the birds. The beach was crowded with snow leopards.", vocab_dict)

tensor([[11169, 16356, 16803],
        [13031, 13031, 13031],
        [16803,  9999, 13031],
        [15553,  2516,  9999],
        [ 7663, 13031, 10178],
        [ 8116,  1458, 12267],
        [ 4128, 13031, 13031],
        [ 7553,  6970,  4605],
        [15353,  4200,  7865],
        [ 8116, 10499, 13031],
        [18392, 14781, 12606],
        [13031,  8733,  8400],
        [12901,  7064, 16803],
        [18403,  5319,  5286],
        [13031,  5319,  6832],
        [10736, 13972,  6390]])
3.699110190073649
Test Perplexity: 40.411
