### set seed 

In [1]:
import random
import numpy as np
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

#### BPE encoder and decoder

In [2]:
import os
import json
import regex as re
import requests

def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:] 
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    d = dict(zip(bs, cs))
    return d

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

def get_file(local_file, remote_file):
    if not os.path.isfile(local_file):
        print(f"downloading {remote_file} to {local_file}")
        response = requests.get(remote_file)
        open(local_file, "wb").write(response.content)


class Encoder:

    def __init__(self, encoder, bpe_merges):
        # byte encoder/decoder
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        
        # bpe token encoder/decoder
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))

        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
        self.cache = {}

    def bpe(self, token):
        # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'

        # memoization, for efficiency
        if token in self.cache:
            return self.cache[token]

        word = tuple(token) # individual characters that make up the token, in a tuple
        pairs = get_pairs(word) # get all bigrams

        if not pairs:
            return token

        while True:

            # find the next lowest rank bigram that can be merged
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break # no more bigrams are eligible to be merged
            first, second = bigram

            # we will now replace all occurences of (first, second) in the list of current
            # words into one merged token first_second, in the output list new_words
            new_word = []
            i = 0
            while i < len(word):

                # find the next occurence of first in the sequence of current words
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                # if this occurence is also followed by second, then merge them into one
                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1

            # all occurences of (first, second) have been merged to first_second
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)

        # concat all words into a string, and use ' ' as the separator. Note that
        # by now all characters have been byte encoded, guaranteeing that ' ' is
        # not used in the actual data and is a 'special' delimiter character
        word = ' '.join(word)

        # cache the result and return
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_idx = []
        # pre-tokenize the input text into string tokens (words, roughly speaking)
        tokens = re.findall(self.pat, text)
        # process each token into BPE integers
        for token in tokens:
            # encode the token as a bytes (b'') object
            token_bytes = token.encode('utf-8')
            # translate all bytes to their unicode string representation and flatten
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            # perform all the applicable bpe merges according to self.bpe_ranks
            token_merged = self.bpe(token_translated).split(' ')
            # translate all bpe tokens to integers
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            # extend our running list of all output integers
            bpe_idx.extend(token_ix)
        return bpe_idx

    def encode_and_show_work(self, text):
        bpe_idx = []
        parts = []
        tokens = re.findall(self.pat, text)
        for token in tokens:
            token_bytes = token.encode('utf-8')
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            token_merged = self.bpe(token_translated).split(' ')
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            bpe_idx.extend(token_ix)
            parts.append({
                'token': token,
                'token_bytes': token_bytes,
                'token_translated': token_translated,
                'token_merged': token_merged,
                'token_ix': token_ix,
            })
        out = {
            'bpe_idx': bpe_idx, # the actual output sequence
            'tokens': tokens, # result of pre-tokenization
            'parts': parts, # intermediates for each token part
        }
        return out

    def decode(self, bpe_idx):
        # inverse map the integers to get the tokens
        tokens_merged = [self.decoder[token] for token in bpe_idx]
        # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
        tokens_flat = ''.join(tokens_merged)
        tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
        # recover the full utf-8 string
        text = tokens_bytes.decode('utf-8', errors='replace')
        return text

class BPE:
    def __init__(self):
        self.encoder, self.decoder, self.vocab_size = self.get_encoder()
        # self.decoder = enc_obj
    
    def get_file(self, local_file, remote_file):
        if not os.path.isfile(local_file):
            print(f"downloading {remote_file} to {local_file}")
            response = requests.get(remote_file)
            open(local_file, "wb").write(response.content)

    def get_the_encoder(self):

        directory = "./temp/"
        
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        # downloading the pairs, which is used in GPT-2 model
        encoder_local_file = os.path.join(directory, 'encoder.json')
        encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json'
        self.get_file(encoder_local_file, encoder_remote_file)
        
        with open("./temp/encoder.json", 'r') as f:
            encoder = json.load(f)
        
        assert (len(encoder) == 50257), "Encoder length donwloaded is not matching 50257"
        
        # donwloading the vocab.bpe rules
        vocab_local_file = os.path.join(directory, 'vocab.bpe')
        vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
        self.get_file(vocab_local_file, vocab_remote_file)
        with open(vocab_local_file, 'r', encoding="utf-8") as f:
            bpe_data = f.read()
        
        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
        
        assert (len(bpe_merges) == 50000), "BPE length donwloaded is not matching 50000"
        
        encoder = Encoder(encoder, bpe_merges)

        return encoder

    def get_encoder(self):
        enc_dec_obj = self.get_the_encoder()
        encoder = enc_dec_obj.encode
        decoder = enc_dec_obj.decode
        return encoder, decoder, len(enc_dec_obj.encoder)

### Model code

In [3]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
class NewGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [5]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config['n_embd'] % config['n_head'] == 0
        
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config['n_embd'], 3 * config['n_embd'])
        
        # output projection
        self.c_proj = nn.Linear(config['n_embd'], config['n_embd'])
        
        # regularization
        self.attn_dropout = nn.Dropout(config['attn_pdrop'])
        self.resid_dropout = nn.Dropout(config['resid_pdrop'])
        
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config['block_size'], config['block_size']))
                                     .view(1, 1, config['block_size'], config['block_size']))
        self.n_head = config['n_head']
        self.n_embd = config['n_embd']

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [6]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config['n_embd'])
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config['n_embd'])
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config['n_embd'], 4 * config['n_embd']),
            c_proj  = nn.Linear(4 * config['n_embd'], config['n_embd']),
            act     = NewGELU(),
            dropout = nn.Dropout(config['resid_pdrop']),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x

In [7]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.block_size = config['block_size']
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config['vocab_size'], config['n_embd']),
            wpe = nn.Embedding(config['block_size'], config['n_embd']),
            drop = nn.Dropout(config['embd_pdrop']),
            h = nn.ModuleList([Block(config) for _ in range(config['n_layer'])]),
            ln_f = nn.LayerNorm(config['n_embd']),
        ))
        
        self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config['n_layer']))

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params/1e6,))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def configure_optimizers(self, train_config):
        decay = set()
        no_decay = set()
        
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config['weight_decay']},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config['learning_rate'], betas=train_config['betas'])
        
        return optimizer

    def forward(self, idx, targets=None):
        device = idx.device

        b, t = idx.size()

        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)


        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)

        
        x = self.transformer.drop(tok_emb + pos_emb)
      
        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)


        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
    
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # either sample from the distribution or take the most likely element
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

### Trainer

In [8]:
import time
from collections import defaultdict

import torch
from torch.utils.data.dataloader import DataLoader

class Trainer:
    def __init__(self, config, model, train_dataset, test_dataset):
        self.config = config
        self.model = model
        self.optimizer = None
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.callbacks = defaultdict(list)

        self.device = config['device']
        self.model = self.model.to(self.device)
        print("running on device", self.device)

        # variables that will be assigned to trainer class later for logging and etc
        self.iter_num = 0
        self.iter_time = 0.0
        self.iter_dt = 0.0


    def add_callback(self, onevent: str, callback):
        self.callbacks[onevent].append(callback)

    def set_callback(self, onevent: str, callback):
        self.callbacks[onevent] = [callback]

    def trigger_callbacks(self, onevent: str):
        for callback in self.callbacks.get(onevent, []):
            callback(self)

    def validate(self, val_loader):
        model = self.model
        model.eval()
        total_loss = 0
        count = 0
        val_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False)
        with torch.no_grad():
            for x, y in val_loader:
                
                x, y = x.to(self.device), y.to(self.device)
                _, loss = model(x, y)
                total_loss += loss.item()
                count += 1

        avg_val_loss = total_loss / count
        # print(f"[Validation] Iter {self.iter_num}: val loss = {avg_val_loss:.4f}")
        model.train()
        return avg_val_loss


    def run(self):
        model, config = self.model, self.config
       
        self.optimizer = model.configure_optimizers(config)
       
        train_loader = DataLoader(self.train_dataset, shuffle=True, batch_size=batch_size)
     
        
        model.train()
        
        self.iter_num = 0

        self.iter_time = time.time()
        
        data_iter = iter(train_loader)
        
        while True:

            try:
                batch = next(data_iter)
                
            except StopIteration:
                data_iter = iter(train_loader)
                batch = next(data_iter)
               
            batch = [t.to(self.device) for t in batch]
            
            x, y = batch

            logits, self.loss = model(x, y)

            model.zero_grad(set_to_none=True)

            self.loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm_clip'])
            
            self.optimizer.step()
            
            self.trigger_callbacks('on_batch_end')
            
            self.iter_num += 1
            
            tnow = time.time()
            self.iter_dt = tnow - self.iter_time
            self.iter_time = tnow

            if self.iter_num % 100 == 0:
                avg_val_loss = self.validate(self.test_dataset)
                print(self.iter_num, f": loss: {self.loss.item()} val loss: {avg_val_loss:.4f}")

            # termination conditions
            if config['max_iters'] is not None and self.iter_num >= config['max_iters']:
                break
              # Optional: Run validation every 500 steps
            if self.iter_num % 100 == 0:
                self.validate(self.test_dataset)

In [9]:
class ShakespeareDataset(torch.utils.data.Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # x = self.data[idx : idx + self.block_size]
        # y = self.data[idx + 1 : idx + 1 + self.block_size]
        x = torch.tensor(self.data[idx : idx + self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[idx + 1 : idx + 1 + self.block_size], dtype=torch.long)

        return x, y

In [10]:
config = {
    "n_layer": 8,
    "n_head": 16,
    "n_embd": 512,

    
    "vocab_size" : None,
    "block_size" : None,

    
    "embd_pdrop" : 0.1,
    "resid_pdrop" : 0.1,
    "attn_pdrop" : 0.1,

    "device" : 'cuda' if torch.cuda.is_available() else 'cpu',
    "num_workers" : 3,
        
    # optimizer parameters
    "max_iters" : None,
    "batch_size" : 64,
    "learning_rate" : 3e-4,
    "betas" : (0.9, 0.95),
    "weight_decay" : 0.1, # only applied on matmul weights
    "grad_norm_clip" : 1.0
}

In [11]:
set_seed(42)

In [12]:
config['device']

'cpu'

In [13]:
with open("../data/shakespeare.txt", 
          "r", encoding = 'utf-8') as f:
    text  = f.read()

text = text.lower()

In [14]:
bpe = BPE()
encoder = bpe.encoder
decoder = bpe.decoder
vocab_size = bpe.vocab_size

downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json to ./temp/encoder.json
downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe to ./temp/vocab.bpe


In [15]:
e = bpe.get_the_encoder()

In [16]:
encoded_dataset = encoder(text)

In [17]:
block_size = 128
batch_size = config['batch_size']
dataset = ShakespeareDataset(encoded_dataset, block_size)

In [18]:
from torch.utils.data import random_split

In [19]:
# Calculate split sizes (e.g., 90% train, 10% test)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

In [20]:
# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [21]:
config['vocab_size'] = vocab_size
config['block_size'] = block_size

In [22]:
model = GPT(config)

number of parameters: 51.02M


In [24]:
trainer = Trainer(config, model, train_dataset, test_dataset)

running on device cuda


In [25]:
trainer.run()

100 : loss: 5.036053657531738 val loss: 4.9268
200 : loss: 4.283403396606445 val loss: 4.3206
300 : loss: 3.9661524295806885 val loss: 3.9668
400 : loss: 3.8937549591064453 val loss: 3.7188
500 : loss: 3.461179494857788 val loss: 3.4843
600 : loss: 3.4044785499572754 val loss: 3.2734
700 : loss: 3.0049102306365967 val loss: 3.0334
800 : loss: 2.960927963256836 val loss: 2.7833
900 : loss: 2.708482503890991 val loss: 2.5152
1000 : loss: 2.444145441055298 val loss: 2.2310
1100 : loss: 2.1505088806152344 val loss: 1.9292
1200 : loss: 1.9313178062438965 val loss: 1.6313
1300 : loss: 1.6364212036132812 val loss: 1.3412
1400 : loss: 1.508245587348938 val loss: 1.0886
1500 : loss: 1.2650781869888306 val loss: 0.8704
1600 : loss: 1.0747039318084717 val loss: 0.6844
1700 : loss: 0.9347623586654663 val loss: 0.5423
1800 : loss: 0.7267189621925354 val loss: 0.4402
1900 : loss: 0.6485491394996643 val loss: 0.3674
2000 : loss: 0.6270428895950317 val loss: 0.3145
2100 : loss: 0.5415418148040771 val 

KeyboardInterrupt: 

In [26]:
torch.save(model.state_dict(), '../saved_models/model_shakespeare_new_v5_latest.pth')

In [27]:
import pickle

# Let's say your encoder/tokenizer is in a variable called `encoder`
with open("../saved_models/encoder_shakespeare_v5.pkl", "wb") as f:
    pickle.dump(bpe, f)

In [13]:
import pickle

In [None]:
pickle._