In [10]:
import torch as t
from torch import nn
from torch.utils.data import Dataset, DataLoader
import plotly.express as px
from IPython.display import display
import pandas as pd
import numpy as np
import copy
from fancy_einsum import einsum
from dataclasses import dataclass
from tqdm.notebook import tqdm_notebook

from einops import rearrange, reduce, repeat

import sys 
sys.path.append('../common')



In [2]:
import general_modules as cm
import transformer_modules as tm
from transformer_modules import TransformerConfig


In [11]:
class WordsDataset(Dataset):
    def __init__(self, seq_len, filename, tokenizer, truncate=None):
        self.seq_len = seq_len
        self.filename = filename
        
        with open(filename, 'r') as textfile:
            text = textfile.read()
        
        tokenizer.build_dict(text, save=True)
        self.tokens = tokenizer.encode(text)
        
        word_count = len(self.tokens)

        if truncate:
            word_count = int(word_count * truncate)

        self.x_seqs, self.y_seqs = [], []
        
        for pos in range(0, word_count - seq_len - 1):
            self.x_seqs.append(t.tensor(self.tokens[pos:pos+self.seq_len]))
            self.y_seqs.append(t.tensor(self.tokens[pos+1:pos+self.seq_len+1]))
        self.x_seqs = t.stack(self.x_seqs)
        self.y_seqs = t.stack(self.y_seqs)

    def __len__(self):
        return len(self.x_seqs)

    def __getitem__(self, idx):
        return self.x_seqs[idx], self.y_seqs[idx]

In [12]:
from typing import Optional, Union
import re
import pickle

class WordsTokenizer():
    model_max_length: int

    def __init__(self, model_max_length):
        self.word_id_map = dict()
        self.id_word_map = dict()
        self.model_max_length = model_max_length

    def build_dict(self, initial_text, save=False):

        split_text = re.split(r"\b", initial_text)

        # create token id mapping
        unique_tokens = set(split_text)
        self.word_id_map = {word:id for id, word in enumerate(unique_tokens)}
        self.id_word_map = {id:word for word, id in self.word_id_map.items()}

        if save:
            file = open("word_id_map.pkl", "wb")
            pickle.dump(self.word_id_map, file)
            file.close()

            file = open("id_word_map.pkl", "wb")
            pickle.dump(self.id_word_map, file)
            file.close()

    def load_saved(self):
        file = open("word_id_map.pkl", "rb")
        self.word_id_map = pickle.load(file)

        file = open("id_word_map.pkl", "rb")
        self.id_word_map = pickle.load(file)

    def encode(self, text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
        '''
        Tokenizes initial_text, then returns the token ids.

        Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
        '''
        split_text = re.split(r"\b", text)
        split_text = list(filter(None, split_text))
        
        encoded = [self.word_id_map[word] for word in split_text]

        if return_tensors == "pt":
            encoded = t.tensor(encoded)
        elif return_tensors == "np":
            encoded = np.array(encoded)
        
        return encoded 

    def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
        '''
        Converts ids to a list of tokens, then joins them into a single string.
        '''
        words = [self.id_word_map[id] for id in list_of_ids]
        return "".join(words)

    def __call__(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
        '''
        Returns results of self.encode.
        '''
        return self.encode(initial_text, return_tensors)

In [13]:
tokenizer = WordsTokenizer(16)
words_ds = WordsDataset(seq_len=16, filename='100-0.txt', tokenizer=tokenizer, truncate=0.01)

In [14]:
config = TransformerConfig(
    num_layers=12, 
    num_heads=8, 
    vocab_size=34543, 
    hidden_size=256,
    max_seq_len=128,
    dropout=0.1)

In [15]:
from typing import Callable


epochs = 1
loss_fn = nn.CrossEntropyLoss()
batch_size = 16

MODEL_FILENAME = "./w1d3_transformer_shakespeare.pt"
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

trainloader = DataLoader(words_ds, batch_size=batch_size, shuffle=True)

def train_transformer(trainloader: DataLoader, epochs: int, loss_fn: Callable) -> list:
    '''
    Defines a Transformer from our custom modules, and trains it on the reversed digit dataset.
    '''
    
    model = tm.DecoderOnlyTransformer(config).to(device).train()
    optimizer = t.optim.Adam(model.parameters())
    loss_list = []
    accuracy_list = []

    for epoch in range(epochs):

        progress_bar = tqdm_notebook(trainloader)
        for (x, y) in progress_bar:

            x = x.to(t.float32)
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            #print(logits.shape)
            logits = rearrange(logits, 'B S V -> (B S) V')
            #print(logits)
            #print(y.shape)
            y = rearrange(y, 'B S -> (B S)')


            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_list.append(loss.item())

            #with t.inference_mode():
                #model.eval()
                #preds = model(t.tensor([[1,2,3,4,5,6]]).to(device))
                #preds = preds.argmax(dim=-1)

                #random_case = t.randint(0, 9, (6, )).unsqueeze(dim=0)
                #random_preds = model(random_case.to(device)).argmax(dim=-1)
                #random_corrects = random_case.flip(dims=[0])
                #accuracy = (random_preds == random_corrects).sum() / len(random_preds)
                #model.train()

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

    print(f"Saving model to: {MODEL_FILENAME}")
    t.save(model, MODEL_FILENAME)
    return loss_list, accuracy_list



In [16]:
loss_list, accuracy_list = train_transformer(trainloader, epochs, loss_fn)

fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Shakespeare", yaxis_range=[0, max(loss_list)])
fig.show()

  0%|          | 0/1242 [00:00<?, ?it/s]

Saving model to: ./w1d3_transformer_shakespeare.pt


In [17]:
import sample_methods as s

model = t.load(MODEL_FILENAME, map_location=t.device('cpu'))
model.eval()

initial_text = "turn down for what"

text_output = s.sample_tokens(model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)

print(text_output)

# turn down for what you do you think,
# That take the last, of many, which is so much I
# As this blows along than my life thou say’st, which makes thy hand,
# Thou wilt be given, or more
# Entitled in thy great world’s fresh blood will,
# To answer th’ alluring countenance, beauty

turn down for what the world that should not that which that it that which it the he day,
  In that which it, nor that it love that your sweet love?
Why with more love for me, let thy love’s love shall be,
  To thee in my love that me thou art.
When
