In [None]:
import torch
from torch.utils.data import Dataset
import pickle
import string

In [None]:
from mininlp.data import Tokenizer

tokenizer = Tokenizer()

In [None]:
"""Token are every ascii character and special tokens for start of sentence, 
end of sentence, padding, unknown and mask."""
from mininlp.data import assci_tokens
print(assci_tokens())

In [None]:
tokens = assci_tokens()
tokenizer = Tokenizer(tokens)

In [None]:
"""Test if the tokenizer is able to convert tokens to ids and vice versa."""

for id in tokenizer._tokens:
    assert tokenizer._token_ids[tokenizer._tokens[id]] == id

In [None]:
"""Test if the tokenizer is able to encode and decode a string."""

test_string = "Hello, World! \nLovely day, isn't it?"

test_encoded = tokenizer.encode(test_string)
test_decoded = tokenizer.decode(test_encoded)
assert test_string == "".join(test_decoded)

print(test_string)
print(test_encoded)
print(test_decoded)

In [None]:
"""Test if the tokenizer is able to save and load itself."""

tokenizer.save("tokenizer")

tokenizer2 = Tokenizer()
tokenizer2.load("tokenizer.pkl")

In [None]:
"""Test if the loaded tokenizer is the same as the original tokenizer."""

assert tokenizer._tokens == tokenizer2._tokens
assert tokenizer._token_ids == tokenizer2._token_ids

In [None]:
"""Test if the orginal tokenizer and the loaded one encodes and decodes a string to the same tokens ids."""

test_encoded = tokenizer2.encode(test_string)
test_decoded = tokenizer.decode(test_encoded)
assert test_string == "".join(test_decoded)

test_encoded = tokenizer.encode(test_string)
test_decoded = tokenizer2.decode(test_encoded)
assert test_string == "".join(test_decoded)

print(test_decoded)
print(test_encoded)

In [None]:
from mininlp.data import SequenceDataset

encoded_document = tokenizer.tokenize_document("../data/anna.txt")
dataset = SequenceDataset('../data/anna.txt', tokenizer, 32, 1000)

In [None]:
dataset[0]

In [None]:
print(tokenizer.decode(dataset[50][0]), tokenizer.decode(dataset[50][1]))
print(tokenizer.decode(dataset[46][0]), tokenizer.decode(dataset[46][1]))
print(tokenizer.decode(dataset[5][0]), tokenizer.decode(dataset[5][1]))

In [None]:
from mininlp.transformer import DTransformer

MODEL_NAME = 'decoder_transformer_v1'
SEQ_LEN = 128
EMBEDDING_DIM = 512
HEADS = 8
LAYERS = 4
FACTOR = 4
BATCH_SIZE = 256

MODEL_PATH = "..\models\decoder_transformer_v1.pt"

model = DTransformer(LAYERS, EMBEDDING_DIM, len(tokenizer), SEQ_LEN, HEADS, FACTOR)
state_dict = torch.load(MODEL_PATH)
model.load_state_dict(state_dict)

In [None]:
TOKEN_PATH = "..\models\\tokenizer.pkl"

tokenizer = Tokenizer()
tokenizer.load(TOKEN_PATH)

dataset = SequenceDataset('../data/anna.txt', tokenizer, SEQ_LEN, 1)

model.eval()
with torch.no_grad():
    input = dataset[0][0].unsqueeze(0)
    print(tokenizer.decode(input[0]), tokenizer.decode(dataset[0][1]))
    output = model(input)


In [None]:
import torch.nn.functional as F

probs = F.softmax(output[0, -1, :], dim=0)

import matplotlib.pyplot as plt
plt.figure(figsize=(20, 10))
plt.bar(tokenizer.decode(torch.tensor(range(len(probs)))), probs)
plt.xticks(rotation=90)

print(probs[dataset[0][1]])

In [None]:
model.eval()
text = tokenizer.decode(dataset[0][0])
text += ["<msk>"]
with torch.no_grad():
    input = dataset[0][0].unsqueeze(0)
    for i in range(50):
        output = model(input)
        probs = F.softmax(output[0, -1, :], dim=0)
        new_token = torch.argmax(probs).unsqueeze(0)
        input = torch.cat((input, new_token[:,None]), dim=1)[:, 1:]
        text += tokenizer.decode(new_token)

In [None]:
text = [t for t in text if t != "<pad>"]
print("".join(text))