# Assigment 5

**Submission deadlines**:

* last lab before 27.06.2022 

**Points:** Aim to get 12 out of 15+ possible points

All needed data files are on Drive: <https://drive.google.com/drive/folders/1uufpGn46Mwv4oBwajIeOj4rvAK96iaS-?usp=sharing> (or will be soon :) )

## Task 1 (5 points)

Consider the vowel reconstruction task -- i.e. inserting missing vowels (aeuioy) to obtain proper English text. For instance for the input sentence:

<pre>
h m gd smbd hs stln ll m vwls
</pre>

the best result is

<pre>
oh my god somebody has stolen all my vowels
</pre>

In this task both dev and test data come from the two books about Winnie-the-Pooh. You have to train two RNN Language Models on *pooh-train.txt*. For the first model use the code below, for the second choose different hyperparameters (different dropout, smaller number of units or layers, or just do any modification you want). 

The code below is based on
https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [1]:
from collections import Counter
from collections import defaultdict as dd
from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
SEQUENCE_LENGTH = 15

class PoohDataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length, device):
        txt = open('data/pooh_train.txt').read()
        
        self.words = txt.lower().split() # The text is already tokenized
        
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]
        self.sequence_length = sequence_length
        self.device = device


    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length], device=self.device),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1], device=self.device)
        )

In [4]:
class LSTMModel(nn.Module):
    def __init__(self, dataset, device):
        super(LSTMModel, self).__init__()
        self.lstm_size = 512
        self.embedding_dim = 100
        self.num_layers = 2
        self.device = device

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device))

In [5]:
pooh_dataset = PoohDataset(SEQUENCE_LENGTH, device)
model = LSTMModel(pooh_dataset, device) 
model.to(device)

LSTMModel(
  (embedding): Embedding(2548, 100)
  (lstm): LSTM(100, 512, num_layers=2, dropout=0.2)
  (fc): Linear(in_features=512, out_features=2548, bias=True)
)

In [6]:
def train(dataset, model):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(SEQUENCE_LENGTH)
        
        for batch, (x, y) in enumerate(dataloader):
            
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()            

            loss.backward()
            optimizer.step()

        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [7]:
batch_size = 512
max_epochs = 30

In [8]:
filename = "pooh_2x512_30ep.model"
try:
    model.load_state_dict(torch.load(filename))
except FileNotFoundError:
    train(pooh_dataset, model)
    torch.save(model.state_dict(), filename)

In [9]:
def predict(dataset, model, text, next_words=15):
    model.eval()

    words = text.split()
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        x = x.to(device)
        
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return ' '.join(words)

In [10]:
speakers = ['pooh', 'piglet', 'christopher robin', 'rabbit', 'owl', 'tigger', 'eeyore']
for s in speakers:
    prompt = 'in the morning ' + s 
    for i in range(1):
        print(predict(pooh_dataset, model, prompt, 50))

in the morning pooh and planted to piglet again . and then he had a friendly picked two and helpful . it silence in which kanga and roo are about .... small just here it is n't . it is n't sense , '' he added , `` because where you know what what
in the morning piglet generally a present , and there was those , `` here there is too late much better where it . '' `` but let are me , '' said piglet . there was a long silence . `` now then , '' said piglet again . `` it 's a
in the morning christopher robin i could think of something . '' `` why ? '' said rabbit . `` now , '' said piglet . `` what sort of a lesson , piglet ? '' `` well , '' said pooh , `` because i expect if i say 'now ! ' pooh can
in the morning rabbit ca n't seem to this there ? '' `` just for a moment . '' `` perhaps owl , '' said piglet comfortingly . `` hallo and rabbit and rabbit ! '' `` so do i ? '' `` yes . '' everybody gave ; and two days later rabbit
in the morning owl own jumped if the forest begin really can r

In [11]:
vowels = set("aoiuye'")
def devowelize(s):
    rv = ''.join(a for a in s if a not in vowels)
    if rv:
        return rv
    return '_' # Symbol for words without consonants   

In [12]:
pooh_words = set(open('data/pooh_words.txt').read().split())
representation = dd(set)

for w in pooh_words:
    r = devowelize(w)
    representation[r].add(w)
    
hard_words = set()
for r, ws in representation.items():
    if len(ws) > 1:
        hard_words.update(ws)
        
print(len(hard_words))

863


In [13]:
def _reconstruct(words: List[str], model, start: str, T: float):
    model.eval()
    corrected = []
    state_h, state_c = model.init_state(1)
    x = torch.tensor([[pooh_dataset.word_to_index[start]]], device=device)
    plog = 0
    for word in words:
        possible_idxs = torch.tensor([pooh_dataset.word_to_index[k] for k in representation[word] if k in pooh_dataset.word_to_index], device=device)
        y_pred, (new_state_h, new_state_c) = model(x, (state_h, state_c))
        y_pred = y_pred[:, -1:].contiguous()
        state_h, state_c = new_state_h[:, -1:, ...].contiguous(), new_state_c[:, -1:, ...].contiguous()
        if possible_idxs.numel():
            preds = F.softmax(y_pred.flatten()[possible_idxs] / T, -1)
        else:
            preds = F.softmax(y_pred.flatten() / T, -1)

        selected = torch.multinomial(preds, 1)

        if possible_idxs.numel():
            x = possible_idxs[selected].reshape(1, -1)
        else:
            x = selected.reshape(1, -1)

        corrected.append(pooh_dataset.index_to_word[x.item()])
        plog += torch.log(preds[selected]).item()
    return " ".join([start] + corrected), plog

def reconstruct(text: List[str], model, T: float, n_iter: int = 10):
    if not text:
        return []
    model.eval()
    max_plog = float("-inf")
    for start in representation[text[0]]:
        for _ in range(n_iter):
            correction, plog = _reconstruct(text[1:], model, start, T)
            if plog > max_plog:
                max_plog = plog
                best_correction = correction
    return best_correction

In [14]:
with open("data/pooh_test.txt", "rt") as f:
    test_tokens = f.read().split()

In [15]:
test_input = list(map(devowelize, test_tokens))

In [16]:
reconstructed = reconstruct(test_input, model, 1)

You can assume that only words from pooh_words.txt can occur in the reconstructed text. For decoding you have two options (choose one, or implement both ang get **+1** bonus point)

1. Sample reconstructed text several times (with quite a low temperature), choose the most likely result.
2. Perform beam search.

Of course in the sampling procedure you should consider only words matching the given consonants.

Report accuracy of your methods (for both language models). The accuracy should be computed by the following function, it should be *greater than 0.25*.

In [17]:
def accuracy(original_sequence, reconstructed_sequence):
    sa = original_sequence
    sb = reconstructed_sequence
    score = len([1 for (a,b) in zip(sa, sb) if a == b])
    return score / len(original_sequence)

In [18]:
accuracy(test_tokens, reconstructed.split())

0.7992090827911723

In [25]:
class MyLSTMModel(nn.Module):
    def __init__(self, dataset, device):
        super().__init__()
        self.lstm_size = 512
        self.embedding_dim = 100
        self.num_layers = 3
        self.device = device

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.3,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device))

In [26]:
my_model = MyLSTMModel(pooh_dataset, device)
my_model.to(device)

my_filename = "pooh_my_3x512_30ep.model"
try:
    my_model.load_state_dict(torch.load(my_filename))
except FileNotFoundError:
    train(pooh_dataset, my_model)
    torch.save(my_model.state_dict(), my_filename)

{'epoch': 0, 'batch': 113, 'loss': 5.569244861602783}
{'epoch': 1, 'batch': 113, 'loss': 5.493239402770996}
{'epoch': 2, 'batch': 113, 'loss': 5.4845290184021}
{'epoch': 3, 'batch': 113, 'loss': 5.467758655548096}
{'epoch': 4, 'batch': 113, 'loss': 5.461180210113525}
{'epoch': 5, 'batch': 113, 'loss': 5.4551825523376465}
{'epoch': 6, 'batch': 113, 'loss': 5.451696872711182}
{'epoch': 7, 'batch': 113, 'loss': 5.446776866912842}
{'epoch': 8, 'batch': 113, 'loss': 5.443249225616455}
{'epoch': 9, 'batch': 113, 'loss': 5.441751480102539}
{'epoch': 10, 'batch': 113, 'loss': 5.439263343811035}
{'epoch': 11, 'batch': 113, 'loss': 5.437617778778076}
{'epoch': 12, 'batch': 113, 'loss': 5.436402320861816}
{'epoch': 13, 'batch': 113, 'loss': 5.433779239654541}
{'epoch': 14, 'batch': 113, 'loss': 5.432648181915283}
{'epoch': 15, 'batch': 113, 'loss': 5.439981460571289}
{'epoch': 16, 'batch': 113, 'loss': 5.032291412353516}
{'epoch': 17, 'batch': 113, 'loss': 4.681049823760986}
{'epoch': 18, 'batch'

In [27]:
my_reconstructed = reconstruct(test_input, my_model, 1)
accuracy(test_tokens, my_reconstructed.split())

0.8192371475953566