# FractaFormer - Model Merging

For this one i'm thinking i do the thing where i train tiny models and then combine them into a FractalFormer

The question is, should i let the marger models have their own dense weight matrices or should they only get to use their upper-right and lower-left quadrants? gonna have to experiment with how matmul looks under the latter

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# used for the tokenizer
import pickle
import os

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union
import numpy as np

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer

We'll be using a very simple tokenizer I previoiusly trained off of the TinyShakespeare dataset that has 128 total tokens and ignores stuff like special tokens & regex. 

In [3]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)
        
# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text, len(encoded_text))

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text, len(decoded_text))

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12] 37
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo? 49


# Config

In [18]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for FractalFormer
    In this case, it's the hyperparameters of the largest model and the hyperparameters that define relative size of smaller models
    """
    # The number of tokens in the vocabulary.
    vocab_size: int = tokenizer.vocab_len
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256
    
    # The number of layers in the model.
    num_hidden_layers: int = 4
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4
    
    # The number of key-value heads for implementing multi-query attention.
    num_key_value_heads: int = 1
    # Ensures that the number of query heads is evenly divisible by the number of KV heads.
    assert num_attention_heads % num_key_value_heads == 0
    
    # The hidden size of the model, AKA the embedding dimension
    hidden_size: int = 128
    # the attention heads need to cleanly divide up the hidden_size of the model for MQA
    assert hidden_size % num_attention_heads == 0

    # how much larger the inner dimension of the MLP should be than the hidden size of the model
    intermediate_multiplier = 4
    # The inner dimension of the MLP part of the decoder layer
    @property
    def intermediate_size(self):
        return self.intermediate_multiplier * self.hidden_size
    
    # The number of head dimensions
    head_dim: int = 32
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6 # this is to promote numerical stability & prevent dividing by 0
    
    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too. 10,000 is the usual

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # the % of neurons to dropout in the MLP
    dropout = 0.1

    ####### for debugging & visualization
    verbose = {
    'RMSNorm': False,
    'MLP': False,
    'MQA': False,
    'Layer': False,
    'OutputLayer': False,
    'FractalLoss': False,
    'FractalFormer': False,
    'Sampler': False,
    'Generate': False
    }

    ####### FractalFormer-specific hyperparameters

    # the number of levels for sub-models to exist on
    levels = 3
    
    # the number of splits to make at a given level
    split = 2 # i don't recommend choosing any value other than 2
    # needs to be divisible by 2 in order to splice cleanly
    assert split % 2 == 0
    # RoPE requires a head dimension of length larger than 1 in order to work
    assert head_dim // (split * (levels-1)) > 1
    # really though you shouldn't be getting anywhere near that small of a head dimension even at the lowest level, that'd be useless

    @property
    def model_count(self):
        return [self.split**i for i in range(self.levels)]

    @property
    def model_dim_list(self):
        return [self.hidden_size // (self.split**i) for i in range(self.levels)]

    @property
    def head_dim_list(self):
        return [self.head_dim // (self.split**i) for i in range(self.levels)]

configs = {0: Config()}

print("single large model -> hierarchy of many smaller models inside")
print(f"model_count: {configs[0].model_count}")
print(f"model_dim_list: {configs[0].model_dim_list}")
print(f"head_dim_list: {configs[0].head_dim_list}")
print(configs[0])

single large model -> hierarchy of many smaller models inside
model_count: [1, 2, 4]
model_dim_list: [128, 64, 32]
head_dim_list: [32, 16, 8]
Config(vocab_size=128, max_position_embeddings=256, num_hidden_layers=4, num_attention_heads=4, num_key_value_heads=1, hidden_size=128, head_dim=32, rms_norm_eps=1e-06)


### defining the smaller models' configs
so that we can train the smaller models of that size first and then merge them all later

In [19]:
for i in range(1, configs[0].levels):
    # Create a new Config instance for the smaller configuration
    configs[i] = dataclasses.replace(configs[i-1])

    # adjust levels of new config accordingly
    configs[i].levels = configs[i-1].levels - 1

    # Update the hidden_size and other dependent properties
    configs[i].hidden_size = configs[i-1].hidden_size // configs[i-1].split
    configs[i].head_dim = configs[i-1].head_dim // configs[i-1].split

    # Ensure hidden_size is divisible by the number of attention heads
    assert configs[i].hidden_size % configs[i].num_attention_heads == 0, "hidden_size must be divisible by num_attention_heads"

    # Ensure that the new head_dim values are valid
    assert all(hd > 1 for hd in configs[i].head_dim_list), "All head dimensions must be greater than 1"

    print(configs[i])
    print(configs[i].model_count)
    print(configs[i].model_dim_list)
    print(configs[i].head_dim_list)

Config(vocab_size=128, max_position_embeddings=256, num_hidden_layers=4, num_attention_heads=4, num_key_value_heads=1, hidden_size=64, head_dim=16, rms_norm_eps=1e-06)
[1, 2]
[64, 32]
[16, 8]
Config(vocab_size=128, max_position_embeddings=256, num_hidden_layers=4, num_attention_heads=4, num_key_value_heads=1, hidden_size=32, head_dim=8, rms_norm_eps=1e-06)
[1]
[32]
[8]


Now that i've got relevant configs for each model size, i need to redefine all my functions such that i can train the smaller models first, concatenate them, and then train the bigger models on top with the smaller models frozen. Let's start just by splitting up the dataset into relevant portions. In reality training a legit model i'd want to be using different datasets (or doing different finetunings of the same base model) but for our experiments here it'll be easier to just split up TinyShakespeare for now.

In [20]:
l = len(text)
datasets = ((text,),)
for i in range(1, configs[0].levels):
    l = l // configs[i-1].split
    datasets += (tuple([text[l*j:l*(j+1)] for j in range(configs[i-1].split**i)]),)

print(0, 0, len(datasets[0][0]), datasets[0][0][:200])
print("---------")
print(1, 0, len(datasets[1][0]), datasets[1][0][:200])
print("---------")
print(1, 1, len(datasets[1][1]), datasets[1][1][:200])
print("---------")
print(2, 0, len(datasets[2][0]), datasets[2][0][:200])
print("---------")
print(2, 1, len(datasets[2][1]), datasets[2][1][:200])
print("---------")
print(2, 2, len(datasets[2][2]), datasets[2][2][:200])
print("---------")
print(2, 3, len(datasets[2][3]), datasets[2][3][:200])

0 0 1115394 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
---------
1 0 557697 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
---------
1 1 557697  flattering truth of sleep,
My dreams presage some joyful news at hand:
My bosom's lord sits lightly in his throne;
And all this day an unaccustom'd spirit
Lifts me above the ground with cheerful thou
---------
2 0 278848 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
---------
2 1 278848  how to curse.

QUEEN ELIZABETH:
My words are dull; O, quicken them with thine!

QUEEN MARGARET:
Thy

In [23]:
# should the bigger models get their own full weight matrices or just use the quadrants?
class CustomWeightedLayer(nn.Module):
    def __init__(self, m, n):
        super(CustomWeightedLayer, self).__init__()
        self.m, self.n = m, n
        self.weights = nn.Parameter(torch.Tensor(m, n))
        self.reset_parameters()
        
    def reset_parameters(self):
        # Initialize weights, for example, using Xavier initialization
        nn.init.xavier_uniform_(self.weights)
        
        # Create and apply the mask
        mask = self.create_mask(self.m, self.n)
        self.weights.data *= mask

        # Apply mask to zero out gradients where the mask is zero
        self.weights.register_hook(lambda grad: grad * mask)

    def create_mask(self, m, n):
        mask = torch.ones(m, n)
        m_mid = m // 2
        n_mid = n // 2

        # Adjust the mask based on the matrix dimensions
        mask[:m_mid, :n_mid] = 0  # Upper-left quadrant
        mask[-(m - m_mid):, -(n - n_mid):] = 0  # Lower-right quadrant
        return mask

    def forward(self, x):
        return torch.mm(x, self.weights)

# Example Usage
m, n = 2, 4  # Example dimensions of the weight matrix
layer = CustomWeightedLayer(m, n)
print("Weight Matrix:\n", layer.weights)

Weight Matrix:
 Parameter containing:
tensor([[ 0.0000,  0.0000, -0.3292,  0.1937],
        [-0.9672, -0.2706, -0.0000,  0.0000]], requires_grad=True)


In [24]:
x = torch.randn(3,2)
print(x)

tensor([[-0.4619, -2.6933],
        [ 0.8262, -0.3242],
        [-0.3523,  0.4842]])


In [26]:
layer(x)

tensor([[ 2.6050,  0.7288,  0.1521, -0.0895],
        [ 0.3136,  0.0877, -0.2720,  0.1601],
        [-0.4683, -0.1310,  0.1160, -0.0683]], grad_fn=<MmBackward0>)

# Multi-Layer Perceptron


<p align="center">
<img src="./images/ffwd.jpeg" width="512"/>
</p>

In [8]:
class MLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout: float = 0.1,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
            dropout (float): the dropout rate to use during training in forwardTuple()
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        assert intermediate_size % hidden_size == 0
        self.intermediate_multiplier = intermediate_size // hidden_size

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        self.Wgate = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.Bgate = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        self.Wup = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.Bup = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        self.Wdown = nn.Parameter(torch.Tensor(intermediate_size, hidden_size))
        self.Bdown = nn.Parameter(torch.Tensor(hidden_size))

        # Initialize weights with uniform distribution
        # For gate & up, where in_features is hidden_size
        limit_gateup = 1 / np.sqrt(hidden_size)
        nn.init.uniform_(self.Wgate, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Bgate, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Wup, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Bup, -limit_gateup, limit_gateup)
        
        # For down, where in_features is intermediate_size
        limit_down = 1 / np.sqrt(intermediate_size)
        nn.init.uniform_(self.Wdown, -limit_down, limit_down)
        nn.init.uniform_(self.Bdown, -limit_down, limit_down)
        
        # defining our dropout for training in forwardTuple()
        self.drop = nn.Dropout(dropout)

    def forwardTensor(self, x, model:int=0):
        """
        Defines the forward pass of the MLP module during inference.

        Parameters:
            x (Tensor): The input tensor to the MLP. 
                        shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        global verbose
        if verbose: 
            print("------------- MLP.forwardTensor() ------------")
            print(f"x: {x.shape}\n{x}")
            
        # figuring out how we should do our splicing
        d_dim = x.shape[-1]
        d_skip = model * d_dim
        i_dim = d_dim * self.intermediate_multiplier
        i_skip = model * i_dim
        if verbose: 
            print(f"d_dim: {d_dim}")
            print(f"d_skip: {d_skip}")
            print(f"i_dim: {i_dim}")
            print(f"i_skip: {i_skip}")
        
        # Applies linear transformation for gating.
        Wgate = self.Wgate[d_skip:d_skip + d_dim, i_skip:i_skip + i_dim]
        Bgate = self.Bgate[i_skip:i_skip + i_dim]
        Xgate = x @ Wgate + Bgate
        if verbose: 
            print(f"Wgate: {self.Wgate.shape}\n{self.Wgate}")
            print(f"Wgate spliced: {Wgate.shape}\n{Wgate}")
            print(f"Bgate: {self.Bgate.shape}\n{self.Bgate}")
            print(f"Bgate spliced: {Bgate.shape}\n{Bgate}")
            print(f"Xgate: {Xgate.shape}\n{Xgate}")

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        Xgate = F.gelu(Xgate)
        if verbose: print(f"GeLU'ed Xgate: {Xgate.shape}\n{Xgate}")

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        Wup = self.Wup[d_skip:d_skip + d_dim, i_skip:i_skip + i_dim]
        Bup = self.Bup[i_skip:i_skip + i_dim]
        Xup = x @ Wup + Bup
        if verbose: 
            print(f"Wup: {self.Wup.shape}\n{self.Wup}")
            print(f"Wup spliced: {Wup.shape}\n{Wup}")
            print(f"Bup: {self.Bup.shape}\n{self.Bup}")
            print(f"Bup spliced: {Bup.shape}\n{Bup}")
            print(f"Xup: {Xup.shape}\n{Xup}")

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        Xfuse = Xgate * Xup
        if verbose: print(f"Xfuse: {Xfuse.shape}\n{Xfuse}")

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        Wdown = self.Wdown[i_skip:i_skip + i_dim, d_skip:d_skip + d_dim]
        Bdown = self.Bdown[d_skip:d_skip + d_dim]
        outputs = Xfuse @ Wdown + Bdown
        if verbose: 
            print(f"Wdown: {self.Wdown.shape}\n{self.Wdown}")
            print(f"Wdown spliced: {Wdown.shape}\n{Wdown}")
            print(f"Bdown: {self.Bdown.shape}\n{self.Bdown}")
            print(f"Bdown spliced: {Bdown.shape}\n{Bdown}")
            print(f"outputs: {outputs.shape}\n{outputs}") 
            print("------------- END MLP.forwardTensor() ------------")

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

    def forwardTuple(self, x, drop_bool: bool = True):
        """
        Defines the forward pass of the MLP module during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors to the MLP. 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the GeGLU gating mechanism and the MLP transformations.
        """
        global verbose
        if verbose: 
            print("------------- MLP.forwardTuple() ------------")
            print(f"x: {x}")

        # if we had sent through the config we could've just grabbed these values from there but too late now
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")
        
        out = ()
        for i in range(num_levels):
            if verbose: print(f"i: {i}")
            
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"j: {j}")

                output = self.forwardTensor(x[i][j], model=j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                    
                out_lvl += (self.drop(output),) if drop_bool else (output,)

            # pretty sure i have to save & store everything without overwriting to prevent in-place arguments. so annoying
            if verbose: print(f"out_lvl: {out_lvl}")
            out += (out_lvl,)
        
        if verbose:
            print(f"out: {out}")
            print("------------- END MLP.forwardTuple() ------------")
        return out
        
    def forward(self, x, model=0, drop_bool = True):
        train = True if type(x) == tuple else False
        if verbose: print(f"---------- MLP Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x, drop_bool) if train else self.forwardTensor(x, model)

The following two cells are designed to help you comprehend what's happening. If you walk through every single print statement and follow along even down to watching what happens to each weight, you'll be able to clearly see what's happening with the odd splicing behavior. In order to make this somewhat feasible, I've set very small matrices for these examples. However I will admit it is still inevitably a pain, which is why I included the drawings above.

In [9]:
# Testing our MLP's forwardTensor()
verbose = True
print("--------- Micro Hyperparameters -------")
hold = config.hidden_size
config.hidden_size = 4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,4)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold
print("model_count: ", config.model_count)

# clear up memory
del hold, x, y, mlp

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 4])
tensor([[[0.9590, 0.9053, 0.5484, 0.0616],
         [0.3631, 0.7774, 0.6103, 0.4077]]])
---------- MLP Input: torch.Tensor ------------
------------- MLP.forwardTensor() ------------
x: torch.Size([1, 2, 4])
tensor([[[0.9590, 0.9053, 0.5484, 0.0616],
         [0.3631, 0.7774, 0.6103, 0.4077]]])
d_dim: 4
d_skip: 0
i_dim: 8
i_skip: 0
Wgate: torch.Size([4, 8])
Parameter containing:
tensor([[-0.3171,  0.0909,  0.0684, -0.1795, -0.0728, -0.2897,  0.2435,  0.3105],
        [-0.3894, -0.1471,  0.2837,  0.4253,  0.4213,  0.4507,  0.3852,  0.1857],
        [-0.4413,  0.4736,  0.2402, -0.0791,  0.2685, -0.2349,  0.1284,  0.1757],
        [-0.3339,  0.3974,  0.1393, -0.0508,  0.0225, -0.4128, -0.1585, -0.2609]],
       requires_grad=True)
Wgate spliced: torch.Size([4, 8])
tensor([[-0.3171,  0.090

In [10]:
# Testing our MLP's forwardTuple()
verbose = True
print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.levels
config.hidden_size = 4
config.levels = 2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

mlp = MLP(4,8)
x = ((torch.randn((1,2,4)),),
     (torch.randn((1,2,2)),torch.randn((1,2,2)))
    )
print(f"x: {x}")
out = mlp(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, x, out, mlp

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
x: ((tensor([[[-1.2954, -0.2177,  0.5420,  0.8073],
         [ 0.7121,  0.8843,  0.1992,  1.4276]]]),), (tensor([[[ 0.5660,  0.3989],
         [ 0.2794, -0.2898]]]), tensor([[[-0.7002, -1.2412],
         [-0.5178, -0.3491]]])))
---------- MLP Input: Tuple ------------
------------- MLP.forwardTuple() ------------
x: ((tensor([[[-1.2954, -0.2177,  0.5420,  0.8073],
         [ 0.7121,  0.8843,  0.1992,  1.4276]]]),), (tensor([[[ 0.5660,  0.3989],
         [ 0.2794, -0.2898]]]), tensor([[[-0.7002, -1.2412],
         [-0.5178, -0.3491]]])))
num_levels: 2
models_per_level: [1, 2]
i: 0
j: 0
------------- MLP.forwardTensor() ------------
x: torch.Size([1, 2, 4])
tensor([[[-1.2954, -0.2177,  0.5420,  0.8073],
         [ 0.7121,  0.8843,  0.1992,  1.4276]]])
d_dim: 4
d_skip: 0
i_dim: 8
i_skip: 0
Wgate: torch.Size([4, 8])
Parameter containing:
tensor([[-0.0606, -0.1885,  0.0784, -0.0998, -0.4402, -0.0970,  0.095

# Attention

To subset the attention heads, we have to not only splice according to the model's embedding dimension but also take into account new smaller head sizes and how they're spaced throughout the matrix. I'm assuming you know how self-attention works well enough to look at this weight matrix and get the idea

<p align="center">
<img src="./images/sa.jpeg" width="512"/>
</p>

then we've gotta concatenate the outputs of each head

<p align="center">
<img src="./images/mha_concat.jpeg" width="512"/>
</p>

and after that linearly project them

<p align="center">
<img src="./images/mha_proj.jpeg" width="512"/>
</p>

this is the place where our splicing gets conceptually annoying. instead of just grabbing the matrix in the upper corner, because of the way attention head output concatenation works we actually need to skip over certain parts of the linear projection matrix and then concatenate them together in order to use them. Here's an example of what the matrix multiplication looks like. on the left is a simplified version of the concatenated attention heads where i just showed it as a matrix rather than a tensor, and then on the right is the actual projection matrix. notice how the numbers in the pink output matrix look similar to the first column of the purple output matrix with a positive number, its negative, and then a smaller positive number; that's the self-similarity in action. the yellow arrows point to the parts that get skipped over. obviously this would look a lot uglier with bigger matrices & incorporating the blue/green layer

<p align="center">
<img src="./images/mha_proj_matmul.jpeg" width="512"/>
</p>

In [11]:
class MultiQueryAttention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim
        
        # Initialize our learnable matrices
        # the linear projection layer for queries, keys, and values
        # no real reason why we're creating one matrix instead of separate ones. cleaner model summary view?
        self.Wqkv = nn.Parameter(torch.Tensor(self.hidden_size,
                                              (self.num_heads + 2 * self.num_kv_heads) * self.head_dim))
        # the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        self.Wo = nn.Parameter(torch.Tensor(self.num_heads * self.head_dim, self.hidden_size))
        
        # Initialize weights with uniform distribution
        # For qkv_proj, where in_features is hidden_size
        limit_Wqkv = 1 / np.sqrt(self.hidden_size)
        nn.init.uniform_(self.Wqkv, -limit_Wqkv, limit_Wqkv)
        # for o_proj, where in_features is self.num_heads * self.head_dim
        limit_Wo = 1 / np.sqrt(self.num_heads * self.head_dim)
        nn.init.uniform_(self.Wo, -limit_Wo, limit_Wo)
        
        # for our attention mask we'll use very large negative values to prevent attending to certain tokens
        mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),
                                 -2.3819763e38).to(torch.float)
        # then we'll replace the lower triangular ones with 0's to allow attention to see past tokens
        mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
        # to define self.mask as a tensor that shouldn't undergo gradient descent
        self.register_buffer('mask', mask)
        
        # defining our dropout
        self.drop = nn.Dropout(config.dropout)

    def forwardTensor(self,
                      x: torch.Tensor,
                      model: int = 0,
                     ) -> torch.Tensor:
        """
        Inputs:
            x (torch.Tensor): Te input tensor to the attention mechanism.
                        shape (batch_size, input_len, hidden_size)
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice
        
        Returns:
            Tensor: The output tensor after applying the attention mechanism
        """
        global verbose
        if verbose: print("----------------- MultiQueryAttention.forwardTensor() --------------------")
        
        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        x_shape = x.shape
        assert len(x_shape) == 3
        if verbose: print(f"x shape: {x_shape}")

        # Extracts input sequence length and embedding dimension length from the hidden states tensor.
        batch_size, input_len, d_dim = x_shape
        
        # figuring out how we should do our splicing
        # first along the embedding dimension
        d_skip = model * d_dim  # the size of our skip along the model's embedding dimension
        if verbose: print(f"d_skip: {d_skip}")
        
        # then for splicing along the head sizes dimension
        index = config.model_dim_list.index(d_dim)
        models_in_this_level = config.model_count[index] # how many models are in this level
        h_dim = config.head_dim_list[index] # the head dimension size of this model in this level
        h_skip = model * h_dim # the size of our skip along the head dimension
        if verbose: 
            print(f"models_in_this_level: {models_in_this_level}")
            print(f"h_dim: {h_dim}")
            print(f"h_skip: {h_skip}")

        # Splits the Wqkv tensor into separate tensors for queries, keys, and values based on their respective sizes.
        if verbose: print(f"self.Wqkv: {self.Wqkv.shape}\n{self.Wqkv}")
        Wq, Wk, Wv = self.Wqkv.split([self.q_size,
                                      self.kv_size,
                                      self.kv_size],dim=-1)
        if verbose: 
            print(f"Wq: {Wq.shape}\n{Wq}")
            print(f"Wk: {Wk.shape}\n{Wk}")
            print(f"Wv: {Wv.shape}\n{Wv}")
        
        # splicing to get our correct weight matrices for each respective head
        # d_dim is relatively self-explanatory
        # i*self.head_dim is bc we initialized one single q, k, and v matrix for all heads so we have to
        # iterate through said matrix to get to the correct head
        Wq = torch.cat([Wq[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_heads)], dim=1)
        Wk = torch.cat([Wk[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=1)
        Wv = torch.cat([Wv[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=1)
        if verbose:
            print(f"Wq spliced: {Wq.shape}\n{Wq}")
            print(f"Wk spliced: {Wk.shape}\n{Wk}")
            print(f"Wv spliced: {Wv.shape}\n{Wv}")
        
        # this needs to be size (d_dim, (self.num_heads + 2 * self.num_kv_heads) * h_dim) aka (32,24)
        # recombine the spliced Wq Wk and Wv. Now they're the right size for matmul against x
        Wqkv_spliced = torch.cat((Wq, Wk, Wv), dim=-1)
        if verbose:
            print(f"Wqkv_spliced: {Wqkv_spliced.shape}\n{Wqkv_spliced}")
        

        # finally we can project x to get our queries, keys and values
        xqkv = x @ Wqkv_spliced
        if verbose: print(f"xqkv: {xqkv.shape}\n{xqkv}")
            
        # Splits the combined Xqkv tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = xqkv.split([self.q_size // models_in_this_level,
                                 self.kv_size // models_in_this_level,
                                 self.kv_size // models_in_this_level],dim=-1)
        if verbose:
            print(f"xq: {xq.shape}\n{xq}")
            print(f"xk: {xk.shape}\n{xk}")
            print(f"xv: {xv.shape}\n{xv}")

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, input_len, self.num_heads, h_dim)#, self.head_dim)
        xk = xk.view(batch_size, input_len, self.num_kv_heads, h_dim)#, self.head_dim)
        xv = xv.view(batch_size, input_len, self.num_kv_heads, h_dim)#, self.head_dim)
        if verbose:
            print(f"xq reshaped: {xq.shape}\n{xq}")
            print(f"xk reshaped: {xk.shape}\n{xk}")
            print(f"xv reshaped: {xv.shape}\n{xv}")

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, h_dim, self.theta)#self.head_dim
        xk = apply_rotary_emb(xk, h_dim, self.theta)#self.head_dim
        # is the differring head dimension going to mess with RoPE? Not sure
        if verbose:
            print(f"rotated xq: {xq.shape}\n{xq}")
            print(f"rotated xk: {xk.shape}\n{xk}")

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
            if verbose:
                print(f"repeat_interleaved xk: {xk.shape}\n{xk}")
                print(f"repeat_interleaved xv: {xv.shape}\n{xv}")

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)
        if verbose:
            print(f"transposed xq: {q.shape}\n{q}")
            print(f"transposed xk: {k.shape}\n{k}")
            print(f"transposed xv: {v.shape}\n{v}")

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * h_dim**-0.5#self.scaling
        if verbose: print(f"scores: {scores.shape}\n{scores}")
        
        # Applies the lower-triangular mask to the attention scores
        if verbose: print(f"mask: {self.mask[...,:input_len, :input_len].shape}\n{self.mask[...,:input_len, :input_len]}")
        scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        if verbose: print(f"masked scores: {scores.shape}\n{scores}")

        # Applies softmax to the scores to obtain attention probabilities
        scores = F.softmax(scores, dim=-1)
        if verbose: print(f"softmaxed scores: {scores.shape}\n{scores}")
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        attention = torch.matmul(scores, v)
        if verbose: print(f"attention: {attention.shape}\n{attention}")

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
        if verbose: print(f"reshaped attention: {attention.shape}\n{attention}")

        # Splice the output projection
        Wo = torch.cat([self.Wo[i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim,\
                                d_skip:d_skip + d_dim,\
                               ] for i in range(self.num_heads)], dim=0)
        if verbose: 
            print(f"self.Wo: {self.Wo.shape}\n{self.Wo}")
            print(f"spliced Wo: {Wo.shape}\n{Wo}")
            
        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = attention @ Wo
        if verbose: 
            print(f"projected output: {output.shape}\n{output}")
            print("----------------- END MultiQueryAttention.forwardTensor() --------------------")
            
        return output

    def forwardTuple(self,
                     x: Tuple[Tuple[torch.Tensor]],
                     drop_bool: bool = True
                    ) -> torch.Tensor:
        """
        Defines the forward pass of the Attention module during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the MQA mechanism
        """
        global verbose
        if verbose: 
            print("------------- MultiQueryAttention.forwardTuple() ------------")
            print(f"x: {x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model=j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (self.drop(output),) if drop_bool else (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END MultiQueryAttention.forwardTuple() ------------")

        return out
        
    def forward(self, x, model=0, drop_bool = True):
        train = True if type(x) == tuple else False
        if verbose: print(f"---------- Attention Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x, drop_bool) if train else self.forwardTensor(x, model)

And here are the detailed print statements for the attention mechanism

In [12]:
# Testing our Attention's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,8)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, att, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.7055, 0.8466, 0.5308, 0.1871, 0.2744, 0.8346, 0.3970, 0.4018],
         [0.1970, 0.3540, 0.4749, 0.4411, 0.6331, 0.9159, 0.9953, 0.9859],
         [0.5595, 0.4854, 0.6228, 0.1502, 0.8666, 0.9870, 0.2195, 0.0108]]])
---------- Attention Input: torch.Tensor ------------
----------------- MultiQueryAttention.forwardTensor() --------------------
x shape: torch.Size([1, 3, 8])
d_skip: 0
models_in_this_level: 1
h_dim: 4
h_skip: 0
self.Wqkv: torch.Size([8, 16])
Parameter containing:
tensor([[ 0.0084,  0.0866, -0.2544,  0.0649, -0.0546,  0.0121, -0.1667, -0.0016,
         -0.3344, -0.3115, -0.0106,  0.3216, -0.2331,  0.0319,  0.0506,  0.3433],
        [-0.2864,  0.2378,  0.3431, -0.3004,  0.0780,  0.3157, -0.1521,  0.1160,
         

In [13]:
# Testing our Attention's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

att = MultiQueryAttention(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = att(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, hold5, x, att, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[ 1.2369,  0.7309,  2.6780, -1.0167, -0.7010,  0.2938, -0.0129,
           0.9855],
         [ 0.9342,  0.2813, -1.8521,  1.3869, -0.4117, -0.8592, -0.1207,
           1.4316],
         [ 0.2269,  1.2348,  0.3012,  1.9918,  0.5064,  0.1692, -0.0296,
          -0.5073]]]),), (tensor([[[-0.2603, -0.4997, -0.3798,  1.5055],
         [ 0.4200, -0.1175,  0.0111, -2.7566],
         [ 0.5740, -0.1371, -0.4050,  0.2510]]]), tensor([[[ 0.2608, -2.1316, -0.3448,  1.6580],
         [-1.6757, -1.1787, -0.4115,  0.2764],
         [ 0.9920,  0.2771, -0.2763, -0.0181]]])))
---------- Attention Input: Tuple ------------
------------- MultiQueryAttention.forwardTuple() ------------
x: ((tensor([[[ 1.2369,  0.7309,  2.6780, -1.0167, -0.7010,  0.2938, -0.0129,
           0.9855],
         [ 0.9342,  0.2813, -1.8521,  1.3869, -0.4117, -0.8592, -0.1207,
           1.4316],
         [ 0.

Wk spliced: torch.Size([4, 2])
tensor([[ 0.2420, -0.1638],
        [ 0.0628, -0.0496],
        [-0.0106,  0.1647],
        [ 0.0380,  0.1159]], grad_fn=<CatBackward0>)
Wv spliced: torch.Size([4, 2])
tensor([[ 0.2726,  0.0764],
        [-0.2457,  0.3075],
        [ 0.0943,  0.0897],
        [ 0.3341, -0.1543]], grad_fn=<CatBackward0>)
Wqkv_spliced: torch.Size([4, 8])
tensor([[ 0.2527, -0.2941, -0.0427, -0.3478,  0.2420, -0.1638,  0.2726,  0.0764],
        [ 0.3438,  0.2528, -0.0657, -0.0009,  0.0628, -0.0496, -0.2457,  0.3075],
        [ 0.2754,  0.0175,  0.0330,  0.1987, -0.0106,  0.1647,  0.0943,  0.0897],
        [-0.3143,  0.0516, -0.2421,  0.1039,  0.0380,  0.1159,  0.3341, -0.1543]],
       grad_fn=<CatBackward0>)
xqkv: torch.Size([1, 3, 8])
tensor([[[-0.8154,  0.0212, -0.3331,  0.1719, -0.0332,  0.1793,  0.5189,
          -0.4398],
         [ 0.9353, -0.2952,  0.6575, -0.4301, -0.0105, -0.3805, -0.7765,
           0.4222],
         [-0.0926, -0.1976, -0.0896, -0.2539,  0.1441, -0

# Layer

nothing too interesting here besides the absurd amount of memory we're probably taking up with these tuples

In [14]:
class Layer(nn.Module):
    """
    A decoder layer that integrates the MultiQueryAttention and MLP. It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(self, config: Config):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = MultiQueryAttention(config)
        
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = MLP(
            # the hidden dimension of the model
            hidden_size = config.hidden_size,
            # the number of nodes in the center of the two feedforward layers
            intermediate_size = config.intermediate_size,
            # the % of neurons to set to 0 during training
            dropout = config.dropout,
        )
        
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forwardTensor(self,
                # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
                x: torch.Tensor,
                model: int = 0,
                drop_bool: bool = False
                ) -> torch.Tensor:
        global verbose
        if verbose: print("----------------- Layer.forwardTensor() --------------------")
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual_connection = x
        # Normalizes the input before processing by the attention mechanism.
        x = self.input_layernorm(x, model)
        # Processes the normalized input through the GemmaAttention mechanism
        x = self.self_attn(x, model, drop_bool)
        # The aforementioned residual connection
        x = residual_connection + x
        if verbose: print(f"x in layer after MQA & resid connection and before MLP:\n{x}")

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual_connection = x
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        x = self.post_attention_layernorm(x, model)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        x = self.mlp(x, model, drop_bool)
        # Another residual connection
        x = residual_connection + x
        if verbose: 
            print(f"layer's final residual state:\n{x}")
            print("----------------- END Layer.forwardTensor() --------------------")

        return x

    def forwardTuple(self,
                     x: Tuple[Tuple[torch.Tensor]],
                    ) -> torch.Tensor:
        """
        Defines the forward pass of a decoder layer during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the decoder layer
        """
        global verbose
        if verbose: 
            print("------------- Layer.forwardTuple() ------------")
            print(f"x:\n{x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model = j, drop_bool = True)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END Layer.forwardTuple() ------------")

        return out
        
    def forward(self, x, model=0):
        train = True if type(x) == tuple else False
        if verbose: print(f"---------- Layer Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x) if train else self.forwardTensor(x, model)

In [15]:
# Testing our Layer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = Layer(config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, layer, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.3649, 0.7453, 0.5253, 0.7963, 0.7138, 0.5880, 0.2640, 0.4247],
         [0.2535, 0.5296, 0.5828, 0.2522, 0.4695, 0.7472, 0.0419, 0.7992],
         [0.4768, 0.7702, 0.6017, 0.1678, 0.9640, 0.4123, 0.5156, 0.4232]]])
---------- Layer Input: torch.Tensor ------------
----------------- Layer.forwardTensor() --------------------
------------- RMSNorm.forward() ------------
x: torch.Size([1, 3, 8])
tensor([[[0.3649, 0.7453, 0.5253, 0.7963, 0.7138, 0.5880, 0.2640, 0.4247],
         [0.2535, 0.5296, 0.5828, 0.2522, 0.4695, 0.7472, 0.0419, 0.7992],
         [0.4768, 0.7702, 0.6017, 0.1678, 0.9640, 0.4123, 0.5156, 0.4232]]])
normed x: torch.Size([1, 3, 8])
tensor([[[0.6276, 1.2819, 0.9036, 1.3697, 1.2278, 1.0113, 0.4542, 0.7305],
    

In [16]:
# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

layer = Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, hold5, x, layer, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[-0.7193, -0.6913,  0.2116,  0.1197,  1.4033,  0.2525, -1.4282,
          -1.9589],
         [ 1.1528, -0.1403,  3.1788, -0.0754, -0.8046, -1.1919,  1.9537,
          -0.4478],
         [-1.7944,  0.7290, -1.7912,  0.9675, -0.5966,  0.3567,  1.2787,
          -0.5683]]]),), (tensor([[[ 0.4044,  0.9881,  0.9838, -2.1977],
         [-0.7348,  1.6922, -0.8843,  0.1139],
         [-0.7336, -0.3367, -0.9742, -1.2100]]]), tensor([[[-0.7594, -0.8046, -0.5859, -1.3914],
         [ 1.1770,  0.8434,  1.4686,  0.0103],
         [ 1.6388,  0.5452,  0.6675,  0.4108]]])))
---------- Layer Input: Tuple ------------
------------- Layer.forwardTuple() ------------
x:
((tensor([[[-0.7193, -0.6913,  0.2116,  0.1197,  1.4033,  0.2525, -1.4282,
          -1.9589],
         [ 1.1528, -0.1403,  3.1788, -0.0754, -0.8046, -1.1919,  1.9537,
          -0.4478],
         [-1.7944,  0.7290, -1.

reshaped attention: torch.Size([1, 3, 4])
tensor([[[-0.4757,  0.2614, -0.4757,  0.2614],
         [-0.3545,  0.1510, -0.3354,  0.1336],
         [-0.4063,  0.0756, -0.4070,  0.0868]]], grad_fn=<ViewBackward0>)
self.Wo: torch.Size([8, 8])
Parameter containing:
tensor([[ 0.1812, -0.1638, -0.1178,  0.1815, -0.0778, -0.0560,  0.3211,  0.3358],
        [-0.1059,  0.2363,  0.1847, -0.2948,  0.2672,  0.3174,  0.3080,  0.2044],
        [ 0.2060,  0.0377, -0.2959, -0.2276, -0.2968,  0.1777, -0.0755, -0.1395],
        [ 0.3207,  0.0063,  0.1376, -0.2537,  0.3473, -0.3356,  0.0987, -0.2497],
        [-0.3443,  0.2761, -0.2423, -0.0421,  0.0527, -0.2862,  0.0705,  0.1438],
        [ 0.1588, -0.2870,  0.0738, -0.2562,  0.0793,  0.2388,  0.3065,  0.1916],
        [ 0.0715, -0.2669,  0.0129,  0.3258, -0.1468, -0.0368, -0.2941,  0.0135],
        [-0.0419, -0.0367, -0.1927, -0.0744,  0.0314, -0.2928,  0.1988, -0.1777]],
       requires_grad=True)
spliced Wo: torch.Size([4, 4])
tensor([[ 0.1812, -0.1638

# Output Layer

In [17]:
class OutputLayer(nn.Module):
    def __init__(self, embedding: torch.Tensor, config: Config):
        super().__init__()
        self.embedding = embedding
        self.v = config.vocab_size
        self.model_dim_list = config.model_dim_list

        # applies RMSNorm to the embedding matrix
        self.embedding_norm = RMSNorm(config.hidden_size,
                                      eps = config.rms_norm_eps)
        
        # Applies RMSNorm to the model's final residual state before we use the embedding matrix to get logits
        self.final_norm = RMSNorm(config.hidden_size,
                                  eps = config.rms_norm_eps)

    def forwardTensor(self, x, model=0):
        global verbose
        if verbose: 
            print("------------- OutputLayer.forwardTensor() ------------")
            print(f"x: {x.shape}\n{x}")

        # setting up our splicing logic
        d_i = x.shape[-1]
        skip = model * d_i
        if verbose:
            print(f"d_i: {d_i}")
            print(f"skip: {skip}")
            print(f"embedding: {self.embedding.shape}\n{self.embedding}")

        # splice out our embedding matrix according to what model we're using
        sliced_embed = self.embedding[:,skip:skip + d_i]
        if verbose: print(f"sliced_embed: {sliced_embed.shape}\n{sliced_embed}")

        # normalize our sliced embedding matrix
        normed_sliced_embed = self.embedding_norm(sliced_embed)
        if verbose: print(f"normed & sliced embedding: {normed_sliced_embed.shape}\n{normed_sliced_embed}")

        # normalize the residual state before the final linear layer
        x = self.final_norm(x, model)
        if verbose: print(f"normed x: {x.shape}\n{x}")

        # calculating the final output logits of the model
        logits = x @ normed_sliced_embed.t()
        if verbose: 
            print(f"final logits: {logits.shape}\n{logits}")
            print("------------- END OutputLayer.forwardTensor() ------------")

        return logits

    def forwardTuple(self, x):
        """
        Defines the forward pass of the final embedding classification layer during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            output (Tuple[Tuple[Tensor]]): 
                The output tuple of tuples of tensors after applying the final embedding classification
        """
        global verbose
        if verbose: 
            print("------------- OutputLayer.forwardTuple() ------------")
            print(f"x:\n{x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        assert type(x) == tuple
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model = j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END Layer.forwardTuple() ------------")
        
        return out
        
    def forward(self, x, model=0):
        train = True if type(x) == tuple else False
        if verbose: print(f"---------- Layer Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x) if train else self.forwardTensor(x, model)

In [18]:
# Testing our OutputLayer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.vocab_size
config.hidden_size = 4
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = OutputLayer(embedding, config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.vocab_size = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, x, layer, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
embedding: torch.Size([5, 4])
tensor([[ 1.1402,  0.2886, -0.5796, -0.3033],
        [-0.2475, -0.0909,  0.8898, -1.5566],
        [ 0.4822,  1.4014, -0.2881, -0.0487],
        [ 0.0914, -0.2257,  0.1666, -0.1041],
        [-0.0122, -0.8733,  1.2139, -0.5787]])
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 4])
tensor([[[0.5298, 0.9770, 0.7230, 0.6953],
         [0.3665, 0.9225, 0.7337, 0.1682],
         [0.1566, 0.0722, 0.5491, 0.9139]]])
---------- Layer Input: torch.Tensor ------------
------------- OutputLayer.forwardTensor() ------------
x: torch.Size([1, 3, 4])
tensor([[[0.5298, 0.9770, 0.7230, 0.6953],
         [0.3665, 0.9225, 0.7337, 0.1682],
         [0.1566, 0.0722, 0.5491, 0.9139]]])
d_i: 4
skip: 0
embedding: torch.Size([5, 4])
tensor([[ 1.1402,  0.2886, -0.5796, -0.3033],
        [-0.2475, -0.0909,  

In [19]:
# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

layer = Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,config.hidden_size)),),
     (torch.randn((1,3,config.hidden_size//config.split)),torch.randn((1,3,config.hidden_size//config.split))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, layer, out, embedding

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[ 1.0172, -0.1558, -0.1834,  0.0112],
        [ 0.4118,  0.4647, -0.0051,  0.9454],
        [-0.1046, -1.1461,  0.3192,  0.1052],
        [ 0.3178, -0.1745,  1.4441, -0.8474],
        [ 0.7193,  1.2875,  0.2692,  1.9234]])
x: ((tensor([[[ 0.9494, -0.2405,  0.5879,  1.4847],
         [ 1.0608, -1.3170,  1.0158,  1.9927],
         [ 1.5784,  1.0044, -1.6931, -0.0397]]]),), (tensor([[[ 0.6289,  0.4586],
         [-0.6905,  1.2571],
         [ 0.7723,  3.0443]]]), tensor([[[-0.4059, -0.8723],
         [ 0.0212,  1.2475],
         [ 0.7738,  0.3308]]])))
---------- Layer Input: Tuple ------------
------------- Layer.forwardTuple() ------------
x:
((tensor([[[ 0.9494, -0.2405,  0.5879,  1.4847],
         [ 1.0608, -1.3170,  1.0158,  1.9927],
         [ 1.5784,  1.0044, -1.6931, -0.0397]]]),), (tensor([[[ 0.6289,  0.4586],
         [-0.6905,  1.2571],
         [ 0.7723,  

Wk spliced: torch.Size([4, 32])
tensor([[ 0.1247,  0.4458,  0.0637, -0.4035, -0.1888, -0.3908,  0.4974,  0.0636,
         -0.3287,  0.3867, -0.3518, -0.0131,  0.0452, -0.4103, -0.0953,  0.3811,
          0.1975,  0.4332,  0.1069,  0.0880, -0.3776,  0.1469, -0.3684,  0.2586,
          0.0258,  0.2367, -0.4173, -0.4539,  0.0401,  0.3268,  0.4671,  0.3535],
        [ 0.4636, -0.2067, -0.2724,  0.2813,  0.3083, -0.1645,  0.4072, -0.3367,
         -0.4844,  0.4903, -0.2458,  0.3983,  0.4745,  0.1093,  0.3478, -0.3051,
          0.2002, -0.0206,  0.2579, -0.2674, -0.3509,  0.4186, -0.1168,  0.3344,
          0.4014, -0.1908,  0.1122, -0.2635, -0.2765, -0.4911,  0.3065,  0.2265],
        [-0.0153, -0.3781,  0.4079,  0.0030,  0.4418, -0.4477,  0.2635, -0.3938,
          0.0555, -0.3003, -0.2474, -0.3613,  0.0365,  0.0919,  0.0444, -0.0326,
          0.1988, -0.0043,  0.4960, -0.2726,  0.2577, -0.1399,  0.0490, -0.2605,
         -0.2334,  0.1966, -0.3514, -0.0269, -0.2524, -0.0522,  0.3142, -0.

Bdown: torch.Size([4])
Parameter containing:
tensor([-0.0242,  0.1740,  0.0925,  0.2142], requires_grad=True)
Bdown spliced: torch.Size([2])
tensor([-0.0242,  0.1740], grad_fn=<SliceBackward0>)
outputs: torch.Size([1, 3, 2])
tensor([[[-0.0218,  0.1717],
         [-0.0091,  0.1531],
         [-0.0124,  0.1585]]], grad_fn=<AddBackward0>)
------------- END MLP.forwardTensor() ------------
layer's final residual state:
tensor([[[ 0.8293,  0.5314],
         [-0.4652,  1.3387],
         [ 1.0220,  3.1334]]], grad_fn=<AddBackward0>)
----------------- END Layer.forwardTensor() --------------------
forwardTensor() output: torch.Size([1, 3, 2])
tensor([[[ 0.8293,  0.5314],
         [-0.4652,  1.3387],
         [ 1.0220,  3.1334]]], grad_fn=<AddBackward0>)
Model 1 from range(2)
----------------- Layer.forwardTensor() --------------------
------------- RMSNorm.forward() ------------
x: torch.Size([1, 3, 2])
tensor([[[-0.4059, -0.8723],
         [ 0.0212,  1.2475],
         [ 0.7738,  0.3308]]])
no

# Loss Function

In [20]:
class FractalLoss(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, logits, target):
        """
        input: 
            - logits are a tuple of tuples of tensors each of shape [batch_size, max_seq_len, vocab_size]
            - target is a shape [batch_size, max_seq_len] tensor of the integer indices of the correct tokens
        output: a tensor containing a single float of the loss value
        """
        global verbose
        if verbose: 
            print("------------- FractalLoss.forward() ------------")
            print(f"logits:\n{logits}")
            
        assert type(logits) == tuple # since this function should only be used during training
            
        # should only be used during training, so we assert input_len == max_position_embeddings
        b,t,v = logits[0][0].shape
        if verbose: print(f"b:{b}, t:{t}, v:{v}, b*t:{b*t}")
        assert t == config.max_position_embeddings
        
        # Calculate losses for each output and stack them. 
        # i apologize for the weird format instead of regular for loops, but it feels better in my head
        loss = torch.stack([ # stacks across levels
                            torch.stack( # stacks across models in level
                                        [self.criterion(logits_ij.view(b*t, v), # reshapes for CELoss
                                                        target.view(b*t)) 
                                         for logits_ij in logits[i]] # iterates across models in level
                            ).sum() # sums across models in level
                            for i in range(len(logits))] # iterates across levels
                          ).sum() # sums across levels

        if verbose:
            print(f"final loss: {loss}")
            print("------------- END FractalLoss.forward() ------------")

        return loss

In [21]:
# Testing our FractalLoss
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

loss = FractalLoss(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
logits = ((torch.randn((2,3,config.vocab_size)),),
     (torch.randn((2,3,config.vocab_size)),torch.randn((2,3,config.vocab_size))))
print(f"logits: {logits}")
target = torch.randint(config.vocab_size, (2,3)).unsqueeze(0)
print(f"target: {target}")
out = loss(logits, target)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, embedding, loss, logits, target, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[ 0.7798, -1.2821,  1.4271, -1.5708],
        [ 0.5209,  0.8941, -0.8535,  1.2495],
        [-0.2803,  0.2022,  0.6984,  0.7937],
        [-1.9587, -0.5391,  0.3903, -0.3077],
        [-0.1755, -0.0080, -1.2165,  0.8529]])
logits: ((tensor([[[ 0.1422, -0.3405,  0.2353,  0.6880, -0.8742],
         [ 0.1979,  0.9890,  0.0070, -1.8563, -1.5080],
         [ 0.0152,  1.4316,  1.1355,  0.5441,  0.2532]],

        [[-0.7838, -0.7004, -0.1135, -0.6035,  0.8739],
         [-0.7674, -0.7594, -1.9494,  0.0333, -0.8108],
         [ 0.6092, -0.2454,  1.0809,  1.0434,  1.2647]]]),), (tensor([[[ 1.4386, -0.1488, -0.2277,  0.3608,  0.3705],
         [ 1.2773,  0.7523,  0.6605,  0.9307,  1.2303],
         [-1.2080,  0.8718,  0.9394, -0.6804,  0.2274]],

        [[-1.4003,  0.4723,  0.7164, -1.9626,  0.6976],
         [-3.2143,  0.1564,  1.4848,  0.4646,  1.2435],
         [-0.6979,

# The Model itself

In [22]:
class FractalFormer_base(nn.Module):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer

        # hyperparameters
        self.hidden_size = config.hidden_size
        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size

        ### FractalFormer-specific hyperparameters
        self.num_levels = config.levels # the number of levels for sub-models to exist on
        self.split = config.split # the number of splits to make at a given level
        self.model_count = config.model_count # list of number of models at a given level
        self.model_dim_list = config.model_dim_list # list of hidden dimensions corresponding to each given level
        self.head_dim_list = config.head_dim_list # list of attention head dimensions corresponding to each given level    

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.hidden_size)

        # for normalizing the initial embeddings
        self.embedder_norm = RMSNorm(config.hidden_size)

        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        # initializing output layer
        self.output_layer = OutputLayer(self.embedder.weight, config)
        # i think i need to do this bc in the above version you can't use `self.` inside the init
        #@property 
        #def output_layer(self):
            #return OutputLayer(self.embedder.weight, config)

        # the loss function
        self.criterion = FractalLoss(config)

    def forwardTensor(self,
                      input_token_ids: torch.Tensor,
                      level: int = 0, # integer designating the level of model to use. 0 is largest model, -1 is smallest
                      model: int = 0, # integer designating the model in that level to use. 0 is top-left, -1 is bottom right
                     ) -> torch.Tensor:
        """
        inputs: 
            - input_token_ids (torch.Tensor): a tensor of integers size (batch_size, sequence_length)
            - level: integer designating the level of model to use. 0 is largest model, -1 is smallest
            - model: integer designating the model in that level to use. 0 is top-left, -1 is bottom right
        output: a torch.Tensor shape (batch_size, sequence_length, vocab_size)
        """
        global verbose
        if verbose: 
            print("------------- FractalFormer.forwardTensor() ------------")
            print(f"input_token_ids: {input_token_ids.shape}\n{input_token_ids}")
        
        # adjusting everything to the specified level & model
        d_dim = self.hidden_size // (2**level)
        d_skip = model * d_dim
        if verbose:
            print(f"d_dim: {d_dim}")
            print(f"d_skip: {d_skip}")
        
        # turn the input tokens into the first residual state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size) -> (batch_size, input_len, d_dim)
        x = self.embedder(input_token_ids)
        if verbose: print(f"x0: {x.shape}\n{x}")

        x = x[:,:, d_skip:d_skip + d_dim]
        if verbose: print(f"spliced x0: {x0.shape}\n{x0}")
        
        # Gemma normalizes the embedding by sqrt(hidden_size)
        # the question is, should I do this with the full sized hidden_size or do it at the splice size????
        # imma do it at the splice size and change it later if i think the models aren't learning well
        #x = x * (d_dim**0.5)
        # alternatively i could just switch to doing a regular RMSNorm which would be more like me
        # if i figure out this different sizes of hyperspheres thing it'd be more in line with that
        x = self.embedder_norm(x, model)
        if verbose: print(f"normalized initial x: {x.shape}\n{x}")

        # Iteratively process the input through each Layer
        for i, layer in enumerate(self.layers):
            if verbose: print(f"begin layer {i}")
            x = layer(x, model)
            if verbose: print(f"output of layer {i}: {x.shape}\n{x}")

        logits = self.output_layer(x, model)
        if verbose: 
            print(f"output logits: {logits.shape}\n{logits}")
            print("------------- END FractalFormer.forwardTensor() ------------")

        return logits

    def forwardTuple(self,
                     input_token_ids: torch.Tensor,
                     target_token_ids: torch.Tensor,
                    ) -> torch.Tensor:
        global verbose
        if verbose: 
            print("------------- FractalFormer.forwardTuple() ------------")
            print(f"input_token_ids: {input_token_ids.shape}\n{input_token_ids}")
            print(f"target_token_ids: {target_token_ids.shape}\n{target_token_ids}")
        
        # use the embedding matrix to turn the input tokens into the first residual state of the largest model
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size)
        x0 = self.embedder(input_token_ids)
        if verbose: print(f"initial x: {x.shape}\n{x}")

        # create the first fractal tuple of residual states
        x = ()
        for i, models_in_level in enumerate(config.model_count):
            if verbose: print(f"i: {i}, models_in_level: {models_in_level}, iterating over {config.model_count}")
            
            x_lvl = ()
            for j, d_dim in enumerate(config.model_dim_list):
                if verbose: print(f"j: {j}, d_dim: {d_dim}, iterating over {config.model_dim_list}")

                skip = j * d_dim
                if verbose: print(f"skip: {skip}")
                
                x_ij_spliced = x0[:,:,skip:skip + d_dim]
                if verbose: print(f"initial x[{i}][{j}] spliced: {x_ij_spliced.shape}\n{x_ij_spliced}")
                    
                x_ij_spliced_normed = self.embedder_norm(x_ij_spliced, model=j) # * (d_dim**0.5) # if i want to do Gemma normalization instead
                if verbose: print(f"initial x[{i}][{j}] spliced & normed: {x_ij_spliced_normed.shape}\n{x_ij_spliced_normed}")
                
                x_lvl += (x_ij_spliced_normed,)  
            x += (x_lvl,)
        if verbose: print(f"full tuple initial x: {x0}")

        # Iteratively process the input through each Layer
        for i, layer in enumerate(self.layers):
            if verbose: print(f"begin layer {i}")
            
            x = layer(x)
            if verbose: print(f"output of layer {i}: {x}")

        logits = self.output_layer(x)
        if verbose: 
            print(f"output logits: {logits}")
            print("------------- END FractalFormer.forwardTuple() ------------")

        return logits

    def forward(self,
                input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len OR max_seq_len)list of integer token ids
                target_token_ids: torch.Tensor = None, # a shape (batch_size, max_seq_len) list of token ids to train on
                level: int = 0, # integer designating the level of model to use. 0 is largest model
                model: int = 0, # integer designating the model in that level to use. 0 is top-left model in level
                ):
        global verbose
        if verbose: 
            print("------------- FractalFormer.forward() ------------")
            print(f"input_token_ids: {input_token_ids.shape}\n{input_token_ids}")
            print(f"target_token_ids: {target_token_ids}")
            print(f"level: {level}")
            print(f"model: {model}")
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            logits = self.forwardTensor(input_token_ids, level, model)
            loss = None
        else:
            # if we are training
            # training uses a tuple of tuples of tensors
            logits = self.forwardTuple(input_token_ids, target_token_ids) # -> Tuple[Tuple[Tensor shape (batch_size, max_seq_len, vocab_size)]]
            
            # custom Fractal CELoss function
            loss = self.criterion(logits, target_token_ids) 
        
        if verbose: 
            print(f"logits: {logits}")
            print(f"loss: {loss}")
            print("------------- END FractalFormer.forward() ------------")
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        The class operates as follows:
    
        1. Selects the last hidden state for each sequence in the batch
    
        2. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
    
        3. Temperature is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling (flatter -> more random)
    
        4. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        5. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens. 
        We to ignore long tail probabilities to avoid nonsensical output
    
        7. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        8. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        9. The function then samples from the re-normalized probability distribution to select the next token. 
        """
        if config.verbose['Sampler']:
            print("----------------- FractalFormer.Sampler() --------------")
            print(f"temperature: {temperature}, top_p: {top_p}, top_k: {top_k}")
            
        # Select the last element for each sequence.
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        logits = logits[:,-1,:]
        if config.verbose['Sampler']: print(f"logits: {logits.shape}\n{logits}")
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.clone().div_(temperature) # the clone() is because i didn't properly prevent gradient tracking and i'm too lazy to fix the issue at its cause
        if config.verbose['Sampler']: print(f"logits w temperature: {logits.shape}\n{logits}")

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along
        if config.verbose['Sampler']: print(f"probs: {probs.shape}\n{probs}")

        # sort the probabilities to for use in top-p & top-k
        # both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # probs_sort contains float probabilities while probs_idx contains integer indices
        if config.verbose['Sampler']: 
            print(f"probs_sort: {probs_sort.shape}\n{probs_sort}")
            print(f"probs_idx: {probs_idx.shape}\n{probs_idx}")

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        if config.verbose['Sampler']: print(f"probs_sum: {probs_sum.shape}\n{probs_sum}")
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        if config.verbose['Sampler']: print(f"top_ps_mask: {top_ps_mask.shape}\n{top_ps_mask}")
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 
        if config.verbose['Sampler']: print(f"probs_sort: {probs_sort.shape}\n{probs_sort}")

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        if config.verbose['Sampler']: print(f"top_ks_mask: {top_ks_mask.shape}\n{top_ks_mask}")
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        if config.verbose['Sampler']: print(f"top_ks_mask: {top_ks_mask.shape}\n{top_ks_mask}")
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k
        if config.verbose['Sampler']: print(f"top_ks_mask: {top_ks_mask.shape}\n{top_ks_mask}")

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
        if config.verbose['Sampler']: print(f"probs_sort: {probs_sort.shape}\n{probs_sort}")

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        if config.verbose['Sampler']: print(f"probs_sort: {probs_sort.shape}\n{probs_sort}")
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        if config.verbose['Sampler']: print(f"probs: {probs.shape}\n{probs}")
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        if config.verbose['Sampler']: print(f"next_token_id: {next_token_id.shape}\n{next_token_id}")
        
        return next_token_id # returns the predicted token
        
    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.7, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = config.vocab_size, # setting top_k = vocab_size means we're effectively not using top_k at all
        level: int = 0, # which size model we want to perform inference with
        model: int = 0, # which model in that level we want to perform inference with
    ) -> str: 
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_position_embeddings

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len], level=level, model=model)
            
            next_token = self.Sampler(
                logits = logits, # the actual output of the model
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )
            #print(next_token)

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# Training-related Functions

In [23]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [24]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [25]:
# a demonstration of what a batch with batch_size=1 looks like. Notice the one-token offset in characters
xb, yb = get_batch('train', 1)
print(xb)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(yb)
print(tokenizer.decode(yb.squeeze(0).tolist()))

tensor([[ 83,  80,   1,  57,  54, 112,  58,  57,   1,  94,   1,  72,   1,  52,
          47, 122,  58,   0,  17,  52,  95,  50,  53,  54,  43,   1,  88,  66,
          45, 126,  42,   1,  28, 115,  60,  53,  80,   2,   1,  35,  46,  53,
           1,  41,  39,  81,   5,  42,   1,  87,  93,   1,  94,   1,  50,  78,
          43,  12,  75,  28, 115,  60,  53,  80,  71,  26,  79,  43,  66,  57,
          69, 104,   1,  72,   1,  41,  59,  56,  44,  43,  61,   1,  56,  59,
          52,  45,  85,  16,  33,  23,  17,   1,  34,  21,  26,  15,  17,  26,
          32,  21,  27,  71,  26,  53,  58,   1,  21,  57,  39,  98,  50,  12,
          75,  28, 115,  60,  53,  80,  71,  26,  53,  85,  16,  33,  23,  17,
           1,  34,  21,  26,  15,  17,  26,  32,  21,  27,  71,  32,  87,  63,
           1, 100,  81,  66,  72,  52,  66, 110,   5,  58,   1,  98,   1,  50,
          79,  45,  85,  28, 115,  60,  53,  80,  71,  35,  92,  58,   1,  41,
          53,  51, 105,  58,   1,  74,   1, 105,   1

In [26]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [None]:
# just to make sure nothing got messed up above. 
# if an error gets thrown in one of the test cells then the config values won't reset
print(config)

model = FractalFormer_base(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

# Training

In [36]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-5
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 5000

# how often we want to check & see how our loss is doing
eval_interval = 250

# batch size to use
batch_size = 12

# if you want to do debugging
config.verbose['RMSNorm'] = False
config.verbose['MLP'] = False
config.verbose['MQA'] = False
config.verbose['Layer'] = False
config.verbose['OutputLayer'] = False
config.verbose['FractalLoss'] = False
config.verbose['FractalFormer'] = False
config.verbose['Sampler'] = False
config.verbose['Generate'] = False

# ------------ BOOKMARK ----------------

In [37]:
model.train()
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 19.1140, val loss 22.1463, time elapsed: 1.54 seconds
step 250: train loss 19.2682, val loss 22.2341, time elapsed: 354.09 seconds
step 500: train loss 18.8027, val loss 22.1553, time elapsed: 688.91 seconds
step 750: train loss 18.8095, val loss 21.7493, time elapsed: 1021.93 seconds
step 1000: train loss 19.0256, val loss 22.6200, time elapsed: 1356.87 seconds
step 1250: train loss 18.8943, val loss 22.5097, time elapsed: 1692.22 seconds
step 1500: train loss 18.8110, val loss 22.2320, time elapsed: 2028.29 seconds
step 1750: train loss 18.9138, val loss 21.8615, time elapsed: 2363.70 seconds
step 2000: train loss 18.8240, val loss 22.2428, time elapsed: 2709.21 seconds
step 2250: train loss 19.0780, val loss 21.9875, time elapsed: 3042.35 seconds
step 2500: train loss 18.7679, val loss 22.1570, time elapsed: 3372.90 seconds
step 2750: train loss 18.6269, val loss 22.4533, time elapsed: 3703.29 seconds
step 3000: train loss 18.8560, val loss 21.8843, time elapsed: 

# Saving your model

In [38]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
import os

# Ensure the directory exists
model_dir = 'models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Create a shorter, more concise filename
filename = (f'{model.__class__.__name__}'
           f'-v{config.vocab_size}'
           f'-max_t{config.max_position_embeddings}'
           f'-layers{config.num_hidden_layers}'
           f'-heads{config.num_attention_heads}'
           f'-kv_heads{config.num_key_value_heads}'
           f'-hidden{config.hidden_size}'
           f'-intermediate{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-theta{config.rope_theta}'
           f'-levels{config.levels}'
           f'-split{config.split}'
           f'-lr{learning_rate}'
           f'-decay{weight_decay}'
           f'-batch{batch_size}'
            f'-train_iter{15000}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Save the model
model_path = os.path.join(model_dir, filename)
torch.save(model.state_dict(), model_path)

# Load a Pretrained Model

In [27]:
# Initialize a blank model
model = FractalFormer_base(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/FractalFormer_base-v128-max_t256-layers4-heads4-kv_heads1-hidden128-intermediate512-head_dim32-theta100.0-levels3-split2-lr0.0003-decay0.01-batch12--2024-03-06|07-14-57.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

972.672 K parameters


FractalFormer_base(
  (embedder): Embedding(128, 128)
  (embedder_norm): RMSNorm()
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (self_attn): MultiQueryAttention(
        (drop): Dropout(p=0.1, inplace=False)
      )
      (mlp): MLP(
        (drop): Dropout(p=0.1, inplace=False)
      )
      (input_layernorm): RMSNorm()
      (post_attention_layernorm): RMSNorm()
    )
  )
  (output_layer): OutputLayer(
    (embedding_norm): RMSNorm()
    (final_norm): RMSNorm()
  )
  (criterion): FractalLoss(
    (criterion): CrossEntropyLoss()
  )
)

# Inference

In [40]:
model.eval() # sets model to eval mode
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)

for i in range(config.levels):
    for j in range(config.model_count[i]):
        print(f"level: {i}, model: {j}")
        output = model.generate(input_str, 
                                output_len = max_useable_output_len, 
                                temperature=0.7, 
                                top_k = 3, 
                                top_p = 0.95,
                               level = i,
                               model = j)
        print(output)

level: 0, model: 0
JULIET:
O Romeo, Romeo! wherefore art thou Romeo shall have the wars,
Than some prince then the world of this side,
Which all the sacribe of all the presence
Than thou hast be a substance of his pride.

BUSHY:
The grandship of the people of you, sir?

FRIAR LAURENCE:
The suppose spirit to thee thee, and I should have
distraction of the gentle b
level: 1, model: 0
JULIET:
O Romeo, Romeo! wherefore art thou RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR
level: 1, model: 1
JULIET:
O Romeo, Romeo! wherefore art thou Romeo,
The cause a madam, and the prince of thee to must,
The can some to be a spossion and some to so,
To must the prince of your from the compant
Where is the peoples and some in his so made
To pluck'd the people supposed the pair suppose
Is all you shall be propter of the substance
T
level: 2

so there's almost definitely something wrong happening here