In [2]:
import torch
from torch.utils.data import Dataset
import pickle
import string

In [3]:
from mininlp.data import Tokenizer

tokenizer = Tokenizer()

In [4]:
"""Token are every ascii character and special tokens for start of sentence, 
end of sentence, padding, unknown and mask."""
from mininlp.data import assci_tokens
print(assci_tokens())

{'z', 'X', '?', 'm', '<', '2', '5', 'T', "'", '*', 'n', 'g', '>', '8', 'k', '9', 'Q', ')', '^', '<sos>', 'a', 'Z', 'J', '$', 'f', '\x0c', 'v', 'Y', '3', 'r', '&', '0', '7', 'b', ';', '4', '<eos>', 'A', '6', 'N', '<unk>', 'F', '\r', 'u', '"', 'p', 'j', '/', 'V', ':', '|', '\t', 'e', 'M', '`', 'R', '-', ',', 'I', 'O', 'l', 'L', '=', 'c', 'K', '+', 'w', '{', '.', 's', '@', 'H', ' ', 'B', '#', 'i', 'h', 'D', '\\', 'o', '}', '_', '~', '%', ']', 'E', '<pad>', 'S', '<mask>', 'y', 't', '(', '\x0b', '!', 'P', 'U', '[', 'x', 'C', '\n', 'd', 'G', 'W', 'q', '1'}


In [5]:
tokens = assci_tokens()
tokenizer = Tokenizer(tokens)

In [6]:
"""Test if the tokenizer is able to convert tokens to ids and vice versa."""

for id in tokenizer._tokens:
    assert tokenizer._token_ids[tokenizer._tokens[id]] == id

In [7]:
"""Test if the tokenizer is able to encode and decode a string."""

test_string = "Hello, World! \nLovely day, isn't it?"

test_encoded = tokenizer.encode(test_string)
test_decoded = tokenizer.decode(test_encoded)
assert test_string == "".join(test_decoded)

print(test_string)
print(test_encoded)
print(test_decoded)

Hello, World! 
Lovely day, isn't it?
tensor([ 71,  52,  60,  60,  79,  57,  72, 102,  79,  29,  60, 100,  93,  72,
         99,  61,  79,  26,  52,  60,  89,  72, 100,  20,  89,  57,  72,  75,
         69,  10,   8,  90,  72,  75,  90,   2], dtype=torch.int32)
['H', 'e', 'l', 'l', 'o', ',', ' ', 'W', 'o', 'r', 'l', 'd', '!', ' ', '\n', 'L', 'o', 'v', 'e', 'l', 'y', ' ', 'd', 'a', 'y', ',', ' ', 'i', 's', 'n', "'", 't', ' ', 'i', 't', '?']


In [8]:
"""Test if the tokenizer is able to save and load itself."""

tokenizer.save("tokenizer")

tokenizer2 = Tokenizer()
tokenizer2.load("tokenizer.pkl")

In [9]:
"""Test if the loaded tokenizer is the same as the original tokenizer."""

assert tokenizer._tokens == tokenizer2._tokens
assert tokenizer._token_ids == tokenizer2._token_ids

In [10]:
"""Test if the orginal tokenizer and the loaded one encodes and decodes a string to the same tokens ids."""

test_encoded = tokenizer2.encode(test_string)
test_decoded = tokenizer.decode(test_encoded)
assert test_string == "".join(test_decoded)

test_encoded = tokenizer.encode(test_string)
test_decoded = tokenizer2.decode(test_encoded)
assert test_string == "".join(test_decoded)

print(test_decoded)
print(test_encoded)

['H', 'e', 'l', 'l', 'o', ',', ' ', 'W', 'o', 'r', 'l', 'd', '!', ' ', '\n', 'L', 'o', 'v', 'e', 'l', 'y', ' ', 'd', 'a', 'y', ',', ' ', 'i', 's', 'n', "'", 't', ' ', 'i', 't', '?']
tensor([ 71,  52,  60,  60,  79,  57,  72, 102,  79,  29,  60, 100,  93,  72,
         99,  61,  79,  26,  52,  60,  89,  72, 100,  20,  89,  57,  72,  75,
         69,  10,   8,  90,  72,  75,  90,   2], dtype=torch.int32)


In [11]:
from mininlp.data import SequenceDataset

encoded_document = tokenizer.tokenize_document("../data/anna.txt")
dataset = SequenceDataset('../data/anna.txt', tokenizer, 32, 1000)

In [12]:
dataset[0]

(tensor([86, 86, 86, 86, 19, 52, 29, 72, 69, 66, 52, 52, 45, 72, 79, 24, 72, 76,
         75, 69, 72, 69, 63, 89, 90, 76, 52, 57, 72, 10, 79, 90],
        dtype=torch.int32),
 tensor([86, 86, 86, 86, 52, 29, 72, 69, 66, 52, 52, 45, 72, 79, 24, 72, 76, 75,
         69, 72, 69, 63, 89, 90, 76, 52, 57, 72, 10, 79, 90, 72],
        dtype=torch.int32))

In [13]:
print(tokenizer.decode(dataset[50][0]), tokenizer.decode(dataset[50][1]))
print(tokenizer.decode(dataset[46][0]), tokenizer.decode(dataset[46][1]))
print(tokenizer.decode(dataset[5][0]), tokenizer.decode(dataset[5][1]))

['<pad>', '<pad>', '<pad>', '<pad>', '<sos>', 'o', 'w', ' ', 'h', 'e', 'l', 'p', ' ', 'h', 'i', 'm', '?', ' ', 'W', 'h', 'a', 't', '\n', 'c', 'a', 'n', ' ', 'I', ' ', 's', 'a', 'y'] ['<pad>', '<pad>', '<pad>', '<pad>', 'o', 'w', ' ', 'h', 'e', 'l', 'p', ' ', 'h', 'i', 'm', '?', ' ', 'W', 'h', 'a', 't', '\n', 'c', 'a', 'n', ' ', 'I', ' ', 's', 'a', 'y', ' ']
['<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<sos>', 'd', ' ', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', ',', ' '] ['<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', 'd', ' ', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', ',', ' ', 'w']
['<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',

# Model

In [23]:
from mininlp.transformer import DTransformer
import json

VERSION = 0.1
MODEL_NAME = f'decoder_transformer_v{VERSION}'
config = json.load(open(f"../models/{MODEL_NAME}.json"))

model = DTransformer(
    config['layers'], 
    config['embedding_dim'], 
    128, 
    config['seq_len'], 
    config['heads'], 
    config['factor'],
    True)
state_dict = torch.load(f"../models/{MODEL_NAME}.pt")
model.load_state_dict(state_dict)
model.to('cuda')

DTransformer(
  (_embedding): Embedding(
    (_token_embedding): Embedding(128, 512)
  )
  (_decoders): ModuleList(
    (0-5): 6 x Decoder(
      (_laynorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (_laynorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (_mmha): MultiHeadAttention(
        (_projection): ModuleList(
          (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
        )
        (_reprojection): Linear(in_features=512, out_features=512, bias=True)
      )
      (_mha): MultiHeadAttention(
        (_projection): ModuleList(
          (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
        )
        (_reprojection): Linear(in_features=512, out_features=512, bias=True)
      )
      (_ff): FeedForward(
        (_laynorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (_dropout): Dropout(p=0.2, inplace=False)
        (_ff): Sequential(
          (0): Linear(in_features=512, out_features=2048, 

In [24]:
tokenizer = Tokenizer()
tokenizer.load("../models/tokenizer.pkl")
dataset = SequenceDataset('../data/anna.txt', tokenizer, config['seq_len'], 1)

import torch.nn.functional as F
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    input = dataset[0][0].unsqueeze(0)
    output = model(input.to('cuda'))
    probs = F.softmax(output[0, -1, :], dim=0)
    probs = probs.detach().cpu()
    
plt.figure(figsize=(20, 5))
plt.bar(tokenizer.decode(torch.tensor(range(len(probs)))), probs)
plt.xticks(rotation=90)
plt.show()

In [25]:
text = tokenizer.decode(dataset[0][0])
text += ["<msk>"]
model.eval()
with torch.no_grad():
    prompt = dataset[0][0].unsqueeze(0).to('cuda')
    text += tokenizer.decode(model.generate(prompt, 1000))

In [26]:
text = [t for t in text if t != "<pad>"]
print("".join(text))

<sos>to the usual report, he had the most innocent and inoffensive
air. No one, looking at his white hands, with their swollen veins and
long fingers, so softly stroking the edges of the white paper that lay
before him, and at the air of weariness with which his head drooped on
one side, would have suspected that in a few minutes a torrent of words
would flow from his lips that would arouse a fearful storm, set the
members shouting and attacking one another, and force the president to
call for order. When the report was <msk>a slal the room, who had palent he neccept mishappy. You
comment quieter, Sergey Ivanovitch's earful to see for the interest of a heart,
Levin had completely his great hand to him. Mihainin? Scrumpost cond of the Tvildent
would not hungerstand, it when he had no gone the horses words himself: "Well, not
expectively. My thinking! why I have his last about founded?"

"Why go?" he said, simply glooking to be. "Shere, you've." shall grunning said. When
he he could symi