# Assigment 5

**Submission deadlines**:

* last lab before 27.06.2022 

**Points:** Aim to get 12 out of 15+ possible points

All needed data files are on Drive: <https://drive.google.com/drive/folders/1uufpGn46Mwv4oBwajIeOj4rvAK96iaS-?usp=sharing> (or will be soon :) )

## Task 2 (6 points)


This task is about text generation. You have to:


**C**. write text generation procedure. The procedure should fulfill the following requirements:

1. it should use the RNN language model (trained on sub-word tokens)
2. generated tokens should be presented as a text containing words (without extra spaces, or other extra characters, as begin-of-word introduced during tokenization)
3. all words in a generated text should belond to the corpora (note that this is not guaranteed by LSTM)
4. in generation Top-P sampling should be used (see NN-NLP.6, slide X) 
5. in generated texts every token 3-gram should be uniq
6. *(optionally, +1 point)* all token bigrams in generated texts occur in the corpora

In [1]:
import pickle
from collections import defaultdict
from pathlib import Path
import torch.nn.functional as F
import torch
from nltk.tokenize import word_tokenize

from utils import PrusDataset, PrusModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
FINAL_TOKENS_FILEPATH = Path("data/tokens_final.pickle")
POSSIBLE_COMBINATIONS_FILEPATH = Path("data/possible_combinations.pickle")
MODEL_FILEPATH = Path("model_checkpoints/model-epoch=10-train_loss=6.39.ckpt")

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
with FINAL_TOKENS_FILEPATH.open("rb") as f:
    final_tokens = pickle.load(f)
v = torch.load("vocab.pth")

In [5]:
dataset = PrusDataset(final_tokens, v)
model = PrusModule.load_from_checkpoint(MODEL_FILEPATH)
model.to(device)

PrusModule(
  (vocab): Vocab()
  (embedding): Embedding(29186, 100)
  (lstm): LSTM(100, 256, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=256, out_features=29186, bias=True)
)

In [6]:
anywhere_idxs = set()
for token_idx in range(len(v)):
    if not v.lookup_token(token_idx).startswith("$"):
        anywhere_idxs.add(token_idx)

In [7]:
if POSSIBLE_COMBINATIONS_FILEPATH.exists():
    with POSSIBLE_COMBINATIONS_FILEPATH.open("rb") as f:
        possible_idxs_per_idx = pickle.load(f)
else:
    possible_idxs_per_idx = defaultdict(set)
    for idx, token_idx in enumerate(dataset.data):
        token = v.lookup_token(token_idx)
        if token.endswith("$"):
            possible_idxs_per_idx[token_idx.item()].add(dataset.data[idx + 1].item())
            if token.startswith("$"):
                possible_idxs_per_idx[dataset.data[idx - 1].item(), token_idx.item()].add(dataset.data[idx + 1].item())
    with POSSIBLE_COMBINATIONS_FILEPATH.open("wb") as f:
        pickle.dump(possible_idxs_per_idx, f)

In [8]:
def _generate(seed_tokens, next_words=1000, top_p=10):
    words = v.lookup_tokens(seed_tokens)
    for seed_token in seed_tokens:
        y_pred, (state_h, state_c) = model(torch.tensor(seed_token, device=device).reshape(1, 1))

    next_token = F.softmax(y_pred.flatten(), -1).argmax()
    words.append(v.lookup_token(next_token.item()))
    tokens = seed_tokens

    unique_trigrams = defaultdict(set)
    for t1, t2, t3 in zip(tokens, tokens[1:], tokens[2:]):
        unique_trigrams[(t1, t2)].add(t3)

    forbidden_tokens = unique_trigrams[tuple(tokens[-2:])]
    possible_idxs = torch.tensor(sorted(anywhere_idxs - forbidden_tokens), device=device)

    for _ in range(next_words - 1):
        y_pred, (state_h, state_c) = model(next_token.reshape(1, 1), (state_h, state_c))

        probs = F.softmax(y_pred.flatten()[possible_idxs], -1)
        if len(probs) > top_p:
            top_p_probs, top_p_idxs = probs.topk(top_p)
            selected = torch.multinomial(top_p_probs, 1)
            next_token = possible_idxs[top_p_idxs[selected]]
        else:
            selected = torch.multinomial(probs, 1)
            next_token = possible_idxs[selected]

        next_word = v.lookup_token(next_token)
        words.append(next_word)
        unique_trigrams[tuple(tokens[-2:])].add(next_token.item())
        tokens.append(next_token.item())

        forbidden_tokens = unique_trigrams[tuple(tokens[-2:])]
        if next_word.endswith("$"):
            if words[-2].endswith("$"):
                possible_idxs = torch.tensor(
                    sorted(possible_idxs_per_idx[tokens[-2], next_token.item()] - forbidden_tokens),
                    device=device,
                )
            else:
                possible_idxs = torch.tensor(
                    sorted(possible_idxs_per_idx[next_token.item()] - forbidden_tokens),
                    device=device,
                )
        else:
            possible_idxs = torch.tensor(sorted(anywhere_idxs - forbidden_tokens), device=device)
    return tokens

In [9]:
def tokens_to_text(tokens):
    text = ""
    for word in v.lookup_tokens(tokens):
        if not word.startswith("$") and not word in {".", ",", "?", "!"}:
            text += " "
        text += word.strip("$")
    return text

In [10]:
def get_subtokens(token):
    rest = token
    subtokens = []
    for idx in range(len(rest), 0, -1):
        possible_token = f"{rest[:idx]}$"
        if possible_token in v:
            subtokens.append(possible_token)
            # print(possible_token)
            rest = rest[idx:]
            break
    else:
        raise ValueError(f"Unable to tokenize '{token}'")

    for idx in range(len(rest), 0, -1):
        possible_token = f"${token[-idx:]}"
        if possible_token in v:
            subtokens.append(possible_token)
            # print(possible_token)
            rest = rest[:-idx]
            break
    else:
        raise ValueError(f"Unable to tokenize '{token}'")
    if rest:
        possible_token = f"${rest}$"
        if possible_token in v:
            subtokens.append(possible_token)
        else:
            raise ValueError(f"Unable to tokenize '{token}'")
    return subtokens

In [11]:
def tokenize(seed_text):
    tokens = []
    for base_token in word_tokenize(seed_text.lower()):
        for preprocessed_token in filter(lambda s: s, base_token.replace("…", "$$…$$").split("$$")):
            if preprocessed_token in v:
                tokens.append(preprocessed_token)
            else:
                subtokens = get_subtokens(preprocessed_token)
                tokens.extend(subtokens)
    return tokens

In [12]:
def generate(seed_text, next_words=1000, top_p=10):
    tokens = tokenize(seed_text)
    token_idxs = v.lookup_indices(tokens)
    all_tokens = _generate(token_idxs, next_words, top_p)
    return tokens_to_text(all_tokens)

In [13]:
generate("Wokulski powiedział do Łęckiej", top_p=5)

' wokulski powiedział do łęckiej że, nie wiem, że nie nie, a to, nie, co to, że to pan nie wiem. — ja to nie to, co, a co to ja pan, że już ja, to to nie wiem — rzekł, że w tej chwili, że, że ja ja nie wiem!  …  — co nie ma, a ja to pan!  — a to ja, co nie to jest nie ma i nie było nie nie ma!   — spytał, to ja to jest, że na co nie nie nie było się do mnie, co ja, a już ja ja jest, to mi, a nie to pan, nie nie wiem ... — nie ma nie wiem? ... …  nie to — a nie ma mi, że z tej mnie nie wiem  — odparł! i to nie było to, jak to pan jest nie nie mnie na niego, a on nie ma. — to to to, a pan, to pan? ... — odparł … — rzekł — odparł?  … ... — to nie jest, co się nie wiem do niej. …  i już nie, i ja nie nie pani — mówił …  [ przypis redakcyjny. — co to mi!  [ sobie i w chwili, co już nie ma : — to ja jest to, to jest na tej chwili? ... ”! ... ”, nie jest na niego i i na mnie nie nie to nie ma w tym, nie wiadomo i nie jest z mnie. [ przypis ] ] ], który w tej nim na mnie w tej razie jest w tej