# Next-Concept Prediction v11.2

So the main plan here is to
1. [x] turn minGemma into a kind of MatFormer
2. [x] have the smallest level deal with tokens, medium with first-level concepts, and large with 2nd-level concepts
3. [x] dynamically generate concept vectors based on regressive cosine similarity rather than storing huge vector combination vocabularies
4. [x] allow lower levels to attend to the vector representations at the levels above them using cross-attention
5. [x] ~~with inference, higher levels will run first, lower levels attend to them, and then each time a lowest level finishes its predictions of length config.combo then that value gets pushed up to the higher levels which now have to recalculate their predictions~~ -- ~~i'm an idiot, the code written for training actually does work perfectly well for performing inference. i mean it's not as efficient as it could be but it is peak quality given how the model was trained.~~ 
    - [ ] no i'm an idiot again. so i think when you pump it through it only predicts one concept vector ahead and then suddenly uses that for lower levels. this is made worse by the fact that i've been training it to do exactly that rather than training it to finish to the full context length just because i was concerned about dynamic RoPE. what i need to do is precompute RoPE, return get_batch() to its previous state where it always called a full length, and then mess with i guess Body to make it call our cosine similarity search thing and build out the full future sequence
    - [ ] rather than actually use the raw regression outputs, we can use cosine similarity to search for the token combination at that level that is most similar to the current regression output, and then add that to the sequence for our next run of the decoder instead. That way the model is always working in-distribution. Would be interesting to see which one performs better. Also, if this option is chosen then config.noise_sd would be useless bc we'd no longer have to deal with lower levels attending to upper level concepts that aren't in-distribution. this is a similarly big project to the above bullet point; really they should be implemented at the same time
6. [ ] I think the moving goalposts problem is likely too difficult for the model to have to deal with straight from the get-go. Rather, i should set it up to train the token level first and then progressively engage higher levels. i think all i have to do to make this happen is pull the loss additions out of the model and into the training loop, no? 
    - [ ] ~~honestly now that i'm thinking about it, it might make sense to pull the concept loss out entirely since the model can train just through the lowest level's backpropogation through crossMQA.~~ no that makes no sense because then how would i do autoregressive decoding of concepts
    - [ ] Also, might it make sense to project the final output of a conceptual layer to the dimension size below? that way it's really focusing on the task at hand rather than having this entire second half of the vector who's goal is to just predict itself after starting from a random initialization. doing this would require that i mess with crossMQA. 

less important todo list
- [x] clean up the primary forward() by putting things into their own classes
- [ ] ~~change the get_batch() function to give us sequence lengths that are not multiples of config.combo so that the thing generalizes to prompts that aren't a perfect length. likely need to add a <|padding|> embedding vector for this to work well. should it be learnable or maybe set it to all zeros?~~ nvm i don't think that's a good idea bc i don't want to encourage the model to predict simpler concepts (concepts made up of fewer tokens/concepts) so i think instead i have to just hope that if they give me a prompt that's off number than it can generalize.
    - [x] yoooo what if instead of padding the end i pad the beginning???? models tend to use the first embedding in the sequence as an attention sink anyways and that way it'd never be encouraged to predict dumb shit. ok i think this is the way to go. should be a simple edit to the combine embeddings class and then i can do the thing where i give varying sequence lengths. wait but how would that mess with what gets paid attention to? i think it doesn't matter with our current version of cross-attention. so just adjust the current implementation of padding and then the get_batch() function
- [x] add an option to norm the initial residual state instead of scaling it by sqrt(embed_dim). would it be interesting to both norm and scale? for example, with cosine norm that would put the initial vectors on a sqrt(d) radius hypersphere centered at the origin i guess
- [ ] add noise to the upper-triangular part of the cross-attention; maybe anothe option where this "noise mask" is stronger over the course of each row to represent the likely inaccuracy that the model will experience during inference for progressive concept embeddings. or alternatively, add a mask that only allows the model to see the concept it's currently working on, so kinda like a mask of diagonal 1's except with blocks of size config.combo rather than entries. this won't be necessary if i go down the route of replacing the model's guesses with the nearest cosine similarity concept combos.
- [x] pump `training` variable throughout for stuff like dropout
- [x] add a ~~config.verbose dictionary~~ debug/demonstrate logger
- [ ] ~~switch the cross-attention to be inbetween self-attention and MLP (more like vaswani, which i assume is how encoder-decoder structures usually work). while we're making it more like vaswani, i should just let every layer use cross-attention instead of doing it in intervals. not sure if they used the same cross-attention mechanism weights at every layer or created an nn.ModuleList~~ 
    - [x] actually i'm thinking it shouldn't matter and doing this would require me to tear up `Layer` and make `Body` more important which i think isn't worth the effort. like it really shouldn't matter at all tbh. I should prolly set it to activate every layer tho.
- [x] the code in the norms have gotten so similar it'd prolly be more visually appealing for me to combine them all into one object plus then i wouldn't have to be picky about where in the stack of cells `config` gets defined
- [ ] write a more efficient inference algorithm. specifically, have each level predict config.combo vectors rather than 1 vector. this *i think* should be just as ideal in terms of accuracy but obviously more efficient 

further ideas for v11.3 (maybe call it 11.2b?):
- [ ] do it without the matryoshka-ness
    - [ ] add multiple <|bos|> tokens to put at the beginning of each sequence, one for each level. So that way the model can know what level it's working with and have somewhere to sink attention.
- [ ] setup `effective_seq_len_mult` and `seq_len_list` and a bunch of downstream stuff like weird batching at lower sub-levels to allow for hella long effective context lengths

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the debugging/demonstration setup
import functools
import inspect

# imports for the tokenizer
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

# Imports used for the config
from dataclasses import dataclass
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used for training
import random
import time

# used to save & load models
import json
from dataclasses import asdict

In [3]:
def log_io(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if not self.logging_enabled:
            return func(self, *args, **kwargs)

        def log_item(item, name, level=0, is_root=False):
            indent = "    " * level
            if isinstance(item, torch.Tensor):
                print(f"{indent}Tensor '{name}' shape: {item.shape}")
            elif isinstance(item, tuple):
                if is_root and level == 0:
                    # Root level tuple, don't print it as a tuple unless it's a "true" tuple
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level)
                else:
                    print(f"{indent}Tuple '{name}':")
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level + 1)
            elif isinstance(item, int):
                print(f"{indent}Integer '{name}': Value={item}")
            else:
                print(f"{indent}Other-type '{name}': Type={type(item).__name__}, Value={item}")

        print(f"\n{'='*10}Entering {self.__class__.__name__}.{func.__name__}{'='*10}")
        print("Inputs:")
        arg_names = inspect.getfullargspec(func).args[1:]  # Excluding 'self'
        arg_values = args + tuple(kwargs.values())
        for name, value in zip(arg_names, arg_values):
            log_item(value, name)

        result = func(self, *args, **kwargs)
        print("\nOutputs:")
        if isinstance(result, tuple):
            log_item(result, "output", is_root=True)
        else:
            log_item(result, "output")

        print(f"{'='*10}Exiting {self.__class__.__name__}.{func.__name__}{'='*10}")
        return result
    return wrapper


In [4]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# and the tokenizer
tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)

# Config

In [5]:
@dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for my next-concept predictor
    """
    ### boring hyperparameters
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 256
    num_layers: int = 6
    sa_q_heads: int = 2
    sa_kv_heads: int = 1
    embed_dim: int = 256
    mlp_multiplier: int = 4
    sa_head_dim: int = 64
    theta: float = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate: float = 0.05
    eps = 1e-6

    ### Normalization related (also boring)
    norm_init_embed: bool = True # whether or not to normalize right after turning input indices into the first residual state
    scale_init_embed: bool = True # whether or not to scale the initial residual state by the nth root of embed_dim. default True
    init_scale_degree: int = 2 # what nth root to take of embed_dim to use as the scale for the initial residual. default 2
    norm_final_embed: bool = True # whether or not to normalize the embeddings at the output layer before multiplying for logits
    norm_affine: bool = False # whether norms should have a linear & bias after them
    norm_type = "CosineNorm"  # Options are RMSNorm, CosineNorm and LayerNorm

    ### MatFormer+ related
    levels = 2
    split = 2
    @property
    def embed_dim_list(self):
        return [self.embed_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]
    @property
    def sa_head_dim_list(self):
        return [self.sa_head_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]

    ### Concept embedding vectors
    combo = 4
    concept_loss = "cos" # options are 'mae', 'mse', and 'cos'(default)
    @property
    def seq_len_list(self):
        return [(self.max_seq_len // (self.combo ** (i-1))) for i in range(1, self.levels + 1)]
    # how much to discount each higher level in the loss function compared to the last. gives the smaller models more say in the gradient
    level_loss_weight: float = 2.0

    ### Dualcoder cross-attention
    # how many times to do the cross-attention
    ca_connections: int = num_layers # for simplicity let's just keep them the same & not mess with it
    ca_q_heads: int = sa_q_heads
    ca_kv_heads: int = sa_kv_heads
    shared_ca: bool = False # True: same weights shared for each ca connection. False: separarely instantiated module for each connection
    ca_use_RoPE: bool = False # True: expands out k & v tensors to be usable with rope. False: leaves k & v same size but no positional encodings
    ca_noise_sd: float = None # float: adds gaussian noise with that sd to higher-level concept vecs. None: no noise added
    # i don't think i can change this one because
    @property
    def ca_head_dim(self): 
        return self.sa_head_dim // self.split 
    @property
    def ca_head_dim_list(self):
        return [self.ca_head_dim // (self.split ** (i-1)) for i in range(self.levels-1, 0, -1)]
    @property
    def ca_interval(self):
        return self.num_layers // self.ca_connections

    ### assertions
    assert sa_q_heads % sa_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in self-attention'
    assert ca_q_heads % ca_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in cross-attention'
    assert max_seq_len % (split ** levels) == 0, 'your context length must be a multiple of split ** levels'
    assert embed_dim % (split ** levels) == 0, 'your embedding dimension must be a multiple of split ** levels'
    assert sa_head_dim % (split ** levels) == 0, 'your self-attentionhead dimension must be divisible by split ** levels'
    #assert self.ca_head_dim % (split ** levels) == 0, 'your cross-attentionhead dimension must be divisible by split ** levels'
    assert num_layers % ca_connections == 0, 'the total number of layers must be divisible by the number of cross-attention connections'        
        
config = Config()
print(config)
print(f"embedding dimension size of each model: {config.embed_dim_list}")
print(f"attention head size of each model: {config.sa_head_dim_list}")
print(f"sequence length of each model: {config.seq_len_list}")
print(f"base head dimension of cross-attention to higher levels: {config.ca_head_dim}")
print(f"head size of each cross-attention connection: {config.ca_head_dim_list}")
print(f"cross-attention to a higher level will be applied every {config.ca_interval} layers")
print(f"loss discounts starting from lowest level: {[config.level_loss_weight**i for i in range(config.levels)]}")

Config(vocab_size=128, max_seq_len=256, num_layers=6, sa_q_heads=2, sa_kv_heads=1, embed_dim=256, mlp_multiplier=4, sa_head_dim=64, theta=100.0, dropout_rate=0.05, norm_init_embed=True, scale_init_embed=True, init_scale_degree=2, norm_final_embed=True, norm_affine=False, level_loss_weight=2.0, ca_connections=6, ca_q_heads=2, ca_kv_heads=1, shared_ca=False, ca_use_RoPE=False, ca_noise_sd=None)
embedding dimension size of each model: [128, 256]
attention head size of each model: [32, 64]
sequence length of each model: [256, 64]
base head dimension of cross-attention to higher levels: 32
head size of each cross-attention connection: [32]
cross-attention to a higher level will be applied every 1 layers
loss discounts starting from lowest level: [1.0, 2.0]


# Norms

In [17]:
class Norm(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps
        self.affine = config.norm_affine
        self.dropout_rate = config.dropout_rate
        self.type = config.norm_type

        # Initialize weight and bias parameters for affine transformation
        # We start with ones for weight to keep the original scale initially, and zeros for bias.
        self.w = nn.Parameter(torch.ones(config.embed_dim))
        self.b = nn.Parameter(torch.zeros(config.embed_dim))
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Normalize the input tensor
        if self.type == "CosineNorm":
            x = self.CosineNorm(x)
        elif self.type == "LayerNorm":
            x = self.LayerNorm(x)
        else: # defaults to RMSNorm bc that's the most commonly used nowadays
            x = self.RMSNorm(x)

        if self.affine: # Optionally apply the affine transformation with splicing
            w, b = self.splice_affine(self.w, self.b, x.shape[-1])
            x = x * w + b
            # and dropout the linear projection if we're training
            x = F.dropout(x, p=self.dropout_rate, training=training) 
            
        return x 

    @log_io
    def CosineNorm(self, x):
        # normalize x by dividing by its L2 norm along the last dimension.
        # this places x on the unit hypersphere centered at the origin
        # Add a small constant to the denominator to avoid division by zero.
        return x / torch.norm(x, p=2, dim=-1, keepdim=True).clamp(min=self.eps)

    @log_io
    def LayerNorm(self, x):
        # normalize x by subtracting by its mean then dividing by its variance
        # this places x on a hypersphere of radius sqrt(dimension) centered at the origin
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps)

    @log_io
    def RMSNorm(self, x):
        # normalize x by dividing by its root-mean-square along the last dimension
        # this places x on a hypersphere of radius sqrt(dimension) with no certain center
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    @log_io
    def splice_affine(self, weight, bias, d_i):
        return weight[:d_i], bias[:d_i]

### demonstration/debugging
I've setup these little snippets after each nn.Module to help you see what's happening and for my own debugging

In [18]:
# Create an instance of RMSNorm
module = Norm(config)

# Initially, logging is disabled
# Enable logging
module.enable_logging()

# Call the forward method - logging will occur
output = module(torch.randn(32, config.max_seq_len, config.embed_dim // config.combo))

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 64])


# RoPE

i like the idea of pre-computing RoPE embeddings but at this point i don't think it's worth the effort bc i'd have to not only use this code but also pipe the two instantiations of this class through from `Body` all the way to `selfMQA` and `crossMQA` and I'm not even sure if it matters. I really should learn more about RoPE

class RoPE(nn.Module):
    def __init__(self, 
                 dim: int, 
                 max_seq_len:int = config.max_seq_len, 
                 device: str = config.device):
        super().__init__()
        # Validate that dim is even since we split it by 2 for real and imaginary parts
        if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
        # Precompute frequencies based on configuration
        theta = config.theta if hasattr(config, 'theta') else 10000.0
        
        freqs = 1.0 / (theta ** (torch.arange(0, config.dim, 2, device=config.device).float() / config.dim))
        t = torch.arange(config.max_seq_len, device=config.device)
        freqs = torch.outer(t, freqs).to(config.device).float()
        
        # Register as buffer to prevent gradient tracking
        self.register_buffer('freqs_cis', torch.polar(torch.ones_like(freqs), freqs)) # complex64

    def forward(self, x):
        # Apply rotary embeddings to the input tensor
        x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
        x_out = torch.view_as_real(x_ * self.freqs_cis.unsqueeze(0)).type_as(x)
        x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
        x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)
        return x_out

\# demonstration/debugging
module = RoPE(dim=10)
module.enable_logging()
output = module(torch.randn(, 5))
del module, output

In [19]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Validate that dim is even since we split it by 2 for real and imaginary parts
    if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # it's important to train on a wide variety of sequence lengths within your context length so that the model learns to generalize

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# selfMQA

In [20]:
class selfMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.sa_q_heads = config.sa_q_heads
        self.sa_kv_heads = config.sa_kv_heads
        assert self.sa_q_heads % self.sa_kv_heads == 0
        self.num_queries_per_kv = self.sa_q_heads // self.sa_kv_heads

        self.embed_dim = config.embed_dim
        self.sa_head_dim = config.sa_head_dim
        self.theta = config.theta
        self.dropout_rate = config.dropout_rate

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.sa_q_heads + 2 * self.sa_kv_heads) * self.sa_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)

        self.Wo = nn.Parameter(torch.Tensor(self.sa_q_heads * self.sa_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5), (1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5)

        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, d_i = x.shape
        h_i = self.sa_head_dim // (self.embed_dim // d_i)

        # splicing our primary projection to get the correct sub-matrices
        Wq, Wk, Wv, Wo = self.weight_splicing(self.Wqkv, self.Wo, d_i, h_i)
        # technically self.weight_splicing has access to self.Wqkv & Wo but this way our debugger can see them

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also dropout if we're training
        xk = F.dropout(x @ Wk, p=self.dropout_rate, training=training)
        xv = F.dropout(x @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.sa_q_heads, h_i)
        xk = xk.view(batch_size, -1, self.sa_kv_heads, h_i)
        xv = xv.view(batch_size, -1, self.sa_kv_heads, h_i)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = self.RoPE(xq, h_i)
        xk = self.RoPE(xk, h_i)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.sa_kv_heads != self.sa_q_heads:
            xk = self.match_headcount(xk) # [batch_size, input_len, n_local_heads, sa_head_dim]
            xv = self.match_headcount(xv)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, sa_head_dim]
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, xk, h_i) # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = self.apply_mask(logits, input_len)

        # applies values to get final output
        output = self.calc_output(logits, xv, batch_size, input_len) 

        # Applies the final linear projection to the attention output, mapping it back to d_i
        return F.dropout(output @ Wo, p=self.dropout_rate, training=training) # also dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv, Wo, d_i, h_i):
        Wq, Wk, Wv = Wqkv.split([self.sa_q_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim],dim = -1)
        Wq = torch.cat([Wq[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_q_heads)], dim = 1)
        Wk = torch.cat([Wk[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_kv_heads)], dim = 1)
        Wv = torch.cat([Wv[:d_i, j*self.sa_head_dim:j*self.sa_head_dim + h_i] for j in range(self.sa_kv_heads)], dim = 1)
        Wo = torch.cat([Wo[j*self.sa_head_dim :j*self.sa_head_dim + h_i, :d_i] for j in range(self.sa_q_heads)], dim=0)
        return Wq, Wk, Wv, Wo

    @log_io
    def RoPE(self, x, h_i):
        return RoPE(x, h_i, self.theta)

    @log_io
    def match_headcount(self, xn):
        return torch.repeat_interleave(xn, self.num_queries_per_kv, dim=2)

    @log_io
    def attend(self, xq, xk, h_i):
        return torch.matmul(xq, xk.transpose(2, 3)) * (h_i ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, xv, batch_size, input_len):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ xv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

### demonstration/debugging

In [21]:
module = selfMQA(config)
module.enable_logging()
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])
Tensor 'Wo' shape: torch.Size([128, 256])
Integer 'd_i': Value=256
Integer 'h_i': Value=64

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])
Tensor 'output[3]' shape: torch.Size([128, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 2, 64])
Integer 'h_i': Value=64

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 1, 64])
Integer 'h_i': Value=64

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 1, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 2, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'xk' shape: 

# MLP

In [22]:
class MLP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 mlp_multiplier: int,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier

        # the gate
        self.Wgate = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bgate = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the up projection
        self.Wup = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bup = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the down projection
        self.Wdown = nn.Parameter(torch.Tensor(self.hidden_size, embed_dim))
        self.Bdown = nn.Parameter(torch.Tensor(embed_dim))
        torch.nn.init.uniform_(self.Wdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        torch.nn.init.uniform_(self.Bdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        
        self.dropout_rate = dropout_rate
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def forward(self,
                x: torch.Tensor,
                training: bool = False
               ) -> torch.Tensor:
        d_i = x.shape[-1]
        gate = x @ self.Wgate[:d_i, :d_i * self.mlp_multiplier] + self.Bgate[:d_i * self.mlp_multiplier]
        up = x @ self.Wup[:d_i, :d_i * self.mlp_multiplier] + self.Bup[:d_i * self.mlp_multiplier]
        fuse = F.dropout(F.gelu(gate) * up, p=self.dropout_rate, training=training)
        down = fuse @ self.Wdown[:d_i * self.mlp_multiplier, :d_i] + self.Bdown[:d_i]
        return F.dropout(down, p=self.dropout_rate, training=training)

### demonstration/debugging

In [23]:
module = MLP(config.embed_dim, config.mlp_multiplier, config.dropout_rate)
module.enable_logging()
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


# Layer

In [24]:
class Layer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.mqa = selfMQA(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier, config.dropout_rate)

        # Initialize normalization layers using the class reference from config
        self.pre_mqa_norm = Norm(config)
        self.post_mqa_norm = Norm(config)
        self.pre_mlp_norm = Norm(config)
        self.post_mlp_norm = Norm(config)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, 
                x: torch.Tensor,
                training: bool = False
               ) -> torch.Tensor:
        x = x + self.mqa_connection(x, training)
        x = x + self.mlp_connection(x, training)
        return x

    @log_io
    def mqa_connection(self, x, training):
        return self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x, training), training), training)

    @log_io
    def mlp_connection(self, x, training):
        return self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x, training), training), training)

### demonstration/debugging

In [25]:
module = Layer(config)
module.enable_logging()
#module.pre_mqa_norm.enable_logging()
#module.post_mqa_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.post_mlp_norm.enable_logging()
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


# crossMQA

In [26]:
class crossMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.ca_q_heads = config.ca_q_heads
        self.ca_kv_heads = config.ca_kv_heads
        assert self.ca_q_heads % self.ca_kv_heads == 0
        self.num_queries_per_kv = self.ca_q_heads // self.ca_kv_heads

        self.embed_dim = config.embed_dim
        self.ca_head_dim = config.ca_head_dim
        self.sa_head_dim = config.sa_head_dim # used only for an assertion to make sure sizes will fit
        self.combo = config.combo # used only for an assertion to make sure sizes will fit
        self.split = config.split # used only for an assertion to make sure sizes will fit
        self.theta = config.theta
        self.use_RoPE = config.ca_use_RoPE
        self.noise_sd = config.ca_noise_sd
        self.dropout_rate = config.dropout_rate

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.ca_q_heads + 2 * self.ca_kv_heads) * self.ca_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)
        
        self.Wo = nn.Parameter(torch.Tensor(self.ca_q_heads * self.ca_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5), (1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5)

        # for now we'll try no attention mask and then maybe (probably) edit that later 
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, 
                x: torch.Tensor, # the lower level tensor, sometimes a resid state full of tokens & sometimes concepts
                c: torch.Tensor, # the upper level tensor, always a resid state full of concept vecs
                training: bool = False
               ) -> torch.Tensor:
        # optionally adds some random gaussian noise to the upper-level concepts to simulate how bad they'll prolly be
        # i think this should only be implemented during training, but imma have to go & spread that bool througout later
        # this version does the entire txt matrix but later i'd like switch it to just upper-triangular
        if self.noise_sd is not None:
            noise = torch.randn_like(c, requires_grad = False)
            c = c + noise * self.noise_sd
        
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len_x, xd_i = x.shape
        batch_size_c, input_len_c, cd_i = c.shape
        assert batch_size == batch_size_c
        
        # finds the appropriate head dimension and ensures that they match
        h_i = self.ca_head_dim // (self.embed_dim // cd_i)
        assert h_i == self.sa_head_dim // (self.embed_dim // xd_i), 'head_dim was not same bw levels in cross attention'
        
        # splicing our projection to get the correct sub-matrices
        Wq, Wk, Wv, Wo = self.weight_splicing(self.Wqkv, self.Wo, xd_i, cd_i, h_i)

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also applies dropout if we're training
        ck = F.dropout(c @ Wk, p=self.dropout_rate, training=training)
        cv = F.dropout(c @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.ca_q_heads, h_i)
        ck = ck.view(batch_size, -1, self.ca_kv_heads, h_i)
        cv = cv.view(batch_size, -1, self.ca_kv_heads, h_i)

        # IF we want to use RoPE (doesn't fully make sense to)
        if self.use_RoPE:
            expand = input_len_x // input_len_c
            ck = ck.repeat_interleave(expand, dim=1) 
            cv = cv.repeat_interleave(expand, dim=1) # values need to be expanded for their use later on if we do this

            # Applies rotary positional embeddings to queries and keys to incorporate positional information.
            xq = self.RoPE(xq, h_i) 
            ck = self.RoPE(ck, h_i)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.ca_kv_heads != self.ca_q_heads:
            ck = self.match_headcount(ck)
            cv = self.match_headcount(cv) # [batch_size, input_len, n_local_heads, ca_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, ca_head_dim]
        ck = ck.transpose(1, 2)
        cv = cv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, ck, h_i) # [batch_size, n_local_heads, input_len, input_len]

        # applies values to get final output
        output = self.calc_output(logits, cv, batch_size, input_len_x)

        # Applies the final linear projection to the attention output, mapping it back to d_i. 
        return F.dropout(output @ Wo, p=self.dropout_rate, training=training) # and dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv, Wo, xd_i, cd_i, h_i):
        Wq, Wk, Wv = Wqkv.split([self.ca_q_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim],dim = -1)
        Wq = torch.cat([Wq[:xd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_q_heads)], dim = 1)
        Wk = torch.cat([Wk[:cd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_kv_heads)], dim = 1)
        Wv = torch.cat([Wv[:cd_i, j*self.ca_head_dim:j*self.ca_head_dim + h_i] for j in range(self.ca_kv_heads)], dim = 1)
        Wo = torch.cat([Wo[j*self.ca_head_dim:j*self.ca_head_dim + h_i, :xd_i] for j in range(self.ca_q_heads)], dim=0)
        return Wq, Wk, Wv, Wo

    @log_io
    def RoPE(self, x, h_i):
        return RoPE(x, h_i, self.theta)

    @log_io
    def match_headcount(self, xn):
        return torch.repeat_interleave(xn, self.num_queries_per_kv, dim=2)

    @log_io
    def attend(self, xq, ck, h_i):
        return torch.matmul(xq, ck.transpose(2, 3)) * (h_i ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, cv, batch_size, input_len_x):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ cv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len_x, -1)

### demonstration/debugging

at one point this one had been giving me trouble so here's multiple different config.levels setups for ya

In [27]:
hold = config.levels
config.levels = 2
module = crossMQA(config)
module.enable_logging()
x0 = torch.randn(32, config.max_seq_len, config.embed_dim // config.split)
c1 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = module(x0, c1)
config.levels = hold
del hold, module, x0, c1, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 128])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 128])
Tensor 'Wo' shape: torch.Size([64, 256])
Integer 'xd_i': Value=128
Integer 'cd_i': Value=256
Integer 'h_i': Value=32

Outputs:
Tensor 'output[0]' shape: torch.Size([128, 64])
Tensor 'output[1]' shape: torch.Size([256, 32])
Tensor 'output[2]' shape: torch.Size([256, 32])
Tensor 'output[3]' shape: torch.Size([64, 128])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 32])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 32])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 32])
Tensor 'ck' shape: torch.Size([32, 2, 64, 32])
Integer 'h_i': Value=32

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Tensor 'cv' shape: torch.Size([32, 2, 64, 32]

In [28]:
hold = config.levels
config.levels = 3
module = crossMQA(config)
module.enable_logging()
x0 = torch.randn(32, config.max_seq_len, config.embed_dim // (config.split**2))
c1 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim // config.split)
c2 = torch.randn(32, config.max_seq_len // (config.combo**2), config.embed_dim)
output = module(x0, c1)
output = module(c1, c2)
config.levels = hold
del hold, module, x0, c1, c2, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 64])
Tensor 'c' shape: torch.Size([32, 64, 128])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 128])
Tensor 'Wo' shape: torch.Size([64, 256])
Integer 'xd_i': Value=64
Integer 'cd_i': Value=128
Integer 'h_i': Value=16

Outputs:
Tensor 'output[0]' shape: torch.Size([64, 32])
Tensor 'output[1]' shape: torch.Size([128, 16])
Tensor 'output[2]' shape: torch.Size([128, 16])
Tensor 'output[3]' shape: torch.Size([32, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 16])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 16])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 16])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 16])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 16])
Tensor 'ck' shape: torch.Size([32, 2, 64, 16])
Integer 'h_i': Value=16

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Tensor 'cv' shape: torch.Size([32, 2, 64, 16])
In

In [29]:
hold = config.levels
config.levels = 4
module = crossMQA(config)
module.enable_logging()
x0 = torch.randn(32, config.max_seq_len, config.embed_dim // (config.split**3))
c1 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim // (config.split**2))
c2 = torch.randn(32, config.max_seq_len // (config.combo**2), config.embed_dim // config.split)
c3 = torch.randn(32, config.max_seq_len // (config.combo**3), config.embed_dim)
output = module(x0, c1)
output = module(c1, c2)
output = module(c2, c3)
config.levels = hold
del hold, module, x0, c1, c2, c3, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 32])
Tensor 'c' shape: torch.Size([32, 64, 64])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 128])
Tensor 'Wo' shape: torch.Size([64, 256])
Integer 'xd_i': Value=32
Integer 'cd_i': Value=64
Integer 'h_i': Value=8

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 16])
Tensor 'output[1]' shape: torch.Size([64, 8])
Tensor 'output[2]' shape: torch.Size([64, 8])
Tensor 'output[3]' shape: torch.Size([16, 32])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 8])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 8])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 8])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 8])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 8])
Tensor 'ck' shape: torch.Size([32, 2, 64, 8])
Integer 'h_i': Value=8

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Tensor 'cv' shape: torch.Size([32, 2, 64, 8])
Integer 'batch_si

# crossLayer

In [30]:
class crossLayer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.mqa = crossMQA(config)

        # jic i want to add an MLP later
        #self.mlp = MLP(config.embed_dim, config.mlp_multiplier)

        # Initialize normalization layers using the class reference from config
        self.pre_mqa_norm_x = Norm(config)
        self.pre_mqa_norm_c = Norm(config)
        self.post_mqa_norm = Norm(config)
        #self.pre_mlp_norm = Norm(config)
        #self.post_mlp_norm = Norm(config)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def forward(self,
                x: torch.Tensor,
                c: torch.Tensor,
                training: bool = False
               ) -> torch.Tensor:
        x = x + self.mqa_connection(x, c, training)
        #x = x + self.mlp_connection(x, training)
        return x

    @log_io
    def mqa_connection(self, x, c, training):
        return self.post_mqa_norm(self.mqa(x = self.pre_mqa_norm_x(x, training), 
                                           c = self.pre_mqa_norm_c(c, training), training=training), training)
        return x

    #@log_io
    #def mlp_connection(self, x, training):
        #return self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x, training), training), training)

### demonstration/debugging

In [31]:
module = crossLayer(config)
module.enable_logging()
#module.mqa.enable_logging()
#module.pre_mqa_norm_x.enable_logging()
#module.pre_mqa_norm_c.enable_logging()
#module.post_mqa_norm.enable_logging()
x = torch.randn(32, config.max_seq_len, config.embed_dim // config.split)
c = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = module(x, c)
del module, x, c, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 128])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 128])
Tensor 'c' shape: torch.Size([32, 64, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 128])


# Body

In [63]:
# to prevent the warning statement from printing hella times
cvec_warning = False

class Body(nn.Module):
    def __init__(self, config: Config, embedder: torch.Tensor):
        super().__init__()
        self.ca_interval = config.ca_interval
        self.max_seq_len = config.max_seq_len
        self.combo = config.combo
        self.levels = config.levels
        self.embedding = embedder.weight
        
        # Initialize a sequence of Layer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_layers))
        # initialize the (single / sequence of) cross-attention layer(s)
        self.ca_layers = nn.ModuleList(crossLayer(config)) if config.shared_ca else nn.ModuleList(crossLayer(config) for _ in range(config.ca_connections))
        
        # Initialize normalization layers to be applied after the last decoder layers, stabilizing the output
        self.final_norm = Norm(config)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, 
                x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                cvec_samples: int = None,
                cvec_greedy: bool = False,
                cvec_temp: float = 1.0,
                training: bool = False,
               ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level
            
            effective_max_seq_len = self.max_seq_len // (self.combo ** (self.levels-1-i))
            assert x.shape[1] <= effective_max_seq_len, f'somehow a too-long sequence ({x.shape[1]} vs {effective_max_seq_len}) made it all the way to Body'
            extra_runs = effective_max_seq_len - x.shape[1]
            for k in range(extra_runs): # if extra_runs == 0 then this just won't do anything
                # run
                x_ = self.layers_loop(x, i, x0s, training)
                # subset -1
                x_ = x_[:,-1,:]
                # norm? no, i'm worried about affine transformations messing with it since this process wasn't present during training
                # select most similar concept vectors to be appended to the sequence
                c = self.concept_matchup(x_, cvec_samples, cvec_greedy, cvec_temp)
                # append to x
                x = torch.concat([x, c.unsqueeze(1)], dim=1)

            # the final run. this one will actually be logits to get used with crossMQA
            x = self.layers_loop(x, i, x0s, training)
            
            # add the final residual state of the level to our tuple, normed
            xfs += (self.final_norm(x, training),) # should i be using separate final norms? nah

        return xfs # now it's ordered from highest concepts -> token

    @log_io
    def layers_loop(self, x, i, x0s, training):
        
        # Iteratively process the input through each Layer of the model
        for j in range(len(self.layers)):
            
            # our occasional cross-attention connection to the upper level
            if (i != 0) & (j % self.ca_interval == 0): # i can't equal zero bc there'd be no higher level model to pay attention to
                ca_layer = self.ca_layers[j // self.ca_interval]
                x = ca_layer(x, x0s[len(x0s)-i], training)

            # our current-level model's work
            layer = self.layers[j]
            x = layer(x, training)
        return x

    @log_io
    def concept_matchup(self,
                        c: torch.Tensor,
                        cvec_samples: int,
                        cvec_greedy: bool,
                        cvec_temp: float,
                        ) -> torch.Tensor:
        global cvec_warning
        batch_size, d = c.size()
        vocab_size = self.embedding.size(0)
    
        # Batch cosine similarity
        # Reshape c: (batch_size x 1 x embedding_dim)
        # Reshape embedding: (1 x vocab_size x embedding_dim)
        # Resulting similarity: (batch_size x vocab_size)
        token_similarities = F.cosine_similarity(c.unsqueeze(1), self.embedding.unsqueeze(0), dim=-1)
        
        # how many tokens will we sample to build up our chosen concept vector?
        if cvec_samples is None:
            cvec_samples = self.combo ** (self.levels-1)
            if (cvec_warning == False) or (cvec_warning is None):
                print(f"cvec_samples not defined. defaulting to highest level's minimum size: combo**(levels-1) = {cvec_samples}")
                cvec_warning = True
        assert cvec_samples >= self.combo ** (self.levels-1)
        
        # Select top-k token embeddings for each concept vector
        topk_token_indices = torch.topk(token_similarities, k=cvec_samples, dim=1).indices  # (batch_size x sample)
    
        # Generate concept embeddings for each set of top-k token embeddings
        concept_embeddings_batch = []
        X_sizes_batch = []
        for i in range(batch_size):
            # Pass the list of indices for each concept
            concept_embeddings, X_sizes = self.create_concept_embeddings(self.embedding, 
                                                                         [topk_token_indices[i].tolist()])
            concept_embeddings_batch.append(concept_embeddings.squeeze(0))  # Remove the extra batch dimension
            X_sizes_batch.append(X_sizes)
    
        # Convert list of tensors to a tensor
        concept_embeddings_batch = torch.stack(concept_embeddings_batch)  # (batch_size x max_X_size x d)
    
        # Calculate concept similarities for each concept in the batch
        concept_similarities_batch = F.cosine_similarity(c.unsqueeze(1), concept_embeddings_batch, dim=-1)
    
        # Select the best matching concept embedding for each concept vector in the batch
        if cvec_greedy:
            best_concept_indices = concept_similarities_batch.argmax(dim=1)
            matched_concepts = concept_embeddings_batch[torch.arange(batch_size), best_concept_indices]
        else:
            # Apply softmax with temperature and sample
            topk_concept_probs = F.softmax(concept_similarities_batch / cvec_temp, dim=1)
            concept_topk_idx = torch.multinomial(topk_concept_probs, num_samples=1).squeeze(1)
            matched_concepts = concept_embeddings_batch[torch.arange(batch_size), concept_topk_idx]
    
        return matched_concepts

    @log_io
    def create_concept_embeddings(self, E: torch.Tensor, indices: torch.Tensor):
        """
        Create concept embeddings for a batch of indices.
    
        E: Embedding matrix (vocab_size x embedding_dim)
        indices: A list of lists of indices (batch_size x num_indices)
        """
        batch_size = len(indices)
        d = E.size(1)
        X_sizes = [(len(ind) - 1) * len(ind) // 2 for ind in indices]
        max_X_size = max(X_sizes)
        X = torch.empty((batch_size, max_X_size, d), dtype=E.dtype)
        
        # this could prolly be done way more efficiently with tensor operations
        for b in range(batch_size):
            count = 0
            for i in range(len(indices[b])):
                for j in range(i + 1, len(indices[b])):
                    X[b, count] = E[indices[b][i]] + E[indices[b][j]]
                    count += 1
            # Padding the rest if necessary
            if count < max_X_size:
                X[b, count:] = torch.zeros((max_X_size - count, d))
        
        # X_sizes is not useful rn but i think it may be later when we switch away from TinyShakespeare
        # and over to data that actually has variable sequence lengths
        return X, X_sizes

### demonstration/debugging

In [64]:
# first let's do 2 levels with full sequence length
embedding = nn.Embedding(config.vocab_size, config.embed_dim)
hold = config.levels
config.levels = 2
module = Body(config, embedding)
module.enable_logging()
#module.layers.enable_logging()
#module.ca_layers.enable_logging()
#module.final_norm.enable_logging()
x = torch.randn(32, config.max_seq_len, config.embed_dim // config.split)
c = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
x0s = (x,c)
output = module(x0s)
config.levels = hold
del embedding, hold, module, x, c, x0s, output


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 256])
Integer 'i': Value=0
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 128])
Integer 'i': Value=1
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 128])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 64, 256])
Tensor 'output[1]' shape: torch.Size([32, 256, 128])


In [65]:
# now 3 levels full sequence length
embedding = nn.Embedding(config.vocab_size, config.embed_dim)
hold = config.levels
config.levels = 3
module = Body(config, embedding)
module.enable_logging()
#module.layers.enable_logging()
#module.ca_layers.enable_logging()
#module.final_norm.enable_logging()
x0 = torch.randn(32, config.max_seq_len // (config.combo ** 0), config.embed_dim // (config.split ** 2))
c1 = torch.randn(32, config.max_seq_len // (config.combo ** 1), config.embed_dim // (config.split ** 1))
c2 = torch.randn(32, config.max_seq_len // (config.combo ** 2), config.embed_dim // (config.split ** 0))
x0s = (x0, c1, c2)
output = module(x0s)
config.levels = hold
del embedding, hold, module, x0, c1, c2, x0s, output


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 128])
    Tensor 'x0s[2]' shape: torch.Size([32, 16, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 16, 256])
Integer 'i': Value=0
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 128])
    Tensor 'x0s[2]' shape: torch.Size([32, 16, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 16, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 128])
Integer 'i': Value=1
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 64, 128])
    Tensor 'x0s[2]' shape: torch.Size([32, 16, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 64])
Integer 'i': Value=2
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 256, 64])
    T

In [66]:
# now 2 levels partial sequence (so like we're doing inference)
embedding = nn.Embedding(config.vocab_size, config.embed_dim)
hold = config.levels
config.levels = 2
module = Body(config, embedding)
module.enable_logging()
#module.layers.enable_logging()
#module.ca_layers.enable_logging()
#module.final_norm.enable_logging()
x = torch.randn(32, config.max_seq_len - (config.combo**2), config.embed_dim // config.split)
c = torch.randn(32, (config.max_seq_len // config.combo) - config.combo, config.embed_dim)
x0s = (x,c)
output = module(x0s, cvec_samples = config.combo ** config.levels)
config.levels = hold
del embedding, hold, module, x, c, x0s, output




Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 240, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 60, 256])
Integer 'cvec_samples': Value=16

Inputs:
Tensor 'x' shape: torch.Size([32, 60, 256])
Integer 'i': Value=0
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 240, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 60, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 60, 256])

Inputs:
Tensor 'c' shape: torch.Size([32, 256])
Integer 'cvec_samples': Value=16
Integer 'cvec_greedy': Value=False
Other-type 'cvec_temp': Type=float, Value=1.0

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[80, 97, 85, 23, 51, 118, 110, 73, 44, 33, 86, 46, 87, 106, 103, 38]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[121, 12, 42, 52, 57, 72, 108, 


Outputs:
Tensor 'output' shape: torch.Size([32, 61, 256])

Inputs:
Tensor 'c' shape: torch.Size([32, 256])
Integer 'cvec_samples': Value=16
Integer 'cvec_greedy': Value=False
Other-type 'cvec_temp': Type=float, Value=1.0

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[51, 85, 111, 71, 61, 69, 106, 29, 73, 38, 65, 72, 113, 43, 75, 33]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[112, 5, 93, 78, 81, 66, 105, 95, 15, 121, 22, 58, 20, 68, 110, 52]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[21, 121, 110, 81, 63, 38, 14, 28, 127, 92, 27, 108, 47, 112, 17, 49]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[


Outputs:
Tensor 'output' shape: torch.Size([32, 62, 256])

Inputs:
Tensor 'c' shape: torch.Size([32, 256])
Integer 'cvec_samples': Value=16
Integer 'cvec_greedy': Value=False
Other-type 'cvec_temp': Type=float, Value=1.0

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[113, 71, 85, 111, 100, 6, 108, 48, 51, 106, 122, 84, 65, 112, 31, 61]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[22, 78, 111, 5, 57, 106, 107, 9, 0, 112, 104, 68, 8, 116, 33, 83]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[63, 47, 122, 121, 35, 119, 93, 46, 25, 51, 83, 100, 44, 67, 26, 38]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output


Outputs:
Tensor 'output' shape: torch.Size([32, 63, 256])

Inputs:
Tensor 'c' shape: torch.Size([32, 256])
Integer 'cvec_samples': Value=16
Integer 'cvec_greedy': Value=False
Other-type 'cvec_temp': Type=float, Value=1.0

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[84, 71, 122, 0, 80, 39, 61, 51, 74, 17, 53, 10, 86, 47, 100, 77]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[106, 8, 11, 46, 34, 69, 78, 60, 79, 36, 95, 57, 125, 114, 22, 61]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': Type=list, Value=[120]

Inputs:
Tensor 'E' shape: torch.Size([128, 256])
Other-type 'indices': Type=list, Value=[[122, 46, 26, 63, 47, 74, 125, 1, 75, 90, 65, 106, 28, 87, 38, 93]]

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 120, 256])
Other-type 'output[1]': T


Outputs:
Tensor 'output' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 240, 128])
Integer 'i': Value=1
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 240, 128])
    Tensor 'x0s[1]' shape: torch.Size([32, 60, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 240, 128])

Inputs:
Tensor 'c' shape: torch.Size([32, 128])
Integer 'cvec_samples': Value=16
Integer 'cvec_greedy': Value=False
Other-type 'cvec_temp': Type=float, Value=1.0


RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 2

In [62]:
tensor = torch.randn(32, 256)
tensor.unsqueeze(1).shape

torch.Size([32, 1, 256])

# embedding vector combination function

In [14]:
class CombineEmbeddings(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.padding_vector = nn.Parameter(torch.zeros(embed_dim), requires_grad=True)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False

    @log_io
    def forward(self, tensor, combine_factor):
        b, t, d = tensor.shape

        # Calculate the necessary amount of padding
        remainder = t % combine_factor
        padding_needed = 0 if remainder == 0 else combine_factor - remainder
        
        if padding_needed > 0:
            # Replicate the padding vector the necessary number of times
            padding = self.padding_vector.repeat(padding_needed, 1).unsqueeze(0).expand(b, -1, -1)
            tensor = torch.cat([padding[...,:d], tensor], dim=1) # subset padding to fit with matryoshka size
        
        # Update t after padding
        t_padded = t + padding_needed
        
        # Reshape the tensor to group 'combine_factor' entries along the t dimension
        reshaped_tensor = tensor.view(b, t_padded // combine_factor, combine_factor, d)
        
        # Sum over the groups
        combined_tensor = reshaped_tensor.sum(dim=2)

        return combined_tensor

### demonstration/debugging

In [None]:
module = CombineEmbeddings(config.embed_dim)
module.enable_logging()
x = torch.randn(32, config.max_seq_len, config.embed_dim // config.combo)
output = module(x, config.combo)
del module, x, output

# Matryoshka Concept Loss

In [15]:
class matryoshkaConceptLoss(nn.Module):
    def __init__(self, config: Config, embedding_combiner):
        super().__init__()
        self.combo = config.combo
        self.levels = config.levels
        self.split = config.split
        self.level_loss_weight = config.level_loss_weight
        self.embedding_combiner = embedding_combiner

        if config.concept_loss == "mae":
            self.concept_loss_fn = nn.L1Loss()
        elif config.concept_loss == "mse":
            self.concept_loss_fn = nn.MSELoss()
        else: # defaults to cosine similarity loss
            self.concept_loss_fn = nn.CosineSimilarity(dim=-1, eps=1e-6)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def create_target_vecs(self,
                           lvl_output: torch.Tensor,
                           target_token_ids: torch.Tensor,
                           input_len: int,
                           embedder: nn.Module,
                           i: int
                          ) -> torch.Tensor:
        lvl_combo = self.combo ** (self.levels-1-i)
        target_token_ids_adj = target_token_ids[:, lvl_combo:lvl_combo + input_len]
        raw_target_vectors = embedder(target_token_ids_adj)[...,:lvl_output.shape[-1]]
        # remember to detach/clone so that we don't mess with the embeddings we want to train on
        return self.embedding_combiner(raw_target_vectors, lvl_combo).detach().clone()
            
    @log_io
    def forward(self, 
                xfs: Tuple[torch.Tensor], 
                target_token_ids: torch.Tensor, 
                input_len: int,
                embedder: nn.Module,
               ) -> torch.Tensor:
        # iterate through all concept-embedding layers and calculate loss
        concept_loss = torch.tensor(0.0)
        for i in range(self.levels - 1):
            # select our relevant final residual state
            lvl_output = xfs[i]
            
            # calculate the decay value placed on this level's total amount of loss
            lambadada = (self.level_loss_weight ** (self.levels -1 -i))
            
            # create the target vectors for this level
            target_vectors = self.create_target_vecs(lvl_output,
                                                     target_token_ids,
                                                     input_len,
                                                     embedder,
                                                     i)
            
            # now choose whether we're doing mse/mae vs cos sim loss bc they get calc'd differently
            if config.concept_loss == 'mae' or config.concept_loss == 'mse':
                # Reshape output and target_vectors to combine batch and seq_len dimensions
                lvl_output_flat = lvl_output.view(-1, lvl_output.size(-1))
                target_vectors_flat = target_vectors.view(-1, target_vectors.size(-1))
                
                # Calculate MSE or MAE & add it to the total
                concept_loss = concept_loss + self.concept_loss_fn(lvl_output_flat, target_vectors_flat) * lambadada
            else: # defaults to cosine similarity loss
                cosine_loss = (1 - self.concept_loss_fn(lvl_output, target_vectors)).mean()
                concept_loss = concept_loss + cosine_loss * lambadada

        return concept_loss

### demonstration/debugging

In [None]:
hold = config.levels
config.levels = 3 # i'm too lazy to make this part dynamic so let's just set this to 3
embedding_combiner = CombineEmbeddings(config.embed_dim)
module = matryoshkaConceptLoss(config, embedding_combiner)
module.enable_logging()
#module.embedding_combiner.enable_logging()

embedder = nn.Embedding(config.vocab_size, config.embed_dim)
target_token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + (config.combo**(config.levels-1))))
x = torch.randn(32, config.max_seq_len, config.embed_dim // (config.combo**(config.levels-1)))
c2 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim // config.combo)
c4 = torch.randn(32, config.max_seq_len // (config.combo**(config.levels-1)), config.embed_dim)
xfs = (c4, c2, x)
output = module(xfs, target_token_ids, config.max_seq_len, embedder)
config.levels = hold
del hold, embedding_combiner, module, embedder, target_token_ids, x, c2, c4, xfs, output

# Model

In [16]:
class NCP_MatFormer(nn.Module):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        
        ### general hyperparameters
        self.max_seq_len = config.max_seq_len
        self.sa_head_dim = config.sa_head_dim
        self.vocab_size = config.vocab_size
        self.embed_dim = config.embed_dim
        
        ### matryoshka hyperparameters
        self.combo = config.combo
        self.levels = config.levels
        self.split = config.split
        
        ### cross-attention
        self.ca_connections = config.ca_connections
        self.ca_interval = config.ca_interval
        
        ### embedding
        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.embed_dim)
        # the function that combines embeddings into higher level concept residual states
        self.embedding_combiner = CombineEmbeddings(config.embed_dim)
        self.norm_init_embed = config.norm_init_embed # whether to norm the first residual states
        if config.norm_init_embed: # if so, define its norm
            self.init_embed_norm = Norm(config)
        self.scale_init_embed = config.scale_init_embed # whether to scale first residuals by nth root of respective matryoshka dimension
        self.init_scale_degree = config.init_scale_degree # the nth root to take
        # should we norm the final embed before calculating logits? idk i can't decide
        self.norm_final_embed = config.norm_final_embed
        if config.norm_final_embed:
            self.final_embed_norm = Norm(config)

        # the actual bulk of the model
        self.body = Body(config, self.embedder)
                
        ### the loss functions
        # lowest-level token model
        self.ce_loss_fn = nn.CrossEntropyLoss()
        # concept models
        self.concept_loss_fn = matryoshkaConceptLoss(config, self.embedding_combiner)
        
        self.logging_enabled = False
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
        
    @log_io
    def create_x0s(self, input_token_ids: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor]:
        # turn the input tokens into the first residual state using the embedding matrix
        x0 = self.embedder(input_token_ids) # (batch_size, input_len, embed_dim)
        
        ### prepping our first token-wise residual state
        # find what matryoshka dimension it should be
        first_dim = self.embed_dim // (self.split ** (self.levels-1)) 
        # splice to said matryoshka dimension
        x0t = x0[...,:first_dim] # t stands for token. need to separate it from the raw x0 which will be used later
        # optionally norm it
        if self.norm_init_embed: x0t = self.init_embed_norm(x0t, training) 
        # optionally scale by n'th root of dimension
        if self.scale_init_embed: x0t = x0t * (first_dim ** (1/self.init_scale_degree)) 
        # finally instantiate the tuple that'll hold all the residual states
        x0s = (x0t,) 
        
        ### iterating through levels to create each higher-level concept residual state
        for i in range(self.levels-1):
            # combine into smaller tensor by adding token (or lower level concept) embeddings together
            lvl_combo = self.combo ** (i+1)
            x0c = self.embedding_combiner(x0, lvl_combo) # c stands for concept
            
            # find the correct matryoshka dimension for this level
            this_level_dim = self.embed_dim // (self.split ** (self.levels - 2 - i))
            # splice to said matryoshka dimension
            x0c = x0c[...,:this_level_dim]
            # optionally norm it
            if self.norm_init_embed: x0c = self.init_embed_norm(x0c, training) 
            # optionally scale by n'th root of dimension
            if self.scale_init_embed: x0c = x0c * (this_level_dim ** (1/self.init_scale_degree)) 
            # finally add it to the tuple of residual states
            x0s += (x0c,)
        
        return x0s, first_dim
        
    @log_io
    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids to run forward pass on
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len + combo ** (levels-1)) list of token ids to train on
        ) -> torch.Tensor:
        training = True if target_token_ids is not None else False

        # create the tuple of initial residual states to calculate on
        x0s, first_dim = self.create_x0s(input_token_ids, training) # also, first_dim will be used later

        # the body of the model that iterates through the decoder & cross-attention layers
        xfs = self.body(x0s, training)

        # grabbing the weights of the embedding matrix shape (vocab_size, embed_dim) for use as the output layer
        embedder_weight = self.embedder.weight
        # optionally norm it
        if self.norm_final_embed: embedder_weight = self.final_embed_norm(embedder_weight, training)
        # calculating output logits
        logits = xfs[-1] @ embedder_weight[:,:first_dim].t()
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else: # if we are training
            ### first up is regular CE token loss
            batch_size, input_len, vocab_size = logits.shape
            # splice target tokens to exclude the ones that were only to be used by concept levels
            target_token_ids_spliced = target_token_ids[:,:input_len]
            # we reshape our logits & targets before calculating cross-entropy loss
            ce_loss = self.ce_loss_fn(logits.view(batch_size*input_len, vocab_size),
                                      target_token_ids_spliced.reshape(batch_size*input_len))

            ### the new thing, a regression loss for all our concept-embedding layers
            concept_loss = self.concept_loss_fn(xfs, target_token_ids, 
                                                input_len, self.embedder)
            loss = ce_loss + concept_loss
            
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    @log_io
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        # Select the last element for each sequence & apply temperature scaling
        logits = logits[:,-1,:].div_(temperature) # -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    @log_io
    def generate(self,
                 prompt: str,
                 output_len: int = 1, # the model will output 1 token by default
                 temperature: float = 0.7, # 1.0 would be no effect
                 top_p: float = 0.8,
                 top_k: int = 4,
                ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        if len(tokens) + output_len > self.config.max_seq_len:
            output_len = self.max_seq_len - len(tokens)
            print("capping output at maximum sequence length")

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(logits, temperature, top_p, top_k)

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        return self.tokenizer.decode(tokens.squeeze(0).tolist())

### demonstration/debugging

In [None]:
module = NCP_MatFormer(config, tokenizer)
module.enable_logging()
#module.embedding_combiner.enable_logging()
#module.init_embed_norm.enable_logging()
#module.body.enable_logging()
#module.final_embed_norm.enable_logging()
#module.concept_loss_fn.enable_logging()
input_token_ids = torch.randint(config.vocab_size, 
                                (32, config.max_seq_len))
target_token_ids = torch.randint(config.vocab_size, 
                                 (32, config.max_seq_len + (config.combo ** (config.levels-1))))
output, loss = module(input_token_ids, target_token_ids)
del module, input_token_ids, target_token_ids, output, loss

# Instantiate a brand new model

In [17]:
model = NCP_MatFormer(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

5673.216 K parameters
NCP_MatFormer(
  (embedder): Embedding(128, 256)
  (embedding_combiner): CombineEmbeddings()
  (init_embed_norm): Norm()
  (final_embed_norm): Norm()
  (body): Body(
    (layers): ModuleList(
      (0-5): 6 x Layer(
        (mqa): selfMQA()
        (mlp): MLP()
        (pre_mqa_norm): Norm()
        (post_mqa_norm): Norm()
        (pre_mlp_norm): Norm()
        (post_mlp_norm): Norm()
      )
    )
    (ca_layers): ModuleList(
      (0-5): 6 x crossLayer(
        (mqa): crossMQA()
        (pre_mqa_norm_x): Norm()
        (pre_mqa_norm_c): Norm()
        (post_mqa_norm): Norm()
      )
    )
    (final_norm): Norm()
  )
  (ce_loss_fn): CrossEntropyLoss()
  (concept_loss_fn): matryoshkaConceptLoss(
    (embedding_combiner): CombineEmbeddings()
    (concept_loss_fn): CosineSimilarity()
  )
)


# Training

In [18]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [23]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - (config.max_seq_len + (config.combo ** (config.levels-1))), (batch_size,))
    x = torch.stack([data[i:i+config.max_seq_len] for i in ix])
    ### i actually need the y tensor to be + (config.combo ** (config.levels-1)) to fit the future concepts
    y = torch.stack([data[i+1:i+1+(config.max_seq_len + (config.combo ** (config.levels-1)))] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [24]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [25]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
config.learning_rate = 1e-5
config.weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

# how long we want to train for
config.max_iters = 2

# how often we want to check & see how our loss is doing
eval_interval = 1

# batch size to use
config.batch_size = 32

In [26]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(config.max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', config.batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == config.max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, config.batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 6.5168, val loss 6.5191, time elapsed: 1.88 seconds
step 1: train loss 6.5127, val loss 6.5208, time elapsed: 9.99 seconds


# Saving your model

In [23]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
config_dict = asdict(config)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(config_dict, f)

# Load a Pretrained Model

In [None]:
name = 'NCP_MatFormer_2024-03-27|01-39-19'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    config_dict = json.load(f)

# Convert the dictionary back to a dataclass object
config = Config(**config_dict)

# Initialize a blank model
model = NCP_MatFormer(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

# Inference

In [24]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_seq_len - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou  s thy s to to, s su te thut sots, so tearts susts s,  e  ss triss sis,  tr ts  sut ,  t  so,  so,  ,  t  s, ti  s   tearss,  trus tr e t  s sist s  so, t t tess, s ,   , sost, shouous sit,  sif sotttset, es,  sis ttett  sur sos,  eart, t t


In [None]:
len(output)