# FractaFormer

this base version is going to be absurdly terribly no-good inefficient because we're taking the biggest computational issue ($O(t^2)$ attention) and making it way worse by doing multiple of them at once. plan is to fix / work around this later such that we only double our memory requirement or even keep it the same at the sacrifice of a marginal amount of context length. 

# TODO
- ~config~
- ~layernorm~
- embedding
    - ~decide whether to normalize embedding according to largest hidden size or that model's hidden size. i think i've gotta do the latter~
- mlp
- mha
- block
- output
- loss
- body
- batch splitup?
- inference choice?

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# used for the tokenizer
import pickle
import os

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union
import numpy as np

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer

We'll be using a very simple tokenizer I previoiusly trained off of the TinyShakespeare dataset that has 128 total tokens and ignores stuff like special tokens & regex. 

In [3]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)
        
# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text, len(encoded_text))

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text, len(decoded_text))

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12] 37
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo? 49


# Config

In [4]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for FractalFormer
    """
    # The number of tokens in the vocabulary.
    vocab_size: int = tokenizer.vocab_len
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256
    
    # The number of layers in the model.
    num_hidden_layers: int = 4
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4
    
    # The number of key-value heads for implementing multi-query attention.
    num_key_value_heads: int = 1
    # Ensures that the number of query heads is evenly divisible by the number of KV heads.
    assert num_attention_heads % num_key_value_heads == 0
    
    # The hidden size of the model, AKA the embedding dimension
    hidden_size: int = 128
    # the attention heads need to cleanly divide up the hidden_size of the model for MQA
    assert hidden_size % num_attention_heads == 0

    # how much larger the inner dimension of the MLP should be than the hidden size of the model
    intermediate_multiplier = 4
    # The inner dimension of the MLP part of the decoder layer
    @property
    def intermediate_size(self):
        return self.intermediate_multiplier * self.hidden_size
    
    # The number of head dimensions
    head_dim: int = 32
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6 # this is to promote numerical stability & prevent dividing by 0
    
    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too. 10,000 is the usual

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # the % of neurons to dropout in the MLP
    dropout = 0.1

    ####### FractalFormer-specific hyperparameters

    # the number of levels for sub-models to exist on
    levels = 3
    
    # the number of splits to make at a given level
    split = 2 # i don't recommend choosing any value other than 2
    assert split % 2 == 0

    @property
    def model_count(self):
        return [self.split**i for i in range(self.levels)]

    @property
    def model_dim_list(self):
        return [self.hidden_size // (self.split**i) for i in range(self.levels)]

    @property
    def head_dim_list(self):
        return [self.head_dim // (self.split**i) for i in range(self.levels)]

config = Config()

print("single large model -> hierarchy of many smaller models inside")
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

single large model -> hierarchy of many smaller models inside
model_count:  [1, 2, 4]
model_dim_list:  [128, 64, 32]
head_dim_list:  [32, 16, 8]


# Rotary Positional Encoding (RoPE)

i don't think i need to adjust the code for this one as long as i always call it individually

In [5]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    # dynamic is less efficient but pre-computed was giving me trouble so whatever
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# RMSNorm

Layernorm is relatively simple code-wise. However, of note is the fact that during training, the entire full length vector gets normalized whereas during inference we only layernorm the sub-vector we've been given if we're not using the full model size. This probably isn't a big deal since the sub-vectors are still hopefully being drawn from the same distribution during training. However, it wouldn't be surprising if the logits going into the small vectors are characteristically different from the full super-vectors, in which case this certainly might be a difficulty for the model. It might be worth changing this algorithm such that during training sub-vectors get normalized first and then held constant while super-vectors are normalized. something to think about. 

In [6]:
class RMSNorm(torch.nn.Module):
    """
    Implements the RMS Normalization (Root Mean Square Normalization) layer.
    RMSNorm is a variant of layer normalization that normalizes the activations
    of the previous layer based on their root mean square value.

    Parameters:
    - dim (int): The dimension of the input features the normalization is applied to.
    - eps (float): A small value added to the denominator for numerical stability. Default is 1e-6.
    - add_unit_offset (bool): If True, adds a unit (1) to the learned scaling coefficient, effectively
      starting with no scaling. If False, the scaling coefficient starts from zero. Default is True.
    """

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        #add_unit_offset: bool = True,
    ):
        super().__init__() 
        self.eps = eps  # Small epsilon value for numerical stability since you can't divide by 0
        #self.add_unit_offset = add_unit_offset  # Flag to determine if a unit should be added to the weight
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        """
        Private helper function to normalize the input tensor.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - Tensor: The normalized tensor.
        """
        # Calculate the root mean square value for each feature (across the last dimension),
        # then use reciprocal square root (rsqrt) for normalization.
        # Add self.eps to the denominator for numerical stability.
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forwardTensor(self, x):
        """
        Forward pass of the RMSNorm layer in the case of inference.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - output: The normalized and scaled tensor.
        """
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        
        # If add_unit_offset is True, scale the normalized tensor by (1 + self.weight),
        # effectively starting with a scaling factor of 1 (no scaling).
        # Otherwise, scale by self.weight only.
        #if self.add_unit_offset:
        # nah let's just make that the default to make my code simpler
        output = x * (1 + self.weight)
        #else:
            #output = x * self.weight
            
        # Return the scaled output tensor.
        return output

    def forwardTuple(self, x):
        """
        Forward pass of the RMSNorm layer in the case of training.

        Parameters:
        - x (Tuple[Tuple[Tensor]]): The input tuple containing tuples containing tensors to normalize.

        Returns:
        - output (Tuple[Tuple[Tensor]]): The normalized and scaled tuple of tuples of tensors.
        """
        # if we had sent through the config we could've just grabbed these values from there but too late now
        levels = len(x)
        models_per_level = [len(x[i]) for i in range(levels)]
        
        output = ()
        for i in range(levels):
            
            level = ()
            for j in range(models_per_level[i]):
                # grab the right tensor from the hierarchy, make sure it's the right type, normalize, then multiply by 1+learned weight
                level += (self._norm(x[i][j].float()).type_as(x[i][j]) * (1 + self.weight),)
            
            # adding that level to the final output tensor
            output += (level,)
        
        return output
        
    def forward(self, x):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x)

# Multi-Layer Perceptron

ugh this is gonna be extra complicated taking into account the GeGLU gate. Might consider taking that out entirely

<p align="center">
<img src="./images/ffwd.jpeg" width="512"/>
</p>

In [7]:
class MLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.

    Attributes:
        gate_proj (nn.Linear): A linear layer that transforms the input tensor to an intermediate
                               representation, which is then passed through a GeLU activation for
                               gating purposes.
        up_proj (nn.Linear): Another linear layer that transforms the input tensor to the same
                             intermediate representation but without gating. This representation
                             is element-wise multiplied by the gated tensor from `gate_proj`.
        down_proj (nn.Linear): A linear layer that transforms the gated and combined tensor back
                               to the original dimension of the hidden size, producing the final output.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout: float = 0.1,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.intermediate_multiplier = intermediate_size // hidden_size

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        #self.gate_proj = nn.Linear(hidden_size, intermediate_size)
        self.gate_w = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.gate_b = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        #self.up_proj = nn.Linear(hidden_size, intermediate_size)
        self.up_w = nn.Parameter(torch.Tensor(hidden_size, intermediate_size))
        self.up_b = nn.Parameter(torch.Tensor(intermediate_size))

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        #self.down_proj = nn.Linear(intermediate_size, hidden_size)
        self.down_w = nn.Parameter(torch.Tensor(intermediate_size, hidden_size))
        self.down_b = nn.Parameter(torch.Tensor(hidden_size))

        # Initialize weights with uniform distribution
        # For gate & up, where in_features is hidden_size
        limit_gateup = 1 / np.sqrt(hidden_size)
        #print(type(limit_gateup))
        nn.init.uniform_(self.gate_w, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.gate_b, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.up_w, -limit_gateup, limit_gateup)
        nn.init.uniform_(self.up_b, -limit_gateup, limit_gateup)
        
        # For down, where in_features is intermediate_size
        limit_down = 1 / np.sqrt(intermediate_size)
        nn.init.uniform_(self.down_w, -limit_down, limit_down)
        nn.init.uniform_(self.down_b, -limit_down, limit_down)
        
        self.drop = nn.Dropout(dropout)

    def forwardTensor(self, x, model):
        """
        Defines the forward pass of the MLP module during inference.

        Parameters:
            x (Tensor): The input tensor to the MLP. 
                        shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        # figuring out how we should do our splicing
        h_dim = x.shape[-1]
        h_skip = model * h_dim
        i_dim = h_dim * self.intermediate_multiplier
        i_skip = model * i_dim
        
        # Applies linear transformation for gating.
        gate = x @ self.gate_w[h_skip:h_skip + h_dim, i_skip:i_skip + i_dim] + self.gate_b[i_skip:i_skip + i_dim]

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        gate = F.gelu(gate)

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        up = x @ self.up_w[h_skip:h_skip + h_dim, i_skip:i_skip + i_dim] + self.up_b[i_skip:i_skip + i_dim]

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        fuse = gate * up

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        outputs = fuse @ self.down_w[i_skip:i_skip + i_dim, h_skip:h_skip + h_dim] + self.down_b[h_skip:h_skip + h_dim]

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

    def forwardTuple(self, x):
        """
        Defines the forward pass of the MLP module during training.

        Parameters:
            x (Tuple[Tuple[Tensor]]): 
                The input tuple of tuples of tensors to the MLP. 
                first tuple is of length config.levels and second layer of tuples have lengths of config.model_count
                tensors are shape (batch size, sequence length, hidden dimension) where hidden dimension changes by which model was used

        Returns:
            Tuple[Tuple[Tensor]]: 
                The output tuple of tuples of tensors after applying the GeGLU gating mechanism and the MLP transformations.
        """
        global verbose

        # if we had sent through the config we could've just grabbed these values from there but too late now
        num_levels = len(x)
        if verbose == True: print(f"num_levels: {num_levels}")
        models_per_level = [len(x[i]) for i in range(num_levels)]
        if verbose == True: print(f"models_per_level: {models_per_level}")
        
        gate, up, fuse, out = (), (), (), ()
        for i in range(num_levels):
            if verbose == True: print(f"i: {i}")

            # our splicing setup
            h_dim = x[i][0].shape[-1]
            if verbose == True: print(f"h_dim: {h_dim}")
            i_dim = h_dim * self.intermediate_multiplier
            if verbose == True: print(f"i_dim: {i_dim}")
            
            gate_lvl, up_lvl, fuse_lvl, out_lvl = (), (), (), ()
            for j in range(models_per_level[i]):
                if verbose == True: print(f"j: {j}")

                # splicing specific to this model
                h_skip = j * h_dim
                if verbose == True: print(f"h_skip: {h_skip}")
                i_skip = j * i_dim
                if verbose == True: print(f"i_skip: {i_skip}")

                # Applies linear transformation and then GeLU activation for the gate
                gate_lvl += (F.gelu(x[i][j] @ \
                                self.gate_w[h_skip:h_skip + h_dim, i_skip:i_skip + i_dim] + \
                                self.gate_b[i_skip:i_skip + i_dim]),)
                if verbose == True: print(gate_lvl)
        
                # Applies another linear transformation to the input tensor for subsequent combination with the gate.
                up_lvl += (x[i][j] @ \
                       self.up_w[h_skip:h_skip + h_dim, i_skip:i_skip + i_dim] + \
                       self.up_b[i_skip:i_skip + i_dim],)
                if verbose == True: print(up_lvl)
        
                # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
                # the input based on the gate's activation.
                fuse_lvl += (gate_lvl[j] * up_lvl[j],)
                if verbose == True: print(fuse_lvl)
            
                # Applies the final linear transformation to project the modulated tensor back to the hidden size.
                
                out_lvl += (self.drop(fuse_lvl[j] @ \
                            self.down_w[i_skip:i_skip + i_dim, h_skip:h_skip + h_dim] + \
                            self.down_b[h_skip:h_skip + h_dim]),)

            # pretty sure i have to save & store everything without overwriting to prevent in-place arguments. so annoying
            gate += (gate_lvl,)
            up += (up_lvl,)
            fuse += (fuse_lvl,)
            out += (out_lvl,)

        return out
        
    def forward(self, x, model=0):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x, model)

In [8]:
verbose = False
mlp = MLP(4,16)
x = ((torch.randn((2,3,4)),),(torch.randn((2,3,2)),torch.randn((2,3,2))))
print(f"x: {x}")
out = mlp(x)
print(f"out: {out}")

x: ((tensor([[[-0.0187, -0.3739, -1.3734, -0.2533],
         [ 0.9594,  2.6050, -0.9794, -1.1327],
         [ 1.0331, -2.2028,  0.5739,  1.8022]],

        [[ 0.0093, -1.1242,  0.7825,  1.2958],
         [-1.7637, -1.3356,  0.8032, -1.7656],
         [ 1.6628,  1.5548,  1.4336,  2.7481]]]),), (tensor([[[ 0.6105,  1.1087],
         [ 0.2508, -0.6298],
         [-0.3436,  0.8128]],

        [[ 0.7089, -0.2783],
         [-0.1910,  1.5686],
         [-0.0410,  1.2335]]]), tensor([[[-0.3503,  0.2396],
         [ 1.0995,  1.8221],
         [ 0.8182,  0.0181]],

        [[ 1.3782, -2.1652],
         [ 1.7282,  1.5256],
         [-0.0046, -1.6631]]])))
out: ((tensor([[[-0.0000,  0.0811,  0.0683, -0.2686],
         [-0.4018, -0.0000,  0.3014, -0.2573],
         [-0.2988,  0.2092,  0.0114, -0.0000]],

        [[-0.2822,  0.1442, -0.0011, -0.0000],
         [ 0.6052,  0.3796,  0.7482, -0.2609],
         [ 0.0081, -0.2525, -0.0445, -0.0000]]], grad_fn=<MulBackward0>),), (tensor([[[-0.1524, -0.175

# Attention

To subset the attention heads, we have to not only splice according to the model's embedding dimension but also take into account new smaller head sizes and how they're spaced throughout the matrix. I'm assuming you know how self-attention works well enough to look at this weight matrix and get the idea

<p align="center">
<img src="./images/sa.jpeg" width="512"/>
</p>

Might be annoying taking into account MQA in our splicing but i don't think it should be. moreso it'll be annoying to deal with the weird tensor format they use that i'm not used to


then we've gotta concatenate the outputs of each head

<p align="center">
<img src="./images/mha_concat.jpeg" width="512"/>
</p>

and after that linearly project them

<p align="center">
<img src="./images/mha_proj.jpeg" width="512"/>
</p>

this is the place where our splicing gets conceptually annoying. instead of just grabbing the matrix in the upper corner, because of the way attention head output concatenation works we actually need to skip over certain parts of the linear projection matrix and then concatenate them together in order to use them. Here's an example of what the matrix multiplication looks like. on the left is a simplified version of the concatenated attention heads where i just showed it as a matrix rather than a tensor, and then on the right is the actual projection matrix. notice how the numbers in the pink output matrix look similar to the first column of the purple output matrix with a positive number, its negative, and then a smaller positive number; that's the self-similarity in action. the yellow arrows point to the parts that get skipped over. obviously this would look a lot uglier with bigger matrices & incorporating the blue/green layer

<p align="center">
<img src="./images/mha_proj_matmul.jpeg" width="512"/>
</p>

# ------------ BOOKMARK ----------------

In [89]:
class Attention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim

        # Defines the scaling factor for the attention scores.
        #self.scaling = self.head_dim**-0.5
        # NEED TO DYNAMICALLY DEFINE SCALING
        
        # Initialize our learnable matrices
        # the linear projection layer for queries, keys, and values
        #self.Wqkv = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        # let's try to instead split them up
        #self.Wq = nn.Parameter(torch.Tensor(self.hidden_size, self.num_heads * self.head_dim))
        #self.Wk = nn.Parameter(torch.Tensor(self.hidden_size, self.num_kv_heads * self.head_dim))
        #self.Wv = nn.Parameter(torch.Tensor(self.hidden_size, self.num_kv_heads * self.head_dim))
        # nvm, back to plan a where qkv are defined together
        self.Wqkv = nn.Parameter(torch.Tensor(self.hidden_size, 
                                                  (self.num_heads + 2 * self.num_kv_heads) * self.head_dim))
        # the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        #self.Wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self.Wo = nn.Parameter(torch.Tensor(self.num_heads * self.head_dim, self.hidden_size))
        
        # Initialize weights with uniform distribution
        # For qkv_proj, where in_features is hidden_size
        limit_Wqkv = 1 / np.sqrt(self.hidden_size)
        #nn.init.uniform_(self.Wq, -limit_Wqkv, limit_Wqkv)
        #nn.init.uniform_(self.Wk, -limit_Wqkv, limit_Wqkv)
        #nn.init.uniform_(self.Wv, -limit_Wqkv, limit_Wqkv)
        # nvm, back to plan a where qkv are defined together
        nn.init.uniform_(self.Wqkv, -limit_Wqkv, limit_Wqkv)
        # for o_proj, where in_features is self.num_heads * self.head_dim
        limit_Wo = 1 / np.sqrt(self.num_heads * self.head_dim)
        nn.init.uniform_(self.Wo, -limit_Wo, limit_Wo)
        
        # let's do a multiplicative next-token prediction attention mask
        mask = torch.tril(torch.ones((1, 1, config.max_position_embeddings, config.max_position_embeddings), device=config.device))
        self.register_buffer('mask', mask)

    def forwardTensor(self,
                      x: torch.Tensor,
                      model: int = 0,
                     ) -> torch.Tensor:
        """
        Inputs:
            x (torch.Tensor): Te input tensor to the attention mechanism.
                        shape (batch_size, input_len, hidden_size)
            model (int): the indicator of which model we're using. 
                        used in calculating our skip length for splicing. 
                        defaults to the equivalent of what's used in MatFormer+, meaning no skip, aka we use the top-left-most splice
        
        Returns:
            Tensor: The output tensor after applying the attention mechanism
        """
        global verbose
        if verbose: print("----------------- Attention.forwardTensor() --------------------")
        
        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        hidden_states_shape = x.shape
        assert len(hidden_states_shape) == 3
        if verbose: print(f"hidden_states_shape: {hidden_states_shape}")

        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, d_dim = hidden_states_shape
        
        ### figuring out how we should do our splicing
        # first along the embedding dimension
        d_skip = model * d_dim  # the size of our skip along the model's embedding dimension
        if verbose: print(f"d_skip: {d_skip}")
        # then along the head sizes
        index = config.model_dim_list.index(d_dim)
        models_in_this_level = config.model_count[index] # how many models are in this level
        if verbose: print(f"models_in_this_level: {models_in_this_level}")
        h_dim = config.head_dim_list[index] # the head dimension size of this model in this level
        if verbose: print(f"h_dim: {h_dim}")
        h_skip = model * h_dim # the size of our skip along the head dimension
        if verbose: print(f"h_skip: {h_skip}")

        if verbose: print(f"self.Wqkv: {self.Wqkv.shape}\n{self.Wqkv}")
            #print(f"self.Wq: {self.Wq.shape}\n{self.Wq}")
            #print(f"self.Wk: {self.Wk.shape}\n{self.Wk}")
            #print(f"self.Wv: {self.Wv.shape}\n{self.Wv}")

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        #xqkv = x @ self.Wqkv
        
        # Splits the Wqkv tensor into separate tensors for queries, keys, and values based on their respective sizes.
        Wq, Wk, Wv = self.Wqkv.split([self.q_size,
                                      self.kv_size,
                                      self.kv_size],dim=-1)
        if verbose: 
            print(f"Wq: {Wq.shape}\n{Wq}")
            print(f"Wk: {Wk.shape}\n{Wk}")
            print(f"Wv: {Wv.shape}\n{Wv}")
        
        # splicing to get our correct weight matrices for each respective head
        # d_dim is relatively self-explanatory
        # i*self.head_dim is bc we initialized one single q, k, and v matrix for all heads so we have to
        # iterate through said matrix to get to the correct head
        Wq = torch.cat([Wq[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_heads)], dim=0)
        Wk = torch.cat([Wk[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=0)
        Wv = torch.cat([Wv[d_skip:d_skip + d_dim,\
                               i*self.head_dim + h_skip:i*self.head_dim + h_skip + h_dim] \
                               for i in range(self.num_kv_heads)], dim=0)
        if verbose:
            print(f"Wq spliced: {Wq.shape}\n{Wq}")
            print(f"Wk spliced: {Wk.shape}\n{Wk}")
            print(f"Wv spliced: {Wv.shape}\n{Wv}")

        # need to rearrange in terms of heads so that the first dimension matches x's last dimension
        Wq = Wq.view(d_dim, self.num_heads * h_dim)
        Wk = Wk.view(d_dim, self.num_kv_heads * h_dim)
        Wv = Wv.view(d_dim, self.num_kv_heads * h_dim)
        if verbose:
            print(f"Wq reshaped: {Wq.shape}\n{Wq}")
            print(f"Wk reshaped: {Wk.shape}\n{Wk}")
            print(f"Wv reshaped: {Wv.shape}\n{Wv}")
        
        # this needs to be size (d_dim, (self.num_heads + 2 * self.num_kv_heads) * h_dim) aka (32,24)
        # recombine the spliced Wq Wk and Wv. Now they're the right size for matmul against x
        Wqkv_spliced = torch.cat((Wq, Wk, Wv), dim=-1)
        if verbose:
            print(f"Wqkv_spliced: {Wqkv_spliced.shape}\n{Wqkv_spliced}")
        

        # finally we can project x to get our queries, keys and values
        xqkv = x @ Wqkv_spliced
        if verbose: print(f"xqkv: {xqkv.shape}\n{xqkv}")
            
        # Splits the combined Xqkv tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = xqkv.split([self.q_size // models_in_this_level,
                                 self.kv_size // models_in_this_level,
                                 self.kv_size // models_in_this_level],dim=-1)
        if verbose:
            print(f"xq: {xq.shape}\n{xq}")
            print(f"xk: {xk.shape}\n{xk}")
            print(f"xv: {xv.shape}\n{xv}")
        
        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        #q_proj_spliced_reshaped = q_proj_spliced.view(d_dim, self.num_heads, h_dim)
        #if verbose:
            #print(f"q_proj_spliced_reshaped: {q_proj_spliced_reshaped.shape}\n{q_proj_spliced_reshaped}")
        
        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        #xq = x @ q_proj_spliced_reshaped
        #xk = x @ k_proj_spliced
        #xv = x @ v_proj_spliced
        #if verbose:
            #print(f"xq: {xq.shape}\n{xq}")
            #print(f"xk: {xk.shape}\n {xk}")
            #print(f"xv: {xv.shape}\n{xv}")
        
        # Splits the combined QKV tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        #xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)
        # for readability's sake it would've made more sense to do these separately; this is just more efficient

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, h_dim)#, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, h_dim)#, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, h_dim)#, self.head_dim)
        if verbose:
            print(f"xq reshaped: {xq.shape}\n{xq}")
            print(f"xk reshaped: {xk.shape}\n{xk}")
            print(f"xv reshaped: {xv.shape}\n{xv}")

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, h_dim, self.theta)#self.head_dim
        xk = apply_rotary_emb(xk, h_dim, self.theta)#self.head_dim
        # is the differring head dimension going to mess with RoPE? Not sure
        if verbose:
            print(f"rotated xq: {xq.shape}\n{xq}")
            print(f"rotated xk: {xk.shape}\n{xk}")

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
            if verbose:
                print(f"repeat_interleaved xk: {xk.shape}\n{xk}")
                print(f"repeat_interleaved xv: {xv.shape}\n{xv}")

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)
        if verbose:
            print(f"transposed xq: {q.shape}\n{q}")
            print(f"transposed xk: {k.shape}\n{k}")
            print(f"transposed xv: {v.shape}\n{v}")

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * h_dim**-0.5#self.scaling
        if verbose: print(f"scores: {scores.shape}\n{scores}")
        
        # Applies the lower-triangular mask to the attention scores
        #scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        # is it weird that we're masking with addition of 0's & big negatives instead of multiplication of 1's and 0's or -inf's? 
        # as far as i'm aware it's weird, although my knowledge could just be out of date
        # Let's do multiplicative masking instead
        scores = scores * self.mask[...,:input_len, :input_len]
        if verbose: print(f"masked scores: {scores.shape}\n{scores}")

        # Applies softmax to the scores to obtain attention probabilities
        scores = F.softmax(scores, dim=-1)
        if verbose: print(f"softmaxed scores: {scores.shape}\n{scores}")
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        attention = torch.matmul(scores, v)
        if verbose: print(f"attention: {attention.shape}\n{attention}")

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
        if verbose: print(f"reshaped attention: {attention.shape}\n{attention}")

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        #output = attention @ self.Wo
        ############## ----------- BOOKMARK ------------- #####################
        # SPLICE THE OUTPUT PROJECTION PROPERLY
        if verbose: print(f"projected output: {output.shape}\n{output}")

        return attention#output
        
    def forward(self, x, model=0):
        return self.forwardTuple(x) if type(x) == tuple else self.forwardTensor(x, model)

In [90]:
verbose = True
x = torch.rand(2,3,32)
print(f"x: {x.shape}\n{x}")
att = Attention(config)
y = att(x, model=1)
print(f"y: {y.shape}\n{y}")

x: torch.Size([2, 3, 32])
tensor([[[0.2831, 0.5838, 0.2652, 0.1360, 0.2022, 0.8635, 0.9388, 0.4171,
          0.6227, 0.0468, 0.5525, 0.2578, 0.5256, 0.4055, 0.6851, 0.2224,
          0.0740, 0.8306, 0.8897, 0.6061, 0.8700, 0.2752, 0.2746, 0.2401,
          0.9001, 0.8101, 0.6042, 0.9132, 0.9799, 0.4491, 0.5714, 0.7236],
         [0.4844, 0.3744, 0.1428, 0.3294, 0.9646, 0.9561, 0.5341, 0.6435,
          0.6770, 0.3129, 0.6677, 0.8829, 0.3896, 0.0560, 0.7192, 0.8599,
          0.0610, 0.5095, 0.0233, 0.0155, 0.1114, 0.7996, 0.8959, 0.0939,
          0.0083, 0.0232, 0.6743, 0.2577, 0.7960, 0.0439, 0.1652, 0.2437],
         [0.6087, 0.2721, 0.0491, 0.5058, 0.1912, 0.8581, 0.6730, 0.8118,
          0.0719, 0.0231, 0.7280, 0.9515, 0.3354, 0.5369, 0.2869, 0.9954,
          0.2482, 0.2393, 0.6235, 0.0687, 0.0782, 0.0092, 0.3660, 0.1168,
          0.9901, 0.2285, 0.8573, 0.8750, 0.5281, 0.2803, 0.5326, 0.9085]],

        [[0.4712, 0.3832, 0.2123, 0.1769, 0.8063, 0.5218, 0.2128, 0.5727,
       

NameError: name 'output' is not defined

# Layer

nothing too interesting here besides the absurd amount of memory we're probably taking up with these tuples

In [10]:
class Layer(nn.Module):
    """
    A decoder layer that integrates the GemmaAttention mechanism and multi-layer perceptron (MLP). It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(self, config: Config):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = Attention(config)
        
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = MLP(
            # the hidden dimension of the model
            hidden_size = config.hidden_size,
            # the number of nodes in the center of the two feedforward layers
            intermediate_size = config.intermediate_size,
        )
        
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forward(self,
                # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
                hidden_states: torch.Tensor
                ) -> torch.Tensor:
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual = hidden_states
        # Normalizes the input before processing by the attention mechanism.
        hidden_states = self.input_layernorm(hidden_states)
        # Processes the normalized input through the GemmaAttention mechanism
        hidden_states = self.self_attn(hidden_states=hidden_states)
        # The aforementioned residual connection
        hidden_states = residual + hidden_states

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual = hidden_states
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        hidden_states = self.post_attention_layernorm(hidden_states)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        hidden_states = self.mlp(hidden_states)
        # Another residual connection
        hidden_states = residual + hidden_states

        return hidden_states

# The Body of the Model

In [11]:
class Body(nn.Module):
    """ the class that loops through each layer of Gemma """

    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self,
                # The first residual state of the model. shape (batch_size, input_len, hidden_size)
                hidden_states: torch.Tensor
               ) -> torch.Tensor:

        # Iteratively process the input through each DecoderLayer, passing along necessary parameters for attention.
        # The hidden states are updated at each layer, progressively incorporating more complex dependencies and transformations.
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(hidden_states=hidden_states)
        
        # Apply normalization to the output of the final decoder layer, ensuring the model's output is well-conditioned for subsequent use.
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

# Loss Function

In [None]:
class FractalLoss(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.criterion = nn.CrossEntropyLoss()
        

# The Model itself

So this is where i start it all. ugh

In [12]:
class FractalFormer(nn.Module):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer

        # hyperparameters
        self.hidden_size = config.hidden_size
        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size

        ### FractalFormer-specific hyperparameters
        self.num_levels = config.levels # the number of levels for sub-models to exist on
        self.split = config.split # the number of splits to make at a given level
        self.model_count = config.model_count # list of number of models at a given level
        self.model_dim_list = config.model_dim_list # list of hidden dimensions corresponding to each given level
        self.head_dim_list = config.head_dim_list # list of attention head dimensions corresponding to each given level    

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.hidden_size)
        
        # the body of the model; all the transformer decoder layers
        self.body = Body(config)

        # the loss function
        #self.criterion = nn.CrossEntropyLoss()
        self.criterion = FractalLoss()

    def forwardTensor(self,
                      input_token_ids: torch.Tensor,
                      level: int = 0, # integer designating the level of model to use. 0 is largest model
                      model: int = 0, # integer designating the model in that level to use. 0 is top-left model in level
                     ) -> torch.Tensor:
        
        # adjusting everything to the specified level & model
        h_dim = self.hidden_size / (2**level)
        h_skip = model * h_dim
        
        # turn the input tokens into the first resudial state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size) -> (batch_size, input_len, h_dim)
        hidden_states = self.embedder(input_token_ids)[:,:, h_skip:h_skip + h_dim]
        
        # Gemma normalizes the embedding by sqrt(hidden_size)
        # the question is, should I do this with the full sized hidden_size or do it at the splice size????
        # imma do it at the splice size and change it later if i think the model isn't learning well
        hidden_states = hidden_states * (h_dim**0.5)

        # this is where the bulk of the calculations are performed, the actual decoder layers
        hidden_states = self.body(hidden_states, level, model) # -> (batch_size, input_len, h_dim)

        # grabbing the weights of the spliced embedding matrix shape (vocab_size, h_dim) for use as the output layer
        embedder_weight = self.embedder.weight[:, h_skip:h_skip + h_dim]

        # the embedding matrix is also used as the output layer
        # this saves on parameters & makes sense for interpretability
        # (batch_size, input_len, h_dim) @ (h_dim, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(hidden_states, embedder_weight.t())

        return logits

    def forwardTuple(self,
                     input_token_ids: torch.Tensor,
                     target_token_ids: torch.Tensor,
                    ) -> torch.Tensor:
        
        x0 = ()
        for i in range(self.num_levels):

            # our splicing setup
            h_dim = self.model_dim_list[i]
            i_dim = h_dim * self.intermediate_multiplier
            
            x0_lvl = ()#, (), (), ()
            for j in range(self.model_count[i]):

                # splicing specific to this model
                h_skip = j * h_dim
                i_skip = j * i_dim

                x0_lvl += (elf.embedder(input_token_ids)[:,:, h_skip:h_skip + h_dim] * (h_dim**0.5),)

            x0 += (x0_lvl)
    
    def forward(self,
                input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len OR max_seq_len)list of integer token ids
                target_token_ids: torch.Tensor = None, # a shape (batch_size, max_seq_len) list of token ids to train on
                level: int = 0, # integer designating the level of model to use. 0 is largest model
                mode: int = 0, # integer designating the model in that level to use. 0 is top-left model in level
                ) -> torch.Tensor:
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            logits = self.forwardTensor(input_token_ids, level, mode)
            loss = None
        else:
            # if we are training
            # training uses a tuple of tuples of tensors
            logits = self.forwardTuple(input_token_ids) # -> Tuple[Tuple[Tensor shape (batch_size, max_seq_len, vocab_size)]]
            batch_size, _, _ = logits[0][0].shape
            
            # then we reshape our logits & targets before calculating cross-entropy loss
            #loss = self.criterion(logits.view(batch_size * self.max_seq_len, self.vocab_size), 
            #                      target_token_ids.view(batch_size * self.max_seq_len))
            
            # custom Fractal CE loss function
            loss = self.criterion(logits, targets) 
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        The class operates as follows:
    
        1. Selects the last hidden state for each sequence in the batch
    
        2. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
    
        3. Temperature is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling (flatter -> more random)
    
        4. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        5. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens. 
        We to ignore long tail probabilities to avoid nonsensical output
    
        7. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        8. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        9. The function then samples from the re-normalized probability distribution to select the next token. 
        """
        # Select the last element for each sequence.
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        logits = logits[:,-1,:]
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.div_(temperature) # div_ is an in-place operation which is ok since we don't record gradients during inference

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k
        # both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # probs_sort contains float probabilities while probs_idx contains integer indices

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str: 
        """Generates responses for given prompts using Gemma model."""
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_position_embeddings

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                logits = logits, # the actual output of the model
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# Training-related Functions

In [13]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [14]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [15]:
# a demonstration of what a batch with batch_size=1 looks like. Notice the one-token offset in characters
xb, yb = get_batch('train', 1)
print(xb)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(yb)
print(tokenizer.decode(yb.squeeze(0).tolist()))

tensor([[ 88,   1,  72,  52,   0,  14,  43,   1,  54,  68,  44,  43,  41,  58,
          85,  24,  33,  15,  21,  27,  71,  21,   1,  61,  77,  56,  70,  58,
           1,  88,  56,   1,  46,  79, 127,  85,  16,  33,  23,  17,   1,  34,
          21,  26,  15,  17,  26,  32,  21,  27,  71,  32,  87,   1,  61,  77,
          56,  70,  58,  57,   1, 105,   1,  88,  56,  91,  50,  44, 125,  58,
          39, 124,   1,  87, 113,   1,  84,   5,  58,  85,  21,  31,  13,  14,
          17,  24,  24,  13,  71,  32, 102,  57,   1,  45,  76,  58,  99,  51,
          70,   1,  84, 111,   1,  57,  53,  83,  61,  92,  58,   1,  94,   1,
         101,   1,  58,  39,  99,   6,   7,   7,  75,  24,  33,  15,  21,  27,
          71,  30,  47, 122,  58,  85,  16,  33,  23,  17,   1,  34,  21,  26,
          15,  17,  26,  32,  21,  27,  71,  21,  58,   1,  51, 106,   1,  98,
           1,  56,  47, 122,  58, 125,  40, 114,   1,  88,   1,  77,  43,   1,
          47,   5,   1,  72,   1,  61, 115,  52,  45

In [16]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [17]:
model = minGemma(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

972.416 K parameters
minGemma(
  (embedder): Embedding(128, 128)
  (model): Body(
    (layers): ModuleList(
      (0-3): 4 x Layer(
        (self_attn): Attention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)


# Load a Pretrained Model

In [18]:
# Initialize a blank model
model = minGemma(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGemma-vocab_size128-max_position_embeddings256-num_hidden_layers4-num_attention_heads4-num_key_value_heads1-hidden_size128-intermediate_size512-head_dim32-rms_norm_eps1e-06-rope_theta100.0--2024-02-26|11-10-53.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

964.352 K parameters


minGemma(
  (embedder): Embedding(65, 128)
  (model): GemmaBody(
    (layers): ModuleList(
      (0-3): 4 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)

# Training

In [18]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 5000

# how often we want to check & see how our loss is doing
eval_interval = 250

# batch size to use
batch_size = 32

In [19]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 130.0421, val loss 130.0103, time elapsed: 0.73 seconds
step 250: train loss 4.8103, val loss 4.9765, time elapsed: 138.16 seconds
step 500: train loss 3.6559, val loss 3.6816, time elapsed: 342.70 seconds
step 750: train loss 3.2894, val loss 3.3416, time elapsed: 566.02 seconds
step 1000: train loss 3.1266, val loss 3.1717, time elapsed: 769.44 seconds
step 1250: train loss 3.0514, val loss 3.1126, time elapsed: 903.18 seconds
step 1500: train loss 2.9887, val loss 3.0574, time elapsed: 1036.83 seconds
step 1750: train loss 2.9147, val loss 3.0104, time elapsed: 1171.64 seconds
step 2000: train loss 2.8687, val loss 2.9626, time elapsed: 1336.68 seconds
step 2250: train loss 2.8162, val loss 2.9178, time elapsed: 1470.40 seconds
step 2500: train loss 2.7705, val loss 2.8822, time elapsed: 1652.25 seconds
step 2750: train loss 2.7071, val loss 2.8136, time elapsed: 1785.60 seconds
step 3000: train loss 2.6603, val loss 2.7935, time elapsed: 1918.52 seconds
step 3250

# Saving your model

In [20]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}'
           f'-vocab_size{config.vocab_size}'
           f'-max_position_embeddings{config.max_position_embeddings}'
           f'-num_hidden_layers{config.num_hidden_layers}'
           f'-num_attention_heads{config.num_attention_heads}'
           f'-num_key_value_heads{config.num_key_value_heads}'
           f'-hidden_size{config.hidden_size}'
           f'-intermediate_size{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-rms_norm_eps{config.rms_norm_eps}'
           f'-rope_theta{config.rope_theta}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [23]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou lord,
Bol, am the lad we vowerly her lastion greathe!

Voddon:
He his latter in my slould is fleeck frideed
Or sperate so placel
And mot to which conour barksag his light
And see as please mene meanner.
This scied what is ontued to my lead
How I dod, me wit destrined have fain and
do by 
