In [8]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [9]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [10]:
import urllib
import re
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from collections import OrderedDict, Counter

In [11]:
# Read Shakespeare texts from URL. Texts are translated to modern English
text = urllib.request.urlopen('https://raw.githubusercontent.com/emukans/shakespeare-texts/master/all_modern.txt').read().decode("utf-8", "ignore")

In [12]:
class Tokenizer:
    def __init__(self, data, vocab_size):
        self.vocab_size = vocab_size
        self.vocab = self.build_vocab(data)
        
        self.stoi = { ch:i for i,ch in enumerate(self.vocab) }
        self.itos = { i:ch for i,ch in enumerate(self.vocab) }
    
    def sort_vocab(self, vocab):
        """
        Vocab should have the followind order: hashtag, numbers, characters sorted by length.
        Hashtags should go first, because they will be used as dividers on tokenization step.
        Numbers should go before characters, because token ids are numbers. Otherwise token ids will be considered as usual numbers and replaced twice.
        """
        sorted_vocab = sorted(vocab, key=lambda x: len(x), reverse=True)
        tag = [int(s) for s in sorted_vocab if s == '#']
        
        numeric = [int(s) for s in sorted_vocab if s.isnumeric()]
        numeric = [str(s) for s in sorted(numeric, reverse=True)]
        rest = [s for s in sorted_vocab if not s.isnumeric()]
        
        sorted_vocab = tag + numeric + rest
        
        return sorted_vocab
    
    def build_vocab(self, data):
        """
        Build vocabluary using BPE alghorithm.
        """
        vocab = set(data)
        if len(vocab) > self.vocab_size:
            raise ValueError('Vocab size should be greater than unique char count')

        # check all available characters
        char_set = {c for c in vocab if c.isalpha()}
        
        # candidates dictionary will contain a set of all available tokens to search
        candidate_dict = dict().fromkeys(char_set, 0)
        
        # occurrences will contain all matched tokens and the count, how many times the token has been found.
        token_occurrences = OrderedDict()
        while len(vocab) < self.vocab_size:
            for candidate in candidate_dict.keys():
                occurrences = data.count(candidate)
                candidate_dict[candidate] = occurrences

            candidate_dict = {candidate: count for candidate, count in candidate_dict.items() if count}
            vocab.update(set(candidate_dict.keys()))
            token_occurrences.update(candidate_dict)

            # build new candidates
            temp_candidate_set = set()
            for char in char_set:
                # don't test candidates with occurency <= 2. New candidates won't have occurency higher than 2
                temp_candidate_set.update({candidate + char for candidate in candidate_dict.keys() if token_occurrences[candidate] > 2})

            candidate_dict = dict().fromkeys(temp_candidate_set, 0)

        tokens_to_remove = len(vocab) - self.vocab_size
        token_occurrences = OrderedDict(sorted(token_occurrences.items(), key=lambda x: x[1], reverse=True))
        for _ in range(tokens_to_remove):
            token, _ = token_occurrences.popitem()
            vocab.remove(token)

        sorted_vocab = self.sort_vocab(vocab)
        
        # add a special token for unknown tokens
        sorted_vocab.append('<unk>')
        self.vocab_size += 1 # plus <unk> special token
        
        return sorted_vocab
    
    def tokenize(self, data):
        for token in self.vocab:
            data = data.replace(token, f'#{self.stoi[token]}#')

        # If everything went well, first and last characters won't have # pair. Need to trim them
        data = data[1:-1]
        # Split by ## pairs
        tokenized_text = data.split('##')
        # Filter empty strings
        tokenized_text = [x for x in tokenized_text if x]
        result = []
        for tokenized in tokenized_text:
            # In case other single # found, replace them with <unk> special token, marking the element as unknown
            if '#' in tokenized:
                for unknown_candidate in tokenized.split('#'):
                    if unknown_candidate.isnumeric():
                        result.append(self.itos[int(unknown_candidate)])
                    else:
                        result.append('<unk>')
            else:
                result.append(self.itos[int(tokenized)])

        return result
    
    def encode(self, data):
        return [self.stoi[s] for s in data]
    
    def decode(self, data):
        return ''.join([self.itos[int(i)] for i in data])

In [13]:
vocab_size = 10000

# building vocabluary can take some time. ~5 minutes for 10_000 tokens. 
tokenizer = Tokenizer(text, vocab_size)

In [14]:
import math
from torch.utils.data import Dataset

class WordDataset(Dataset):

    def __init__(self, data, block_size, tokenizer):
        self.tokenizer = tokenizer
        self.block_size = block_size
        #self.vocab_size = self.tokenizer.vocab_size
        self.data = self.tokenizer.tokenize(data)
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = self.tokenizer.encode(chunk)
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y


In [15]:
block_size = 50 # spatial extent of the model for its context

train_dataset = WordDataset(text, block_size, tokenizer) # one line of poem is roughly 50 characters

In [16]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.tokenizer.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

03/30/2021 16:01:04 - INFO - mingpt.model -   number of parameters: 3.548672e+07


In [17]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=2, batch_size=256, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=2)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

epoch 1 iter 3260: train loss 0.55002. lr 3.000578e-04: 100%|██████████| 3261/3261 [40:39<00:00,  1.34it/s]
epoch 2 iter 3260: train loss 0.33295. lr 6.000000e-05: 100%|██████████| 3261/3261 [40:49<00:00,  1.33it/s]


In [18]:
# alright, let's sample some word-level Shakespeare
from mingpt.utils import sample

context = "O God, O God!"
x = torch.tensor(tokenizer.encode(context), dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = tokenizer.decode(y)
print(completion)

O God, O God!
O church, O world Now lies here, and then there is the first time he plays his part.
Welcome.
Set the honorable old man down and let him eat.
I thank you very much on his behalf.
Welcome.
Eat.
I won’t trouble you yet with questions about your situation.—Some music, please, and, good friend, sing.
Time’s not the one in debt.
Your logic is so foolish.
Take it to your master and bring him home immediately.
Every person I meet greets me like an old friend, and every one of them knows my name.
Some of them give me money, some invite me places, some thank me for the kind things I’ve done for them, some try to sell me things.
Just now a tailor showed me fabrics he bought especially for me and then started to take my measurements.
These are tricks of the imagination, and this place is filled with magicians.
Here’s the money you wanted, master.
Who’s this Adam you speak of?
Not the Adam from the garden of Eden, but the Adam from the jailhouse.
I don’t know what you’re talking abou