# Next-Concept Prediction v11.2

So the main plan here is to
1. [x] turn minGemma into a kind of MatFormer
2. [x] have the smallest level deal with tokens, medium with first-level concepts, and large with 2nd-level concepts
3. [x] dynamically generate concept vectors based on regressive cosine similarity rather than storing huge vector combination vocabularies
4. [x] allow lower levels to attend to the vector representations at the levels above them using cross-attention
5. [ ] with inference, higher levels will run first, lower levels attend to them, and then each time a lowest level finishes its predictions of length config.combo then that value gets pushed up to the higher levels which now have to recalculate their predictions
    - [ ] rather than actually use the raw regression outputs, we can use cosine similarity to search for the token combination at that level that is most similar to the current regression output, and then add that to the sequence for our next run of the decoder instead. That way the model is always working in-distribution. Would be interesting to see which one performs better. Also, if this option is chosen then config.noise_sd would be useless bc we'd no longer have to deal with lower levels attending to upper level concepts that aren't in-distribution.

less important todo list
- [x] clean up the primary forward() by putting things into their own classes
- [ ] ~~change the get_batch() function to give us sequence lengths that are not multiples of config.combo so that the thing generalizes to prompts that aren't a perfect length. likely need to add a <|padding|> embedding vector for this to work well. should it be learnable or maybe set it to all zeros?~~ nvm i don't think that's a good idea bc i don't want to encourage the model to predict simpler concepts (concepts made up of fewer tokens/concepts) so i think instead i have to just hope that if they give me a prompt that's off number than it can generalize.
    - [x] yoooo what if instead of padding the end i pad the beginning???? models tend to use the first embedding in the sequence as an attention sink anyways and that way it'd never be encouraged to predict dumb shit. ok i think this is the way to go. should be a simple edit to the combine embeddings class and then i can do the thing where i give varying sequence lengths. wait but how would that mess with what gets paid attention to? i think it doesn't matter with our current version of cross-attention. so just adjust the current implementation of padding and then the get_batch() function
- [x] add an option to norm the initial residual state instead of scaling it by sqrt(embed_dim). would it be interesting to both norm and scale? for example, with cosine norm that would put the initial vectors on a sqrt(d) radius hypersphere centered at the origin i guess
- [ ] add noise to the upper-triangular part of the cross-attention; maybe anothe option where this "noise mask" is stronger over the course of each row to represent the likely inaccuracy that the model will experience during inference for progressive concept embeddings. or alternatively, add a mask that only allows the model to see the concept it's currently working on, so kinda like a mask of diagonal 1's except with blocks of size config.combo rather than entries.
- [ ] pump `training` variable throughout for stuff like dropout
- [ ] add a ~~config.verbose dictionary~~ debug logger

further ideas for v11.3 (maybe call it 11.2b?):
- [ ] do it without the matryoshka-ness
    - [ ] add multiple <|bos|> tokens to put at the beginning of each sequence, one for each level. So that way the model can know what level it's working with and have somewhere to sink attention.
- [ ] setup `effective_seq_len_mult` and `seq_len_list` and a bunch of downstream stuff like weird batching at lower sub-levels to allow for hella long effective contexts

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
#sys.path.append('./venv/lib/python3.10/site-packages')
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the debugging/demonstration setup
import functools
import inspect

# imports for the tokenizer
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used for training
import random
import time

In [69]:
def log_io(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if not self.logging_enabled:
            return func(self, *args, **kwargs)

        print(f"\n---- Entering {self.__class__.__name__}.{func.__name__} ----")

        # Getting the names and values of the arguments
        arg_names = inspect.getfullargspec(func).args[1:]  # Excluding 'self'
        arg_values = args + tuple(kwargs.values())

        # Print input information
        for name, value in zip(arg_names, arg_values):
            if isinstance(value, torch.Tensor):
                print(f"Input tensor '{name}' shape: {value.shape}")
            elif isinstance(value, int):
                print(f"Input integer '{name}': Value={value}")
            else:
                value_type = type(value).__name__
                value_len = len(value) if hasattr(value, '__len__') else 'N/A'
                print(f"Input '{name}': Type={value_type}, Length={value_len}")

        result = func(self, *args, **kwargs)

        # Handle multiple outputs
        if isinstance(result, tuple):
            for idx, item in enumerate(result):
                if isinstance(item, torch.Tensor):
                    print(f"Output tensor {idx} shape: {item.shape}")
                elif isinstance(item, int):
                    print(f"Output integer {idx}: Value={item}")
                else:
                    item_type = type(item).__name__
                    item_len = len(item) if hasattr(item, '__len__') else 'N/A'
                    print(f"Output {idx}: Type={item_type}, Length={item_len}")
        elif isinstance(result, torch.Tensor):
            print(f"Output tensor shape: {result.shape}")
        elif isinstance(result, int):
            print(f"Output integer: Value={result}")
        else:
            result_type = type(result).__name__
            result_len = len(result) if hasattr(result, '__len__') else 'N/A'
            print(f"Output: Type={result_type}, Length={result_len}")

        print(f"---- Exiting {self.__class__.__name__}.{func.__name__} ----")

        return result
    return wrapper

In [70]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# and the tokenizer
tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)

# RoPE

In [71]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # it's important to train on a wide variety of sequence lengths within your context length so that the model learns to generalize

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# Norms

In [72]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, affine = True, verbose = False):
        super().__init__() 
        self.eps = eps
        self.affine = affine
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.w = nn.Parameter(torch.zeros(dim))
        self.b = nn.Parameter(torch.zeros(dim))
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    @log_io
    def scale(self, x, d_i):
        # scale the normalized tensor by (1 + self.weight), which effectively starts with no scaling
        return x * (1 + self.weight[:d_i])

    @log_io
    def splice_affine(self, weight, bias, d_i):
        return weight[:d_i], bias[:d_i]

    @log_io
    def forward(self, x: torch.Tensor, model: int = 0) -> torch.Tensor:
        # Normalize the input tensor using root mean square normalization
        x = self.norm(x)

        # Optionally apply the affine transformation with splicing
        if self.affine:
            d_i = x.shape[-1]  # grab x's dimension for splicing
            w, b = self.splice_affine(self.w, self.b, d_i)
            x = x * w + b
            
        return x

In [73]:
# Create an instance of RMSNorm
norm_layer = RMSNorm(dim=10)

# Initially, logging is disabled
# Enable logging
norm_layer.enable_logging()

# Call the forward method - logging will occur
output = norm_layer(torch.randn(5, 5))

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
norm_layer.disable_logging()

# clearing up ram jic we're training later
del norm_layer, output


---- Entering RMSNorm.forward ----
Input tensor 'x' shape: torch.Size([5, 5])

---- Entering RMSNorm.norm ----
Input tensor 'x' shape: torch.Size([5, 5])
Output tensor shape: torch.Size([5, 5])
---- Exiting RMSNorm.norm ----

---- Entering RMSNorm.splice_affine ----
Input tensor 'weight' shape: torch.Size([10])
Input tensor 'bias' shape: torch.Size([10])
Input integer 'd_i': Value=5
Output tensor 0 shape: torch.Size([5])
Output tensor 1 shape: torch.Size([5])
---- Exiting RMSNorm.splice_affine ----
Output tensor shape: torch.Size([5, 5])
---- Exiting RMSNorm.forward ----


In [74]:
class CosineNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, affine=True):
        super().__init__()
        self.eps = eps
        self.affine = affine

        # Initialize weight and bias parameters for affine transformation
        # We start with ones for weight to keep the original scale initially, and zeros for bias.
        self.w = nn.Parameter(torch.ones(dim))
        self.b = nn.Parameter(torch.zeros(dim))
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def norm(self, x):
        # Cosine normalization: normalize x by dividing by its L2 norm along the last dimension.
        # Add a small constant to the denominator to avoid division by zero.
        l2_norm = torch.norm(x, p=2, dim=-1, keepdim=True).clamp(min=self.eps)
        return x / l2_norm

    @log_io
    def splice_affine(self,weight, bias, d_i):
        return weight[:d_i], bias[:d_i]

    @log_io
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normalize the input tensor using cosine normalization
        x = self.norm(x)

        # Optionally apply the affine transformation with splicing
        if self.affine:
            d_i = x.shape[-1]  # grab x's dimension for splicing
            w, b = self.splice_affine(self.w, self.b, d_i)
            x = x * w + b
            
        return x

In [75]:
# demonstration/debugging
norm_layer = CosineNorm(dim=10)
norm_layer.enable_logging()
output = norm_layer(torch.randn(5, 5))
del norm_layer, output


---- Entering CosineNorm.forward ----
Input tensor 'x' shape: torch.Size([5, 5])

---- Entering CosineNorm.norm ----
Input tensor 'x' shape: torch.Size([5, 5])
Output tensor shape: torch.Size([5, 5])
---- Exiting CosineNorm.norm ----

---- Entering CosineNorm.splice_affine ----
Input tensor 'weight' shape: torch.Size([10])
Input tensor 'bias' shape: torch.Size([10])
Input integer 'd_i': Value=5
Output tensor 0 shape: torch.Size([5])
Output tensor 1 shape: torch.Size([5])
---- Exiting CosineNorm.splice_affine ----
Output tensor shape: torch.Size([5, 5])
---- Exiting CosineNorm.forward ----


In [76]:
class LayerNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, affine = True):
        super().__init__()
        self.eps = eps
        self.affine = affine

        # Define weight (scale) and bias (shift) as learnable parameters
        self.w = nn.Parameter(torch.ones(dim))
        self.b = nn.Parameter(torch.zeros(dim))
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def calc_mean(self, x):
        return x.mean(dim=-1, keepdim=True)

    @log_io
    def calc_var(self, x):
        return x.var(dim=-1, keepdim=True, unbiased=False)

    @log_io
    def norm(self, x, mean, var):
        return (x - mean) / torch.sqrt(var + self.eps)

    @log_io
    def splice_affine(self, w, b, d_i):
        return w[:d_i], b[:d_i]
        
    @log_io
    def forward(self, x):
        # Normalize the input
        mean = self.calc_mean(x)
        var = self.calc_var(x)
        x = self.norm(x, mean, var)

        # Optionally apply the affine transformation with splicing
        if self.affine:
            d_i = x.shape[-1]  # grab x's dimension for splicing
            w, b = self.splice_affine(self.w, self.b, d_i)
            x = x * w + b

        return x

In [77]:
# demonstration/debugging
norm_layer = LayerNorm(dim=10)
norm_layer.enable_logging()
output = norm_layer(torch.randn(5, 5))
del norm_layer, output


---- Entering LayerNorm.forward ----
Input tensor 'x' shape: torch.Size([5, 5])

---- Entering LayerNorm.calc_mean ----
Input tensor 'x' shape: torch.Size([5, 5])
Output tensor shape: torch.Size([5, 1])
---- Exiting LayerNorm.calc_mean ----

---- Entering LayerNorm.calc_var ----
Input tensor 'x' shape: torch.Size([5, 5])
Output tensor shape: torch.Size([5, 1])
---- Exiting LayerNorm.calc_var ----

---- Entering LayerNorm.norm ----
Input tensor 'x' shape: torch.Size([5, 5])
Input tensor 'mean' shape: torch.Size([5, 1])
Input tensor 'var' shape: torch.Size([5, 1])
Output tensor shape: torch.Size([5, 5])
---- Exiting LayerNorm.norm ----

---- Entering LayerNorm.splice_affine ----
Input tensor 'w' shape: torch.Size([10])
Input tensor 'b' shape: torch.Size([10])
Input integer 'd_i': Value=5
Output tensor 0 shape: torch.Size([5])
Output tensor 1 shape: torch.Size([5])
---- Exiting LayerNorm.splice_affine ----
Output tensor shape: torch.Size([5, 5])
---- Exiting LayerNorm.forward ----


# Config

In [78]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for my next-concept predictor
    """
    ### boring hyperparameters
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 256
    num_layers: int = 4
    sa_q_heads: int = 4
    sa_kv_heads: int = 1
    embed_dim: int = 128 
    mlp_multiplier: int = 4
    sa_head_dim: int = 32
    theta = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### Normalization
    norm_init_embed: bool = False # whether or not to normalize right after turning input indices into the first residual state. default False
    scale_init_embed: bool = True # whether or not to scale the initial residual state by the nth root of embed_dim. default True
    init_scale_degree: int = 2 # what nth root to take of embed_dim to use as the scale for the initial residual. default 2
    norm_final_embed: bool = True # whether or not to normalize the embeddings at the output layer before multiplying for logits. default False
    norm_affine: bool = False # whether norms should have a linear & bias after them. default True
    # Directly assign the class of the norm to be used
    def __post_init__(self): # you cant define this one until the norms classes have already been defined
        self.norm = CosineNorm  # Options are RMSNorm, CosineNorm and LayerNorm

    ### MatFormer
    levels = 3
    split = 2
    @property
    def embed_dim_list(self):
        return [self.embed_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]
    @property
    def sa_head_dim_list(self):
        return [self.sa_head_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]

    ### concept embedding vectors
    combo = split#*2 
    concept_loss = "cos" # defaults to cosine. other options are 'mae' and 'mse'
    # gotta figure out how to define seq_len_list s.t. this effective_seq_len_mult works. 
    # also then i'd have to change get_batch accordingly. I think this is more for v11.3
    #effective_seq_len_mult = 1 
    @property
    def seq_len_list(self):
        #return [(self.max_seq_len // (self.split ** (self.levels-1))) // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]
        return [(self.max_seq_len // (self.combo ** (i-1))) for i in range(1, self.levels + 1)]
    # how much to discount each higher level in the loss function compared to the last. gives the smaller models more say in the gradient
    level_loss_weight = 0.5

    ### cross-attention
    # how many times to do the cross-attention
    ca_connections: int = 2 
    ca_q_heads: int = sa_q_heads
    ca_kv_heads: int = sa_kv_heads
    ca_head_dim: int = sa_head_dim // split
    shared_ca: bool = False # True: same weights shared for each ca connection. False: separarely instantiated module for each connection
    ca_use_RoPE: bool = True # True: expands out k & v tensors to be usable with rope. False: leaves k & v same size but no positional encodings
    ca_noise_sd: float = None # float: adds gaussian noise with that sd to upper-level logits. None: no noise added
    @property
    def ca_head_dim_list(self):
        return [self.ca_head_dim // (self.split ** (i-1)) for i in range(self.levels-1, 0, -1)]
    @property
    def ca_interval(self):
        return self.num_layers // self.ca_connections

    ### assertions
    assert sa_q_heads % sa_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in self-attention'
    assert ca_q_heads % ca_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in cross-attention'
    assert max_seq_len % (split ** levels) == 0, 'your context length must be a multiple of split ** levels'
    assert embed_dim % (split ** levels) == 0, 'your embedding dimension must be a multiple of split ** levels'
    assert sa_head_dim % (split ** levels) == 0, 'your self-attentionhead dimension must be divisible by split ** levels'
    assert ca_head_dim % (split ** levels) == 0, 'your cross-attentionhead dimension must be divisible by split ** levels'
    assert num_layers % ca_connections == 0, 'the total number of layers must be divisible by the number of cross-attention connections'

config = Config()
print(config)
print(f"embedding dimension size of each model: {config.embed_dim_list}")
print(f"attention head size of each model: {config.sa_head_dim_list}")
print(f"sequence length of each model: {config.seq_len_list}")
print(f"how many embeddings will be combined at each layer: {config.combo}")
print(f"base head dimension of cross-attention to higher levels: {config.ca_head_dim}")
print(f"attention head size of each connection starting from lowest level: {config.ca_head_dim_list}")
print(f"cross-attention to a higher level will be applied every {config.ca_interval} layers")
print(f"loss discounts starting from lowest level: {[config.level_loss_weight**i for i in range(config.levels)]}")

Config(vocab_size=128, max_seq_len=256, num_layers=4, sa_q_heads=4, sa_kv_heads=1, embed_dim=128, mlp_multiplier=4, sa_head_dim=32, norm_init_embed=False, scale_init_embed=True, init_scale_degree=2, norm_final_embed=True, norm_affine=False, ca_connections=2, ca_q_heads=4, ca_kv_heads=1, ca_head_dim=16, shared_ca=False, ca_use_RoPE=True, ca_noise_sd=None)
embedding dimension size of each model: [32, 64, 128]
attention head size of each model: [8, 16, 32]
sequence length of each model: [256, 128, 64]
how many embeddings will be combined at each layer: 2
base head dimension of cross-attention to higher levels: 16
attention head size of each connection starting from lowest level: [8, 16]
cross-attention to a higher level will be applied every 2 layers
loss discounts starting from lowest level: [1.0, 0.5, 0.25]


# selfMQA

In [79]:
class selfMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.sa_q_heads = config.sa_q_heads
        self.sa_kv_heads = config.sa_kv_heads
        assert self.sa_q_heads % self.sa_kv_heads == 0
        self.num_queries_per_kv = self.sa_q_heads // self.sa_kv_heads

        self.embed_dim = config.embed_dim
        self.sa_head_dim = config.sa_head_dim
        self.theta = config.theta

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.sa_q_heads + 2 * self.sa_kv_heads) * self.sa_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)

        self.Wo = nn.Parameter(torch.Tensor(self.sa_q_heads * self.sa_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5), (1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5)

        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def weight_splicing(self, Wqkv, Wo, d_i, h_i):
        Wq, Wk, Wv = Wqkv.split([self.sa_q_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim],dim = -1)
        Wq = torch.cat([Wq[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_q_heads)], dim = 1)
        Wk = torch.cat([Wk[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_kv_heads)], dim = 1)
        Wv = torch.cat([Wv[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_kv_heads)], dim = 1)
        Wo = torch.cat([Wo[j*self.sa_head_dim :j*self.sa_head_dim + h_i, :d_i] for j in range(self.sa_q_heads)], dim=0)
        return Wq, Wk, Wv, Wo

    @log_io
    def match_headcount(self, xn):
        return torch.repeat_interleave(xn, self.num_queries_per_kv, dim=2)

    @log_io
    def RoPE(self, x, h_i):
        return RoPE(x, h_i, self.theta)

    @log_io
    def attend(self, xq, xk, h_i):
        return torch.matmul(xq, xk.transpose(2, 3)) * (h_i ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, xv, batch_size, input_len):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ xv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

    @log_io
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, d_i = x.shape
        h_i = self.sa_head_dim // (self.embed_dim // d_i)

        # splicing our primary projection to get the correct sub-matrices
        Wq, Wk, Wv, Wo = self.weight_splicing(self.Wqkv, self.Wo, d_i, h_i)
        # technically self.weight_splicing has access to self.Wqkv & Wo but this way our debugger can see them

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq, xk, xv = x @ Wq, x @ Wk, x @ Wv

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.sa_q_heads, h_i)
        xk = xk.view(batch_size, -1, self.sa_kv_heads, h_i)
        xv = xv.view(batch_size, -1, self.sa_kv_heads, h_i)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = self.RoPE(xq, h_i)
        xk = self.RoPE(xk, h_i)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.sa_kv_heads != self.sa_q_heads:
            xk = self.match_headcount(xk) # [batch_size, input_len, n_local_heads, sa_head_dim]
            xv = self.match_headcount(xv)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, sa_head_dim]
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, xk, h_i) # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = self.apply_mask(logits, input_len)

        # applies values to get final output
        output = self.calc_output(logits, xv, batch_size, input_len) 

        # Applies the final linear projection to the attention output, mapping it back to d_i
        return output @ Wo

In [80]:
self_attention = selfMQA(config)
self_attention.enable_logging()
output = self_attention(torch.randn(32,config.max_seq_len,config.embed_dim))
del self_attention, output


---- Entering selfMQA.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 128])

---- Entering selfMQA.weight_splicing ----
Input tensor 'Wqkv' shape: torch.Size([128, 192])
Input tensor 'Wo' shape: torch.Size([128, 128])
Input integer 'd_i': Value=128
Input integer 'h_i': Value=32
Output tensor 0 shape: torch.Size([128, 128])
Output tensor 1 shape: torch.Size([128, 32])
Output tensor 2 shape: torch.Size([128, 32])
Output tensor 3 shape: torch.Size([128, 128])
---- Exiting selfMQA.weight_splicing ----

---- Entering selfMQA.RoPE ----
Input tensor 'x' shape: torch.Size([32, 256, 4, 32])
Input integer 'h_i': Value=32
Output tensor shape: torch.Size([32, 256, 4, 32])
---- Exiting selfMQA.RoPE ----

---- Entering selfMQA.RoPE ----
Input tensor 'x' shape: torch.Size([32, 256, 1, 32])
Input integer 'h_i': Value=32
Output tensor shape: torch.Size([32, 256, 1, 32])
---- Exiting selfMQA.RoPE ----

---- Entering selfMQA.match_headcount ----
Input tensor 'xn' shape: torch.Size([32, 256, 1,

# MLP

In [81]:
class MLP(nn.Module):
    def __init__(self, embed_dim: int, mlp_multiplier: int):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier

        # the gate
        self.Wgate = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bgate = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the up projection
        self.Wup = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bup = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the down projection
        self.Wdown = nn.Parameter(torch.Tensor(self.hidden_size, embed_dim))
        self.Bdown = nn.Parameter(torch.Tensor(embed_dim))
        torch.nn.init.uniform_(self.Wdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        torch.nn.init.uniform_(self.Bdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def forward(self, x):
        d_i = x.shape[-1]
        gate = F.gelu(x @ self.Wgate[:d_i, :d_i * self.mlp_multiplier] + self.Bgate[:d_i * self.mlp_multiplier])
        fuse = gate * (x @ self.Wup[:d_i, :d_i * self.mlp_multiplier] + self.Bup[:d_i * self.mlp_multiplier])
        return fuse @ self.Wdown[:d_i * self.mlp_multiplier, :d_i] + self.Bdown[:d_i]

In [82]:
mlp = MLP(config.embed_dim, config.mlp_multiplier)
mlp.enable_logging()
output = mlp(torch.randn(32,config.max_seq_len,config.embed_dim))
del mlp, output


---- Entering MLP.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 128])
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting MLP.forward ----


# Layer

In [83]:
class Layer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.mqa = selfMQA(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier)

        # Initialize normalization layers using the class reference from config
        self.pre_mqa_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        self.post_mqa_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        self.pre_mlp_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        self.post_mlp_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def mqa_connection(self, x):
        return self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x)))

    @log_io
    def mlp_connection(self, x):
        return self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x)))

    @log_io
    def forward(self, x: torch.Tensor ) -> torch.Tensor:
        x = x + self.mqa_connection(x)
        x = x + self.mlp_connection(x)
        return x

In [84]:
layer = Layer(config)
layer.enable_logging()
# an example of how you can selectively enable sub-classes to also print
layer.post_mqa_norm.enable_logging()
output = layer(torch.randn(32,config.max_seq_len,config.embed_dim))
del layer, output


---- Entering Layer.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 128])

---- Entering Layer.mqa_connection ----
Input tensor 'x' shape: torch.Size([32, 256, 128])

---- Entering CosineNorm.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 128])

---- Entering CosineNorm.norm ----
Input tensor 'x' shape: torch.Size([32, 256, 128])
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting CosineNorm.norm ----
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting CosineNorm.forward ----
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting Layer.mqa_connection ----

---- Entering Layer.mlp_connection ----
Input tensor 'x' shape: torch.Size([32, 256, 128])
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting Layer.mlp_connection ----
Output tensor shape: torch.Size([32, 256, 128])
---- Exiting Layer.forward ----


# crossMQA

In [85]:
class crossMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.ca_q_heads = config.ca_q_heads
        self.ca_kv_heads = config.ca_kv_heads
        assert self.ca_q_heads % self.ca_kv_heads == 0
        self.num_queries_per_kv = self.ca_q_heads // self.ca_kv_heads

        self.embed_dim = config.embed_dim
        self.ca_head_dim = config.ca_head_dim
        self.sa_head_dim = config.sa_head_dim # used for an assertion to make sure sizes will fit
        self.theta = config.theta
        self.use_RoPE = config.ca_use_RoPE
        self.noise_sd = config.ca_noise_sd

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.ca_q_heads + 2 * self.ca_kv_heads) * self.ca_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)
        
        self.Wo = nn.Parameter(torch.Tensor(self.ca_q_heads * self.ca_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5), (1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5)

        # for now we'll try no attention mask and then maybe (probably) edit that later 
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def weight_splicing(self, Wqkv, Wo, xd_i, cd_i, h_i):
        Wq, Wk, Wv = Wqkv.split([self.ca_q_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim],dim = -1)
        Wq = torch.cat([Wq[:xd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_q_heads)], dim = 1)
        Wk = torch.cat([Wk[:cd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_kv_heads)], dim = 1)
        Wv = torch.cat([Wv[:cd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_kv_heads)], dim = 1)
        Wo = torch.cat([Wo[j*self.ca_head_dim:j*self.ca_head_dim + h_i, :xd_i] for j in range(self.ca_q_heads)], dim=0)
        return Wq, Wk, Wv, Wo

    @log_io
    def match_headcount(self, xn):
        return torch.repeat_interleave(xn, self.num_queries_per_kv, dim=2)

    @log_io
    def RoPE(self, x, h_i):
        return RoPE(x, h_i, self.theta)

    @log_io
    def attend(self, xq, ck, h_i):
        return torch.matmul(xq, ck.transpose(2, 3)) * (h_i ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, cv, batch_size, input_len_x):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ cv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len_x, -1)

    @log_io
    def forward(self, 
                x: torch.Tensor, # the lower level tensor, sometimes a resid state full of concepts & sometimes full of tokens
                c: torch.Tensor, # the upper level tensor, always a resid state full of concept vecs
               ) -> torch.Tensor:
        # optionally adds some random gaussian noise to the upper-level concepts to simulate how bad they'll prolly be
        # i think this should only be implemented during training, but imma have to go & spread that bool througout later
        # this version does the entire txt matrix but later i'd like switch it to just upper-triangular
        if self.noise_sd is not None:
            noise = torch.randn_like(c, requires_grad = False)
            c = c + noise * self.noise_sd
        
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len_x, xd_i = x.shape
        batch_size_c, input_len_c, cd_i = c.shape
        assert batch_size == batch_size_c
        #assert input_len_x % input_len_c == 0 # i feel like there's a better assert I could use for this one relating to combo or something

        h_i = self.ca_head_dim // (self.embed_dim // cd_i)
        assert h_i == self.sa_head_dim // (self.embed_dim // xd_i), 'head_dim was not same bw levels in cross attention'

        # splicing our projection to get the correct sub-matrices
        Wq, Wk, Wv, Wo = self.weight_splicing(self.Wqkv, self.Wo, xd_i, cd_i, h_i)

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = x @ Wq
        ck = c @ Wk
        cv = c @ Wv

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.ca_q_heads, h_i)
        ck = ck.view(batch_size, -1, self.ca_kv_heads, h_i)
        cv = cv.view(batch_size, -1, self.ca_kv_heads, h_i)

        # IF we want to use RoPE (doesn't fully make sense to)
        if self.use_RoPE:
            expand = input_len_x // input_len_c
            ck = ck.repeat_interleave(expand, dim=1) 
            cv = cv.repeat_interleave(expand, dim=1) # values need to be expanded for their use later on if we do this

            # Applies rotary positional embeddings to queries and keys to incorporate positional information.
            xq = self.RoPE(xq, h_i) 
            ck = self.RoPE(ck, h_i)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.ca_kv_heads != self.ca_q_heads:
            ck = self.match_headcount(ck)
            cv = self.match_headcount(cv) # [batch_size, input_len, n_local_heads, ca_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, ca_head_dim]
        ck = ck.transpose(1, 2)
        cv = cv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, ck, h_i) # [batch_size, n_local_heads, input_len, input_len]

        # applies values to get final output
        output = self.calc_output(logits, cv, batch_size, input_len_x)

        # Applies the final linear projection to the attention output, mapping it back to d_i
        return output @ Wo

In [86]:
cross_attention = crossMQA(config)
cross_attention.enable_logging()
x = torch.randn(32, config.max_seq_len, config.embed_dim // config.combo)
c = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = cross_attention(x, c)
del cross_attention, output


---- Entering crossMQA.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 64])
Input tensor 'c' shape: torch.Size([32, 128, 128])

---- Entering crossMQA.weight_splicing ----
Input tensor 'Wqkv' shape: torch.Size([128, 96])
Input tensor 'Wo' shape: torch.Size([64, 128])
Input integer 'xd_i': Value=64
Input integer 'cd_i': Value=128
Input integer 'h_i': Value=16
Output tensor 0 shape: torch.Size([64, 64])
Output tensor 1 shape: torch.Size([128, 16])
Output tensor 2 shape: torch.Size([128, 16])
Output tensor 3 shape: torch.Size([64, 64])
---- Exiting crossMQA.weight_splicing ----

---- Entering crossMQA.RoPE ----
Input tensor 'x' shape: torch.Size([32, 256, 4, 16])
Input integer 'h_i': Value=16
Output tensor shape: torch.Size([32, 256, 4, 16])
---- Exiting crossMQA.RoPE ----

---- Entering crossMQA.RoPE ----
Input tensor 'x' shape: torch.Size([32, 256, 1, 16])
Input integer 'h_i': Value=16
Output tensor shape: torch.Size([32, 256, 1, 16])
---- Exiting crossMQA.RoPE ----

---- Ent

# crossLayer

In [87]:
class crossLayer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.mqa = crossMQA(config)

        # Initialize normalization layers using the class reference from config
        self.pre_mqa_norm_x = config.norm(config.embed_dim, affine=config.norm_affine)
        self.pre_mqa_norm_c = config.norm(config.embed_dim, affine=config.norm_affine)
        self.post_mqa_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        
        # jic i want to add an MLP later
        #self.mlp = MLP(config.embed_dim, config.mlp_multiplier)
        #self.pre_mlp_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        #self.post_mlp_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def mqa_connection(self, x, c):
        return self.post_mqa_norm(self.mqa(x = self.pre_mqa_norm_x(x), c = self.pre_mqa_norm_c(c)))
        return x

    #@log_io
    #def mlp_connection(self, x):
        #return self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x)))
        
    @log_io
    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        x = x + self.mqa_connection(x, c)
        #x = x + self.mlp_connection(x)
        return x

In [88]:
layer = crossLayer(config)
layer.enable_logging()
x = torch.randn(32, config.max_seq_len, config.embed_dim // config.combo)
c = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = layer(x, c)
del layer, output


---- Entering crossLayer.forward ----
Input tensor 'x' shape: torch.Size([32, 256, 64])
Input tensor 'c' shape: torch.Size([32, 128, 128])

---- Entering crossLayer.mqa_connection ----
Input tensor 'x' shape: torch.Size([32, 256, 64])
Input tensor 'c' shape: torch.Size([32, 128, 128])
Output tensor shape: torch.Size([32, 256, 64])
---- Exiting crossLayer.mqa_connection ----
Output tensor shape: torch.Size([32, 256, 64])
---- Exiting crossLayer.forward ----


# Body

In [58]:
class Body(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.ca_interval = config.ca_interval
        
        # Initialize a sequence of Layer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_layers))
        # initialize the (single / sequence of) cross-attention layer(s)
        self.ca_layers = nn.ModuleList(crossLayer(config)) if config.shared_ca else nn.ModuleList(crossLayer(config) for _ in range(config.ca_connections))
        
        # Initialize normalization layers to be applied after the last decoder layers, stabilizing the output
        self.final_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def forward(self, x0s: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level
            
            # Iteratively process the input through each Layer of the model
            for j in range(len(self.layers)):
                
                # our occasional cross-attention connection to the upper level
                if (i != 0) & (j % self.ca_interval == 0): # i can't equal zero bc there'd be no higher level model to pay attention to
                    ca_layer = self.ca_layers[j // self.ca_interval]
                    x = ca_layer(x, x0s[len(x0s)-i])

                # our current-level model's work
                layer = self.layers[j]
                x = layer(x)

            # add the final residual state of the level to our tuple, normed
            xfs += (self.final_norm(x),) # should i be using separate final norms? nah

        return xfs

# embedding vector combination function

In [14]:
class CombineEmbeddings(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.padding_vector = nn.Parameter(torch.zeros(embedding_dim), requires_grad=True)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, tensor, combine_factor):
        b, t, d = tensor.shape

        # Calculate the necessary amount of padding
        remainder = t % combine_factor
        padding_needed = 0 if remainder == 0 else combine_factor - remainder
        
        if padding_needed > 0:
            # Replicate the padding vector the necessary number of times
            padding = self.padding_vector.repeat(padding_needed, 1).unsqueeze(0).expand(b, -1, -1)
            tensor = torch.cat([padding[...,:d], tensor], dim=1) # subset padding to fit with matryoshka size
        
        # Update t after padding
        t_padded = t + padding_needed
        
        # Reshape the tensor to group 'combine_factor' entries along the t dimension
        reshaped_tensor = tensor.view(b, t_padded // combine_factor, combine_factor, d)
        
        # Sum over the groups
        combined_tensor = reshaped_tensor.sum(dim=2)

        return combined_tensor

# Matryoshka Concept Loss

In [15]:
class matryoshkaConceptLoss(nn.Module):
    def __init__(self, config: Config, embedding_combiner):
        super().__init__()
        self.combo = config.combo
        self.levels = config.levels
        self.split = config.split
        self.level_loss_weight = config.level_loss_weight
        self.embedding_combiner = embedding_combiner

        if config.concept_loss == "mae":
            self.concept_loss_fn = nn.L1Loss()
        elif config.concept_loss == "mse":
            self.concept_loss_fn = nn.MSELoss()
        else: # defaults to cosine similarity loss
            self.concept_loss_fn = nn.CosineSimilarity(dim=-1, eps=1e-6)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
            
    @log_io
    def forward(self, 
                x: Tuple[torch.Tensor], 
                target_token_ids: torch.Tensor, 
                input_len: int,
                embedder: nn.Module,
               ) -> torch.Tensor:
        # iterate through all concept-embedding layers and calculate loss
        concept_loss = torch.tensor(0.0)
        for i in range(self.levels - 1):
            # select our relevant final residual state
            lvl_output = x[i]
            
            # calculate the decay value placed on this level's total amount of loss
            lambadada = (self.level_loss_weight ** (self.levels -1 -i))
            
            # create the target vectors for this level
            lvl_combo = self.combo ** (self.levels-1-i)
            target_token_ids_adj = target_token_ids[:, lvl_combo:lvl_combo + input_len]
            raw_target_vectors = embedder(target_token_ids_adj)[...,:lvl_output.shape[-1]]
            target_vectors = self.embedding_combiner(raw_target_vectors, lvl_combo).detach().clone()
            # remember to detach/clone so that we don't mess with the embeddings we want to train on
            
            # now choose whether we're doing mse/mae vs cos sim loss bc they get calc'd differently
            if config.concept_loss == 'mae' or config.concept_loss == 'mse':
                # Reshape output and target_vectors to combine batch and seq_len dimensions
                lvl_output_flat = lvl_output.view(-1, lvl_output.size(-1))
                target_vectors_flat = target_vectors.view(-1, target_vectors.size(-1))
                
                # Calculate MSE or MAE & add it to the total
                concept_loss = concept_loss + self.concept_loss_fn(lvl_output_flat, target_vectors_flat) * lambadada
            else: # defaults to cosine similarity loss
                cosine_loss = (1 - self.concept_loss_fn(lvl_output, target_vectors)).mean()
                concept_loss = concept_loss + cosine_loss * lambadada

        return concept_loss

# Model

In [16]:
class NCP_MatFormer(nn.Module):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer

        self.max_seq_len = config.max_seq_len
        self.sa_head_dim = config.sa_head_dim
        self.vocab_size = config.vocab_size
        self.embed_dim = config.embed_dim

        self.combo = config.combo
        self.levels = config.levels
        self.split = config.split

        self.ca_connections = config.ca_connections
        self.ca_interval = config.ca_interval

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.embed_dim)
        # the function that combines embeddings into higher level concept residual states
        self.embedding_combiner = CombineEmbeddings(config.embed_dim)
        self.norm_init_embed = config.norm_init_embed # whether to norm the first residual states
        if config.norm_init_embed: # if so, define its norm
            self.init_embed_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        self.scale_init_embed = config.scale_init_embed # whether to scale first residuals by nth root of respective matryoshka dimension
        self.init_scale_degree = config.init_scale_degree # the nth root to take

        # the actual bulk of the model
        self.body = Body(config)

        # should we norm the final embed before calculating logits? idk i can't decide
        self.norm_final_embed = config.norm_final_embed
        if config.norm_final_embed:
            self.final_embed_norm = config.norm(config.embed_dim, affine=config.norm_affine)
        
        ### the loss functions
        # lowest-level token model
        self.ce_loss_fn = nn.CrossEntropyLoss()
        # concept models
        self.concept_loss_fn = matryoshkaConceptLoss(config, self.embedding_combiner)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids to run forward pass on
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len + combo ** (levels-1)) list of token ids to train on
        ) -> torch.Tensor:
        training = True if target_token_ids is not None else False

        # turn the input tokens into the first residual state using the embedding matrix
        x0 = self.embedder(input_token_ids) # (batch_size, input_len, embed_dim)
        
        ### prepping our first token-wise residual state
        # find what matryoshka dimension it should be
        first_dim = self.embed_dim // (self.split ** (self.levels-1)) 
        # splice to said matryoshka dimension
        x0t = x0[...,:first_dim] # t stands for token. need to separate it from the raw x0 which will be used later
        # optionally norm it
        x0t = self.init_embed_norm(x0t) if self.norm_init_embed else x0t
        # optionally scale by n'th root of dimension
        x0t = x0t * (first_dim ** (1/self.init_scale_degree)) if self.scale_init_embed else x0t 
        # finally instantiate the tuple that'll hold all the residual states
        x0s = (x0t,) 
        
        ### iterating through levels to create each higher-level concept residual state
        for i in range(self.levels-1):
            # combine into smaller tensor by adding token (or lower level concept) embeddings together
            lvl_combo = self.combo ** (i+1)
            x0c = self.embedding_combiner(x0, lvl_combo) # c stands for concept
            
            # find the correct matryoshka dimension for this level
            this_level_dim = self.embed_dim // (self.split ** (self.levels - 2 - i))
            # splice to said matryoshka dimension
            x0c = x0c[...,:this_level_dim]
            # optionally norm it
            x0c = self.init_embed_norm(x0c) if self.norm_init_embed else x0c
            # optionally scale by n'th root of dimension
            x0c = x0c * (this_level_dim ** (1/self.init_scale_degree)) if self.scale_init_embed else x0c
            # finally add it to the tuple of residual states
            x0s += (x0c,)

        # the body of the model that iterates through the decoder & cross-attention layers
        xfs = self.body(x0s)

        # grabbing the weights of the embedding matrix shape (vocab_size, embed_dim) for use as the output layer
        embedder_weight = self.embedder.weight
        # optionally norm it
        embedder_weight = self.final_embed_norm(embedder_weight) if self.norm_final_embed else embedder_weight
        # calculating output logits
        logits = xfs[-1] @ embedder_weight[:,:first_dim].t()
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else: # if we are training
            ### first up is regular CE token loss
            batch_size, input_len, vocab_size = logits.shape
            # splice target tokens to exclude the ones that were only to be used by concept levels
            target_token_ids_spliced = target_token_ids[:,:input_len]
            # we reshape our logits & targets before calculating cross-entropy loss
            ce_loss = self.ce_loss_fn(logits.view(batch_size*input_len, vocab_size),
                                      target_token_ids_spliced.reshape(batch_size*input_len))

            ### the new thing, a regression loss for all our concept-embedding layers
            concept_loss = self.concept_loss_fn(xfs, target_token_ids, 
                                                input_len, self.embedder)
            loss = ce_loss + concept_loss
            
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    @log_io
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        # Select the last element for each sequence & apply temperature scaling
        logits = logits[:,-1,:].div_(temperature) # -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    @log_io
    def generate(self,
                 prompt: str,
                 output_len: int = 1, # the model will output 1 token by default
                 temperature: float = 0.7, # 1.0 would be no effect
                 top_p: float = 0.8,
                 top_k: int = 4,
                ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        if len(tokens) + output_len > self.config.max_seq_len:
            output_len = self.max_seq_len - len(tokens)
            print("capping output at maximum sequence length")

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(logits, temperature, top_p, top_k)

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        return self.tokenizer.decode(tokens.squeeze(0).tolist())

# Instantiate a brand new model

In [18]:
model = NCP_MatFormer(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

1018.496 K parameters
NCP_MatFormer(
  (embedder): Embedding(128, 128)
  (embedding_combiner): CombineEmbeddings()
  (body): Body(
    (layers): ModuleList(
      (0-3): 4 x Layer(
        (mqa): selfMQA()
        (mlp): MLP()
        (pre_mqa_norm): CosineNorm()
        (post_mqa_norm): CosineNorm()
        (pre_mlp_norm): CosineNorm()
        (post_mlp_norm): CosineNorm()
      )
    )
    (ca_layers): ModuleList(
      (0-1): 2 x crossLayer(
        (mqa): crossMQA()
        (pre_mqa_norm_x): CosineNorm()
        (pre_mqa_norm_c): CosineNorm()
        (post_mqa_norm): CosineNorm()
      )
    )
    (final_norm): CosineNorm()
  )
  (final_embed_norm): CosineNorm()
  (ce_loss_fn): CrossEntropyLoss()
  (concept_loss_fn): matryoshkaConceptLoss(
    (embedding_combiner): CombineEmbeddings()
    (concept_loss_fn): CosineSimilarity()
  )
)


# Load a Pretrained Model

In [19]:
# Initialize a blank model
model = NCP_MatFormer(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/?.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

FileNotFoundError: [Errno 2] No such file or directory: 'models/?.pth'

# Training

In [19]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [20]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    # for regular data you'd have varieties of sequence lengths that would just be padded to the max_seq_len
    # but for TinyShakespeare we need to artificially grab data of a variety of sequence lengths to ensure RoPE generalizes
    random_seq_len = random.randint(config.combo ** config.levels, config.max_seq_len)
    # length explained below
    ix = torch.randint(len(data) - (random_seq_len + (config.combo ** (config.levels-1))), (batch_size,))
    x = torch.stack([data[i:i+random_seq_len] for i in ix])
    ### i actually need the y tensor to be length config.max_seq_len + (config.combo ** (config.levels-1)) to fit the future concepts
    y = torch.stack([data[i+1:i+1+(random_seq_len + (config.combo ** (config.levels-1)))] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

x,y = get_batch(split = 'train', batch_size = 32)
x.shape, y.shape

(torch.Size([32, 116]), torch.Size([32, 120]))

In [21]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [22]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.05
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 2

# how often we want to check & see how our loss is doing
eval_interval = 1

# batch size to use
batch_size = 32

# ------ BOOKMARK ------- 

In [25]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)
model.final_embed_norm.enable_logging()
for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

Entering CosineNorm.forward
Input shape: torch.Size([128, 128])
Entering CosineNorm._normalize
Input shape: torch.Size([128, 128])
Output shape: torch.Size([128, 128])
Exiting CosineNorm._normalize
Output shape: torch.Size([128, 128])
Exiting CosineNorm.forward
Entering CosineNorm.forward
Input shape: torch.Size([128, 128])
Entering CosineNorm._normalize
Input shape: torch.Size([128, 128])
Output shape: torch.Size([128, 128])
Exiting CosineNorm._normalize
Output shape: torch.Size([128, 128])
Exiting CosineNorm.forward
Entering CosineNorm.forward
Input shape: torch.Size([128, 128])
Entering CosineNorm._normalize
Input shape: torch.Size([128, 128])
Output shape: torch.Size([128, 128])
Exiting CosineNorm._normalize
Output shape: torch.Size([128, 128])
Exiting CosineNorm.forward
Entering CosineNorm.forward
Input shape: torch.Size([128, 128])
Entering CosineNorm._normalize
Input shape: torch.Size([128, 128])
Output shape: torch.Size([128, 128])
Exiting CosineNorm._normalize
Output shape: to

# Saving your model

In [20]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}-v{config.vocab_size}-seq{config.max_seq_len}'
           f'-num_layers{config.num_layers}'
           f'-sa_qs{config.sa_q_heads}-sa_kvs{config.sa_kv_heads}'
           f'-embed_d{config.embed_dim}'
           f'-mlp_mult{config.mlp_multiplier}'
           f'-sa_head_d{config.sa_head_dim}'
           f'-theta{config.theta}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [23]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_seq_len - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

RuntimeError: shape '[32, -1, 4, 16]' is invalid for input of size 1024

In [21]:
len(output)

256