# TODO
- train model
- DONT FORGET TO SAVE MODEL
- fix bug where inference requires input to be at least as long as the maximum context length

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
from typing import List
import math
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
#device = 'cpu'

# MatryoshkaGPT

the idea here is to have a bunch of tiny models inside the main model like russian nesting dolls. it's based on [this paper](https://arxiv.org/abs/2205.13147) which only created nesting doll embeddings, and then they later expanded the concept to also incorporate the feedforward network in [this paper](https://arxiv.org/pdf/2310.07707.pdf). however their implementation was lame because they didn't bother doing it also with the multi-head attention mechanism, which is what we'll be doing today

to give you a better idea of what i mean by "nesting dolls" take a look at this graphic. for a given embedding vector $z\in \mathbb{R}^d$, we can subset into smaller vectors. In this case we've chosen to cut it in half each time, but really you could do this with any sized subsets. By "subsets" i mean we're literally just splicing. Then as you'll see later, we simultaneously train the model at all of these sizes at once, giving us an embedding representation that's self-similar at each level

<p align="center">
<img src="./images/Screenshot from 2024-02-12 19-27-42.png" width="512"/>
</p>

In [3]:
# hyperparameters
b = 16 # how many independent sequences will we process in parallel?
t = 64 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
lr = 3e-4 # learning rate for each backprop step
eval_iters = 20
h = 4 # number of attention heads
l = 8 # number of transormer layers
dropout = 0.1 # % of parameters to ignore every iteration
l2 = 0.01 # multiplier for our L2 norm to encourage sparsity

# embedding aka hidden dimension. this is the largest that the model will have
d = 128 # make sure it is a power of 2
power_of_d = int(math.log2(d))

# the smallest power of 2 we'll be considering as a matryoshka embedding
min_power = 5 # Starting from 2^min_power
nesting_list = [2**i for i in range(min_power, int(power_of_d) + 1)]
print("embedding sizes: ", nesting_list)
print("number of nesting doll models: ", len(nesting_list), " (I will frequently refer to this number as 'g')")

embedding sizes:  [32, 64, 128]
number of nesting doll models:  3  (I will frequently refer to this number as 'g')


In [4]:
# the dataset we'll be using is just TinyShakespeare
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [5]:
# here are all the unique characters that occur in this text. we'll do character-wise tokenization
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [7]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest validation
train_data = data[:n]
val_data = data[n:]

In [8]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([data[i+1:i+t+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [9]:
@torch.no_grad()
def estimate_loss(): # to use later during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# FEEDFORWARD

this is the part that was done in [MATFORMER](https://arxiv.org/pdf/2310.07707.pdf), however they were lame and ended it here

basically we're subsetting the feedforward matrices such that they match with the size of each model

<p align="center">
<img src="./images/drawings/ffwd.png" width="512"/>
</p>

In [10]:
class matryoshkaFeedFoward(nn.Module):
    def __init__(self, nesting_list: List, dropout):
        super().__init__()

        # our list of different potential embedding sizes
        self.nesting_list = nesting_list
        
        # the embedding dimension of the largest model
        self.d = nesting_list[-1]

        # Initialize only the largest weights and biases
        # this is more efficient than using regualr nn.Linear() because we need to splice them frequently
        self.w1 = nn.Parameter(torch.Tensor(self.d, 4 * self.d))
        self.b1 = nn.Parameter(torch.Tensor(4 * self.d))
        self.w2 = nn.Parameter(torch.Tensor(4 * self.d, self.d))
        self.b2 = nn.Parameter(torch.Tensor(self.d))

        # Initializing parameters
        nn.init.normal_(self.w1, std=0.02)  
        nn.init.normal_(self.b1, std=0.02)
        nn.init.normal_(self.w2, std=0.02)
        nn.init.normal_(self.b2, std=0.02)
        
        # the other parts
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(dropout)
        # notice how in forwardTuple(), the dropout mechanism will actually randomly drop out different weights for
        # each model size during a given forward pass. My intuition says this will actually be good for generalizability
        # but I suppose the opposite could be true. It'd be interesting to create a custom dropout method that ensures 
        # consistency in what parameters are dropped out across model sizes, but that's really not worth the effort
    
    def forwardTuple(self, x):
        """
        input: tuple of length g with tensors of shape (b,t,d_i) for d_i in nesting_list
        operation: 2 linear layers with a 4-times larger hidden depth, a relu nonlinearity in between, and then a dropout
        output: tuple of length g with tensors of shape (b,t,d_i) for d_i in nesting_list
        """
        out = ()
        for i, d_i in enumerate(self.nesting_list):
            out += (self.drop(self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i,:d_i] + self.b2[:d_i]),)
        return out

    def forwardTensor(self, x):
        """
        input: tensor of shape (b,t,d_i)
        operation: 2 linear layers with a 4-times depth, a relu nonlinearity in between, and then a dropout
        output: tensor of shape (b,t,d_i)
        """
        d_i = x.shape[-1]
        return self.relu(x @ self.w1[:d_i, :4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i, :d_i] + self.b2[:d_i]
    
    def forward(self, x):
        # forwardTuple() is for training and forwardTensor() is for inference
        # that will remain true for the rest of the code as well
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x)

# ATTENTION

Now this is where the annoying part began. To subset the attention heads, we have to not only splice according to the model's embedding dimension but also take into account new smaller head sizes. sorry i drew the output so small but it's too late now. I'm assuming you know how self-attention works well enough to look at this and get the idea

<p align="center">
<img src="./images/drawings/head.png" width="512"/>
</p>

In [11]:
class matryoshkaHead(nn.Module):
    def __init__(self, nesting_list: List, head_sizes: List):
        super().__init__()

        # to be used for iterating in forward()
        self.nesting_list = nesting_list
        self.head_sizes = head_sizes
        
        # the largest embedding dimension of the model
        self.d = nesting_list[-1]
        # the largest head size
        self.h = head_sizes[-1]

        # initialize only the largest. we'll subset later during forward()
        self.key = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)
        self.query = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)
        self.value = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)
        
        # Initializing parameters
        nn.init.normal_(self.key, std=0.02)  
        nn.init.normal_(self.query, std=0.02)
        nn.init.normal_(self.value, std=0.02)

        # the mask so they only look into the past
        self.register_buffer('tril', torch.tril(torch.ones(t, t)))

    def forwardTuple(self, x):
        """
        input: tuple length g with tensors of shape (b,t,d_i) for d_i in nesting_list
        operation: masked self-attention
        output: tuple length g with tensors of shape (b,t,h_i) for h_i in head_sizes where h_i = d_i / h
        """
        k,q,v,wei,out = (),(),(),[],() # wei is a list so i can edit it in-place
        for i, (d_i, h_i) in enumerate(zip(self.nesting_list, self.head_sizes)):
            k += (x[i] @ self.key[:d_i, :h_i],)
            q += (x[i] @ self.query[:d_i, :h_i],)
            v += (x[i] @ self.value[:d_i, :h_i],)

            wei.append(q[i] @ k[i].transpose(-2,-1) * h_i ** -0.5) # k[i].shape[-1]**-0.5)
            wei[i] = wei[i].masked_fill(self.tril[:t,:t] == 0, float('-inf'))
            wei[i] = F.softmax(wei[i],dim=-1)
            
            out += (wei[i]@v[i],)
        return out

    def forwardTensor(self, x, h):
        """
        input: 
            - tensor of shape (b,t,d_i)
            - number of heads h
        operation: masked self-attention
        output: tensor of shape (b,t,h_i) where h_i = d_i / h
        """
        d_i = x.shape[-1]
        h_i = d_i // h # the second / ensures it's an int rather than a float

        k = x @ self.key[:d_i, :h_i]
        q = x @ self.query[:d_i, :h_i]
        v = x @ self.value[:d_i, :h_i]

        wei = q @ k.transpose(-2,-1) * h_i ** -0.5 # k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:t,:t] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        return wei @ v
        
    def forward(self, x, h=None):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x, h)

# MHA

then we've gotta concatenate the outputs of each head

<p align="center">
<img src="./images/drawings/mha_concat.png" width="512"/>
</p>

and after that linearly project them

<p align="center">
<img src="./images/drawings/mha_proj.png" width="512"/>
</p>

this is the place where our splicing gets conceptually annoying. instead of just grabbing the matrix in the upper corner, because of the way attention head output concatenation works we actually need to skip over certain parts of the linear projection matrix and then concatenate them together in order to use them. Here's an example of what the matrix multiplication looks like. on the left is a simplified version of the concatenated attention heads where i just showed it as a matrix rather than a tensor, and then on the right is the actual projection matrix. notice how the numbers in the pink output matrix look similar to the first column of the purple output matrix with a positive number, its negative, and then a smaller positive number; that's the self-similarity in action. the yellow arrows point to the parts that get skipped over. obviously this would look a lot uglier with bigger matrices & incorporating the blue/green layer

<p align="center">
<img src="./images/drawings/mha_proj_matmul.png" width="512"/>
</p>

In [12]:
class matryoshkaMultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, nesting_list: List, head_sizes: List, dropout):
        super().__init__()
        
        self.nesting_list = nesting_list
        self.head_sizes = head_sizes
        self.h_count = h # number of heads
        self.d_count = len(nesting_list) # number of nesting doll sizes
        self.h_max = head_sizes[-1] # size of largest head
        self.d_max = nesting_list[-1] # size of largest embedding
        
        # creating all of our different attention heads, then storing them in a list for use later
        self.headsList = nn.ModuleList([matryoshkaHead(self.nesting_list, self.head_sizes) for _ in range(self.h_count)])
        
        # the linear projection that combines the outputs of all the heads
        self.weight = nn.Parameter(torch.Tensor(self.h_max * self.h_count, self.d_max)).to(device)
        self.bias = nn.Parameter(torch.Tensor(self.h_max * self.h_count)).to(device)
        
        # Initializing parameters
        nn.init.normal_(self.weight, std=0.02)  
        nn.init.normal_(self.bias, std=0.02)
        
        self.dropout = nn.Dropout(dropout)

    def forwardTuple(self, x):
        """
        input: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        operation: 
            - perform self-attention w each head
                - input to each head: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
                - output from each head: tuple length g with tensors of shape (b,t,h_i) for h_i=head_sizes[i] where h_i = d_i / h
            - then concatenate into tuple of length g with tensors of shape (b,t,h*h_i). here by design h*h_i=d_i but it need not be that way
            - then linearly project each of the g tensors in the tuple back to (b,t,d_i)
        output: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        """
        # let's get the outputs of each attention head
        # list length h of tuples length g of tensors shape (b,t,h_i) for h_i=d_i/h where d_i = nesting_list[i]
        head_outputs = [head(x) for head in self.headsList]

        # now let's reformat our ugly list of tuples into our usual expected tuple length g containing tensors shape (b,t,d_i)
        mid = ()
        for i in range(self.d_count):
            level = [] # where will store the output of each head for this model size
            for j, head in enumerate(head_outputs):
                level.append(head[i]) # this head's output for the d_i layer of the model
            
            # appending the concatenation of all the heads for this d_i layer of the model
            mid += (torch.cat(level, dim=-1),) # tuple length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        # mid is now a length g tuple of tensors shape (b,t,h*h_i)
        
        # now let's do our linear projection, which is not similar to how we did the matryoshkaFeedForward()
        # because we can't just select nested matrices within the primary matrix, we also have to account for the head 
        # concatenation which means skipping throughout and grabbing specific parts from the projection that match up
        #
        # so along the vertical of the matrix we want to iterate through self.nesting_list 
        # and along the horizontal we need to make skips the size of self.h
        # and then from those skips as starting points iteratively slice using self.head_sizes
        # then we concatenate those multiple spliced pieces along the horizontal
        # then we multiply a given output level by its respective projection
        out = ()
        for i, (d_i, h_i) in enumerate(zip(self.nesting_list, self.head_sizes)):
            # h_i is the head size of this iteration
            # j*self.h_max is our skip length
            this_levels_proj_w = torch.cat([self.weight[j*self.h_max:j*self.h_max+h_i, :d_i] for j in range(self.h_count)], dim=0)

            # bias is only one dimension so a bit simpler
            this_levels_proj_b = torch.cat([self.bias[j*self.h_max:j*self.h_max+h_i] for j in range(self.h_count)])

            # select correct level & multiply by weights then add bias
            # and can't forget to dropout
            out += (self.dropout(mid[i]@this_levels_proj_w + this_levels_proj_b),)
            
        return out

    def forwardTensor(self, x):
        """
        input: tensor of shape (b,t,d_i)
        operation: 
            - perform self-attention w each head
                - input to each head: tensor of shape (b,t,d_i)
                - output from each head: tensor of shape (b,t,h_i) where h_i = d_i / h
            - then concatenate the head outputs
            - then linearly project
        output: tensor of shape (b,t,d_i) 
        """
        d_i = x.shape[-1]
        h_i = d_i // self.h_count
        
        head_outputs = torch.cat([head(x, h=self.h_count) for head in self.headsList], dim=-1) # (b,t,h*h_i)

        spliced_projection_w = torch.cat([self.weight[j*self.h_max:j*self.h_max+h_i,:d_i] for j in range(self.h_count)], dim=0)
        spliced_projection_b = torch.cat([self.bias[j*self.h_max:j*self.h_max+h_i] for j in range(self.h_count)])

        return head_outputs @ spliced_projection_w + spliced_projection_b
        
    def forward(self, x):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x)

# LAYERNORM

Layernorm is relatively simple code-wise. However, of note is the fact that during training, the entire full length vector gets normalized whereas during inference we only layernorm the sub-vector we've been given if we're not using the full model size. This probably isn't a big deal since the sub-vectors are still hopefully being drawn from the same distribution during training. However, it wouldn't be surprising if the logits going into the small vectors are characteristically different from the full super-vectors, in which case this certainly might be a difficulty for the model. It might be worth changing this algorithm such that during training sub-vectors get normalized first and then held constant while super-vectors are normalized. something to think about. 

In [13]:
class matryoshkaLayerNorm(nn.Module):
    def __init__(self, nesting_list: List):
        super().__init__()

        self.nesting_list = nesting_list
        self.d_count = len(nesting_list)

        # we need layernorm attributes for each dimension size
        for d_i in nesting_list:
            setattr(self, f"ln_{d_i}", nn.LayerNorm(d_i, elementwise_affine=False))
            # we do elementwise_affine=False to remove the linear projection at the end
            # the linear projection would be counterproductive since we're layernorming in so many different places
        
    def forward(self, x):
        """
        a layernorm module that is dynamic to the input of either a single tensor or a tuple of tensors
        only works if the dimensions in question are in self.nesting_list

        input: either 
        - a tensor with last dimension equal to some value in self.nesting_list
        - a tuple of tensors where the last dimensions of each matches the values in self.nesting_list IN ORDER

        output: either of the above, but normalized
        """
        if type(x) == tuple:
            out = ()
            for i, d_i in enumerate(self.nesting_list):
                out += (getattr(self, f"ln_{d_i}")(x[i]),)
        else:
            d_i = x.shape[-1]
            out = getattr(self, f"ln_{d_i}")(x)
            
        return out

# RESIDUAL BLOCK

not a whole lot to say here other than the fact that i've chosen to pass everything through in the form of a tuple means that this block structure is HELLA inefficient in terms of memory. that's like 6 different copies of the tensors being forced to stay in memory goddamn

In [14]:
class matryoshkaBlock(nn.Module):
    def __init__(self, h, nesting_list: List, dropout):
        super().__init__()
        
        self.nesting_list = nesting_list
        self.head_sizes = [d_i // h for d_i in nesting_list]
        
        self.ln = matryoshkaLayerNorm(nesting_list)
        self.mha = matryoshkaMultiHeadAttention(h, nesting_list, self.head_sizes, dropout) 
        self.ffwd = matryoshkaFeedFoward(nesting_list, dropout)
    
    def forwardTuple(self, x_i):
        """
        input: length g tuple of shape (b,t,d_i) tensors for d_i in nesting_list
        output: length g tuple of shape (b,t,d_i) tensors for d_i in nesting_list
        """
        # please forgive my weird variable naming scheme

        # layernorming the input
        x_iplus1quart = self.ln(x_i)

        # the full multi-head attention
        attn = self.mha(x_iplus1quart)

        # residual connection for every residual state in our list of models
        x_iplus1half = tuple(x_i[j] + attn[j] for j in range(len(self.nesting_list)))

        # another layernorm
        x_iplus3quart = self.ln(x_iplus1half)

        # the feeforward
        ffwd = self.ffwd(x_iplus3quart)

        # the next residual connection for every residual state in our list of models
        x_iplus1 = tuple(x_iplus1half[j] + ffwd[j] for j in range(len(self.nesting_list)))
            
        return x_iplus1

    def forwardTensor(self, x):
        """
        input: tensor of shape (b,t,d_i)
        output: tensor of shape (b,t,d_i)
        """
        return x + self.ffwd(self.ln(x + self.mha(self.ln(x))))
        
    def forward(self, x):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x)

# OUTPUT

this output layer is similar to what you'll find in in [the original paper](https://arxiv.org/abs/2205.13147) except 
1) i use one output matrix instead of multiple
2) that output matrix i use is the transposed token embedding matrix
3) i add the option to perform inference rather than just training, which is something they did do in the [matformer paper](https://arxiv.org/pdf/2310.07707.pdf)

and then the loss function is the exact same

In [15]:
class matryoshkaOutputLayer(nn.Module):
    def __init__(self, embedding, nesting_list: List, num_classes):
        super().__init__()
        self.nesting_list = nesting_list
        self.num_classes = num_classes  # Number of tokens in the vocabulary
        
        self.embedding = embedding  # Store reference to the embedding matrix

        self.norm = matryoshkaLayerNorm(nesting_list)

    def forwardTuple(self, x):
        """
        input: length g tuple of tensors shape (b,t,d_i) for d_i in nesting_list
        operation: layernorm then multiply the final residual state by the transposed embedding matrix to get final logits
        output: length g tuple of tensors shape (b,t,v) where v is token vocabulary length
        """
        normed_logits = self.norm(x)
        normed_embeddings = self.norm(self.embedding).t()
        
        out = ()
        for i, d_i in enumerate(self.nesting_list):
            out += (normed_logits[i] @ normed_embeddings[:d_i,:],) 
            
        return out

    def forwardTensor(self, x):
        """
        input: tensor shape (b,t,d_i)
        operation: layernorm then multiply the final residual state by the transposed embedding matrix to get final logits
        output: tensor shape (b,t,v) where v is token vocabulary length
        """
        d_i = x.shape[-1]
        normed_logits = self.norm(x)
        normed_embeddings = self.norm(self.embedding[:,:d_i]).t()
        return normed_logits @ normed_embeddings
        
    def forward(self, x):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x)

In [16]:
class matryoshkaCEL(nn.Module):
    '''
    Loss function for Matryoshka Representation Learning
    we don't need to create a tensor version of the loss function bc training always involves all nesting levels
    '''
    def __init__(self):
        super().__init__()

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, logits, target):
        """
        input: 
            - logits are a length g tuple each of shape [b batch size, t sequence length, v number of classes]
            - target is a shape [b batch size, t sequence length] tensor of the indices of the correct tokens
        output: a tensor containing a single float
        """
        g = len(logits)
        b,t,v = logits[0].shape

        # Calculate losses for each output and stack them
        losses = torch.stack([self.criterion(logits_i.view(b*t, v), target.view(b*t)) for logits_i in logits])

        return losses.sum()

# THE MODEL

In [17]:
class matryoshkaGPT(nn.Module):
    def __init__(self, nesting_list: List, v, t, h, dropout):
        super().__init__()

        # the list of dimensions we'll be using
        self.nesting_list = nesting_list
        
        # the embedding size of the largest model
        self.d = nesting_list[-1]
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, self.d).to(device)
        
        # simple learned positional encodings rather than sine or RoPE
        # at some point i'm gonna write up a new model using all the stuff like RoPE that I should be using
        self.position_embedding_table = nn.Embedding(t, self.d).to(device)
        self.context_len = t

        # our special implementation of layernorm
        self.ln = matryoshkaLayerNorm(nesting_list)

        # bulk of the beast
        self.blocks = nn.Sequential(*[matryoshkaBlock(h, nesting_list, dropout) for _ in range(l)]) 

        # MATRYOSHKA OUTPUT HEADS
        self.out_heads = matryoshkaOutputLayer(self.token_embedding_table.weight, nesting_list, num_classes=v)
        
        # MATRYOSHKA LOSS
        self.loss = matryoshkaCEL()

    def forward(self, idx, targets=None, desired_d=nesting_list[-1]): 
        # desired_d is the desired dimension to use when performing inference (not used during training)

        # idx and targets are both (b,t) tensor of integers
        b, t = idx.shape
        
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
    
        if targets is None: 
            # if we are NOT training AKA just performing inference
            # send in a single matrix using desired_d
            x_0 = self.ln(tok_emb[:,:,:desired_d]) + pos_emb[:,:desired_d] # (b,t,d) + (t,d) -> (b,t,d)
        else:
            # if we ARE training
            # create tuple of residual states & send it thru
            x_0 = ()
            for d_i in self.nesting_list:
                x_0 += (self.ln(tok_emb[:,:,:d_i]) + pos_emb[:,:d_i],)
            # so in total the for loop gives us (b,t,d) & (t,d) -> g*(b,t,d_i) for d_i in nesting_list

        # most of the model is here
        x_f = self.blocks(x_0)

        # Matryoshka output head
        # self.out_heads includes within it the final layernorm
        logits = self.out_heads(x_f)

        loss = None if targets is None else self.loss(logits, targets) # g*(b,t,d) & (b,t) -> float

        return logits, loss

    def generate(self, idx, max_new_tokens=100, degree=-1):
        """
        input: 
            - idx is (b, ?) tensor of indices from the current context
            - max_new_tokens sets generation length
            - degree determines which model to use. 0 for smallest & -1 for largest
        output: idx is (b,?+max_new_tokens) tensor of indices
        """
        # making sure the user specified an actual existing model. 0 is the smallest model
        assert degree >= -1 & degree < len(nesting_list)

        # getting the actual embedding size of the model we've chosen
        desired_d = self.nesting_list[degree]
        
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.context_len:]
            
            # get the predictions
            logits, loss = self(idx_cond, desired_d=desired_d)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
            
        return idx

# TRAINING

In [18]:
model = matryoshkaGPT(nesting_list, v, t, h, dropout).to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

1595.52 K parameters


In [23]:
start_time = time.time()

# Enable anomaly detection
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 6.8818, val loss 6.9728, time elapsed: 0.23 seconds
step 100: train loss 6.8676, val loss 6.9765, time elapsed: 23.09 seconds
step 200: train loss 6.8080, val loss 6.8736, time elapsed: 45.84 seconds
step 300: train loss 6.8021, val loss 6.9008, time elapsed: 68.68 seconds
step 400: train loss 6.8075, val loss 6.9001, time elapsed: 91.29 seconds
step 500: train loss 6.7299, val loss 6.8127, time elapsed: 113.99 seconds
step 600: train loss 6.7066, val loss 6.8978, time elapsed: 136.54 seconds
step 700: train loss 6.7235, val loss 6.8059, time elapsed: 159.59 seconds
step 800: train loss 6.7163, val loss 6.7379, time elapsed: 182.33 seconds
step 900: train loss 6.6745, val loss 6.8096, time elapsed: 204.89 seconds
step 1000: train loss 6.7024, val loss 6.7507, time elapsed: 227.48 seconds
step 1100: train loss 6.6677, val loss 6.7020, time elapsed: 250.05 seconds
step 1200: train loss 6.6395, val loss 6.6584, time elapsed: 272.63 seconds
step 1300: train loss 6.6142, 

In [25]:
## save the trained model
torch.save(model.state_dict(), f'models/{model.__class__.__name__}_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2-{l2}_min_power{min_power}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Load a saved model

In [None]:
model = matryoshkaGPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/matryoshkaGPT_b16_t64_d2_h4_l8_lr0.0003_drop0.1_l2-0.01_min_power5_2024-02-13|01-55-54.pth'))

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

# Inference

In [24]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?\nDeny thy fathe" # the classic line
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
for d in range(len(nesting_list)):
    print("-----------------model: ", d, "------------------")
    output = model.generate(context_tensor, max_new_tokens=100, degree=d) # -1 for biggest model size
    output_str = decode(output[0].tolist())
    print(output_str)

-----------------model:  0 ------------------
JULIET:
O Romeo, Romeo! wherefore art thou Romeo?
Deny thy fathe, d.

Th,
Cw.
BBUBBt,
Mur pwnd, ming wid wit?
Thifck,
Y:
Thow, wwive,
GUBun,
ICowid, IUCowe ITELOUUB
-----------------model:  1 ------------------
JULIET:
O Romeo, Romeo! wherefore art thou Romeo?
Deny thy fathend in ch cow aws hy;;;:
Sur tht chak'd wn.
Twior y--bixck- d lifurer nd mived,
ANGt,
O, craver hath,
-----------------model:  2 ------------------
JULIET:
O Romeo, Romeo! wherefore art thou Romeo?
Deny thy fathere thener preaks.
Be'd her:
Proute:
K.
War.
A nothinghts;

Give canggeme
Hirt tivesbe-d ZZZUp, JUpea


### obviously given the size of this model it's not very good. oh well
idk about you but it looks to me like the biggest model is the best, as you'd expect. it seems to have a better understanding of the length of a word. also these outputs would prolly be better if i scaled the logits with a temperature but it's late and i'm tired