In [3]:
import time
import torch
import torch.nn as nn
import numpy as np
import random
from torch import optim
import matplotlib.pyplot as plt
from typing import List
from utils import *
from torch.utils.data import Dataset, DataLoader, RandomSampler
import tqdm
from scipy.stats import ttest_ind
# from bus_transformer import *
from datasets import load_dataset
from transformers import AutoTokenizer

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
seq_len = 128 
batch_size = 16 
print(DEVICE)

ModuleNotFoundError: No module named 'pyarrow'

In [5]:
max_iters = 5000
eval_interval = 100
eval_iters = 200
test_iters = 1000
vocab_size = 9128

In [3]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_internal):
        super().__init__()

        self.W_Q = torch.nn.Linear(d_model, d_internal, False)
        self.W_K = torch.nn.Linear(d_model, d_internal, False)
        self.W_V = torch.nn.Linear(d_model, d_internal, False)

        self.SoftMax = torch.nn.Softmax(dim=-1)


        self.d_model = d_model
        self.d_internal = d_internal
        self.norm = torch.tensor(d_model**-0.5)
        self.tril = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE))

    def expand(self, d_mnew, d_inew):

        self.W_Q.weight.data = torch.cat([self.W_Q.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_Q.weight.data = torch.cat([self.W_Q.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_Q.weight.data[i][i] = self.W_Q.weight.data[i][i] if self.W_Q.weight.data[i][i] != 0 else 1

        self.W_K.weight.data = torch.cat([self.W_K.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_K.weight.data = torch.cat([self.W_K.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_K.weight.data[i][i] = self.W_K.weight.data[i][i] if self.W_K.weight.data[i][i] != 0 else 1

        self.W_V.weight.data = torch.cat([self.W_V.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0)
        self.W_V.weight.data = torch.cat([self.W_V.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1)
        for i in range(self.d_internal, d_inew):
            self.W_V.weight.data[i][i] = self.W_V.weight.data[i][i] if self.W_V.weight.data[i][i] != 0 else 1

        self.d_internal = d_inew
        self.d_model = d_mnew 
        self.SoftMax = torch.nn.Softmax(dim=-1)



    def forward(self, input_vecs):
        B, T, C = input_vecs.shape

        Q = self.W_Q(input_vecs)
        K = self.W_K(input_vecs)
        V = self.W_V(input_vecs)

        weights = Q @ K.transpose(-2, -1) * C**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        Attn = self.SoftMax(weights)


        out = Attn @ V

        return out

In [5]:
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_hidden):
        super().__init__()
        self.d_model = d_model
        self.d_internal = d_model//num_heads
        self.num_heads = num_heads
        self.d_hidden = d_hidden

        self.heads = nn.ModuleList([AttentionHead(d_model, self.d_internal) for _ in range(num_heads)])
        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.FFN = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(self.d_hidden, self.d_hidden),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
        )
        self.W_O = torch.nn.Linear(d_model, d_model, False)
        self.layernorm1 = torch.nn.LayerNorm(d_model)
        self.layernorm2 = torch.nn.LayerNorm(d_model)

        self.connection = torch.nn.Linear(d_model, self.d_hidden)
        self.cout = torch.nn.Linear(self.d_hidden, d_model)
        self.relu = torch.nn.ReLU()


    def forward(self, x):
        """
        :param x: input embeddings 
        :return: output of decoder block, same shape as input
        """
        t = self.layernorm1(x)
        t = torch.cat([head(t) for head in self.heads], dim=-1)
        t = self.W_O(t)
        t1 = self.layernorm2(t + x)
        t = self.relu(self.cout(self.FFN(self.connection(t1))))
        # t = self.layernorm2(t + t1)

        return t



    def expand(self, d_mnew, d_inew):

        self.connection = torch.nn.Linear(d_mnew, self.d_hidden)
        self.cout = torch.nn.Linear(self.d_hidden, d_mnew)

        self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew-self.d_model, self.d_model, device=DEVICE)], dim=0)
        self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew, d_mnew-self.d_model,  device=DEVICE)], dim=1)
        self.layernorm1 = torch.nn.LayerNorm(d_mnew)
        self.layernorm2 = torch.nn.LayerNorm(d_mnew)
        for i in range(self.d_model+1, d_mnew):
            self.W_O.weight.data[i][i] = 1

        for head in self.heads:
            head.expand(d_mnew, d_inew)

        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)




In [6]:
class Decoder(nn.Module):
    def __init__(self, num_blocks, d_model, d_hidden, vocab_size, num_heads, final_dmodel):
        super().__init__()
        self.num_blocks = num_blocks
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.SoftMax = torch.nn.LogSoftmax(dim=-1)
        self.blocks = torch.nn.ModuleList([Transformer(d_model, num_heads, d_hidden) for _ in range(num_blocks)])
        self.d_hidden = d_hidden

        self.connection = torch.nn.Linear(d_model, d_hidden)
        self.FFN = torch.nn.Sequential(
            torch.nn.Dropout(0.1),
            torch.nn.ReLU(),
            torch.nn.Linear(d_hidden, d_hidden), 
            torch.nn.Dropout(0.1),
            torch.nn.ReLU(),
            torch.nn.Linear(d_hidden, vocab_size),
            torch.nn.LogSoftmax(dim=-1),
        )
        self.dout = torch.nn.Dropout(0.1)
        
        self.final_dmodel = final_dmodel
        self.embeddings = torch.nn.Embedding(vocab_size, d_model, device=DEVICE)
        self.pos_embedding = None
        # self.pos_embedding = torch.nn.Embedding(seq_len, d_model, device=DEVICE)
        self.generate_pos_embed(d_model)
        
        self.layernorm = torch.nn.LayerNorm(d_model, device=DEVICE)

    def forward(self, x):
        x = self.embeddings(x) + self.pos_embedding(torch.arange(x.shape[-1], device=DEVICE))
        x = self.dout(x)
        t = x
        for head in self.blocks:
            t = head(t) + t 

        t = self.layernorm(t)

        t = self.connection(t)
        ret = self.FFN(t)

        return ret

    def generate_pos_embed(self, d_model):
        # TODO: make more efficient 
        pos_em = torch.zeros((seq_len, d_model))
        for pos in range(seq_len):
            for i in range(d_model):
                if i % 2 == 0:
                    pos_em[pos][i] += torch.sin(torch.tensor(pos/(10000**(2*i/d_model))))
                else:
                    pos_em[pos][i] += torch.cos(torch.tensor(pos/(10000** (2*i/d_model))))

        self.pos_embedding = torch.nn.Embedding.from_pretrained(pos_em, freeze=True)




    def expand(self, d_mnew):
        d_inew = d_mnew // self.num_heads
        self.connection = torch.nn.Linear(d_mnew, self.d_hidden, device=DEVICE)
        self.layernorm = torch.nn.LayerNorm(d_mnew, device=DEVICE)
        for block in self.blocks:
            block.expand(d_mnew, d_inew)

        self.embeddings = torch.nn.Embedding.from_pretrained(torch.cat([self.embeddings.weight, torch.zeros(self.vocab_size, d_mnew-self.d_model, device=DEVICE).uniform_()], dim=1))
        # self.pos_embedding = torch.nn.Embedding.from_pretrained(torch.cat([self.pos_embedding.weight, torch.zeros(seq_len, d_mnew-self.d_model, device=DEVICE).uniform_()], dim=1))
        self.generate_pos_embed(d_mnew)
        # self.embeddings = torch.nn.Embedding(self.vocab_size, d_mnew)
        # self.pos_embedding = torch.nn.Embedding(seq_len, d_mnew)

        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)

In [7]:
data = load_dataset('Salesforce/wikitext', 'wikitext-103-v1')

In [14]:
data.column_names

{'test': ['text'], 'train': ['text'], 'validation': ['text']}

In [15]:
train = data['train']
test = data['test']
valid = data['validation']

In [20]:
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')

In [152]:
def train(model, lr=1e-3, min_lr=1e-6):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters, min_lr)
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {optimizer.param_groups[0]['lr']}")

        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()


In [24]:
model = Decoder(num_blocks=4, d_model=128, vocab_size=len(chars), num_heads=4, d_hidden=64*4, final_dmodel=1024)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

3.345857 M parameters


In [25]:
train(model)

step 0: train loss 4.1688, val loss 4.1690
step 100: train loss 2.5251, val loss 2.5221
step 200: train loss 2.4296, val loss 2.4312
step 300: train loss 2.3301, val loss 2.3314
step 400: train loss 2.1980, val loss 2.2196
step 500: train loss 2.0995, val loss 2.1247
step 600: train loss 2.0298, val loss 2.0811
step 700: train loss 1.9629, val loss 2.0257
step 800: train loss 1.8915, val loss 1.9733
step 900: train loss 1.8468, val loss 1.9321
step 1000: train loss 1.8003, val loss 1.8927
step 1100: train loss 1.7630, val loss 1.8664
step 1200: train loss 1.7296, val loss 1.8358
step 1300: train loss 1.6972, val loss 1.8303
step 1400: train loss 1.6773, val loss 1.8195
step 1500: train loss 1.6501, val loss 1.7725
step 1600: train loss 1.6232, val loss 1.7389
step 1700: train loss 1.6052, val loss 1.7483
step 1800: train loss 1.6014, val loss 1.7330
step 1900: train loss 1.5839, val loss 1.7124
step 2000: train loss 1.5635, val loss 1.6877
step 2100: train loss 1.5527, val loss 1.6970


In [83]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


# make_dot was moved to https://github.com/szagoruyko/pytorchviz
from torchviz import make_dot

In [148]:
def train_transfer(model, transfer_step=900, target_size=1024, lr=1e-3, min_lr=1e-6):
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters, min_lr)
    for iter in range(1, max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1 or iter == 1:
            losses = estimate_loss()
            print(f"step {iter}:\t train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {optimizer.param_groups[0]['lr']}")

        if iter == transfer_step:
        # if iter <= 1000 and iter % 500 == 0:
            model.expand(target_size)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.5)
            print('at step {}: expanded model to: {} M parameters'.format(iter, sum(p.numel() for p in model.parameters())/1e6))
            model.to('cpu')
            model.to(DEVICE)    # Shortcut to recompile gradient backprop since the model changed sizes
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters-transfer_step, min_lr)
            loss_func = torch.nn.CrossEntropyLoss()
        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = loss_func(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [28]:
model = Decoder(num_blocks=4, d_model=64, vocab_size=len(chars), num_heads=4, d_hidden=64*4, final_dmodel=1024)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

0.535041 M parameters


In [29]:
train_transfer(model)

step 1: train loss 4.1782, val loss 4.1792
step 100: train loss 2.5954, val loss 2.5947
step 200: train loss 2.4439, val loss 2.4388
step 300: train loss 2.3527, val loss 2.3512
step 400: train loss 2.3144, val loss 2.3242
step 500: train loss 2.2663, val loss 2.2731
step 600: train loss 2.2449, val loss 2.2586
step 700: train loss 2.2024, val loss 2.2258
step 800: train loss 2.1698, val loss 2.1884
step 900: train loss 2.1484, val loss 2.1620
at step 900: expanded model to: 0.903617 M parameters
step 1000: train loss 2.3213, val loss 2.3277
step 1100: train loss 2.2354, val loss 2.2447
step 1200: train loss 2.1777, val loss 2.1962
step 1300: train loss 2.1376, val loss 2.1698
step 1400: train loss 2.1093, val loss 2.1410
step 1500: train loss 2.0825, val loss 2.1281
step 1600: train loss 2.0704, val loss 2.1064
step 1700: train loss 2.0356, val loss 2.0877
step 1800: train loss 2.0014, val loss 2.0619
step 1900: train loss 1.9868, val loss 2.0340
step 2000: train loss 1.9596, val loss

In [26]:
def train_transfer_gradual(model, transfer_step=600, final_size=128, start_size=64, final_bus_step=1200,  lr=1e-3):
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000, 0.5)
    step = final_bus_step // transfer_step
    step_size = (final_size-start_size)//step
    for iter in range(1, max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1 or iter == 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # if iter <= 1000 and iter % 500 == 0:
        if iter % transfer_step == 0 and iter <= final_bus_step:
            start_size += step_size
            model.expand(start_size)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.5)
            print('at step {}: expanded model to: {} M parameters\tmodel_size: {}'.format(iter, sum(p.numel() for p in model.parameters())/1e6, start_size))
            model.to('cpu')
            model.to(DEVICE)    # Shortcut to recompile gradient backprop since the model changed sizes

            loss_func = torch.nn.CrossEntropyLoss()
        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = loss_func(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # scheduler.step()

In [67]:
model = Decoder(num_blocks=4, d_model=128, vocab_size=len(chars), num_heads=4, d_hidden=128, final_dmodel=1024)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train_transfer_gradual(model, transfer_step=300, final_bus_step=1500, start_size=128, final_size=512)

0.685761 M parameters
step 1: train loss 4.1678, val loss 4.1682
step 100: train loss 2.5144, val loss 2.5178
step 200: train loss 2.4015, val loss 2.3948
step 300: train loss 2.2808, val loss 2.2922
at step 300: expanded model to: 1.193365 M parameters	model_size: 204
step 400: train loss 2.3386, val loss 2.3418
step 500: train loss 2.2482, val loss 2.2624
step 600: train loss 2.1824, val loss 2.2040
at step 600: expanded model to: 1.885801 M parameters	model_size: 280
step 700: train loss 2.3070, val loss 2.3106
step 800: train loss 2.2090, val loss 2.2302
step 900: train loss 2.1307, val loss 2.1514
at step 900: expanded model to: 2.763069 M parameters	model_size: 356
step 1000: train loss 2.2723, val loss 2.2772
step 1100: train loss 2.1761, val loss 2.1989
step 1200: train loss 2.0827, val loss 2.1294
at step 1200: expanded model to: 3.825169 M parameters	model_size: 432
step 1300: train loss 2.1690, val loss 2.1862
step 1400: train loss 2.0768, val loss 2.1126
step 1500: train lo

# WORK BENCH

In [155]:
model = Decoder(num_blocks=6, d_model=360, vocab_size=len(chars), num_heads=8, d_hidden=1024, final_dmodel=1024)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train_transfer(model, transfer_step=800, target_size=512, lr=1e-3)

15.404713 M parameters
step 1:	 train loss 4.1814, val loss 4.1809, lr 0.001
step 100:	 train loss 2.4459, val loss 2.4551, lr 0.000999033958943103
step 200:	 train loss 2.3195, val loss 2.3500, lr 0.0009961005307061032
step 300:	 train loss 2.2446, val loss 2.2977, lr 0.0009912111935927526
step 400:	 train loss 2.1711, val loss 2.2149, lr 0.0009843852435829092
step 500:	 train loss 2.1285, val loss 2.1885, lr 0.0009756496195827828
step 600:	 train loss 2.0560, val loss 2.1359, lr 0.0009650387971093778
step 700:	 train loss 2.0061, val loss 2.0846, lr 0.0009525946522313713
step 800:	 train loss 1.9437, val loss 2.0346, lr 0.0009383662963034076
at step 800: expanded model to: 20.643393 M parameters
step 900:	 train loss 2.2716, val loss 2.3072, lr 0.0009986032966919998
step 1000:	 train loss 2.1260, val loss 2.1910, lr 0.0009944209976994521
step 1100:	 train loss 2.0491, val loss 2.1144, lr 0.0009874764921348207
step 1200:	 train loss 1.9814, val loss 2.0648, lr 0.0009778086164901761
st

In [154]:

model = Decoder(num_blocks=6, d_model=512, vocab_size=len(chars), num_heads=8, d_hidden=1024, final_dmodel=1024)
model.to(DEVICE)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")
train(model, lr=1e-4, min_lr=1e-5)

20.643393 M parameters
step 0: train loss 4.1703, val loss 4.1703, lr 0.0001
step 100: train loss 2.5414, val loss 2.5442, lr 9.991120277927216e-05
step 200: train loss 2.4371, val loss 2.4484, lr 9.964516155915146e-05
step 300: train loss 2.2989, val loss 2.3068, lr 9.920292628279099e-05
step 400: train loss 2.2434, val loss 2.2575, lr 9.858624225078842e-05
step 500: train loss 2.1763, val loss 2.1936, lr 9.779754323328188e-05
step 600: train loss 2.1150, val loss 2.1282, lr 9.683994186497123e-05
step 700: train loss 2.0623, val loss 2.0977, lr 9.571721736097081e-05
step 800: train loss 2.0129, val loss 2.0532, lr 9.44338006019738e-05
step 900: train loss 1.9783, val loss 2.0294, lr 9.299475664759062e-05
step 1000: train loss 1.9239, val loss 1.9899, lr 9.140576474687264e-05
step 1100: train loss 1.8907, val loss 1.9573, lr 8.967309592491052e-05
step 1200: train loss 1.8533, val loss 1.9423, lr 8.78035882339636e-05
step 1300: train loss 1.8167, val loss 1.9096, lr 8.580461976679112e-0

In [127]:
model

Decoder(
  (SoftMax): LogSoftmax(dim=-1)
  (blocks): ModuleList(
    (0-3): 4 x Transformer(
      (heads): ModuleList(
        (0-7): 8 x AttentionHead(
          (W_Q): Linear(in_features=512, out_features=64, bias=False)
          (W_K): Linear(in_features=512, out_features=64, bias=False)
          (W_V): Linear(in_features=512, out_features=64, bias=False)
          (SoftMax): Softmax(dim=-1)
        )
      )
      (Softmax): LogSoftmax(dim=-1)
      (FFN): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.1, inplace=False)
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
        (5): Linear(in_features=1024, out_features=1024, bias=True)
        (6): ReLU()
        (7): Dropout(p=0.1, inplace=False)
      )
      (W_O): Linear(in_features=512, out_features=512, bias=False)
      (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((512,), eps=1e-0