In [1]:
!pip install datasets



In [27]:
import time
import torch
import torch.nn as nn
import numpy as np
import random
from torch import optim
import matplotlib.pyplot as plt
from typing import List

from torch.utils.data import Dataset, DataLoader, RandomSampler
import tqdm
from scipy.stats import ttest_ind
# from bus_transformer import *
from datasets import load_dataset
from transformers import AutoTokenizer
from collections import defaultdict
import tensorflow as tf

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
seq_len = 256
batch_size = 32
print(DEVICE)

2024-11-01 11:20:57.361085: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-01 11:20:57.458873: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-01 11:20:57.486879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-01 11:20:57.664954: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


cuda


In [3]:
!nvidia-smi

Fri Nov  1 10:48:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3080        Off | 00000000:2D:00.0  On |                  N/A |
|  0%   52C    P8              38W / 320W |   1072MiB / 10240MiB |     10%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
from graphviz import Digraph
from tqdm.notebook import trange, tqdm
from torch.autograd import Variable

# make_dot was moved to https://github.com/szagoruyko/pytorchviz
# from torchviz import make_dot

- limit sequences to 128
- limit tasks to sentence classification
- use single sequence training without NSP
-


In [5]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_internal):
        super().__init__()

        self.W_Q = torch.nn.Linear(d_model, d_internal, False)
        self.W_K = torch.nn.Linear(d_model, d_internal, False)
        self.W_V = torch.nn.Linear(d_model, d_internal, False)

        self.SoftMax = torch.nn.Softmax(dim=-1)


        self.d_model = d_model
        self.d_internal = d_internal
        self.norm = torch.tensor(d_model**-0.5)
        self.tril = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE))

    def expand(self, d_mnew, d_inew):

        self.W_Q.weight.data = torch.cat([self.W_Q.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_Q.weight.data = torch.cat([self.W_Q.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_Q.weight.data[i][i] = self.W_Q.weight.data[i][i] if self.W_Q.weight.data[i][i] != 0 else 1

        self.W_K.weight.data = torch.cat([self.W_K.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_K.weight.data = torch.cat([self.W_K.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_K.weight.data[i][i] = self.W_K.weight.data[i][i] if self.W_K.weight.data[i][i] != 0 else 1

        self.W_V.weight.data = torch.cat([self.W_V.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_V.weight.data = torch.cat([self.W_V.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_V.weight.data[i][i] = self.W_V.weight.data[i][i] if self.W_V.weight.data[i][i] != 0 else 1

        self.d_internal = d_inew
        self.d_model = d_mnew
        self.SoftMax = torch.nn.Softmax(dim=-1)
        self.tril = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE))



    def forward(self, input_vecs):
        B, T, C = input_vecs.shape

        Q = self.W_Q(input_vecs)
        K = self.W_K(input_vecs)
        V = self.W_V(input_vecs)

        weights = Q @ K.transpose(-2, -1) * C**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        Attn = self.SoftMax(weights)


        out = Attn @ V

        return out

In [6]:
class Transformer(nn.Module):
    def __init__(self, d_model, vocab_size, num_heads, d_hidden):
        super().__init__()
        self.d_model = d_model
        self.d_internal = d_model//num_heads
        self.num_heads = num_heads
        self.vocab_size = vocab_size
        self.d_hidden = d_hidden

        self.heads = nn.ModuleList([AttentionHead(d_model, self.d_internal) for _ in range(num_heads)])
        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(self.d_model, self.d_hidden),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(self.d_hidden, self.d_model),
        )
        self.W_O = torch.nn.Linear(d_model, d_model, False)
        self.layernorm = torch.nn.LayerNorm(d_model)



    def forward(self, x):
        """
        :param x: input embeddings
        :return: output of decoder block, same shape as input
        """
        t = x
        t = torch.cat([head(t) for head in self.heads], dim=-1)
        t = self.W_O(t)
        t1 = self.layernorm(t + x)
        # t = self.relu(self.cout(self.FFN(self.connection(t1))))
        t = self.FFN(t1)
        t = self.layernorm(t + t1)

        return t



    def expand(self, d_mnew, d_inew):

        # self.connection = torch.nn.Linear(d_mnew, self.d_hidden)
        # self.cout = torch.nn.Linear(self.d_hidden, d_mnew)

        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(d_mnew, self.d_hidden),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(self.d_hidden, d_mnew),
        )
        self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew-self.d_model, self.d_model, device=DEVICE)], dim=0)
        self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew, d_mnew-self.d_model,  device=DEVICE)], dim=1)
        self.layernorm = torch.nn.LayerNorm(d_mnew)
        for i in range(self.d_model+1, d_mnew):
            self.W_O.weight.data[i][i] = 1

        for head in self.heads:
            head.expand(d_mnew, d_inew)

        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)




In [7]:
class Decoder(nn.Module):
    def __init__(self, num_blocks, d_model, d_hidden, vocab_size, num_heads):
        super().__init__()
        self.num_blocks = num_blocks
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.SoftMax = torch.nn.LogSoftmax(dim=-1)
        self.blocks = torch.nn.ModuleList([Transformer(d_model, vocab_size, num_heads, d_hidden) for _ in range(num_blocks)])
        self.d_hidden = d_hidden

        # self.connection = torch.nn.Linear(d_model, d_hidden)
        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(d_model, vocab_size),
            torch.nn.LogSoftmax(dim=-1),
        )
        self.dout = torch.nn.Dropout(0.1)

        self.embeddings = torch.nn.Embedding(vocab_size, d_model, device=DEVICE)
        self.pos_embedding = None
        # self.pos_embedding = torch.nn.Embedding(seq_len, d_model, device=DEVICE)
        self.generate_pos_embed(d_model)


    def forward(self, x):
        x = self.embeddings(x) + self.pos_embedding(torch.arange(x.shape[-1], device=DEVICE))
        x = self.dout(x)
        t = x
        for head in self.blocks:
            t = head(t) + t

        ret = self.FFN(t)

        return ret

    def generate_pos_embed(self, d_model):
        # TODO: make more efficient
        pos_em = torch.zeros((seq_len, d_model))
        for pos in range(seq_len):
            for i in range(d_model):
                if i % 2 == 0:
                    pos_em[pos][i] += torch.sin(torch.tensor(pos/(10000**(2*i/d_model))))
                else:
                    pos_em[pos][i] += torch.cos(torch.tensor(pos/(10000** (2*i/d_model))))

        self.pos_embedding = torch.nn.Embedding.from_pretrained(pos_em, freeze=True)




    def expand(self, d_mnew):
        d_inew = d_mnew // self.num_heads
        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(d_mnew, self.vocab_size),
            torch.nn.LogSoftmax(dim=-1),
        )

        self.layernorm = torch.nn.LayerNorm(d_mnew, device=DEVICE)
        for block in self.blocks:
            block.expand(d_mnew, d_inew)

        self.embeddings = torch.nn.Embedding.from_pretrained(torch.cat([self.embeddings.weight, torch.zeros(self.vocab_size, d_mnew-self.d_model, device=DEVICE).uniform_()], dim=1))
        # self.pos_embedding = torch.nn.Embedding.from_pretrained(torch.cat([self.pos_embedding.weight, torch.zeros(seq_len, d_mnew-self.d_model, device=DEVICE).uniform_()], dim=1))
        self.generate_pos_embed(d_mnew)
        # self.embeddings = torch.nn.Embedding(self.vocab_size, d_mnew)
        # self.pos_embedding = torch.nn.Embedding(seq_len, d_mnew)

        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)

In [8]:
model = Decoder(num_blocks=12, d_model=768, d_hidden=768*4, vocab_size=50257, num_heads=12)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

162.440785 M parameters


In [9]:
data = load_dataset('Salesforce/wikitext', 'wikitext-103-raw-v1')
# data = load_dataset('tiny_shakespeare')
train = data['train']
validation = data['validation']
test = data['test']

In [10]:
train.column_names

['text']

In [11]:
train['text'][0:100]

['',
 ' = Valkyria Chronicles III = \n',
 '',
 ' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n',
 " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more f

In [12]:
train_join = ' '.join(train['text'])
val_join = " ".join(validation['text'])
test_join = " ".join(test['text'])

In [13]:
train_join[0:100]

'  = Valkyria Chronicles III = \n   Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリ'

In [14]:
import re
train_join = re.sub(r'\n', "", train_join)
val_join = re.sub(r'\n', "", val_join)
test_join = re.sub(r'\n', "", test_join)

In [15]:
train_join[:1000]

'  = Valkyria Chronicles III =    Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " .   The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series n

In [17]:
train_join = re.sub(r'[^a-zA-Z0-9\s]+', '', train_join)
val_join = re.sub(r'[^a-zA-Z0-9\s]+', '', val_join)
test_join = re.sub(r'[^a-zA-Z0-9\s]+', '', test_join)

In [18]:
train_join[:1000]

'   Valkyria Chronicles III     Senj no Valkyria 3  Unrecorded Chronicles  Japanese  3  lit  Valkyria of the Battlefield 3   commonly referred to as Valkyria Chronicles III outside Japan  is a tactical role  playing video game developed by Sega and MediaVision for the PlayStation Portable  Released in January 2011 in Japan  it is the third game in the Valkyria series  Employing the same fusion of tactical and real  time gameplay as its predecessors  the story runs parallel to the first game and follows the  Nameless   a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit  Calamaty Raven     The game began development in 2010  carrying over a large portion of the work done on Valkyria Chronicles II  While it retained the standard features of the series  it also underwent multiple adjustments  such as making the game more forgiving for series newcomers  Character designer Raita Honjou 

In [19]:
train_join = re.sub(r'\s+', " ",  train_join)
val_join = re.sub(r'\s+', " ",  val_join)
test_join = re.sub(r'\s+', " ",  test_join)

In [20]:
train_join[:1000]

' Valkyria Chronicles III Senj no Valkyria 3 Unrecorded Chronicles Japanese 3 lit Valkyria of the Battlefield 3 commonly referred to as Valkyria Chronicles III outside Japan is a tactical role playing video game developed by Sega and MediaVision for the PlayStation Portable Released in January 2011 in Japan it is the third game in the Valkyria series Employing the same fusion of tactical and real time gameplay as its predecessors the story runs parallel to the first game and follows the Nameless a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit Calamaty Raven The game began development in 2010 carrying over a large portion of the work done on Valkyria Chronicles II While it retained the standard features of the series it also underwent multiple adjustments such as making the game more forgiving for series newcomers Character designer Raita Honjou and composer Hitoshi Sakimoto bot

### BPE Tokenization

In [21]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')



tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [25]:
word_freqs = defaultdict(int)

for text in train_join:
    words_with_offset = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)

    new_words = [word for word, offset in words_with_offset]
    for word in new_words:
        word_freqs[word] += 1

print(word_freqs)

defaultdict(<class 'int'>, {'Ġ': 86446761, 'V': 259051, 'a': 33724387, 'l': 16223446, 'k': 2511620, 'y': 6026960, 'r': 26331592, 'i': 28992651, 'C': 1342412, 'h': 19044643, 'o': 28818255, 'n': 29289422, 'c': 12135913, 'e': 49091225, 's': 25200925, 'I': 1012318, 'S': 1630502, 'j': 382822, '3': 546727, 'U': 316120, 'd': 16030336, 'J': 508138, 'p': 7513193, 't': 34095601, 'f': 8448546, 'B': 1061661, 'm': 9441795, 'u': 10275908, 'g': 7719415, 'v': 3996501, 'b': 5436369, 'M': 1117297, 'P': 795441, 'R': 738484, '2': 1080240, '0': 1620524, '1': 1803724, 'E': 532202, 'w': 6528108, 'N': 612081, 'G': 610946, 'W': 648356, 'T': 1724329, 'H': 852063, 'O': 471124, 'z': 420398, 'A': 1542090, 'x': 751270, 'D': 737617, '4': 500512, 'q': 342605, 'Z': 56138, 'K': 351189, 'F': 660293, 'L': 629185, '7': 447199, 'Y': 146648, '9': 948062, '5': 543627, 'X': 39174, 'Q': 47174, '8': 538930, '6': 451024})


In [26]:
alphabet = []

for word in word_freqs.keys():
    for letter in word: 
        if letter not in alphabet:
            alphabet.append(letter)

alphabet.sort()

print(alphabet)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'Ġ']


In [28]:
vocab = ["<|endoftext|>"] + alphabet.copy()

splits = {word: [c for c in word] for word in word_freqs.keys()}

In [29]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i+1])
            pair_freqs[pair] += freq

    return pair_freqs

In [30]:
pair_freqs = compute_pair_freqs(splits)

In [31]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0 
        while i < len(split) -1:
            if split[i] == a and split[i+1] == b:
                split = split[:i] + [a+b] + split[i+2:]
            else:
                i += 1

        splits[word] = split

    return splits

In [None]:
vocab_size = 1000
merges = {}

while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])

TypeError: merge_pair() missing 2 required positional arguments: 'b' and 'splits'

In [None]:
print(merges)

In [37]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenize_text = [word for word, offset in pre_tokenize_result]
    splits = [[l for l in word] for word in pre_tokenize_text]
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) -1:
                if split[i] == pair[0] and split[i+1] == pair:
                    split = split[:i] + [merge] + split[i+2:]
                else:
                    i += 1
    return sum(splits, [])

In [38]:
print(vocab)

['<|endoftext|>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'Ġ']


In [None]:
train_tok = tokenize(train_join)

In [None]:
print(train_tok[:200])

In [None]:
val_tok = tokenize(val_join)
test_tok = tokenize(test_join)

In [20]:
chars = sorted(set(next(iter([train_join]))))
len(chars)


63

In [21]:
chars.append('@')   # @ = start of sen token

In [22]:
print(chars)

[' ', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '@']


In [23]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [26]:
n = len(train_join)
t_len = int(n*0.8)

In [27]:
test[:100]

'ientists and his writings and lectures played an important part in the formulation of public attitud'

In [28]:
train_data = encode(train)
train_data[:10]

[0, 32, 37, 48, 47, 61, 54, 45, 37, 0]

In [29]:
val_data = encode(val_join)
test_data = encode(test_join)

In [39]:
encode('@')

[63]

In [43]:
def batch(s):
    if s == 'train':
        data = train_data 
    elif s == 'val':
        data = val_data
    elif s == 'test':
        data = test_data
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([torch.tensor([63]+data[i:i+seq_len-1], device=DEVICE) for i in ix])
    y = torch.stack([torch.tensor(data[i:i+seq_len], device=DEVICE) for i in ix])
    return x, y

In [44]:
xb, yb = batch('train')
xb

tensor([[63, 51,  0,  ..., 37, 56, 41],
        [63,  0, 54,  ..., 37, 40, 57],
        [63, 56, 54,  ..., 41, 43,  0],
        ...,
        [63, 29, 44,  ..., 50, 40,  0],
        [63, 41, 54,  ..., 54, 41, 40],
        [63, 51,  0,  ..., 40,  0, 56]], device='cuda:0')

In [45]:
yb

tensor([[51,  0, 55,  ..., 56, 41,  0],
        [ 0, 54, 41,  ..., 40, 57, 48],
        [56, 54, 45,  ..., 43,  0, 21],
        ...,
        [29, 44, 37,  ..., 40,  0, 18],
        [41, 54, 45,  ..., 41, 40,  0],
        [51,  0, 38,  ...,  0, 56, 44]], device='cuda:0')

In [60]:
max_iters = 100000
eval_interval = 1000
eval_iters = 200
test_iters = 1000

In [47]:
@torch.no_grad()
def estimate_loss(s=['train', 'val']):
    out = {}
    model.eval()
    for split in s:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = batch(split)
            logits= model(X)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = Y.view(B*T)
            loss = torch.nn.functional.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [48]:
def acc(model):
    model.eval()
    a = []

    for k in range(test_iters):
        X, Y = batch("test")
        logits = torch.argmax(model(X), dim=-1)
        x = [[0. if logits[sample][i] != Y[sample][i] else 1. for i in range(seq_len)] for sample in range(batch_size)]
        a.append(x)
    return torch.mean(torch.tensor(a))




In [None]:
def train(model, lr=1e-3, min_lr=1e-6, max_it=max_iters):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters, min_lr)
    writer = tf.summary.create_file_writer('/logs', flush_millis=100000)
    for iter in tqdm.tqdm(range(max_it)):

        # every once in a while evaluate the loss on train and val sets
        # if iter % eval_interval == 0 or iter == max_iters - 1:
        #     losses = estimate_loss()
        #     print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {optimizer.param_groups[0]['lr']}")

        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

        tf.summary.scalar('Loss', loss)

    writer.add_summary('summary_ops', global_step=max_it)
    
    writer.close()


In [65]:
model = Decoder(num_blocks=4, d_model=128, vocab_size=len(chars), num_heads=4, d_hidden=64*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

0.576064 M parameters


In [66]:
train(model)

  1%|          | 1212/100000 [00:35<48:04, 34.24it/s]


KeyboardInterrupt: 

In [52]:
def eval(model):
    losses = estimate_loss(['train', 'val', 'test'])

    print(f"step {iter}:\t train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, test loss {losses['test']:.4f}")

In [53]:
eval(model)

step <built-in function iter>:	 train loss 1.8413, val loss 1.8377, test loss 1.8437


In [56]:
len(train_data)/seq_len

1569197.71484375

In [None]:
def train_transfer(model, transfer_step=900, target_size=1024, lr=1e-3, min_lr=1e-6):
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters, min_lr)
    for iter in tqdm(range(1, max_iters)):

        # every once in a while evaluate the loss on train and val sets
        # if iter % eval_interval == 0 or iter == max_iters - 1 or iter == 1:
        #     losses = estimate_loss()
        #     print(f"step {iter}:\t train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {optimizer.param_groups[0]['lr']}")

        if iter == transfer_step:
        # if iter <= 1000 and iter % 500 == 0:
            eval(model)
            model.expand(target_size)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.5)
            print('at step {}: expanded model to: {} M parameters'.format(iter, sum(p.numel() for p in model.parameters())/1e6))
            model.to('cpu')
            model.to(DEVICE)    # Shortcut to recompile gradient backprop since the model changed sizes
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters-transfer_step, min_lr)
            loss_func = torch.nn.CrossEntropyLoss()
        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = loss_func(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [None]:
model = Decoder(num_blocks=4, d_model=64, vocab_size=len(chars), num_heads=4, d_hidden=64*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

0.214977 M parameters


In [None]:
def train_transfer_gradual(model, transfer_step=600, final_size=128, start_size=64, final_bus_step=1200,  lr=1e-3, min_lr=1e-5):
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, min_lr)
    step = final_bus_step // transfer_step
    step_size = (final_size-start_size)//step
    for iter in tqdm(range(1, max_iters)):

        # every once in a while evaluate the loss on train and val sets
        # if iter % eval_interval == 0 or iter == max_iters - 1 or iter == 1:
        #     losses = estimate_loss()
        #     print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # if iter <= 1000 and iter % 500 == 0:
        if iter % transfer_step == 0 and iter <= final_bus_step:
            start_size += step_size
            model.expand(start_size)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.5)
            print('at step {}: expanded model to: {} M parameters\tmodel_size: {}'.format(iter, sum(p.numel() for p in model.parameters())/1e6, start_size))
            model.to('cpu')
            model.to(DEVICE)    # Shortcut to recompile gradient backprop since the model changed sizes

            loss_func = torch.nn.CrossEntropyLoss()
        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = loss_func(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

# WORK BENCH

### Proof of concept


In [None]:
model = Decoder(num_blocks=4, d_model=384, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train_transfer(model, transfer_step=800, target_size=512, lr=1e-3, min_lr=1e-4)
eval(model)

8.762689 M parameters


  0%|          | 0/4999 [00:00<?, ?it/s]

at step 800: expanded model to: 12.730433 M parameters
step <built-in function iter>:	 train loss 1.3517, val loss 1.5096, test loss 1.5172


In [None]:

model = Decoder(num_blocks=4, d_model=512, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train(model, lr=1e-3, min_lr=1e-4)
eval(model)

12.729409 M parameters


  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.3611, val loss 1.5403, test loss 1.5416


In [None]:
model = Decoder(num_blocks=4, d_model=512, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train(model, lr=1e-3, min_lr=1e-4, max_it=4200)
eval(model)

12.729409 M parameters


  0%|          | 0/4200 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.3972, val loss 1.5633, test loss 1.5603


In [None]:
model = Decoder(num_blocks=12, d_model=384, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train_transfer(model, transfer_step=800, target_size=512, lr=1e-3, min_lr=1e-4)
eval(model)

26.089793 M parameters


  0%|          | 0/4999 [00:00<?, ?it/s]

at step 800: expanded model to: 37.924929 M parameters
step <built-in function iter>:	 train loss 1.3910, val loss 1.5503, test loss 1.5542


In [None]:

model = Decoder(num_blocks=12, d_model=512, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train(model, lr=1e-3, min_lr=1e-4)
eval(model)

37.923905 M parameters


  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.3977, val loss 1.5536, test loss 1.5493


In [None]:
model = Decoder(num_blocks=4, d_model=512, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train(model, lr=1e-3, min_lr=1e-4, max_it=5000)
eval(model)

12.729409 M parameters


  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.3588, val loss 1.5308, test loss 1.5267


In [None]:
train(model, lr=5e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.2162, val loss 1.4678, test loss 1.4707


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.1403, val loss 1.4538, test loss 1.4568


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.0699, val loss 1.4872, test loss 1.4800


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 0.9942, val loss 1.5202, test loss 1.5071


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 0.9154, val loss 1.5392, test loss 1.5573


In [None]:
train(model, lr=5e-4, min_lr=1e-5, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 0.8550, val loss 1.5745, test loss 1.5831


long training with a transfer first

In [None]:
model = Decoder(num_blocks=4, d_model=384, vocab_size=len(chars), num_heads=8, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train_transfer(model, transfer_step=4500, target_size=512, lr=1e-3, min_lr=1e-4)
eval(model)

8.762689 M parameters


  0%|          | 0/4999 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.3308, val loss 1.4914, test loss 1.4965
at step 4500: expanded model to: 12.730433 M parameters
step <built-in function iter>:	 train loss 1.7626, val loss 1.8929, test loss 1.8901


In [None]:
train(model, lr=5e-4, min_lr=1e-4)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.2705, val loss 1.4714, test loss 1.4686


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.1942, val loss 1.4544, test loss 1.4493


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.1257, val loss 1.4656, test loss 1.4595


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 1.0627, val loss 1.4846, test loss 1.4792


In [None]:
train(model, lr=1e-4, min_lr=1e-4, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 0.9944, val loss 1.5142, test loss 1.5011


In [None]:
train(model, lr=1e-4, min_lr=5e-5, max_it=5000)
eval(model)

  0%|          | 0/5000 [00:00<?, ?it/s]

step <built-in function iter>:	 train loss 0.9002, val loss 1.5354, test loss 1.5442


In [None]:
# GPT-3-small model params ~125M params
model = Decoder(num_blocks=12, d_model=768, vocab_size=50257, num_heads=12, d_hidden=512*4)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

143.455825 M parameters


In [None]:
model

Decoder(
  (SoftMax): LogSoftmax(dim=-1)
  (blocks): ModuleList(
    (0-3): 4 x Transformer(
      (heads): ModuleList(
        (0-7): 8 x AttentionHead(
          (W_Q): Linear(in_features=512, out_features=64, bias=False)
          (W_K): Linear(in_features=512, out_features=64, bias=False)
          (W_V): Linear(in_features=512, out_features=64, bias=False)
          (SoftMax): Softmax(dim=-1)
        )
      )
      (Softmax): LogSoftmax(dim=-1)
      (FFN): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.1, inplace=False)
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
        (5): Linear(in_features=1024, out_features=1024, bias=True)
        (6): ReLU()
        (7): Dropout(p=0.1, inplace=False)
      )
      (W_O): Linear(in_features=512, out_features=512, bias=False)
      (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((512,), eps=1e-0