# Language modeling with LSTM in Pytorch

based on
https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [1]:
import torch
from collections import Counter
from nltk.tokenize import word_tokenize

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEQUENCE_LENGTH = 15

class PoohDataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length, device):
        txt = open('../data/pooh1.txt').read()
        txt += open('../data/pooh2.txt').read()
        
        self.words = word_tokenize(txt.lower())
        #self.words = word_tokenize(txt)
        
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]
        self.sequence_length = sequence_length
        self.device = device


    def get_uniq_words(self):
        #return list(set(self.words))
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length], device=self.device),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1], device=self.device)
        )
        
pooh_dataset = PoohDataset(SEQUENCE_LENGTH, device)        

In [2]:
from torch.utils.data import DataLoader
dataloader = DataLoader(pooh_dataset, batch_size=512)

for x,y in dataloader:
    print (x.shape, y.shape)
    break

torch.Size([512, 15]) torch.Size([512, 15])


In [3]:
from torch import nn, optim

class LSTMModel(nn.Module):
    def __init__(self, dataset, device):
        super(LSTMModel, self).__init__()
        self.lstm_size = 512
        self.embedding_dim = 100
        self.num_layers = 2
        self.device = device
        

        n_vocab = len(dataset.uniq_words)
        
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
            batch_first=True
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        
        #Remember: state is a pair (h,c)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state_batched(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device))
        
    def init_state_unbatched(self):
        return (torch.zeros(self.num_layers, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, self.lstm_size).to(self.device))
        
        
model = LSTMModel(pooh_dataset, device) 
model.to(device)

LSTMModel(
  (embedding): Embedding(2791, 100)
  (lstm): LSTM(100, 512, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=512, out_features=2791, bias=True)
)

In [4]:
def predict(dataset, model, text, next_words=15):
    model.eval()

    words = word_tokenize(text.lower())
    state_h, state_c = model.init_state_unbatched()

    for i in range(0, next_words):
        x = torch.tensor([dataset.word_to_index[w] for w in words[i:]])
        x = x.to(device)
        
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return ' '.join(words)

In [5]:
import numpy as np


batch_size = 512
max_epochs = 30

def train(dataset, model):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state_batched(batch_size)
        
        for batch, (x, y) in enumerate(dataloader):
            if x.shape[0] != batch_size:
                continue # better option: change the size of state_h, state_c, according to the last batch size
                
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            
            loss = criterion(y_pred.transpose(1, 2), y)
                        
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()

        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
        print (predict(pooh_dataset, model, 'and then'))
        model.train()
            
train(pooh_dataset, model)       

{'epoch': 0, 'batch': 128, 'loss': 5.454943656921387}
and then first 5. on rabbit had away say haven't to blue-bells that first or that him
{'epoch': 1, 'batch': 128, 'loss': 5.040148735046387}
and then you working owl before the reply to on to . '' if had come were
{'epoch': 2, 'batch': 128, 'loss': 4.6826066970825195}
and then written that he had anything a letters best things and did happy the room .
{'epoch': 3, 'batch': 128, 'loss': 4.476163387298584}
and then was piglet to answer , '' '' explained a point with as why it is
{'epoch': 4, 'batch': 128, 'loss': 4.31948709487915}
and then 's finished yet , his which once , which i he knew he said nothing
{'epoch': 5, 'batch': 128, 'loss': 4.20013952255249}
and then '' chapter face used . `` i wish , '' said pooh , `` if
{'epoch': 6, 'batch': 128, 'loss': 4.091150760650635}
and then and piglet knocked to his . `` nobody is , '' said pooh , only
{'epoch': 7, 'batch': 128, 'loss': 3.954676389694214}
and then piglet ; `` i 'll want the fr

In [6]:
torch.save(model.state_dict(), 'pooh_2x512_30ep_2023.model')

In [7]:
def predict2(dataset, model, text, next_words=15):
    model.eval()

    words = word_tokenize(text.lower())
    state_h, state_c = model.init_state_unbatched()

    for i in range(0, next_words):
        x = torch.tensor([dataset.word_to_index[w] for w in words[i:]])
        x = x.to(device)
        
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[-1] / 0.001
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return ' '.join(words)


In [9]:
speakers = ['pooh', 'piglet', 'christopher robin', 'rabbit', 'owl', 'tigger', 'eeyore']

for s in speakers:
    prompt = 'i am tired ' + s 
    for i in range(1):
        print (predict2(pooh_dataset, model, prompt, 50))
    print ()

i am tired pooh , who was n't quite sure about a little smackerel of something , and then he thought that perhaps he had been a little while ago , and he was so excited that he had been opening and said `` yes , yes , '' said pooh , `` i

i am tired piglet , `` i think it was a very particular animal that it was a very good thing to do , but it was a very good thing to do , but it was a very good thing to do , but it was a good thing to do , but

i am tired christopher robin , who was just beginning to think , '' said pooh , `` i am not quite sure , '' said pooh , `` i am not quite sure , '' said pooh , `` i am not quite sure , '' said pooh , `` i am not quite

i am tired rabbit , `` i shall have to go on , `` i am not sure , '' said pooh , `` i think it 's a very funny thing , '' said pooh , `` i think it 's a very funny thing , '' said pooh , `` i think

i am tired owl , `` i think it 's a very funny thing , '' said pooh , `` i think it 's a very funny thing , '' said pooh , `` i thi

In [26]:
E = model.embedding.weight.cpu().detach().numpy()

In [27]:
len(E)

2791

In [28]:
def length(a):
    return a.dot(a) ** 0.5

def cos(a, b):
    return a.dot(b) / (length(a) * length(b))

def best_line(a, K):
    res = [(cos(E[a], E[b]), b) for b in range(len(E)) if b != a]
    res.sort(reverse=True)
    return [pooh_dataset.index_to_word[i] for (v, i) in res[:K]]

special_words = 'eeyore tigger pooh piglet long honey door silent path house forest thought brain little'.split()

for w in special_words:
    print ('WORD', w)
    for b in best_line(pooh_dataset.word_to_index[w], 5):
        print ('    ', b)
    print ()



WORD eeyore
     exciting
     branches
     slowly
     fir-cone
     spikes

WORD tigger
     good-morning
     portrait
     fir-cones
     wanting
     hill

WORD pooh
     gloomily
     walking
     crying
     jaws
     coming

WORD piglet
     pop
     5.
     laugh
     up
     tiddely-poms

WORD long
     turned
     smile
     where
     ing
     hummy

WORD honey
     deed
     clapped
     pin
     organize
     surface

WORD door
     revolving
     notice-board
     ached
     gon
     gorse-bush

WORD silent
     w-what
     handkerchief
     deception
     roo's
     hurriedly

WORD path
     downstairs
     living
     dishes
     weight
     temporary

WORD house
     wednesday
     helpfully
     misses
     meekly
     buffeted

WORD forest
     smallest-of-all
     glowed
     explains
     own
     sends

WORD thought
     seat
     'help
     'all
     singers
     holding

WORD brain
     dorsal
     forward
     usual
     sparkle
     'help

WORD little
     p

In [40]:
from collections import defaultdict as dd

vowels = set("aoiuye'")
def devowelize(s):
    rv = ''.join(a for a in s if a not in vowels)
    if rv:
        return rv
    return '_'

representation = dd(set)

for w in pooh_dataset.words:
    r = devowelize(w)
    representation[r].add(w)
    
hard_words = set()
for r, ws in representation.items():
    if len(ws) > 1:
        hard_words.update(ws)
        
print (len(hard_words))    

893


In [43]:
dataloader = DataLoader(pooh_dataset, batch_size=batch_size)
for batch, (x, y) in enumerate(dataloader):
    print (x.shape, y.shape)
    break

RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.