In [1]:
import time
import torch
import torch.nn as nn
import numpy as np
import random
from torch import optim
import matplotlib.pyplot as plt
from typing import List
from utils import *
from torch.utils.data import Dataset, DataLoader, RandomSampler
import tqdm
from scipy.stats import ttest_ind
# from bus_transformer import *
from datasets import load_dataset
import datasets
from transformers import AutoTokenizer

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
seq_len = 32
batch_size = 32 
print(DEVICE)

cuda


- limit sequences to 128
- limit tasks to sentence classification
- use single sequence training without NSP
- 


In [91]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_internal):
        super().__init__()

        self.W_Q = torch.nn.Linear(d_model, d_internal, False)
        self.W_K = torch.nn.Linear(d_model, d_internal, False)
        self.W_V = torch.nn.Linear(d_model, d_internal, False)

        self.SoftMax = torch.nn.Softmax(dim=-1)


        self.d_model = d_model
        self.d_internal = d_internal
        self.norm = torch.tensor(d_model**-0.5)
        self.tril = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE))

    def expand(self, d_mnew, d_inew):

        with torch.no_grad():
            self.W_Q.weight.data = torch.nn.Parameter(torch.cat([self.W_Q.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0))
            self.W_Q.weight.data = torch.nn.Parameter(torch.cat([self.W_Q.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1))
            for i in range(self.d_internal, d_inew):
                self.W_Q.weight.data[i][i] = self.W_Q.weight.data[i][i] if self.W_Q.weight.data[i][i] != 0 else 1

            self.W_K.weight.data = torch.nn.Parameter(torch.cat([self.W_K.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0))
            self.W_K.weight.data = torch.nn.Parameter(torch.cat([self.W_K.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1))
            for i in range(self.d_internal, d_inew):
                self.W_K.weight.data[i][i] = self.W_K.weight.data[i][i] if self.W_K.weight.data[i][i] != 0 else 1

            self.W_V.weight.data = torch.nn.Parameter(torch.cat([self.W_V.weight.data, torch.zeros(d_inew - self.d_internal, self.d_model, device=DEVICE)], dim=0))
            self.W_V.weight.data = torch.nn.Parameter(torch.cat([self.W_V.weight.data, torch.zeros(d_inew, d_mnew - self.d_model, device=DEVICE)], dim=1))
            for i in range(self.d_internal, d_inew):
                self.W_V.weight.data[i][i] = self.W_V.weight.data[i][i] if self.W_V.weight.data[i][i] != 0 else 1

        self.d_internal = d_inew
        self.d_model = d_mnew 
        self.SoftMax = torch.nn.Softmax(dim=-1)
        self.norm = torch.tensor(d_mnew**-0.5)



    def forward(self, input_vecs):
        # print(input_vecs.shape)
        # print(self.d_model, self.d_internal)
        # print(self.W_Q.weight.data.shape)
        Q = self.W_Q(input_vecs)
        K = self.W_K(input_vecs)
        V = self.W_V(input_vecs)

        weights = Q @ K.transpose(-2, -1) * self.norm
        weights = weights.masked_fill(self.tril == 0, float('-inf'))
        Attn = self.SoftMax(weights)

        out = Attn @ V

        return out

In [83]:
class Transformer(nn.Module):
    def __init__(self, d_model, vocab_size, num_heads):
        super().__init__()
        self.d_model = d_model
        self.d_internal = d_model//num_heads
        self.num_heads = num_heads
        self.vocab_size = vocab_size 

        self.heads = nn.ModuleList([AttentionHead(d_model, self.d_internal) for _ in range(num_heads)])
        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(d_model, d_model),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(d_model, d_model),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(d_model, d_model)
        )
        self.W_O = torch.nn.Linear(d_model, d_model, False)
        self.layernorm1 = torch.nn.LayerNorm(d_model, device=DEVICE)
        self.layernorm2 = torch.nn.LayerNorm(d_model, device=DEVICE)
        # self.embed = torch.nn.Embedding(vocab_size, d_model).to(DEVICE)



    def forward(self, x):
        """
        :param x: input embeddings 
        :return: output of decoder block, same shape as input
        """
        t = torch.cat([head(x) for head in self.heads], dim=-1)
        t = self.W_O(t)
        t = self.layernorm1(t + x)
        t = self.FFN(t) 
        t = self.layernorm2(t + x)

        return t



    def expand(self, d_mnew, d_inew):

        # TODO: / room for future experiments, how can we expand this ffn to not erase it everytime we expand?
        self.FFN = torch.nn.Sequential(
            torch.nn.Linear(d_mnew, d_mnew),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(d_mnew, d_mnew),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(d_mnew, d_mnew)
        )
        with torch.no_grad():

            self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew-self.d_model, self.d_model, device=DEVICE)], dim=0)
            self.W_O.weight.data = torch.cat([self.W_O.weight.data, torch.zeros(d_mnew, d_mnew-self.d_model, device=DEVICE)], dim=1)
            self.layernorm1 = torch.nn.LayerNorm(d_mnew, device=DEVICE)
            self.layernorm2 = torch.nn.LayerNorm(d_mnew, device=DEVICE)
            for i in range(self.d_model+1, d_mnew):
                self.W_O.weight.data[i][i] = 1

        for head in self.heads:
            head.expand(d_mnew, d_inew)

        self.Softmax = torch.nn.LogSoftmax(dim=-1)
        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)




In [88]:
class Decoder(nn.Module):
    def __init__(self, num_blocks, d_model, d_hidden, vocab_size, num_heads, final_dmodel):
        super().__init__()
        self.num_blocks = num_blocks
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.SoftMax = torch.nn.LogSoftmax(dim=-1)
        self.blocks = torch.nn.ModuleList([Transformer(d_model, vocab_size, num_heads) for _ in range(num_blocks)])
        self.d_hidden = d_hidden

        self.connection = torch.nn.Linear(d_model, d_hidden)
        self.FFN = torch.nn.Sequential(
            torch.nn.Dropout(0.1),
            torch.nn.ReLU(),
            torch.nn.Linear(d_hidden, 4*d_hidden), 
            torch.nn.Dropout(0.1),
            torch.nn.ReLU(),
            torch.nn.Linear(4*d_hidden, vocab_size),
            torch.nn.LogSoftmax(dim=-1),
        )
        
        self.dropout = torch.nn.Dropout(0.1)
        self.final_dmodel = final_dmodel
        self.embeddings = torch.nn.Embedding(vocab_size, d_model, device=DEVICE)
        self.pos_embedding = torch.nn.Embedding(seq_len, d_model, device=DEVICE)

        # self.rotation_matrix = torch.zeros(d_model, d_model, device=DEVICE, dtype=torch.double)
        # self.pos_embedding = torch.zeros(seq_len, d_model, device=DEVICE, dtype=torch.double)

        # for i in range(d_model):
        #     for j in range(d_model):
        #         z = torch.cos(torch.tensor(i*j*0.01))
        #         self.rotation_matrix[i,j] = z
        #         if i < seq_len: self.pos_embedding[i,j] = z
        
        self.layernorm = torch.nn.LayerNorm(d_model, device=DEVICE)

    def forward(self, x):
        # t = self.pos_embedding + self.embeddings(x) @ self.rotation_matrix
        t = self.embeddings(x) + self.pos_embedding(torch.arange(x.shape[-1], device=DEVICE))
        t = self.dropout(t)
        for head in self.blocks:
            t = head(t) + t 

        t = self.layernorm(t)

        t = self.connection(t)
        ret = self.FFN(t)

        return ret


    def expand(self, d_mnew):
        d_inew = d_mnew // self.num_heads
        self.connection = torch.nn.Linear(d_mnew, self.d_hidden, device=DEVICE)
        self.layernorm = torch.nn.LayerNorm(d_mnew, device=DEVICE)
        for block in self.blocks:
            block.expand(d_mnew, d_inew)

        with torch.no_grad():
            self.embeddings = torch.nn.Embedding.from_pretrained(torch.cat([self.embeddings.weight, torch.zeros(self.vocab_size, d_mnew-self.d_model, device=DEVICE)], dim=1))
            self.pos_embedding = torch.nn.Embedding.from_pretrained(torch.cat([self.pos_embedding.weight, torch.zeros(seq_len, d_mnew-self.d_model, device=DEVICE)], dim=1))
            # self.embeddings.weight = torch.nn.Parameter(torch.cat([self.embeddings.weight, torch.zeros(self.vocab_size, d_mnew-self.d_model, device=DEVICE)], dim=1))
            # self.pos_embedding.weight = torch.nn.Parameter(torch.cat([self.pos_embedding.weight, torch.zeros(seq_len, d_mnew-self.d_model, device=DEVICE)], dim=1)) 

        self.d_model = d_mnew
        self.d_internal = d_inew
        self.to(DEVICE)

In [5]:
# data = load_dataset('Salesforce/wikitext', 'wikitext-103-v1')
data = load_dataset('tiny_shakespeare')
train = data['train']
validation = data['validation']
test = data['test']

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [6]:
train.column_names

['text']

In [7]:
chars = sorted(set(next(iter(train['text']))))
len(chars)


65

In [8]:
chars

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [9]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [10]:
train = data['train']['text'][0]
validation = data['validation']['text'][0]
test = data['test']['text'][0]

In [11]:
test



In [12]:
train_data = encode(train)
train_data[:10]

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47]

In [13]:
val_data = encode(validation)
test_data = encode(test)

In [14]:
model = Decoder(num_blocks=4, d_model=64, vocab_size=len(chars), num_heads=4, d_hidden=256, final_dmodel=1024)
model.to(DEVICE)

Decoder(
  (SoftMax): LogSoftmax(dim=-1)
  (blocks): ModuleList(
    (0-3): 4 x Transformer(
      (heads): ModuleList(
        (0-3): 4 x AttentionHead(
          (W_Q): Linear(in_features=64, out_features=16, bias=False)
          (W_K): Linear(in_features=64, out_features=16, bias=False)
          (W_V): Linear(in_features=64, out_features=16, bias=False)
          (SoftMax): Softmax(dim=-1)
        )
      )
      (Softmax): LogSoftmax(dim=-1)
      (FFN): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=64, out_features=64, bias=True)
        (4): ReLU()
        (5): Dropout(p=0.1, inplace=False)
        (6): Linear(in_features=64, out_features=64, bias=True)
      )
      (W_O): Linear(in_features=64, out_features=64, bias=False)
      (layernorm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((64,), eps=1e-05, elementwise

In [15]:
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

0.469249 M parameters


In [16]:
def batch(s):
    data = train_data if s == 'train' else val_data
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+seq_len], device=DEVICE) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+seq_len+1], device=DEVICE) for i in ix])
    return x, y

In [17]:
xb, yb = batch('train')
xb

tensor([[64, 43, 52,  ..., 45, 43,  6],
        [42, 47, 60,  ..., 58, 39, 40],
        [57, 46, 53,  ..., 58, 46, 43],
        ...,
        [53, 51, 57,  ..., 41, 58, 47],
        [39, 42, 51,  ..., 24, 24, 13],
        [ 0, 13, 52,  ..., 47, 57,  1]], device='cuda:0')

In [18]:
yb

tensor([[43, 52, 10,  ..., 43,  6,  1],
        [47, 60, 53,  ..., 39, 40, 50],
        [46, 53, 59,  ..., 46, 43, 56],
        ...,
        [51, 57,  6,  ..., 58, 47, 53],
        [42, 51, 47,  ..., 24, 13, 10],
        [13, 52, 42,  ..., 57,  1, 39]], device='cuda:0')

In [19]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
max_iters = 5000
eval_interval = 100
eval_iters = 200

In [20]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = batch(split)
            logits= model(X)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = Y.view(B*T)
            loss = torch.nn.functional.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [21]:
def train(model, optimizer):
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


In [None]:
train(model, optimizer)

In [22]:
def generate(model, idx, max_new_tokens):
    for _ in range(max_new_tokens):
        idx_cond = idx[:,-seq_len:]
        logits = model(idx_cond)
        logits= logits[:,-1,:]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode(generate(model, context, max_new_tokens=2000)[0].tolist()))

In [78]:
def train(model, optimizer):
    loss_func = torch.nn.CrossEntropyLoss()
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        if iter == 100:
            print('expand')
            model.expand(64)
            print('expanded model to: {} M parameters'.format(sum(p.numel() for p in model.parameters())/1e6))
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
            loss_func = torch.nn.CrossEntropyLoss()
        # sample a batch of data
        xb, yb = batch('train')

        # evaluate the loss
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = yb.view(B*T)
        loss = loss_func(logits, targets)
        loss.backward()
        optimizer.step()

In [89]:
model = Decoder(num_blocks=4, d_model=32, vocab_size=len(chars), num_heads=4, d_hidden=128, final_dmodel=1024)
model.to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
print(sum(p.numel() for p in model.parameters())/1e6, "M parameters")

0.136353 M parameters


In [90]:
train(model, optimizer)

step 0: train loss 4.1889, val loss 4.1888
step 100: train loss 2.7145, val loss 2.7253
expand
expanded model to: 0.230529 M parameters


RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [64, 64] but expected shape compatible with [32, 32]