# FractaFormer

this base version is going to be absurdly terribly no-good inefficient because we're taking the biggest computational issue ($O(t^2)$ attention) and making it way worse by doing MANY of them at once and then having to keep track of each parameter's gradient from MANY different perspectives. This is basically just an extension of [MatFormer+](https://github.com/evintunador/matryoshkaGPT/blob/main/MatFormer%2B.ipynb) where instead of one inner model, we have 2 (or whatever number you specify) models inside 1 at each layer

# TODO
- ~config~
- ~RMSNorm~
    - ~test~
- ~mlp~
    - ~tensor~
    - ~tuple~
    - ~triple-check test~
- ~mqa~
    - ~tensor~
    - ~tuple~
    - ~triple-check test~
- ~layer~
- ~output~
    - ~tensor~
    - ~tuple~
    - triple check test
- loss~~
    - ~tuple~
    - triple check test
- model itself
    - ~tensor~
    - tuple
    - triple check test

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# used for the tokenizer
import pickle
import os

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union
import numpy as np

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer

We'll be using a very simple tokenizer I previoiusly trained off of the TinyShakespeare dataset that has 128 total tokens and ignores stuff like special tokens & regex. 

In [3]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)
        
# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text, len(encoded_text))

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text, len(decoded_text))

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12] 37
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo? 49


# Config

In [4]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for FractalFormer
    """
    # The number of tokens in the vocabulary.
    vocab_size: int = tokenizer.vocab_len
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256
    
    # The number of layers in the model.
    num_hidden_layers: int = 4
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4
    
    # The number of key-value heads for implementing multi-query attention.
    num_key_value_heads: int = 1
    # Ensures that the number of query heads is evenly divisible by the number of KV heads.
    assert num_attention_heads % num_key_value_heads == 0
    
    # The hidden size of the model, AKA the embedding dimension
    hidden_size: int = 128
    # the attention heads need to cleanly divide up the hidden_size of the model for MQA
    assert hidden_size % num_attention_heads == 0

    # how much larger the inner dimension of the MLP should be than the hidden size of the model
    intermediate_multiplier = 4
    # The inner dimension of the MLP part of the decoder layer
    @property
    def intermediate_size(self):
        return self.intermediate_multiplier * self.hidden_size
    
    # The number of head dimensions
    head_dim: int = 32
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6 # this is to promote numerical stability & prevent dividing by 0
    
    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too. 10,000 is the usual

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # the % of neurons to dropout in the MLP
    dropout = 0.1

    ####### FractalFormer-specific hyperparameters

    # the number of levels for sub-models to exist on
    levels = 3
    
    # the number of splits to make at a given level
    split = 2 # i don't recommend choosing any value other than 2
    # needs to be divisible by 2 in order to splice cleanly
    assert split % 2 == 0
    # RoPE requires a head dimension of length larger than 1 in order to work
    assert head_dim // (split * (levels-1)) > 1
    # really though you shouldn't be getting anywhere near that small of a head dimension even at the lowest level, that'd be useless

    @property
    def model_count(self):
        return [self.split**i for i in range(self.levels)]

    @property
    def model_dim_list(self):
        return [self.hidden_size // (self.split**i) for i in range(self.levels)]

    @property
    def head_dim_list(self):
        return [self.head_dim // (self.split**i) for i in range(self.levels)]

config = Config()

print("single large model -> hierarchy of many smaller models inside")
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

single large model -> hierarchy of many smaller models inside
model_count:  [1, 2, 4]
model_dim_list:  [128, 64, 32]
head_dim_list:  [32, 16, 8]


# Rotary Positional Encoding (RoPE)

i don't think i need to adjust the code for this one as long as i always call it individually

In [5]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    # dynamic is less efficient but pre-computed was giving me trouble so whatever
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# RMSNorm

Layernorm is relatively simple code-wise. However, of note is the fact that during training, the entire full length vector gets normalized whereas during inference we only layernorm the sub-vector we've been given if we're not using the full model size. This is interesting because RMSNorm puts a vector of length $d$ onto a hypersphere of radius $\sqrt{d}$ which means that while the embeddings of the largest model exist on a hypersphere of the aforementioned size, for each number of layers $i\in\mathbb{N}$ s.t. $0 < i \leq$ `config.model_count` the embeddings are placed onto a hypersphere of radius $\sqrt{\frac{d}{s^i}}$ where $s=$`config.split`. I'm not sure yet exactly how to interpret this concatenation of vectors geometrically. When you combine the entries of two hypserspheres to make a larger hypserspheres, what happens to the feature groupings on the surface of the smaller hyperspheres? I presume there are some type of interaction effects or something. 

In [6]:
class RMSNorm(torch.nn.Module):
    """
    Implements the RMS Normalization (Root Mean Square Normalization) layer.
    RMSNorm is a variant of layer normalization that normalizes the activations
    of the previous layer based on their root mean square value.

    Parameters:
    - dim (int): The dimension of the input features the normalization is applied to.
    - eps (float): A small value added to the denominator for numerical stability. Default is 1e-6.
    - add_unit_offset (bool): If True, adds a unit (1) to the learned scaling coefficient, effectively
      starting with no scaling. If False, the scaling coefficient starts from zero. Default is True.
    """

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        #add_unit_offset: bool = True,
    ):
        super().__init__() 
        self.eps = eps  # Small epsilon value for numerical stability since you can't divide by 0
        #self.add_unit_offset = add_unit_offset  # Flag to determine if a unit should be added to the weight
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        """
        Private helper function to normalize the input tensor.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - Tensor: The normalized tensor.
        """
        # Calculate the root mean square value for each feature (across the last dimension),
        # then use reciprocal square root (rsqrt) for normalization.
        # Add self.eps to the denominator for numerical stability.
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor, model: int = 0) -> torch.Tensor:
        """
        Forward pass of the RMSNorm layer

        Parameters:
        - x (Tensor): The input tensor to normalize.
        - model (int): the index indicating the model being used in this layer. used for splicing self.weight

        Returns:
        - output: The normalized and scaled tensor.
        """
        global verbose
        if verbose: 
            print("------------- RMSNorm.forward() ------------")
            print(f"x: {x.shape}\n{x}")
            
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        if verbose: print(f"normed x: {x.shape}\n{x}")
        
        # grabbing x's dimension to use for splicing
        dim = x.shape[-1]
        
        # calculating skip for our splice
        skip = model * dim
        if verbose: 
            print(f"dim: {dim}")
            print(f"skip: {skip}")
        
        # scale the normalized tensor by (1 + self.weight), which effectively starts with no scaling
        spliced_scale = self.weight[skip:skip + dim]
        output = x * (1 + spliced_scale)
        if verbose:
            print(f"spliced scale: {spliced_scale.shape}\n{spliced_scale}")
            print(f"scaled normed x: {output.shape}\n{output}")
            print("------------- END RMSNorm.forward() ------------")
                          
        return output

The following cell was designed to help you visualize what's happening with RMSNorm's splicing. With RMSNorm we'll only have to think about doing this with individual tensors, but with future methods like MLP and MQA we'll have to create an entirely separate forward method used during training that deals with tuples of tensors. The thing to pay attention to here is the size of the scale weights. scale_weights' entries are 0's because we've not yet undergone training

In [7]:
# Testing our RMSNorm's forward()
verbose = True
print("--------- Micro Hyperparameters -------")
hold = config.hidden_size
config.hidden_size = 4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size)
print(f"x: {x.shape}\n{x}")
norm = RMSNorm(config.hidden_size)
y = norm(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size//2)
print(f"x: {x.shape}\n{x}")
norm = RMSNorm(config.hidden_size)
y = norm(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size//2)
print(f"x: {x.shape}\n{x}")
norm = RMSNorm(config.hidden_size)
y = norm(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold
print("model_count: ", config.model_count)

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 4])
tensor([[[0.5476, 0.4260, 0.2188, 0.1814],
         [0.3320, 0.7403, 0.4972, 0.3718]]])
------------- RMSNorm.forward() ------------
x: torch.Size([1, 2, 4])
tensor([[[0.5476, 0.4260, 0.2188, 0.1814],
         [0.3320, 0.7403, 0.4972, 0.3718]]])
normed x: torch.Size([1, 2, 4])
tensor([[[1.4608, 1.1363, 0.5837, 0.4840],
         [0.6500, 1.4493, 0.9733, 0.7279]]])
dim: 4
skip: 0
spliced scale: torch.Size([4])
tensor([0., 0., 0., 0.], grad_fn=<SliceBackward0>)
scaled normed x: torch.Size([1, 2, 4])
tensor([[[1.4608, 1.1363, 0.5837, 0.4840],
         [0.6500, 1.4493, 0.9733, 0.7279]]], grad_fn=<MulBackward0>)
------------- END RMSNorm.forward() ------------
y: torch.Size([1, 2, 4])
tensor([[[1.4608, 1.1363, 0.5837, 0.4840],
         [0.6500, 1.4493, 0.9733, 0.7279]]], grad_fn=<MulBackward

# Multi-Layer Perceptron


<p align="center">
<img src="./images/ffwd.jpeg" width="512"/>
</p>

In [8]:
class MLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout: float = 0.1,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
            dropout (float): the dropout rate to use during training in forwardTuple()
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        assert intermediate_size % hidden_size == 0
        self.intermediate_multiplier = intermediate_size // hidden_size

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        self.Wgate = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.Bgate = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        self.Wup = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.Bup = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        self.Wdown = nn.Parameter(torch.Tensor(intermediate_size, hidden_size))
        self.Bdown = nn.Parameter(torch.Tensor(hidden_size))

        # Initialize weights with uniform distribution
        # For gate & up, where in_features is hidden_size
        limit_gateup = 1 / np.sqrt(hidden_size)
        nn.init.uniform_(self.Wgate, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Bgate, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Wup, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.Bup, -limit_gateup, limit_gateup)
        
        # For down, where in_features is intermediate_size
        limit_down = 1 / np.sqrt(intermediate_size)
        nn.init.uniform_(self.Wdown, -limit_down, limit_down)
        nn.init.uniform_(self.Bdown, -limit_down, limit_down)
        
        # defining our dropout for training in forwardTuple()
        self.drop = nn.Dropout(dropout)

    def forwardTensor(self, x, model:int=0):
        """
        Defines the forward pass of the MLP module during inference.

        Parameters:
            x (Tensor): The input tensor to the MLP. 
                        shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        global verbose
        if verbose: 
            print("------------- MLP.forwardTensor() ------------")
            print(f"x: {x.shape}\n{x}")
            
        # figuring out how we should do our splicing
        d_dim = x.shape[-1]
        d_skip = model * d_dim
        i_dim = d_dim * self.intermediate_multiplier
        i_skip = model * i_dim
        if verbose: 
            print(f"d_dim: {d_dim}")
            print(f"d_skip: {d_skip}")
            print(f"i_dim: {i_dim}")
            print(f"i_skip: {i_skip}")
        
        # Applies linear transformation for gating.
        Wgate = self.Wgate[d_skip:d_skip + d_dim, i_skip:i_skip + i_dim]
        Bgate = self.Bgate[i_skip:i_skip + i_dim]
        Xgate = x @ Wgate + Bgate
        if verbose: 
            print(f"Wgate: {self.Wgate.shape}\n{self.Wgate}")
            print(f"Wgate spliced: {Wgate.shape}\n{Wgate}")
            print(f"Bgate: {self.Bgate.shape}\n{self.Bgate}")
            print(f"Bgate spliced: {Bgate.shape}\n{Bgate}")
            print(f"Xgate: {Xgate.shape}\n{Xgate}")

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        Xgate = F.gelu(Xgate)
        if verbose: print(f"GeLU'ed Xgate: {Xgate.shape}\n{Xgate}")

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        Wup = self.Wup[d_skip:d_skip + d_dim, i_skip:i_skip + i_dim]
        Bup = self.Bup[i_skip:i_skip + i_dim]
        Xup = x @ Wup + Bup
        if verbose: 
            print(f"Wup: {self.Wup.shape}\n{self.Wup}")
            print(f"Wup spliced: {Wup.shape}\n{Wup}")
            print(f"Bup: {self.Bup.shape}\n{self.Bup}")
            print(f"Bup spliced: {Bup.shape}\n{Bup}")
            print(f"Xup: {Xup.shape}\n{Xup}")

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        Xfuse = Xgate * Xup
        if verbose: print(f"Xfuse: {Xfuse.shape}\n{Xfuse}")

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        Wdown = self.Wdown[i_skip:i_skip + i_dim, d_skip:d_skip + d_dim]
        Bdown = self.Bdown[d_skip:d_skip + d_dim]
        outputs = Xfuse @ Wdown + Bdown
        if verbose: 
            print(f"Wdown: {self.Wdown.shape}\n{self.Wdown}")
            print(f"Wdown spliced: {Wdown.shape}\n{Wdown}")
            print(f"Bdown: {self.Bdown.shape}\n{self.Bdown}")
            print(f"Bdown spliced: {Bdown.shape}\n{Bdown}")
            print(f"outputs: {outputs.shape}\n{outputs}") 
            print("------------- END MLP.forwardTensor() ------------")

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

    def forwardTuple(self, x, drop_bool: bool = True):
        """
        Defines the forward pass of the MLP module during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors to the MLP. 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the GeGLU gating mechanism and the MLP transformations.
        """
        global verbose
        if verbose: 
            print("------------- MLP.forwardTuple() ------------")
            print(f"x: {x}")

        # if we had sent through the config we could've just grabbed these values from there but too late now
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")
        
        out = ()
        for i in range(num_levels):
            if verbose: print(f"i: {i}")
            
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"j: {j}")

                output = self.forwardTensor(x[i][j], model=j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                    
                out_lvl += (self.drop(output),) if drop_bool else (output,)

            # pretty sure i have to save & store everything without overwriting to prevent in-place arguments. so annoying
            if verbose: print(f"out_lvl: {out_lvl}")
            out += (out_lvl,)
        
        if verbose:
            print(f"out: {out}")
            print("------------- END MLP.forwardTuple() ------------")
        return out
        
    def forward(self, x, model=0, drop_bool = True):
        train = True if type(x) == tuple else False
        print(f"---------- MLP Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x, drop_bool) if train else self.forwardTensor(x, model)

The following two cells are designed to help you comprehend what's happening. If you walk through every single print statement and follow along even down to watching what happens to each weight, you'll be able to clearly see what's happening with the odd splicing behavior. In order to make this somewhat feasible, I've set very small matrices for these examples. However I will admit it is still inevitably a pain, which is why I included the drawings above.

In [9]:
# Testing our MLP's forwardTensor()
verbose = True
print("--------- Micro Hyperparameters -------")
hold = config.hidden_size
config.hidden_size = 4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,4)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = MLP(4,8)
y = mlp(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold
print("model_count: ", config.model_count)

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 4])
tensor([[[0.8357, 0.6235, 0.2187, 0.8647],
         [0.0509, 0.1876, 0.5219, 0.2856]]])
---------- MLP Input: torch.Tensor ------------
------------- MLP.forwardTensor() ------------
x: torch.Size([1, 2, 4])
tensor([[[0.8357, 0.6235, 0.2187, 0.8647],
         [0.0509, 0.1876, 0.5219, 0.2856]]])
d_dim: 4
d_skip: 0
i_dim: 8
i_skip: 0
Wgate: torch.Size([4, 8])
Parameter containing:
tensor([[-0.3891,  0.1615, -0.3803,  0.0968, -0.3116,  0.0135, -0.3844, -0.0793],
        [-0.4522, -0.0294, -0.4864, -0.2767,  0.0902,  0.2481, -0.2995, -0.4016],
        [-0.3055,  0.2483,  0.0267, -0.2415,  0.2368, -0.4932,  0.4063, -0.0098],
        [ 0.3214, -0.3187,  0.0295, -0.3992, -0.1088, -0.3484, -0.0786,  0.1379]],
       requires_grad=True)
Wgate spliced: torch.Size([4, 8])
tensor([[-0.3891,  0.161

In [10]:
# Testing our MLP's forwardTuple()
verbose = True
print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.levels
config.hidden_size = 4
config.levels = 2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

mlp = MLP(4,8)
x = ((torch.randn((1,2,4)),),
     (torch.randn((1,2,2)),torch.randn((1,2,2)))
    )
print(f"x: {x}")
out = mlp(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
x: ((tensor([[[ 0.2583,  0.1370, -1.1712,  1.6180],
         [-0.6207, -0.9031,  1.2158,  0.2890]]]),), (tensor([[[-0.1530, -1.2461],
         [ 1.2806, -0.1644]]]), tensor([[[ 0.2353, -1.0113],
         [-1.0704, -0.7673]]])))
---------- MLP Input: Tuple ------------
------------- MLP.forwardTuple() ------------
x: ((tensor([[[ 0.2583,  0.1370, -1.1712,  1.6180],
         [-0.6207, -0.9031,  1.2158,  0.2890]]]),), (tensor([[[-0.1530, -1.2461],
         [ 1.2806, -0.1644]]]), tensor([[[ 0.2353, -1.0113],
         [-1.0704, -0.7673]]])))
num_levels: 2
models_per_level: [1, 2]
i: 0
j: 0
------------- MLP.forwardTensor() ------------
x: torch.Size([1, 2, 4])
tensor([[[ 0.2583,  0.1370, -1.1712,  1.6180],
         [-0.6207, -0.9031,  1.2158,  0.2890]]])
d_dim: 4
d_skip: 0
i_dim: 8
i_skip: 0
Wgate: torch.Size([4, 8])
Parameter containing:
tensor([[ 0.1863, -0.2803,  0.4625, -0.4580,  0.3121,  0.0467, -0.079

# Attention

To subset the attention heads, we have to not only splice according to the model's embedding dimension but also take into account new smaller head sizes and how they're spaced throughout the matrix. I'm assuming you know how self-attention works well enough to look at this weight matrix and get the idea

<p align="center">
<img src="./images/sa.jpeg" width="512"/>
</p>

then we've gotta concatenate the outputs of each head

<p align="center">
<img src="./images/mha_concat.jpeg" width="512"/>
</p>

and after that linearly project them

<p align="center">
<img src="./images/mha_proj.jpeg" width="512"/>
</p>

this is the place where our splicing gets conceptually annoying. instead of just grabbing the matrix in the upper corner, because of the way attention head output concatenation works we actually need to skip over certain parts of the linear projection matrix and then concatenate them together in order to use them. Here's an example of what the matrix multiplication looks like. on the left is a simplified version of the concatenated attention heads where i just showed it as a matrix rather than a tensor, and then on the right is the actual projection matrix. notice how the numbers in the pink output matrix look similar to the first column of the purple output matrix with a positive number, its negative, and then a smaller positive number; that's the self-similarity in action. the yellow arrows point to the parts that get skipped over. obviously this would look a lot uglier with bigger matrices & incorporating the blue/green layer

<p align="center">
<img src="./images/mha_proj_matmul.jpeg" width="512"/>
</p>

In [11]:
class MultiQueryAttention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim
        
        # Initialize our learnable matrices
        # the linear projection layer for queries, keys, and values
        # no real reason why we're creating one matrix instead of separate ones. cleaner model summary view?
        self.Wqkv = nn.Parameter(torch.Tensor(self.hidden_size,
                                              (self.num_heads + 2 * self.num_kv_heads) * self.head_dim))
        # the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        self.Wo = nn.Parameter(torch.Tensor(self.num_heads * self.head_dim, self.hidden_size))
        
        # Initialize weights with uniform distribution
        # For qkv_proj, where in_features is hidden_size
        limit_Wqkv = 1 / np.sqrt(self.hidden_size)
        nn.init.uniform_(self.Wqkv, -limit_Wqkv, limit_Wqkv)
        # for o_proj, where in_features is self.num_heads * self.head_dim
        limit_Wo = 1 / np.sqrt(self.num_heads * self.head_dim)
        nn.init.uniform_(self.Wo, -limit_Wo, limit_Wo)
        
        # for our attention mask we'll use very large negative values to prevent attending to certain tokens
        mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),
                                 -2.3819763e38).to(torch.float)
        # then we'll replace the lower triangular ones with 0's to allow attention to see past tokens
        mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
        # to define self.mask as a tensor that shouldn't undergo gradient descent
        self.register_buffer('mask', mask)
        
        # defining our dropout
        self.drop = nn.Dropout(config.dropout)

    def forwardTensor(self,
                      x: torch.Tensor,
                      model: int = 0,
                     ) -> torch.Tensor:
        """
        Inputs:
            x (torch.Tensor): Te input tensor to the attention mechanism.
                        shape (batch_size, input_len, hidden_size)
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice
        
        Returns:
            Tensor: The output tensor after applying the attention mechanism
        """
        global verbose
        if verbose: print("----------------- MultiQueryAttention.forwardTensor() --------------------")
        
        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        x_shape = x.shape
        assert len(x_shape) == 3
        if verbose: print(f"x shape: {x_shape}")

        # Extracts input sequence length and embedding dimension length from the hidden states tensor.
        batch_size, input_len, d_dim = x_shape
        
        # figuring out how we should do our splicing
        # first along the embedding dimension
        d_skip = model * d_dim  # the size of our skip along the model's embedding dimension
        if verbose: print(f"d_skip: {d_skip}")
        
        # then for splicing along the head sizes dimension
        index = config.model_dim_list.index(d_dim)
        models_in_this_level = config.model_count[index] # how many models are in this level
        h_dim = config.head_dim_list[index] # the head dimension size of this model in this level
        h_skip = model * h_dim # the size of our skip along the head dimension
        if verbose: 
            print(f"models_in_this_level: {models_in_this_level}")
            print(f"h_dim: {h_dim}")
            print(f"h_skip: {h_skip}")

        # Splits the Wqkv tensor into separate tensors for queries, keys, and values based on their respective sizes.
        if verbose: print(f"self.Wqkv: {self.Wqkv.shape}\n{self.Wqkv}")
        Wq, Wk, Wv = self.Wqkv.split([self.q_size,
                                      self.kv_size,
                                      self.kv_size],dim=-1)
        if verbose: 
            print(f"Wq: {Wq.shape}\n{Wq}")
            print(f"Wk: {Wk.shape}\n{Wk}")
            print(f"Wv: {Wv.shape}\n{Wv}")
        
        # splicing to get our correct weight matrices for each respective head
        # d_dim is relatively self-explanatory
        # i*self.head_dim is bc we initialized one single q, k, and v matrix for all heads so we have to
        # iterate through said matrix to get to the correct head
        Wq = torch.cat([Wq[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_heads)], dim=1)
        Wk = torch.cat([Wk[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=1)
        Wv = torch.cat([Wv[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=1)
        if verbose:
            print(f"Wq spliced: {Wq.shape}\n{Wq}")
            print(f"Wk spliced: {Wk.shape}\n{Wk}")
            print(f"Wv spliced: {Wv.shape}\n{Wv}")
        
        # this needs to be size (d_dim, (self.num_heads + 2 * self.num_kv_heads) * h_dim) aka (32,24)
        # recombine the spliced Wq Wk and Wv. Now they're the right size for matmul against x
        Wqkv_spliced = torch.cat((Wq, Wk, Wv), dim=-1)
        if verbose:
            print(f"Wqkv_spliced: {Wqkv_spliced.shape}\n{Wqkv_spliced}")
        

        # finally we can project x to get our queries, keys and values
        xqkv = x @ Wqkv_spliced
        if verbose: print(f"xqkv: {xqkv.shape}\n{xqkv}")
            
        # Splits the combined Xqkv tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = xqkv.split([self.q_size // models_in_this_level,
                                 self.kv_size // models_in_this_level,
                                 self.kv_size // models_in_this_level],dim=-1)
        if verbose:
            print(f"xq: {xq.shape}\n{xq}")
            print(f"xk: {xk.shape}\n{xk}")
            print(f"xv: {xv.shape}\n{xv}")

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, input_len, self.num_heads, h_dim)#, self.head_dim)
        xk = xk.view(batch_size, input_len, self.num_kv_heads, h_dim)#, self.head_dim)
        xv = xv.view(batch_size, input_len, self.num_kv_heads, h_dim)#, self.head_dim)
        if verbose:
            print(f"xq reshaped: {xq.shape}\n{xq}")
            print(f"xk reshaped: {xk.shape}\n{xk}")
            print(f"xv reshaped: {xv.shape}\n{xv}")

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, h_dim, self.theta)#self.head_dim
        xk = apply_rotary_emb(xk, h_dim, self.theta)#self.head_dim
        # is the differring head dimension going to mess with RoPE? Not sure
        if verbose:
            print(f"rotated xq: {xq.shape}\n{xq}")
            print(f"rotated xk: {xk.shape}\n{xk}")

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
            if verbose:
                print(f"repeat_interleaved xk: {xk.shape}\n{xk}")
                print(f"repeat_interleaved xv: {xv.shape}\n{xv}")

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)
        if verbose:
            print(f"transposed xq: {q.shape}\n{q}")
            print(f"transposed xk: {k.shape}\n{k}")
            print(f"transposed xv: {v.shape}\n{v}")

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * h_dim**-0.5#self.scaling
        if verbose: print(f"scores: {scores.shape}\n{scores}")
        
        # Applies the lower-triangular mask to the attention scores
        if verbose: print(f"mask: {self.mask[...,:input_len, :input_len].shape}\n{self.mask[...,:input_len, :input_len]}")
        scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        if verbose: print(f"masked scores: {scores.shape}\n{scores}")

        # Applies softmax to the scores to obtain attention probabilities
        scores = F.softmax(scores, dim=-1)
        if verbose: print(f"softmaxed scores: {scores.shape}\n{scores}")
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        attention = torch.matmul(scores, v)
        if verbose: print(f"attention: {attention.shape}\n{attention}")

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
        if verbose: print(f"reshaped attention: {attention.shape}\n{attention}")

        # Splice the output projection
        Wo = torch.cat([self.Wo[i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim,\
                                d_skip:d_skip + d_dim,\
                               ] for i in range(self.num_heads)], dim=0)
        if verbose: 
            print(f"self.Wo: {self.Wo.shape}\n{self.Wo}")
            print(f"spliced Wo: {Wo.shape}\n{Wo}")
            
        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = attention @ Wo
        if verbose: 
            print(f"projected output: {output.shape}\n{output}")
            print("----------------- END MultiQueryAttention.forwardTensor() --------------------")
            
        return output

    def forwardTuple(self,
                     x: Tuple[Tuple[torch.Tensor]],
                     drop_bool: bool = True
                    ) -> torch.Tensor:
        """
        Defines the forward pass of the Attention module during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the MQA mechanism
        """
        global verbose
        if verbose: 
            print("------------- MultiQueryAttention.forwardTuple() ------------")
            print(f"x: {x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model=j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (self.drop(output),) if drop_bool else (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END MultiQueryAttention.forwardTuple() ------------")

        return out
        
    def forward(self, x, model=0, drop_bool = True):
        train = True if type(x) == tuple else False
        print(f"---------- Attention Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x, drop_bool) if train else self.forwardTensor(x, model)

And here are the detailed print statements for the attention mechanism

In [12]:
# Testing our Attention's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,8)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = MultiQueryAttention(config)
y = att(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.0562, 0.9052, 0.6449, 0.7861, 0.4129, 0.1102, 0.5653, 0.5609],
         [0.0178, 0.8824, 0.6357, 0.0244, 0.5049, 0.4234, 0.2152, 0.0481],
         [0.8340, 0.2033, 0.0968, 0.9891, 0.6792, 0.4305, 0.8989, 0.7066]]])
---------- Attention Input: torch.Tensor ------------
----------------- MultiQueryAttention.forwardTensor() --------------------
x shape: torch.Size([1, 3, 8])
d_skip: 0
models_in_this_level: 1
h_dim: 4
h_skip: 0
self.Wqkv: torch.Size([8, 16])
Parameter containing:
tensor([[-0.0521,  0.3254,  0.3317,  0.0196, -0.0625, -0.1371, -0.2930,  0.0693,
          0.0161,  0.2415,  0.0518, -0.2021, -0.1875,  0.0807, -0.3144,  0.2513],
        [-0.0458,  0.1770, -0.1642, -0.3496,  0.1726, -0.3313, -0.2308,  0.1149,
         

In [13]:
# Testing our Attention's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

att = MultiQueryAttention(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = att(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[-0.3497, -2.9496, -0.4524,  1.9168, -0.0945, -0.4339, -0.2205,
           0.3369],
         [ 0.0089,  0.9981,  0.1863, -1.5402,  0.5119,  0.0959, -0.8294,
           1.6965],
         [ 1.3860, -1.5527,  0.2305, -2.2388,  0.9591, -0.5175,  0.9746,
           1.2292]]]),), (tensor([[[-0.8564, -0.5784,  0.6057, -0.8484],
         [ 1.2820, -0.1706,  1.0822,  0.0600],
         [ 0.1781, -0.9001,  0.9855, -0.2710]]]), tensor([[[ 3.0712,  0.1522, -0.1656,  0.7491],
         [ 0.8894,  0.9551, -0.5230, -0.2508],
         [-1.0178, -0.4089, -0.9741, -0.9321]]])))
---------- Attention Input: Tuple ------------
------------- MultiQueryAttention.forwardTuple() ------------
x: ((tensor([[[-0.3497, -2.9496, -0.4524,  1.9168, -0.0945, -0.4339, -0.2205,
           0.3369],
         [ 0.0089,  0.9981,  0.1863, -1.5402,  0.5119,  0.0959, -0.8294,
           1.6965],
         [ 1.

# Layer

nothing too interesting here besides the absurd amount of memory we're probably taking up with these tuples

In [14]:
class Layer(nn.Module):
    """
    A decoder layer that integrates the MultiQueryAttention and MLP. It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(self, config: Config):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = MultiQueryAttention(config)
        
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = MLP(
            # the hidden dimension of the model
            hidden_size = config.hidden_size,
            # the number of nodes in the center of the two feedforward layers
            intermediate_size = config.intermediate_size,
            # the % of neurons to set to 0 during training
            dropout = config.dropout,
        )
        
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forwardTensor(self,
                # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
                x: torch.Tensor,
                model: int = 0,
                drop_bool: bool = False
                ) -> torch.Tensor:
        global verbose
        if verbose: print("----------------- Layer.forwardTensor() --------------------")
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual_connection = x
        # Normalizes the input before processing by the attention mechanism.
        x = self.input_layernorm(x, model)
        # Processes the normalized input through the GemmaAttention mechanism
        x = self.self_attn(x, model, drop_bool)
        # The aforementioned residual connection
        x = residual_connection + x
        if verbose: print(f"x in layer after MQA & resid connection and before MLP:\n{x}")

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual_connection = x
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        x = self.post_attention_layernorm(x, model)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        x = self.mlp(x, model, drop_bool)
        # Another residual connection
        x = residual_connection + x
        if verbose: 
            print(f"layer's final residual state:\n{x}")
            print("----------------- END Layer.forwardTensor() --------------------")

        return x

    def forwardTuple(self,
                     x: Tuple[Tuple[torch.Tensor]],
                    ) -> torch.Tensor:
        """
        Defines the forward pass of a decoder layer during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the decoder layer
        """
        global verbose
        if verbose: 
            print("------------- Layer.forwardTuple() ------------")
            print(f"x:\n{x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model = j, drop_bool = True)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END Layer.forwardTuple() ------------")

        return out
        
    def forward(self, x, model=0):
        train = True if type(x) == tuple else False
        print(f"---------- Layer Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x) if train else self.forwardTensor(x, model)

In [15]:
# Testing our Layer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = Layer(config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
layer = Layer(config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
layer = Layer(config)
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.9469, 0.5231, 0.5780, 0.0538, 0.4399, 0.5627, 0.7584, 0.9368],
         [0.0660, 0.3250, 0.0122, 0.8770, 0.5177, 0.9000, 0.7999, 0.7379],
         [0.6697, 0.2385, 0.2210, 0.2782, 0.0758, 0.9189, 0.5016, 0.0058]]])
---------- Layer Input: torch.Tensor ------------
----------------- Layer.forwardTensor() --------------------
------------- RMSNorm.forward() ------------
x: torch.Size([1, 3, 8])
tensor([[[0.9469, 0.5231, 0.5780, 0.0538, 0.4399, 0.5627, 0.7584, 0.9368],
         [0.0660, 0.3250, 0.0122, 0.8770, 0.5177, 0.9000, 0.7999, 0.7379],
         [0.6697, 0.2385, 0.2210, 0.2782, 0.0758, 0.9189, 0.5016, 0.0058]]])
normed x: torch.Size([1, 3, 8])
tensor([[[1.4377, 0.7942, 0.8776, 0.0817, 0.6679, 0.8545, 1.1516, 1.4223],
    

In [16]:
# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

layer = Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[ 2.2088,  0.2584,  0.7021, -0.9308,  0.1443,  1.9641,  1.0878,
           0.3879],
         [-0.2463, -1.1001,  0.0890,  0.6983,  0.2876,  0.4139, -2.8226,
           0.2855],
         [ 0.2658, -1.2915, -0.3992, -0.9281, -0.8413,  0.3240,  1.2216,
          -0.2044]]]),), (tensor([[[ 0.2293, -0.9171, -0.9802, -0.5067],
         [ 0.1680, -1.1694,  0.7582,  0.1348],
         [-1.9322, -0.7696, -1.3604, -0.6910]]]), tensor([[[ 2.0527,  1.9986,  1.9487, -0.8838],
         [-1.3443,  0.4112, -1.1328,  0.0875],
         [-2.1688, -1.0016, -0.4252, -0.2276]]])))
---------- Layer Input: Tuple ------------
------------- Layer.forwardTuple() ------------
x:
((tensor([[[ 2.2088,  0.2584,  0.7021, -0.9308,  0.1443,  1.9641,  1.0878,
           0.3879],
         [-0.2463, -1.1001,  0.0890,  0.6983,  0.2876,  0.4139, -2.8226,
           0.2855],
         [ 0.2658, -1.2915, -0.

# Output Layer

In [17]:
class OutputLayer(nn.Module):
    def __init__(self, embedding: torch.Tensor, config: Config):
        super().__init__()
        self.embedding = embedding
        self.v = config.vocab_size
        self.model_dim_list = config.model_dim_list

        # applies RMSNorm to the embedding matrix
        self.embedding_norm = RMSNorm(config.hidden_size,
                                      eps = config.rms_norm_eps)
        
        # Applies RMSNorm to the model's final residual state before we use the embedding matrix to get logits
        self.final_norm = RMSNorm(config.hidden_size,
                                  eps = config.rms_norm_eps)

    def forwardTensor(self, x, model=0):
        global verbose
        if verbose: 
            print("------------- OutputLayer.forwardTensor() ------------")
            print(f"x: {x.shape}\n{x}")

        # setting up our splicing logic
        d_i = x.shape[-1]
        skip = model * d_i
        if verbose:
            print(f"d_i: {d_i}")
            print(f"skip: {skip}")
            print(f"embedding: {self.embedding.shape}\n{self.embedding}")

        # splice out our embedding matrix according to what model we're using
        sliced_embed = self.embedding[:,skip:skip + d_i]
        if verbose: print(f"sliced_embed: {sliced_embed.shape}\n{sliced_embed}")

        # normalize our sliced embedding matrix
        normed_sliced_embed = self.embedding_norm(sliced_embed)
        if verbose: print(f"normed & sliced embedding: {normed_sliced_embed.shape}\n{normed_sliced_embed}")

        # normalize the residual state before the final linear layer
        x = self.final_norm(x, model)
        if verbose: print(f"normed x: {x.shape}\n{x}")

        # calculating the final output logits of the model
        logits = x @ normed_sliced_embed.t()
        if verbose: 
            print(f"final logits: {logits.shape}\n{logits}")
            print("------------- END OutputLayer.forwardTensor() ------------")

        return logits

    def forwardTuple(self, x):
        """
        Defines the forward pass of the final embedding classification layer during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            output (Tuple[Tuple[Tensor]]): 
                The output tuple of tuples of tensors after applying the final embedding classification
        """
        global verbose
        if verbose: 
            print("------------- OutputLayer.forwardTuple() ------------")
            print(f"x:\n{x}")
            
        # forwardTuple() should only be used during training, so we assert input_len == max_position_embeddings
        assert x is tuple
        input_len = x[0][0].shape[1]
        if verbose: print(f"input_len: {input_len}")
        assert input_len == config.max_position_embeddings

        # we could define these from the config but this way the method is more flexible to testing
        num_levels = len(x)
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose: 
            print(f"num_levels: {num_levels}")
            print(f"models_per_level: {models_per_level}")

        # the loop that iterates over levels, aka the different potential sizes of models
        out = ()
        for i in range(num_levels):
            if verbose: print(f"Level {i} from range({num_levels})")

            # now for the loop that iterates over models in this level
            out_lvl = ()
            for j in range(models_per_level[i]):
                if verbose: print(f"Model {j} from range({models_per_level[i]})")

                output = self.forwardTensor(x[i][j], model = j)
                if verbose: print(f"forwardTensor() output: {output.shape}\n{output}")
                
                out_lvl += (output,)
            
            out += (out_lvl,)
        
        if verbose:
            print(f"final output: {out}")
            print("------------- END Layer.forwardTuple() ------------")
        
        return out
        
    def forward(self, x, model=0):
        train = True if type(x) == tuple else False
        print(f"---------- Layer Input: {'Tuple' if train else 'torch.Tensor'} ------------")
        return self.forwardTuple(x) if train else self.forwardTensor(x, model)

In [18]:
# Testing our OutputLayer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.vocab_size
config.hidden_size = 4
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = OutputLayer(embedding, config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
layer = OutputLayer(embedding, config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
layer = OutputLayer(embedding, config)
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.vocab_size = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
embedding: torch.Size([5, 4])
tensor([[-1.0187,  0.9947, -0.0253, -3.0200],
        [ 2.4933, -0.1411,  1.8727, -0.3439],
        [ 1.6053,  1.1763, -0.3581, -1.0166],
        [ 0.5760,  0.5348,  1.8814,  0.2625],
        [-1.8697, -1.8058,  1.7278, -0.0814]])
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 4])
tensor([[[0.0579, 0.5822, 0.7695, 0.3253],
         [0.5348, 0.7898, 0.7203, 0.9398],
         [0.5933, 0.6965, 0.8708, 0.9116]]])
---------- Layer Input: torch.Tensor ------------
------------- OutputLayer.forwardTensor() ------------
x: torch.Size([1, 3, 4])
tensor([[[0.0579, 0.5822, 0.7695, 0.3253],
         [0.5348, 0.7898, 0.7203, 0.9398],
         [0.5933, 0.6965, 0.8708, 0.9116]]])
d_i: 4
skip: 0
embedding: torch.Size([5, 4])
tensor([[-1.0187,  0.9947, -0.0253, -3.0200],
        [ 2.4933, -0.1411,  

In [19]:
# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

layer = Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,config.hidden_size)),),
     (torch.randn((1,3,config.hidden_size//config.split)),torch.randn((1,3,config.hidden_size//config.split))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[-0.8400, -0.5197, -1.2188,  0.5423],
        [-0.2075,  0.7968, -0.4593, -1.0467],
        [ 1.3413, -0.5771, -1.0626,  0.6517],
        [-0.7116, -1.1119, -0.8719,  1.4338],
        [-1.2388, -0.6028,  1.4139,  0.1831]])
x: ((tensor([[[-0.1462,  0.2696, -1.3195, -1.2945],
         [-0.9201, -1.2686, -1.1653,  2.0797],
         [-0.4185, -0.1431,  0.7092,  1.2014]]]),), (tensor([[[-0.2275, -0.4399],
         [ 2.3621,  0.2800],
         [-0.2841,  1.2598]]]), tensor([[[-1.0265, -1.4249],
         [ 0.9723,  1.2800],
         [ 1.3253, -1.7510]]])))
---------- Layer Input: Tuple ------------
------------- Layer.forwardTuple() ------------
x:
((tensor([[[-0.1462,  0.2696, -1.3195, -1.2945],
         [-0.9201, -1.2686, -1.1653,  2.0797],
         [-0.4185, -0.1431,  0.7092,  1.2014]]]),), (tensor([[[-0.2275, -0.4399],
         [ 2.3621,  0.2800],
         [-0.2841,  

# Loss Function

In [20]:
class FractalLoss(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, logits, target):
        """
        input: 
            - logits are a tuple of tuples of tensors each of shape [batch_size, max_seq_len, vocab_size]
            - target is a shape [batch_size, max_seq_len] tensor of the integer indices of the correct tokens
        output: a tensor containing a single float of the loss value
        """
        global verbose
        if verbose: 
            print("------------- FractalLoss.forward() ------------")
            print(f"logits:\n{logits}")
            
        assert type(x) == tuple # since this function should only be used during training
            
        # should only be used during training, so we assert input_len == max_position_embeddings
        b,t,v = logits[0][0].shape
        if verbose: print(f"b:{b}, t:{t}, v:{v}, b*t:{b*t}")
        assert t == config.max_position_embeddings
        
        # Calculate losses for each output and stack them. i apologize for the weird format instead of regular for loops
        loss = torch.stack([ # stacks across levels
                            torch.stack( # stacks across models in level
                                        [self.criterion(logits_ij.view(b*t, v), # reshapes for CELoss
                                                        target.view(b*t)) 
                                         for logits_ij in logits[i]] # iterates across models in level
                            ).sum() # sums across models in level
                            for i in range(len(logits))] # iterates across levels
                          ).sum() # sums across levels

        if verbose:
            print(f"final loss: {loss}")
            print("------------- END FractalLoss.forward() ------------")

        return loss

In [21]:
# Testing our FractalLoss
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

loss = FractalLoss(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
logits = ((torch.randn((2,3,config.vocab_size)),),
     (torch.randn((2,3,config.vocab_size)),torch.randn((2,3,config.vocab_size))))
print(f"logits: {logits}")
target = torch.randint(config.vocab_size, (2,3)).unsqueeze(0)
print(f"target: {target}")
out = loss(logits, target)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[-1.3753, -0.7115, -1.8179, -0.5396],
        [ 0.2276,  1.7926, -1.4529,  1.0655],
        [ 0.5938,  0.5983,  0.4228,  0.5722],
        [ 1.0451,  0.0733, -0.4143, -0.4269],
        [ 0.0781, -0.5721, -0.1324, -0.7755]])
logits: ((tensor([[[-1.8138,  1.2353,  0.5865,  0.6241, -0.8431],
         [ 0.5524,  1.0414, -0.9325, -0.9413,  0.1517],
         [ 2.0922, -1.1598,  0.0593, -0.7795,  0.3292]],

        [[-0.6612,  1.0559,  0.7058,  0.1552,  0.9851],
         [-2.2625, -0.3402, -0.6354,  0.8113,  1.6704],
         [ 1.1679, -0.0254,  0.5453, -0.0048, -0.5645]]]),), (tensor([[[ 0.6266, -0.2598,  0.9878,  0.1268,  0.2202],
         [-2.1192, -1.1076, -1.2282, -0.7977, -1.2385],
         [ 0.7055, -0.5674,  1.9132,  0.1368,  1.4071]],

        [[-0.8021, -1.8081, -0.6775, -1.6985, -0.7555],
         [ 2.6726, -0.7442,  0.1797, -0.6538,  0.7061],
         [ 1.1845,

# ------------ BOOKMARK ----------------

# The Model itself

In [12]:
class FractalFormer(nn.Module):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer

        # hyperparameters
        self.hidden_size = config.hidden_size
        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size

        ### FractalFormer-specific hyperparameters
        self.num_levels = config.levels # the number of levels for sub-models to exist on
        self.split = config.split # the number of splits to make at a given level
        self.model_count = config.model_count # list of number of models at a given level
        self.model_dim_list = config.model_dim_list # list of hidden dimensions corresponding to each given level
        self.head_dim_list = config.head_dim_list # list of attention head dimensions corresponding to each given level    

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.hidden_size)

        # for normalizing the initial embeddings
        self.embedder_norm = RMSNorm(config.hidden_size)

        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        self.output_layer = OutputLayer(self.embedder.weight, config)
        # i think i need to do this bc in the above version you can't use `self.` inside the init
        #@property 
        #def output_layer(self):
            #return OutputLayer(self.embedder.weight, config)

        # the loss function
        self.criterion = FractalLoss()

    def forwardTensor(self,
                      input_token_ids: torch.Tensor,
                      level: int = 0, # integer designating the level of model to use. 0 is largest model, -1 is smallest
                      model: int = 0, # integer designating the model in that level to use. 0 is top-left, -1 is bottom right
                     ) -> torch.Tensor:
        """
        inputs: 
            - input_token_ids (torch.Tensor): a tensor of integers size (batch_size, sequence_length)
            - level: integer designating the level of model to use. 0 is largest model, -1 is smallest
            - model: integer designating the model in that level to use. 0 is top-left, -1 is bottom right
        output: a torch.Tensor shape (batch_size, sequence_length, vocab_size)
        """
        global verbose
        if verbose: 
            print("------------- FractalFormer.forwardTensor() ------------")
            print(f"input_token_ids: {input_token_ids.shape}\n{input_token_ids}")
        
        # adjusting everything to the specified level & model
        d_dim = self.hidden_size / (2**level)
        d_skip = model * h_dim
        if verbose:
            print(f"d_dim: {d_dim}")
            print(f"d_skip: {d_skip}")
        
        # turn the input tokens into the first residual state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size) -> (batch_size, input_len, h_dim)
        x = self.embedder(input_token_ids)[:,:, d_skip:d_skip + d_dim]
        if verbose: print(f"initial x: {x.shape}\n{x}")
        
        # Gemma normalizes the embedding by sqrt(hidden_size)
        # the question is, should I do this with the full sized hidden_size or do it at the splice size????
        # imma do it at the splice size and change it later if i think the models aren't learning well
        #x = x * (d_dim**0.5)
        # alternatively i could just switch to doing a regular RMSNorm which would be more like me
        # if i figure out this different sizes of hyperspheres thing it'd be more in line with that
        x = self.embedder_norm(x, model)
        if verbose: print(f"normalized initial x: {x.shape}\n{x}")

        # Iteratively process the input through each Layer
        for i, layer in enumerate(self.layers):
            if verbose: print(f"begin layer {i}")
            x = layer(x, model)
            if verbose: print(f"output of layer {i}: {x.shape}\n{x}")

        logits = self.output_layer(x, model)
        if verbose: print(f"output logits: {logits.shape}\n{logits}")

        return logits

    def forwardTuple(self,
                     input_token_ids: torch.Tensor,
                     target_token_ids: torch.Tensor,
                    ) -> torch.Tensor:
        
        x0 = ()
        for i in range(self.num_levels):

            # our splicing setup
            h_dim = self.model_dim_list[i]
            i_dim = h_dim * self.intermediate_multiplier
            
            x0_lvl = ()#, (), (), ()
            for j in range(self.model_count[i]):

                # splicing specific to this model
                h_skip = j * h_dim
                i_skip = j * i_dim

                x0_lvl += (elf.embedder(input_token_ids)[:,:, h_skip:h_skip + h_dim] * (h_dim**0.5),)

            x0 += (x0_lvl)
    
    def forward(self,
                input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len OR max_seq_len)list of integer token ids
                target_token_ids: torch.Tensor = None, # a shape (batch_size, max_seq_len) list of token ids to train on
                level: int = 0, # integer designating the level of model to use. 0 is largest model
                model: int = 0, # integer designating the model in that level to use. 0 is top-left model in level
                ):
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            logits = self.forwardTensor(input_token_ids, level, mode)
            loss = None
        else:
            # if we are training
            # training uses a tuple of tuples of tensors
            logits = self.forwardTuple(input_token_ids) # -> Tuple[Tuple[Tensor shape (batch_size, max_seq_len, vocab_size)]]
            
            # custom Fractal CE loss function
            loss = self.criterion(logits, targets) 
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        The class operates as follows:
    
        1. Selects the last hidden state for each sequence in the batch
    
        2. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
    
        3. Temperature is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling (flatter -> more random)
    
        4. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        5. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens. 
        We to ignore long tail probabilities to avoid nonsensical output
    
        7. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        8. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        9. The function then samples from the re-normalized probability distribution to select the next token. 
        """
        # Select the last element for each sequence.
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        logits = logits[:,-1,:]
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.div_(temperature) # div_ is an in-place operation which is ok since we don't record gradients during inference

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k
        # both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # probs_sort contains float probabilities while probs_idx contains integer indices

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str: 
        """Generates responses for given prompts using Gemma model."""
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_position_embeddings

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                logits = logits, # the actual output of the model
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# Training-related Functions

In [13]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [14]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [15]:
# a demonstration of what a batch with batch_size=1 looks like. Notice the one-token offset in characters
xb, yb = get_batch('train', 1)
print(xb)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(yb)
print(tokenizer.decode(yb.squeeze(0).tolist()))

tensor([[ 88,   1,  72,  52,   0,  14,  43,   1,  54,  68,  44,  43,  41,  58,
          85,  24,  33,  15,  21,  27,  71,  21,   1,  61,  77,  56,  70,  58,
           1,  88,  56,   1,  46,  79, 127,  85,  16,  33,  23,  17,   1,  34,
          21,  26,  15,  17,  26,  32,  21,  27,  71,  32,  87,   1,  61,  77,
          56,  70,  58,  57,   1, 105,   1,  88,  56,  91,  50,  44, 125,  58,
          39, 124,   1,  87, 113,   1,  84,   5,  58,  85,  21,  31,  13,  14,
          17,  24,  24,  13,  71,  32, 102,  57,   1,  45,  76,  58,  99,  51,
          70,   1,  84, 111,   1,  57,  53,  83,  61,  92,  58,   1,  94,   1,
         101,   1,  58,  39,  99,   6,   7,   7,  75,  24,  33,  15,  21,  27,
          71,  30,  47, 122,  58,  85,  16,  33,  23,  17,   1,  34,  21,  26,
          15,  17,  26,  32,  21,  27,  71,  21,  58,   1,  51, 106,   1,  98,
           1,  56,  47, 122,  58, 125,  40, 114,   1,  88,   1,  77,  43,   1,
          47,   5,   1,  72,   1,  61, 115,  52,  45

In [16]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [17]:
model = minGemma(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

972.416 K parameters
minGemma(
  (embedder): Embedding(128, 128)
  (model): Body(
    (layers): ModuleList(
      (0-3): 4 x Layer(
        (self_attn): Attention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)


# Load a Pretrained Model

In [18]:
# Initialize a blank model
model = minGemma(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGemma-vocab_size128-max_position_embeddings256-num_hidden_layers4-num_attention_heads4-num_key_value_heads1-hidden_size128-intermediate_size512-head_dim32-rms_norm_eps1e-06-rope_theta100.0--2024-02-26|11-10-53.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

964.352 K parameters


minGemma(
  (embedder): Embedding(65, 128)
  (model): GemmaBody(
    (layers): ModuleList(
      (0-3): 4 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)

# Training

In [18]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 5000

# how often we want to check & see how our loss is doing
eval_interval = 250

# batch size to use
batch_size = 32

In [19]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 130.0421, val loss 130.0103, time elapsed: 0.73 seconds
step 250: train loss 4.8103, val loss 4.9765, time elapsed: 138.16 seconds
step 500: train loss 3.6559, val loss 3.6816, time elapsed: 342.70 seconds
step 750: train loss 3.2894, val loss 3.3416, time elapsed: 566.02 seconds
step 1000: train loss 3.1266, val loss 3.1717, time elapsed: 769.44 seconds
step 1250: train loss 3.0514, val loss 3.1126, time elapsed: 903.18 seconds
step 1500: train loss 2.9887, val loss 3.0574, time elapsed: 1036.83 seconds
step 1750: train loss 2.9147, val loss 3.0104, time elapsed: 1171.64 seconds
step 2000: train loss 2.8687, val loss 2.9626, time elapsed: 1336.68 seconds
step 2250: train loss 2.8162, val loss 2.9178, time elapsed: 1470.40 seconds
step 2500: train loss 2.7705, val loss 2.8822, time elapsed: 1652.25 seconds
step 2750: train loss 2.7071, val loss 2.8136, time elapsed: 1785.60 seconds
step 3000: train loss 2.6603, val loss 2.7935, time elapsed: 1918.52 seconds
step 3250

# Saving your model

In [20]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}'
           f'-vocab_size{config.vocab_size}'
           f'-max_position_embeddings{config.max_position_embeddings}'
           f'-num_hidden_layers{config.num_hidden_layers}'
           f'-num_attention_heads{config.num_attention_heads}'
           f'-num_key_value_heads{config.num_key_value_heads}'
           f'-hidden_size{config.hidden_size}'
           f'-intermediate_size{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-rms_norm_eps{config.rms_norm_eps}'
           f'-rope_theta{config.rope_theta}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [23]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou lord,
Bol, am the lad we vowerly her lastion greathe!

Voddon:
He his latter in my slould is fleeck frideed
Or sperate so placel
And mot to which conour barksag his light
And see as please mene meanner.
This scied what is ontued to my lead
How I dod, me wit destrined have fain and
do by 
