# TODO
- annotate well so someone reading can understand
- make parameter initializations consistent
- double check that the feeforward doesn't need a skip on the indices
- figure out generate() to use the submodels
- copy the cosine similarity visual exploration tools from `matryoshka_embeddings_gpt`?
- train & save a model

#### !!!! DO NOT RUN THIS FIRST CELL UNLESS YOU HAVE THE SAME VENV PATH ISSUE THAT I DO

In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/ng-video-lecture/venv/lib/python3.11/site-packages')

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
from typing import List
import math
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

# MatryoshkaGPT

the idea here is to have a bunch of tiny models inside the main model like russian nesting dolls

In [3]:
# hyperparameters
b = 4 # how many independent sequences will we process in parallel?
t = 16 # what is the maximum context length for predictions?
max_iters = 10
eval_interval = 2
lr = 3e-4 # learning rate for each backprop step
eval_iters = 20
h = 4 # number of attention heads
l = 4 # number of transormer layers
dropout = 0.1 # % of parameters to ignore every iteration
l2 = 0.01 # multiplier for our L2 norm to encourage sparsity

# embedding aka hidden dimension. this is the largest that th emodel will have
d = 32
power_of_d = int(math.log2(d))
# the smallest power of 2 we'll be considering as a matryoshka embedding
min_power = 4 # Starting from 2^min_power
nesting_list = [2**i for i in range(min_power, int(power_of_d) + 1)]
print(nesting_list)

[16, 32]


In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [7]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [8]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([data[i+1:i+t+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [9]:
# so you can see what the tokenized data looks like
x,y = get_batch('train')
print("x ", x.shape, "\n", x)
print("y ", y.shape, "\n", y)

x  torch.Size([4, 16]) 
 tensor([[ 1, 59, 57,  2,  1, 15, 53, 51, 43,  6,  1, 51, 39, 57, 58, 43],
        [ 0, 21,  1, 39, 51,  1, 58, 46, 43,  1, 45, 56, 43, 39, 58, 43],
        [43, 51, 40, 50, 39, 52, 41, 43,  6,  1, 40, 59, 58,  1, 39,  1],
        [ 1, 58, 53,  1, 46, 47, 57,  1, 51, 39, 48, 43, 57, 58, 63,  8]])
y  torch.Size([4, 16]) 
 tensor([[59, 57,  2,  1, 15, 53, 51, 43,  6,  1, 51, 39, 57, 58, 43, 56],
        [21,  1, 39, 51,  1, 58, 46, 43,  1, 45, 56, 43, 39, 58, 43, 57],
        [51, 40, 50, 39, 52, 41, 43,  6,  1, 40, 59, 58,  1, 39,  1, 41],
        [58, 53,  1, 46, 47, 57,  1, 51, 39, 48, 43, 57, 58, 63,  8,  0]])


In [10]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# FEEDFORWARD

chatGPT's attempt at the matryoshkaFeedForward()

supposedly it's more efficient bc my use of .weight on the linear layers was stupid, and it also has the benefit of bringing back bias vectors. 

Honestly yeah it does look better than mine i think i should try it after i confirm that mine bare-minimum functions for sake of my own pride. for the record tho it got to look at mine before making its edits so it's not like it could've understood the concept from scratch
```
class matryoshkaFeedForward_chatGPT(nn.Module):
    def __init__(self, nesting_list, dropout_rate):
        super().__init__()
        
        # The largest embedding dimension of the model
        self.d = nesting_list[-1]

        # Initialize only the largest weights and biases
        self.weight_w1 = nn.Parameter(torch.Tensor(4 * self.d, self.d))
        self.bias_w1 = nn.Parameter(torch.Tensor(4 * self.d))
        self.weight_w2 = nn.Parameter(torch.Tensor(self.d, 4 * self.d))
        self.bias_w2 = nn.Parameter(torch.Tensor(self.d))

        # Initialize weights and biases
        self.reset_parameters()

        self.nesting_list = nesting_list
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(dropout_rate)

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_w1, a=math.sqrt(5))  # or any other initialization
        nn.init.kaiming_uniform_(self.weight_w2, a=math.sqrt(5))  # or any other initialization
        fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_w1)
        bound1 = 1 / math.sqrt(fan_in1)
        nn.init.uniform_(self.bias_w1, -bound1, bound1)
        fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_w2)
        bound2 = 1 / math.sqrt(fan_in2)
        nn.init.uniform_(self.bias_w2, -bound2, bound2)

    def forward(self, x_0):
        x_f = ()
        for i, d_i in enumerate(self.nesting_list):
            # Subset the weights and biases
            weight_w1_sub = self.weight_w1[:4 * d_i, :d_i]
            bias_w1_sub = self.bias_w1[:4 * d_i]
            weight_w2_sub = self.weight_w2[:d_i, :4 * d_i]
            bias_w2_sub = self.bias_w2[:d_i]

            # Apply the linear transformations using the subset weights and biases
            x = F.linear(x_0[i], weight_w1_sub, bias_w1_sub)
            x = self.relu(x)
            x = F.linear(x, weight_w2_sub, bias_w2_sub)
            x = self.drop(x)
            x_f += (x,)

        return x_f
```

In [113]:
class matryoshkaFeedFoward(nn.Module):
    def __init__(self, nesting_list: List, dropout):
        super().__init__()
        
        # the largest embedding dimension of the model
        self.d = nesting_list[-1]

        # initialize only the largest. we'll subset later during forward()
        self.w1 = nn.Linear(self.d, 4 * self.d).to(device)
        self.w2 = nn.Linear(4 * self.d, self.d).to(device)

        # Initialize only the largest weights and biases
        self.w1 = nn.Parameter(torch.Tensor(self.d, 4 * self.d)) # need to double check correct sizes
        self.b1 = nn.Parameter(torch.Tensor(4 * self.d))
        self.w2 = nn.Parameter(torch.Tensor(4 * self.d, self.d))
        self.b2 = nn.Parameter(torch.Tensor(self.d))

        # Initializing weights
        nn.init.normal_(self.w1, std=0.02)  
        nn.init.normal_(self.b1, std=0.02)
        nn.init.normal_(self.w2, std=0.02)
        nn.init.normal_(self.b2, std=0.02)

        # to be used for iterating in forward()
        self.nesting_list = nesting_list
        
        # the other parts
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(dropout)
        # so dropout might become an issue
        # dropping out 10% of an 8x8 matrix will have a different effect from dropping out 10% of a 1024x1024 one
        # and potentially more importantly, different weights will get dropped out for each nesting doll
        # this may actually be beneficial in terms of the model's generalizability, but maybe it'll be bad idk

                    
    def forward(self, x):
        """
        input: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        output: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        """
        # old
        #return self.drop(self.w2(self.relu(self.w1(x))))
        print("ffwd")
        #print("x: ", x[-1].shape)
        #print("w1: ", self.w1.shape)
        #print("b1: ", self.b1.shape)
        #print("w2: ", self.w2.shape)
        #print("b2: ", self.b2.shape)
        out = ()
        for i, d_i in enumerate(self.nesting_list): # i is int from 0 to g-1 while d_i=nesting_list[i]
            #print(x[i].shape)
            #print((x[i]@self.w1[:d_i,:4*d_i]).shape)
            #print((x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]).shape)
            #print((self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i])).shape)
            #print((self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i,:d_i]).shape)
            #print((self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i,:d_i] + self.b2[:d_i]).shape)
            #print((self.drop(self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i,:d_i] + self.b2[:d_i])).shape)
            out += (self.drop(self.relu(x[i] @ self.w1[:d_i,:4*d_i] + self.b1[:4*d_i]) @ self.w2[:4*d_i,:d_i] + self.b2[:d_i]),)
            #print(f"ffwd out {i}: {out[i].shape}")
        return out

# ATTENTION

In [104]:
class matryoshkaHead(nn.Module):
    def __init__(self, nesting_list: List, head_sizes: List):
        super().__init__()
        
        # the largest embedding dimension of the model
        self.d = nesting_list[-1]
        # the largest head size
        self.h = head_sizes[-1]

        # to be used for iterating in forward()
        self.nesting_list = nesting_list
        self.head_sizes = head_sizes

        # initialize only the largest. we'll subset later during forward()
        #self.key = nn.Linear(self.d, self.h, bias=False)
        self.key = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)
        #self.query = nn.Linear(self.d, self.h, bias=False)
        self.query = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)
        #self.value = nn.Linear(self.d, self.h, bias=False)
        self.value = nn.Parameter(torch.Tensor(self.d, self.h)).to(device)

        # the mask so they only look into the past
        self.register_buffer('tril', torch.tril(torch.ones(t, t))) # mask future timestesps

        # Initialize weights and biases
        self.reset_parameters()

    def reset_parameters(self): # i need to make all my parameter initializations consistent instead of using this
        a = math.sqrt(5)
        nn.init.kaiming_uniform_(self.key, a)  # or any other initialization
        nn.init.kaiming_uniform_(self.query, a)
        nn.init.kaiming_uniform_(self.value, a)

    def forward(self, x_0):
        """
        input: tuple length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        output: tuple length g with tensors of shape (b,t,h_i) for h_i=head_sizes[i] where h_i = d_i / h
        """
        #print("head")
        #print("Wk full: ", self.key.shape) #.weight.shape)
        #print("Wq full: ", self.query.shape)
        #print("Wv full: ", self.value.shape)
        #b,t,d = x.shape
        k,q,v,wei,out = (),(),(),[],()
        for i, (d_i, h_i) in enumerate(zip(self.nesting_list, self.head_sizes)):
            #print(i)
            #print(f"x_0[{i}]: ", x_0[i].shape)
            #print(f"d_i: {d_i} h_i: {h_i}")
            #Wk = self.key[:d_i, :h_i] #.weight[:d_i, :h_i]
            #print("Wk: ", Wk.shape)
            k += (torch.matmul(x_0[i],self.key[:d_i, :h_i]),)
            #Wq = self.query[:d_i, :h_i] #.weight[:d_i, :h_i]
            #print("Wq: ", Wq.shape)
            q += (x_0[i] @ self.query[:d_i, :h_i],)
            #Wv = self.value[:d_i, :h_i] #.weight[:d_i, :h_i]
            #print("Wv: ", Wv.shape)
            v += (x_0[i] @ self.value[:d_i, :h_i],)

            # not sure if this is a bunch of "in-place" operations
            # if i get an error about that then what i gotta do is make it separate variables instead of repeatedly editing wei
            wei.append(q[i] @ k[i].transpose(-2,-1) * k[i].shape[-1]**-0.5)
            wei[i] = wei[i].masked_fill(self.tril[:t,:t] == 0, float('-inf'))
            wei[i] = F.softmax(wei[i],dim=-1)
            
            out += (wei[i]@v[i],)
        #print("out: ", out[0].shape, out[1].shape)
        return out

# MHA

In [105]:
class matryoshkaMultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, nesting_list: List, head_sizes: List, dropout):
        super().__init__()
        
        self.nesting_list = nesting_list
        self.head_sizes = head_sizes
        self.h_count = h # number of heads
        self.d_count = len(nesting_list) # number of nesting doll sizes
        self.h_max = head_sizes[-1] # size of largest head
        self.d_max = nesting_list[-1]

        # thinking maybe i should mess with the forward() to figure out what i need here
        #self.headsDict = nn.ModuleDict()
        #for head_idx in range(h):
            #self.headsDict[f'head_{head_idx}'] = matryoshkaHead(self.nesting_list, self.head_sizes)

        # can you have tuples inside a module list? i hope so
        self.headsList = nn.ModuleList([matryoshkaHead(self.nesting_list, self.head_sizes) for _ in range(self.h_count)])
        
        # can i even use module list if i'm listing tuples rather than tensors? idk prolly not
        # maybe i can create different module lists using setattr() and selections from the outputs of the heads?
        #self.heads = nn.ModuleList([matryoshkaHead(self.nesting_list, self.head_sizes) for _ in range(self.h)])
        
        #self.proj = nn.Linear(head_sizes[-1] * h, nesting_list[-1])
        # the linear projection that combines the outputs of all the heads
        self.weight = nn.Parameter(torch.Tensor(self.h_max * self.h_count, self.d_max)).to(device)
        self.bias = nn.Parameter(torch.Tensor(self.h_max * self.h_count)).to(device)

        # Initialize weights and biases
        self.reset_parameters()
        
        self.dropout = nn.Dropout(dropout)

    def reset_parameters(self): # i need to make all my parameter initializations consistent instead of using this
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))  # or any other initialization
        fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound1 = 1 / math.sqrt(fan_in1)
        nn.init.uniform_(self.bias, -bound1, bound1)
        
    def forward(self, x):
        """
        input: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
            input to each head: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
            output from each head: tuple length g with tensors of shape (b,t,h_i) for h_i=head_sizes[i] where h_i = d_i / h
        output: tuple of length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
        """
        print("mha")
        # let's get the outputs of each attention head
        # list length h of tuples length g of tensors shape (b,t,h_i) for h_i=d_i/h where d_i=nesting_list[i]
        head_outputs = [head(x) for head in self.headsList]

        # now let's reformat our ugly list of tuples into our usual expected tuple legnth g containing tensors shape (b,t,d_i)
        mid = ()
        for i in range(self.d_count):
            level = [] # where will store the output of each head for this size d_i
            for j, head in enumerate(head_outputs):
                level.append(head[i]) # this head's output for the d_i layer of the model
            
            # appending the concatenation of all the heads for this d_i layer of the model
            mid += (torch.cat(level, dim=-1),) # tuple length g with tensors of shape (b,t,d_i) for d_i=nesting_list[i]
            #print(f"mha_before_projection[{i}]: ", mid[i].shape)

        # now let's do our linear projection, which is not similar to how we did the matryoshkaFeedForward()
        # because we can't just select nested matrices within the primary matrix, we also have to account for the head concatenation
        # and this means skipping throughout and grabbing specific parts from the projection that match up
        #
        # so along the vertical of the matrix we want to iterate through self.nesting_list 
        # and along the horizontal we need to make skips the size of self.h the largest head
        # and then from those skips as starting points iteratively slice using self.head_sizes
        # then we concatenate those multiple spliced pieces along the horizontal
        # then we multiply a given output level by its respective projection
        out = ()
        for i, (d_i, h_i) in enumerate(zip(self.nesting_list, self.head_sizes)):
            skip = j*self.h_max
            # h_i is the head size of this iteration
            this_levels_proj_w = torch.cat([self.weight[skip:skip+h_i,:d_i] for j in range(self.h_count)], dim=0)

            # bias is only one dimension so a bit simpler
            this_levels_proj_b = torch.cat([self.bias[skip:skip+h_i] for j in range(self.h_count)])

            # select correct level & multiply by weights then add bias
            # and can't forget to dropout
            out += (self.dropout(mid[i]@this_levels_proj_w + this_levels_proj_b),)
            #print(f"mha_after_projection[{i}]: ", out[i].shape)
            
        return out




        
        #temp = [self.heads_dict[f'head_{i}'](x) for i in range(len(x))]
        
        #attn = ()
        #or h in range(self.h_count):
            #temp2 = []
            #for i in range(len(self.nesting_list)):
                #head_out = temp[h][i]
                #temp2.append(head_out)

            #next = torch.cat(temp2, dim-1)
            #attn += (next,)

        # i think i gotta create `out` & do selections to make these
        #out = [] # do i turn into tuple later or just concatenate?
        #for i, (d_i, h_i) in enumerate(zip(nesting_list, head_sizes)):
        
        #out = torch.cat([head(x) for head in self.heads], dim=-1)
        #out = self.dropout(self.proj(out))

- *these are a bunch of drafts & ChatGPT attempts at MHA*

class matryoshkaMultiHeadAttention(nn.Module):
    def __init__(self, h, nesting_list, head_sizes, dropout=0.1):
        super().__init__()
        
        self.nesting_list = nesting_list
        self.head_sizes = head_sizes
        self.h = h  # number of heads

        # Initialize heads for each combination of granularity and head size
        self.heads = nn.ModuleDict()
        for head_idx in range(h):
            for i, (d_i, h_i) in enumerate(zip(nesting_list, head_sizes)):
                self.heads[f'head_{head_idx}_level_{i}'] = matryoshkaHead([d_i], [h_i])

        # Output transformation layers for each granularity level to ensure output shape consistency
        self.output_transforms = nn.ModuleList([nn.Linear(h_i * h, d_i) for d_i, h_i in zip(nesting_list, head_sizes)])
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_0):
        out = []
        for i, (d_i, h_i) in enumerate(zip(self.nesting_list, self.head_sizes)):
            # Aggregate outputs from all heads for the current granularity level
            head_outputs = []
            for head_idx in range(self.h):
                head = self.heads[f'head_{head_idx}_level_{i}']
                head_output = head((x_0[i],))  # matryoshkaHead expects a tuple input
                head_outputs.append(head_output[0])  # Unpack the single-element tuple

            # Concatenate along the last dimension and apply the output transformation
            combined_output = torch.cat(head_outputs, dim=-1)
            transformed_output = self.output_transforms[i](self.dropout(combined_output))
            out.append(transformed_output)

        # Return a tuple of tensors to maintain consistency with other components
        return tuple(out)


class matryoshkaMultiHeadAttention(nn.Module):
    def __init__(self, h, nesting_list, dropout):
        super().__init__()
        self.heads = nn.ModuleList([AdaptiveHead(max(nesting_list), nesting_list, dropout) for _ in range(h)])
        self.projections = nn.ModuleDict({
            str(d_i): nn.Linear(d_i * h, d_i) for d_i in nesting_list
        })
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_tuple):
        head_outputs = [head(x_tuple) for head in self.heads]  # List of tuples

        # Concatenate outputs from all heads
        concatenated = tuple(torch.cat([head_output[i] for head_output in head_outputs], dim=-1) for i in range(len(x_tuple)))

        # Project concatenated outputs back to original dimensions
        projected = tuple(self.dropout(self.projections[str(x.size(-1))](concatenated[i])) for i, x in enumerate(x_tuple))

        return projected

chatGPT's rough sketch of the whole attention process

```
class AdaptiveHead(nn.Module):
    def __init__(self, max_head_size, nesting_list):
        super().__init__()
        # Initialize weights for the largest dimension
        self.query_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        self.key_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        self.value_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        self.nesting_list = nesting_list
        # Other initializations (dropout, etc.)

    def forward(self, x_tuple):
        outputs = []
        for x, d_i in zip(x_tuple, self.nesting_list):
            # Adjust weights and operations for d_i
            # Compute attention and add to outputs
            outputs.append(adjusted_output)
        return tuple(outputs)

class MatryoshkaMultiHeadAttention(nn.Module):
    def __init__(self, h, nesting_list):
        super().__init__()
        self.heads = nn.ModuleList([AdaptiveHead(max(nesting_list), nesting_list) for _ in range(h)])
        # Projection layers for each d_i in nesting_list
        self.projections = nn.ModuleDict({str(d_i): nn.Linear(d_i * h, d_i) for d_i in nesting_list})

    def forward(self, x_tuple):
        head_outputs = [head(x_tuple) for head in self.heads]
        # Concatenate, project, and return outputs for each d_i
```

then chatGPT's first attempt at the attention heads
```
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaptiveHead(nn.Module):
    def __init__(self, max_head_size, nesting_list, dropout):
        super().__init__()
        self.max_head_size = max_head_size
        self.nesting_list = nesting_list
        self.dropout = nn.Dropout(dropout)

        # Parameters for the largest possible head size
        self.query_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        self.key_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        self.value_weights = nn.Parameter(torch.Tensor(max_head_size, max_head_size))
        nn.init.normal_(self.query_weights, std=0.02)  # Initializing weights as per common practice
        nn.init.normal_(self.key_weights, std=0.02)
        nn.init.normal_(self.value_weights, std=0.02)

        # Pre-compute masks for efficiency
        self.masks = {}
        for d_i in nesting_list:
            self.masks[d_i] = torch.tril(torch.ones(d_i, d_i)).unsqueeze(0)

    def forward(self, x_tuple):
        outputs = []
        for x, d_i in zip(x_tuple, self.nesting_list):
            b, t, _ = x.size()

            # Slice weights according to the current dimension
            query_w = self.query_weights[:d_i, :d_i]
            key_w = self.key_weights[:d_i, :d_i]
            value_w = self.value_weights[:d_i, :d_i]

            # Compute Q, K, V
            Q = torch.matmul(x, query_w)
            K = torch.matmul(x, key_w)
            V = torch.matmul(x, value_w)

            # Scaled Dot-Product Attention
            attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_i ** 0.5)
            if t <= d_i:  # Ensure the mask fits the time dimension
                mask = self.masks[d_i][:, :t, :t].to(x.device)
                attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = self.dropout(attention_probs)

            # Weighted sum of values
            out = torch.matmul(attention_probs, V)
            outputs.append(out)

        return tuple(outputs)
```

and then multi-head attention
```
class MatryoshkaMultiHeadAttention(nn.Module):
    def __init__(self, h, nesting_list, dropout):
        super().__init__()
        self.heads = nn.ModuleList([AdaptiveHead(max(nesting_list), nesting_list, dropout) for _ in range(h)])
        self.projections = nn.ModuleDict({
            str(d_i): nn.Linear(d_i * h, d_i) for d_i in nesting_list
        })
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_tuple):
        head_outputs = [head(x_tuple) for head in self.heads]  # List of tuples

        # Concatenate outputs from all heads
        concatenated = tuple(torch.cat([head_output[i] for head_output in head_outputs], dim=-1) for i in range(len(x_tuple)))

        # Project concatenated outputs back to original dimensions
        projected = tuple(self.dropout(self.projections[str(x.size(-1))](concatenated[i])) for i, x in enumerate(x_tuple))

        return projected
```

# LAYERNORM

In [114]:
class matryoshkaLayerNorm(nn.Module):
    def __init__(self, nesting_list: List):
        super().__init__()

        self.nesting_list = nesting_list
        self.d_count = len(nesting_list)

        # we need layernorm attributes for each dimension size
        for d_i in nesting_list:
            setattr(self, f"ln_{d_i}", nn.LayerNorm(d_i, elementwise_affine=False)) 
        
    def forward(self, x):
        """
        a layernorm module that is dynamic to the input of either a single tensor or a tuple of tensors
        only works if the dimensions in question are in self.nesting_list

        input: either 
        - a tensor with last dimension equal to some value in self.nesting_list
        - a tuple of tensors where the last dimensions of each matches the values in self.nesting_list IN ORDER

        output: either of the above, but normalized

        NOTE: later i might do a weird scheme where it layernorms the smallest embedding dimension first, then holds that constant
        and layernorms all the remaining values in the next sized embedding dimension, and then so on. this might help w stability
        depending on how the rest of the model ends up looking
        """
        print("layernorm")
        if type(x) == tuple:
            out = ()
            for i, d_i in enumerate(self.nesting_list):
                out += (getattr(self, f"ln_{d_i}")(x[i]),)
        elif type(x) == torch.Tensor:
            d = x.shape[-1]
            out = getattr(self, f"ln_{d}")(x)
        else:
            print("ERROR: LAYERNORM NEEDED TUPLE or TENSOR BUT RECEIVED ", type(x))
        return out

# BLOCK

In [115]:
class matryoshkaBlock(nn.Module):
    """
    Transformer block: communication followed by computation
    
    input: length g tuple of shape (b,t,d_i) tensors for d_i in nesting_list
    output: length g tuple of shape (b,t,d_i) tensors for d_i in nesting_list
    """

    def __init__(self, h, nesting_list: List, dropout):
        # d: the biggest embedding dimension, h: the number of heads we'd like
        super().__init__()
        
        self.nesting_list = nesting_list
        self.head_sizes = [d_i // h for d_i in nesting_list] # the second / forces the value to be an int isntead of a float
        
        self.norm = matryoshkaLayerNorm(nesting_list)
        self.mha = matryoshkaMultiHeadAttention(h, nesting_list, self.head_sizes, dropout) 
        self.ffwd = matryoshkaFeedFoward(nesting_list, dropout)
        
        # originals
        #head_size = d // h # the double backslash just makes the output an int instead of float
        #self.ln = nn.LayerNorm(d, elementwise_affine=False)
    
    def forward(self, x_i):
        print("block")
        print("head_sizes: ", self.head_sizes)
        #x = x_i + self.mha(self.ln(x_i))
        #x = x + self.ffwd(self.ln(x))

        x_iplus1quart = self.norm(x_i)
        print("x_iplus1quart: ", x_iplus1quart[0].shape, x_iplus1quart[1].shape)
        
        attn = self.mha(x_iplus1quart)

        x_iplus1half = ()
        for j in range(len(self.nesting_list)):
            x_iplus1half += (x_i[j] + attn[j],)

        x_iplus3quart = self.norm(x_iplus1half)

        ffwd = self.ffwd(x_iplus3quart)

        x_iplus1 = ()
        for j in range(len(self.nesting_list)):
            x_iplus1 += (x_iplus1half[j] + ffwd[j],)

        # i can make this all prettier later by changing every single function to either take in a tensor or a tuple
        # i think at that point i might be able to reuse the code below \/

        #x_iplus1quart, x_iplus1half, x_iplus3quart, x_iplus1 = (), (), (), ()
        #for j, d_j in enumerate(self.nesting_list):
            #x_iplus1quart += (self.norm(x_i[j]),)
            #x_iplus1half += (x_i[j] + self.mha(x_iplus1quart[j]),)
            #x_iplus3quart += (self.norm(x_iplus1half[j]),)
            #x_iplus1 += (x_iplus1half[j]  + self.ffwd(x_iplus3quart[j]),)
        # this is so inefficient it's absurd
            
        return x_iplus1

# OUTPUT

In [116]:
class matryoshkaOutputLayer(nn.Module):
    def __init__(self, embedding, nesting_list: List, num_classes): # , **kwargs # <- not sure why that was an argument
        super().__init__() # matryoshkaOutputLayer, self # <- not sure why those were inside super()
        self.nesting_list = nesting_list
        self.num_classes = num_classes  # Number of classes for classification
        
        self.embedding = embedding  # Store reference to the embedding layer

        self.norm = matryoshkaLayerNorm(nesting_list)
            
        # Initialize layer normalization
        #self.layer_norm = nn.LayerNorm(nesting_list[-1], elementwise_affine=False)

    def forward(self, x):
        """
        the output layer. we've gotta layernorm each size then use the transposed embedding matrix as our linear layer to multiply by
        input: length g tuple of tensors shape (b,t,d_i) for d_i in nesting_list
        output: length g tuple of tensors shape (b,t,v) where v is token vocabulary length
        """
        normed_logits = self.norm(x)
        normed_embeddings = self.norm(self.embedding).t().to(x.device) # can i put this in the __init__???

        out = ()
        for i, d_i in enumerate(self.nesting_list):
            out += (x[i] @ normed_embeddings[:d_i,:],) 
            
        return out

In [117]:
class matryoshkaCEL(nn.Module):
    '''
    Loss function for Matryoshka Representation Learning 
    '''
    def __init__(self, relative_importance: List[float]=None): #, **kwargs
        super().__init__() # matryoshkaCEL, self # not sure why those were in super()
        self.criterion = nn.CrossEntropyLoss()
        
        # relative importance shape: [G]
        # this is optional for if you want to weight them differently
        self.relative_importance = relative_importance

    def forward(self, logits, target):
        # logits are a length g tuple each of shape [b batch size, t sequence length, v number of classes]
        # target shape: [b batch size, t sequence length]
        
        g = len(logits)
        b,t,v = logits[0].shape

        # Calculate losses for each output and stack them
        # might need to do .view() or .reshape() to make sure these go in well
        losses = torch.stack([self.criterion(logits_i.view(b*t, v), target.view(b*t)) for logits_i in logits])

        # Set relative_importance to 1 if not specified
        # I don't think i'm gonna be messing around with this part
        rel_importance = torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance)

        # Apply relative importance weights
        weighted_losses = rel_importance * losses
        return weighted_losses.sum()

# THE MODEL

In [118]:
class matryoshkaGPT(nn.Module):
    def __init__(self, nesting_list: List, v, t, h, dropout):
        super().__init__()

        # the list of dimensions we'll be using
        self.nesting_list = nesting_list
        
        # the largest embedding size
        self.d = nesting_list[-1]
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, self.d).to(device)
        
        # simple learned positional encodings rather than sine or RoPE
        self.position_embedding_table = nn.Embedding(t, self.d).to(device)
        self.context_len = t

        # our special implementation of layernorm
        self.ln = matryoshkaLayerNorm(nesting_list)

        # bulk of the beast
        self.blocks = nn.Sequential(*[matryoshkaBlock(h, nesting_list, dropout) for _ in range(l)]) 

        ### MATRYOSHKA OUTPUT HEADS
        self.out_heads = matryoshkaOutputLayer(self.token_embedding_table.weight, nesting_list, num_classes=v)
        
        ### MATRYOSHKA LOSS
        self.loss = matryoshkaCEL()


        # i commented out the weight initialization bc i don't think i ever actually called this function
        # initialize weights
        #self.apply(self._init_weights)

    #def _init_weights(self, module):
        #if isinstance(module, nn.Linear):
            #torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            #if module.bias is not None:
                #torch.nn.init.zeros_(module.bias)
        #elif isinstance(module, nn.Embedding):
            #torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        print("forward")
        # i think later imma have to separate between full_forward() which is what's here rn and subset_forward() which will
        # be an option that lets us choose which of the russian nesting dolls to use
        
        b, t = idx.shape
        
        # idx and targets are both (b,t) tensor of integers
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        print("pos_emb: ", pos_emb.shape)
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
        print("tok_emb: ", tok_emb.shape)

        #x = self.ln(tok_emb) + pos_emb # (b,t,d) + (t,d) = (b,t,d)
        # our first nested thingy
        x_0 = ()
        for d_i in self.nesting_list:
            # notice how we're layernorming the specific size not the whole thing
            x_0 += (self.ln(tok_emb[...,:d_i]) + pos_emb[...,:d_i],) # (b,t,d) + (t,d) = (b,t,d)
        # so in total the for loop gives us (b,t,d) & (t,d) -> g*(b,t,d_i) for d_i in nesting_list
        print("x_0: ", x_0[0].shape, x_0[1].shape)

        x_f = self.blocks(x_0)

        # Matryoshka output head
        # self.out_heads includes within it the final layernorm
        logits = self.out_heads(x_f)

        loss = None if targets is None else self.loss(logits, targets) # g*(b,t,d) & (b,t) -> float

        return logits, loss

    def generate(self, idx, max_new_tokens):
        ### CURRENTLY THIS FUNCTION DOESN"T SELECT A SPECIFIC NESTING DOLL LAYER TO COMPUTE
        # RATHER IT JUST DOES ALL OF THEM, WHICH IS OBVIOUSLY NOT COMPUTATIONALLY IDEAL
        # and it means that they all keep moving based on the biggest one's output so it's not a true test of the smaller ones
        # I"LL FIX IT LATER. I THINK IT"LL BE ANNOYING TO DO UGH
        # actually i'm thinking maybe all it'll take is splitting up each class's forward() into two versions like in matryoshkaLayerNorm
        """
        input: idx is (b, t) array of indices in the current context
        output: each_idx is a length g list of (b,t) tensors with indices in the current context
        """
        
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.context_len:]
            
            # get the predictions
            logits, loss = self(idx_cond)
            
            # select the largest model
            logits = logits[-1]
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
            
        return idx

# TRAINING

In [119]:
model = matryoshkaGPT(nesting_list, v, t, h, dropout).to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

52.512 K parameters


# -------------------- BOOKMARK -----------------------------------

In [120]:
start_time = time.time()

# Enable anomaly detection
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

forward
pos_emb:  torch.Size([16, 32])
tok_emb:  torch.Size([4, 16, 32])
layernorm
layernorm
x_0:  torch.Size([4, 16, 16]) torch.Size([4, 16, 32])
block
head_sizes:  [4, 8]
layernorm
x_iplus1quart:  torch.Size([4, 16, 16]) torch.Size([4, 16, 32])
mha
layernorm
ffwd
block
head_sizes:  [4, 8]
layernorm
x_iplus1quart:  torch.Size([4, 16, 16]) torch.Size([4, 16, 32])
mha
layernorm
ffwd
block
head_sizes:  [4, 8]
layernorm
x_iplus1quart:  torch.Size([4, 16, 16]) torch.Size([4, 16, 32])
mha
layernorm
ffwd
block
head_sizes:  [4, 8]
layernorm
x_iplus1quart:  torch.Size([4, 16, 16]) torch.Size([4, 16, 32])
mha
layernorm
ffwd
layernorm
layernorm
ERROR: LAYERNORM NEEDED TUPLE or TENSOR BUT RECEIVED  <class 'torch.nn.parameter.Parameter'>


UnboundLocalError: local variable 'out' referenced before assignment

In [160]:
## save the trained model
torch.save(model.state_dict(), f'models/{model.__class__.__name__}_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2-{l2}_min_power{min_power}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Load a saved model

In [None]:
model = matryoshka GPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/GPT_b24_t128_d128_h8_l8_lr0.0003_drop0.2_l2-0.01_2024-01-25|23-31-12.pth'))

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

# Inference

In [49]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
print(context_tensor)

tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30]])


In [None]:
output = model.generate(context_tensor, max_new_tokens=100)
output_str = decode(output[0].tolist())
print(output_str)