# v11.3

whereas the last version used matryoshka models, i think it makes more sense to go all-in on the self-similarity thing and simplify the whole process by using the same decoder for all levels

todo's:
- [x] copy & paste stuff that doesn't need to be changed
- [x] copy, paste, & remove references to matryoshka embeddings for modules w/ em
- [x] bring over todo items that haven't been completed from v11.2 and from notes app
- [ ] ~~setup single attention mechanism module for both self & cross?~~ i feel like separate is prolly still easier. and this is kinda just unnecessary extra work
- [x] setup single decoder layer body
- [x] setup concept loss to use MSE and COS simultaneously
- [ ] get predictive attention mask working in crossMQA
- [x] add different types of pooling options to embedding combiner
- [ ] pull concept loss out of model and into training loop so that i can progressively introduce the concept embedding prediction problem rather than deal with moving goalposts from the get-go
- [x] make option for linear layer vs MLP in output from final concept residual states to actual concepts
- [ ] write an inference algorithm that sacrifices quality for efficiency & see how well it works
- [ ] ~~add multiple <|bos|> tokens to put at the beginning of each sequence, one for each level. So that way the model can know what level it's working with and have somewhere to sink attention.~~ the whole "language is fractal" thing means that the model shouldn't need to know. also if an MLP or linear is used for combining concepts then this would be happening implicitly anyways. better to just let a future better tokenizer do its job. 
- [x] make get_batch output values that'd result in actually using the padding token so that it gets trained
- [ ] ~~switch to GaLore optimizer to save ram~~
- [x] implement basic stuff i'm missing like cosine learning rate scheduling (gradient clipping?)
- [ ] move everything into .py files without the LoggingModule parent
- [ ] create big hyperparameter testing script to send to ben
- [ ] would it make sense to have the RoPE theta adjust for the different sequence levels? like divide it by config.combine_factor? i really need to learn more about how RoPE works at an intuitive level rather than just "uuuuuh trig rotations"
- [ ] ~~maybe add learned positional encodings to pooling?~~ nah ur doing too much
- [ ] get the current version running & training so that i can replace TinyShakespeare with TinyStories and make corresponding edits to the dataloader
- [x] make concepts defined recursively
- [ ] is the mlp / linear layer / affine on the norm that gets used for each level->up combination shared across level bridges? If not it should be. or there should at least be an option to make it so.
- [ ] setup inference concept selection

      - [ ] add option to select based on cos, mae, and/or mse

further ideas likely for v11.4:
- [ ] setup effective_seq_len_mult and seq_len_list and a bunch of downstream stuff like weird batching at lower sub-levels to allow for hella long effective context lengths. i think what i do here is setup batches such that for a given context length T the concept level above will see nT before and after. i guess not really nT before and after, rather imagine if the context length of the token level model is T and the context length of the concept leel is nT then the token level will attend to a random splice of length T, and ofc it's prolly gonna be super annoying adapting the predictive mask to work with this
- [ ] figure out a way to make concepts dynamic instead of predetermined length. I might be able to do something similar to [Quiet-STAR](https://arxiv.org/abs/2403.09629) to get this to work in a way that's not exactly what i was expecting. Like allow the model to generate multiple outputs and if those outputs help in the actual concept vector generation then keep them in the sequence

# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the debugging/demonstration setup
import functools
import inspect

# imports for the tokenizer
from tiny_stories_tokenizer import *

# Imports used for the config
from dataclasses import dataclass
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used for training
import random
import time

# used to save & load models
import json
from dataclasses import asdict

# Demonstration/Debugging wrapper & Module

In [3]:
# this function will be used throughout for debugging/demonstration purposes
# using this is way cleaner than cluttering up our code with print statements
def log_io(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        # Check if logging is enabled globally and for the specific function
        if not self.logging_enabled or func.__name__ in self.disabled_logging_functions:
            return func(self, *args, **kwargs)
        #if not self.logging_enabled:
            #return func(self, *args, **kwargs)

        def log_item(item, name, level=0, is_root=False):
            indent = "    " * level
            if isinstance(item, torch.Tensor):
                print(f"{indent}Tensor '{name}' shape: {item.shape}")
            elif isinstance(item, tuple):
                if is_root and level == 0:
                    # Root level tuple, don't print it as a tuple unless it's a "true" tuple
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level)
                else:
                    print(f"{indent}Tuple '{name}':")
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level + 1)
            elif isinstance(item, int):
                print(f"{indent}Integer '{name}': Value={item}")
            elif isinstance(item, float):
                print(f"{indent}Float '{name}': Value={item}")
            else:
                print(f"{indent}Other-type '{name}': Type={type(item).__name__}, Value={item}")

        print(f"\n{'='*10}Entering {self.__class__.__name__}.{func.__name__}{'='*10}")
        print("Inputs:")
        arg_names = inspect.getfullargspec(func).args[1:]  # Excluding 'self'
        arg_values = args + tuple(kwargs.values())
        for name, value in zip(arg_names, arg_values):
            log_item(value, name)

        result = func(self, *args, **kwargs)
        print("\nOutputs:")
        if isinstance(result, tuple):
            log_item(result, "output", is_root=True)
        else:
            log_item(result, "output")

        print(f"{'='*10}Exiting {self.__class__.__name__}.{func.__name__}{'='*10}")
        return result
    return wrapper

class LoggingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.logging_enabled = False
        self.disabled_logging_functions = set()

    def enable_logging(self):
        self.logging_enabled = True

    def disable_logging(self):
        self.logging_enabled = False

    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)

    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

# Config & tokenizer

In [4]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
# we'll replace this with the TinyStories dataset and the corresponding funcitons below once we've got a regular running model again

# and the tokenizer
tokenizer = get_tokenizer(512)

In [5]:
@dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for my next-concept predictor
    """
    ### boring hyperparameters
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 128
    embed_dim: int =128
    num_layers: int = 8
    sa_q_heads: int = 2
    sa_kv_heads: int = 1
    #attn_bias: bool = False
    sa_head_dim: int = 32
    mlp_multiplier: int = 2
    theta: float = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate: float = 0.1
    eps = 1e-6
    norm_type = "RMSNorm"  # Options are RMSNorm, CosineNorm and LayerNorm. defaults to RMSNorm
    norm_affine: bool = True # whether norms should have a linear & bias after them. likely less necessary if you're using RMSNorm
    learning_rate_init: float = 1e-1
    learning_rate_end: float = 1e-4
    weight_decay: float = 0.01
    max_iters: int = 1000

    ### Concept embedding vectors
    levels: int = 3
    combine_factor: int = 4 # how many lower-level tokens/concepts to combine into the next level's concept
    combine_type: str = 'reshape->mlp_post_reshape->norm' # options to combine w/ '->' are 'sum', 'mean', 'max', 'linear', 'mlp', 'reshape', 'linear_post_reshape', and 'mlp_post_reshape'
    @property
    def seq_len_list(self):
        return [(self.max_seq_len // (self.combine_factor ** (i-1))) for i in range(1, self.levels + 1)]

    ### Dualcoder cross-attention
    ca_q_heads: int = sa_q_heads
    ca_kv_heads: int = sa_kv_heads
    ca_head_dim: int = sa_head_dim
    ca_use_RoPE: bool = False # True: expands out k & v tensors to be usable with rope. False: leaves k & v same size but no positional encodings
    predictive_mask: bool = False # True: upper-triangular predictive mask to focus model's attention (not currently working). False: no mask like a regular encoder
    predictive_mask_noise: float = None # float: sd of noise to add to predictively masked concept vecs. None: don't implement noise

    ### concept output
    output_layer = 'mlp' # options are 'linear' and 'mlp' which uses the default mlp_multiplier
    # how much to discount each higher level in the loss function compared to the last
    level_loss_weight: float = 1.0 
    # multiple losses can act on the concept vectors at once
    cos_loss: bool = True
    mse_loss: bool = True
    mae_loss: bool = True
    
    ### assertions
    assert sa_q_heads % sa_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in self-attention'
    assert ca_q_heads % ca_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in cross-attention'     
        
config = Config()
print(config)
print(f"sequence length of each model: {config.seq_len_list}")
print(f"loss discounts starting from lowest level: {[config.level_loss_weight**i for i in range(config.levels)]}")

Config(vocab_size=512, max_seq_len=128, embed_dim=128, num_layers=8, sa_q_heads=2, sa_kv_heads=1, sa_head_dim=32, mlp_multiplier=2, theta=100.0, dropout_rate=0.1, norm_affine=True, learning_rate_init=0.1, learning_rate_end=0.0001, weight_decay=0.01, max_iters=1000, levels=3, combine_factor=4, combine_type='reshape->mlp_post_reshape->norm', ca_q_heads=2, ca_kv_heads=1, ca_head_dim=32, ca_use_RoPE=False, predictive_mask=False, predictive_mask_noise=None, level_loss_weight=1.0, cos_loss=True, mse_loss=True, mae_loss=True)
sequence length of each model: [128, 32, 8]
loss discounts starting from lowest level: [1.0, 1.0, 1.0]


# RoPE

In [6]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Validate that dim is even since we split it by 2 for real and imaginary parts
    if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # it's important to train on a wide variety of sequence lengths within your context length so that the model learns to generalize

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

## maybe precompute frequencies later

i like the idea of pre-computing RoPE embeddings but at this point i don't think it's worth the effort bc i'd have to not only use this code but also pipe the two instantiations of this class through from `Body` all the way to `selfMQA` and `crossMQA` and I'm not even sure if it matters. I really should learn more about RoPE

```Python
# this class has not been implemented nor even tested. on my todo list
class RoPE(nn.Module):
    def __init__(self, 
                 dim: int, 
                 max_seq_len:int = config.max_seq_len, 
                 device: str = config.device):
        super().__init__()
        # Validate that dim is even since we split it by 2 for real and imaginary parts
        if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
        # Precompute frequencies based on configuration
        theta = config.theta if hasattr(config, 'theta') else 10000.0
        
        freqs = 1.0 / (theta ** (torch.arange(0, config.dim, 2, device=config.device).float() / config.dim))
        t = torch.arange(config.max_seq_len, device=config.device)
        freqs = torch.outer(t, freqs).to(config.device).float()
        
        # Register as buffer to prevent gradient tracking
        self.register_buffer('freqs_cis', torch.polar(torch.ones_like(freqs), freqs)) # complex64

    def forward(self, x):
        # Apply rotary embeddings to the input tensor
        x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
        x_out = torch.view_as_real(x_ * self.freqs_cis.unsqueeze(0)).type_as(x)
        x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
        x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)
        return x_out
```

# selfMQA

In [7]:
class selfMQA(LoggingModule): # notice thoughout we'll be inheriting from LoggingModule instead of nn.Module
    def __init__(self, config: Config):
        super().__init__()

        self.sa_q_heads = config.sa_q_heads
        self.sa_kv_heads = config.sa_kv_heads
        assert self.sa_q_heads % self.sa_kv_heads == 0
        self.num_queries_per_kv = self.sa_q_heads // self.sa_kv_heads

        self.embed_dim = config.embed_dim
        self.sa_head_dim = config.sa_head_dim
        self.theta = config.theta
        self.dropout_rate = config.dropout_rate

        #self.Wqkv = nn.Linear(self.embed_dim, 
        #                      (self.sa_q_heads + 2 * self.sq_kv_heads) * self.sa_head_dim, 
        #                      bias = config.attn_bias)
        #self.Wo = nn.Linear(self.sa_q_heads * self.sa_head_dim,
        #                    self.embed_dim,
        #                    bias = config.attn_bias)
        
        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.sa_q_heads + 2 * self.sa_kv_heads) * self.sa_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)

        self.Wo = nn.Parameter(torch.Tensor(self.sa_q_heads * self.sa_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5), (1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5)

        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = x.shape

        # splicing our primary projection to get the correct sub-matrices
        Wq, Wk, Wv = self.weight_splicing(self.Wqkv)
        # technically self.weight_splicing has access to self.Wqkv & Wo but this way our debugger can see them

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also dropout if we're training
        xk = F.dropout(x @ Wk, p=self.dropout_rate, training=training)
        xv = F.dropout(x @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.sa_q_heads, self.sa_head_dim)
        xk = xk.view(batch_size, -1, self.sa_kv_heads, self.sa_head_dim)
        xv = xv.view(batch_size, -1, self.sa_kv_heads, self.sa_head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq, xk = self.RoPE(xq, xk)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.sa_kv_heads != self.sa_q_heads:
            xk, xv = self.match_headcount(xk, xv) # [batch_size, input_len, n_local_heads, sa_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, sa_head_dim]
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, xk) # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = self.apply_mask(logits, input_len)

        # applies values to get final output
        output = self.calc_output(logits, xv, batch_size, input_len, training) 

        # Applies the final linear projection to the attention output, mapping it back to self.embed_dim
        return F.dropout(output @ self.Wo, p=self.dropout_rate, training=training) # also dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv):
        Wq, Wk, Wv = Wqkv.split([self.sa_q_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim],dim = -1)
        return Wq, Wk, Wv

    @log_io
    def RoPE(self, xq, xk):
        xq = RoPE(xq, self.sa_head_dim, self.theta)
        xk = RoPE(xk, self.sa_head_dim, self.theta)
        return xq, xk

    @log_io
    def match_headcount(self, xk, xv):
        xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
        xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
        return xk, xv

    @log_io
    def attend(self, xq, xk):
        return torch.matmul(xq, xk.transpose(2, 3)) * (self.sa_head_dim ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, xv, batch_size, input_len, training):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # also applies dropout if we're training
        scores = F.dropout(scores, p=self.dropout_rate, training=training)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ xv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

## demonstration/debugging
I've setup these little snippets after each nn.Module to help you see what's happening and for my own debugging

In [8]:
# Create an instance of selfMQA
module = selfMQA(config)

# Initially, logging is disabled
# Enable logging
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

# Call the forward method - logging will occur
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))#, training=True)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 256, 2, 64])
Tensor 'xk' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 256, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 256, 1, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 256, 1, 64])
Tensor 'xv' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 256, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 256, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'xk' shape: torch.Size([32, 2, 256, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 256])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 256])
Integer 'input_len': Value=256

Outputs:
Tensor 'output' s

# MLP

In [8]:
class MLP(LoggingModule):
    def __init__(self,
                 embed_dim: int,
                 mlp_multiplier: int,
                 output_dim: int,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier
        self.dropout_rate = dropout_rate

        # the gate, up and down projections
        self.gate_proj = nn.Linear(embed_dim, self.hidden_size)
        self.up_proj = nn.Linear(embed_dim, self.hidden_size)
        self.down_proj = nn.Linear(self.hidden_size, output_dim)
        
    @log_io
    def forward(self, x: torch.Tensor, training: bool = False ) -> torch.Tensor:
        output = self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))
        return F.dropout(output, p=self.dropout_rate, training=training)

## demonstration/debugging

In [10]:
module = MLP(config.embed_dim, config.mlp_multiplier, config.embed_dim, config.dropout_rate)
module.enable_logging()
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


# Norms

In [9]:
class Norm(LoggingModule):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps
        self.affine = config.norm_affine
        self.dropout_rate = config.dropout_rate
        self.type = config.norm_type

        # Initialize weight and bias parameters for affine transformation
        # We start with ones for weight to keep the original scale initially, and zeros for bias.
        self.w = nn.Parameter(torch.ones(config.embed_dim))
        self.b = nn.Parameter(torch.zeros(config.embed_dim))

        # Mapping norm types to their respective methods
        self.norm_methods = {
            "CosineNorm": self.CosineNorm,
            "LayerNorm": self.LayerNorm,
            "RMSNorm": self.RMSNorm}

        # Ensure the specified norm type exists, default to RMSNorm if not found
        if self.type not in self.norm_methods:
            self.type = "RMSNorm"
            print(f'{self.type} not found. defaulting to RMSNorm')

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Apply normalization
        norm_method = self.norm_methods[self.type]
        x = norm_method(x)

        if self.affine: # Optionally apply the affine transformation and dropout if we're training
            x = F.dropout(x * self.w + self.b, p=self.dropout_rate, training=training)
            
        return x

    @log_io
    def CosineNorm(self, x):
        # normalize x by dividing by its L2 norm along the last dimension.
        # this places x on the unit hypersphere centered at the origin
        # Add a small constant to the denominator to avoid division by zero.
        return x / torch.norm(x, p=2, dim=-1, keepdim=True).clamp(min=self.eps)

    @log_io
    def LayerNorm(self, x): # nn.LayerNorm() exists but might as well make it from scratch if we have to do the other two
        # normalize x by subtracting by its mean then dividing by its variance
        # this places x on a hypersphere of radius sqrt(dimension) centered at the origin
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps)

    @log_io
    def RMSNorm(self, x):
        # normalize x by dividing by its root-mean-square along the last dimension
        # this places x on a hypersphere of radius sqrt(dimension) with no certain center
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

## demonstration/debugging

In [10]:
module = Norm(config)
module.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('RMSNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('CosineNorm')

output = module(torch.randn(32, config.max_seq_len, config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])


# crossMQA

In [10]:
class crossMQA(LoggingModule):
    def __init__(self, config: Config):
        super().__init__()

        self.ca_q_heads = config.ca_q_heads
        self.ca_kv_heads = config.ca_kv_heads
        assert self.ca_q_heads % self.ca_kv_heads == 0
        self.num_queries_per_kv = self.ca_q_heads // self.ca_kv_heads
        self.embed_dim = config.embed_dim
        self.ca_head_dim = config.ca_head_dim
        self.sa_head_dim = config.sa_head_dim # used only for an assertion to make sure sizes will fit
        self.combine_factor = config.combine_factor # used only for an assertion to make sure sizes will fit
        self.theta = config.theta
        self.use_RoPE = config.ca_use_RoPE
        self.dropout_rate = config.dropout_rate
        self.predictive_mask = config.predictive_mask

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.ca_q_heads + 2 * self.ca_kv_heads) * self.ca_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)
        
        self.Wo = nn.Parameter(torch.Tensor(self.ca_q_heads * self.ca_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5), (1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5)

    @log_io
    def forward(self, 
                x: torch.Tensor, # the current level tensor, sometimes a resid state full of tokens & sometimes concepts
                c: torch.Tensor, # the upper level tensor, always a resid state full of concept vecs
                training: bool = False
               ) -> torch.Tensor:
        
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len_x, _ = x.shape
        batch_size_c, input_len_c, _ = c.shape
        assert batch_size == batch_size_c

        # splicing our projection to get the correct sub-matrices
        Wq, Wk, Wv = self.weight_splicing(self.Wqkv)

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also applies dropout if we're training
        ck = F.dropout(c @ Wk, p=self.dropout_rate, training=training)
        cv = F.dropout(c @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.ca_q_heads, self.ca_head_dim)
        ck = ck.view(batch_size, -1, self.ca_kv_heads, self.ca_head_dim)
        cv = cv.view(batch_size, -1, self.ca_kv_heads, self.ca_head_dim)

        # IF we want to use RoPE (doesn't fully make sense to i don't think)
        if self.use_RoPE:
            expand = input_len_x // input_len_c
            ck = ck.repeat_interleave(expand, dim=1) 
            cv = cv.repeat_interleave(expand, dim=1) # values need to be expanded for their use later on if we do this

            # Applies rotary positional embeddings to queries and keys to incorporate positional information.
            xq, ck = self.RoPE(xq, ck)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.ca_kv_heads != self.ca_q_heads:
            ck, cv = self.match_headcount(ck, cv) # [batch_size, input_len, n_local_heads, sa_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, ca_head_dim]
        ck = ck.transpose(1, 2)
        cv = cv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, ck) # [batch_size, n_local_heads, input_len, input_len]
        
        # Optionally applies the upper-triangular mask to the attention logits
        if self.predictive_mask:
            logits = self.apply_mask(logits, input_len_x, input_len_c)

        # applies values to get final output
        output = self.calc_output(logits, cv, batch_size, input_len_x, training)

        # Applies the final linear projection to the attention output, mapping it back to d_i. 
        return F.dropout(output @ self.Wo, p=self.dropout_rate, training=training) # and dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv):
        Wq, Wk, Wv = Wqkv.split([self.ca_q_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim],dim = -1)
        return Wq, Wk, Wv
        
    @log_io
    def RoPE(self, xq, xk):
        xq = RoPE(xq, self.ca_head_dim, self.theta)
        xk = RoPE(xk, self.ca_head_dim, self.theta)
        return xq, xk

    @log_io
    def match_headcount(self, xk, xv):
        xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
        xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
        return xk, xv

    @log_io
    def attend(self, xq, ck):
        return torch.matmul(xq, ck.transpose(2, 3)) * (self.ca_head_dim ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len_x, input_len_c):
        self.mask = torch.triu(torch.ones((config.max_seq_len, input_len_c), dtype=torch.uint8)
                                  ).view(1, 1, config.max_seq_len, input_len_c).to(dtype=torch.bool)
        
        return torch.where(self.mask[..., :input_len_x, :input_len_c].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, cv, batch_size, input_len_x, training):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # also applies dropout if we're training
        scores = F.dropout(scores, p=self.dropout_rate, training=training)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ cv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len_x, -1)

## demonstration/debugging

at one point this one had been giving me trouble so here's multiple different config.levels setups for ya

In [14]:
hold1 = config.levels
config.levels = 2
hold2 = config.predictive_mask
config.predictive_mask = False
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

x0 = torch.randn(32, config.max_seq_len, config.embed_dim)
c1 = torch.randn(32, config.max_seq_len // config.combine_factor, config.embed_dim)
output = module(x0, c1)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, x0, c1, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 64, 1, 64])
Tensor 'xv' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 64, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'ck' shape: torch.Size([32, 2, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Tensor 'cv' shape: torch.Size([32, 2, 64, 64])
Integer 'batch_size': Value=32
Integer 'input_len_x': Value=256
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 128])

Outputs:
Tensor 'output' shape: torch.Size

In [15]:
hold1 = config.levels
config.levels = 3
hold2 = config.predictive_mask
config.predictive_mask = False
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
output = module(c1, c2)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, c1, c2, output


Inputs:
Tensor 'x' shape: torch.Size([32, 64, 256])
Tensor 'c' shape: torch.Size([32, 16, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 16, 1, 64])
Tensor 'xv' shape: torch.Size([32, 16, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 16, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 16, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 64, 64])
Tensor 'ck' shape: torch.Size([32, 2, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 64, 16])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 64, 16])
Tensor 'cv' shape: torch.Size([32, 2, 16, 64])
Integer 'batch_size': Value=32
Integer 'input_len_x': Value=64
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 

## working on crossMQA predictive mask

notice how when i do the live-splicing of the mask, only the lower level actually needs live splicing since the upper level will always be witnessing the full context length

In [11]:
combine_factor = 4
max_seq_len = 16
random_number = random.randint(0, combine_factor-1)
print(f'random_number {random_number}')
input_len = max_seq_len - random_number
h = 1
head_dim = 4
b = 1

offset = combine_factor - (input_len % combine_factor)
x = torch.randn(b, offset + input_len, h, head_dim)
print(f'x: {x.shape}\n{x}')

c = torch.randn(b, (input_len + combine_factor) // combine_factor, h, head_dim)
print(f'c: {c.shape}\n{c}')

x = x.transpose(1, 2)
c = c.transpose(1, 2)

logits = x @ c.transpose(2,3)
print(f'logits: {logits.shape}\n{logits}')

mask_init = torch.triu(torch.ones(b,h, max_seq_len // combine_factor, max_seq_len // combine_factor))#, diagonal=1)
print(f'mask_init: {mask_init.shape}\n{mask_init}')

mask_init = mask_init.repeat_interleave(combine_factor, dim=2)
print(f'mask_init: {mask_init.shape}\n{mask_init}')

mask = mask_init.view(1,1, offset + input_len, (input_len + combine_factor) // combine_factor).to(dtype=torch.bool)
print(f'mask: {mask.shape}\n{mask}')

masked_logits = torch.where(mask.expand_as(logits),#[-1,-1,logits.shape[2],logits.shape[3]]
                            logits,
                            torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
print(f'masked_logits: {masked_logits.shape}\n{masked_logits}')

scores = F.softmax(masked_logits, dim=-1)
print(f'scores: {scores.shape}\n{scores}')

random_number 0
x: torch.Size([1, 20, 1, 4])
tensor([[[[-1.7051,  0.4757,  1.3667,  2.3230]],

         [[-0.4236,  1.1244, -0.2000,  1.1232]],

         [[ 1.2987, -1.2002,  0.8761,  0.3183]],

         [[-0.5328,  0.0147,  1.2374,  1.2683]],

         [[ 1.7717,  2.0656,  0.9818, -0.2135]],

         [[ 0.2059, -0.3381, -0.3381,  0.7264]],

         [[ 0.1274,  0.1418, -0.8047, -1.3830]],

         [[ 2.6674, -0.1535,  0.0233,  0.0684]],

         [[-1.5625,  1.1472, -1.7107, -0.6834]],

         [[ 0.1548, -0.6723,  0.9502, -0.0697]],

         [[-1.1599, -0.6320,  0.7380, -0.8934]],

         [[ 1.0836,  0.3302, -0.7806, -0.3407]],

         [[ 1.4597, -0.9352,  0.2811, -0.3279]],

         [[-0.2260,  0.1323, -0.6450,  1.1478]],

         [[ 0.6980, -1.5263, -0.1852,  0.4709]],

         [[ 1.1317, -0.1370, -0.3752, -1.3378]],

         [[-0.9330,  0.6437, -0.4607,  0.0485]],

         [[-1.1409, -0.7679, -0.8752,  0.1486]],

         [[ 0.3966,  0.3267,  0.1775, -0.8950]],

     

RuntimeError: shape '[1, 1, 20, 5]' is invalid for input of size 64

In [58]:
hold1 = config.levels
config.levels = 2
hold2 = config.predictive_mask
config.predictive_mask = True
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

x0 = torch.randn(32, config.max_seq_len, config.embed_dim)
c1 = torch.randn(32, config.max_seq_len // config.combine_factor, config.embed_dim)
output = module(x0, c1)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, x0, c1, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'ck' shape: torch.Size([32, 2, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Integer 'input_len': Value=256


RuntimeError: The expanded size of the tensor (64) must match the existing size (256) at non-singleton dimension 3.  Target sizes: [32, 2, 256, 64].  Tensor sizes: [1, 1, 256, 256]

# Layer

we implement cross-attention inbetween self-attention and MLP like is done in Attention is All You Need

In [11]:
class Layer(LoggingModule):
    def __init__(self, config: Config):
        super().__init__()

        self.pre_self_mqa_norm = Norm(config)
        self.self_mqa = selfMQA(config)
        
        self.pre_cross_mqa_x_norm = Norm(config)
        self.pre_cross_mqa_c_norm = Norm(config)
        self.cross_mqa = crossMQA(config)
        
        self.pre_mlp_norm = Norm(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier, config.embed_dim, config.dropout_rate)

    @log_io
    def forward(self, 
                x: torch.Tensor,
                c: torch.Tensor = None,
                training: bool = False,
               ) -> torch.Tensor:
        x = x + self.self_mqa_connection(x, training)
        if c is not None:
            x = x + self.cross_mqa_connection(x, c, training)
        x = x + self.mlp_connection(x, training)
        return x

    @log_io
    def self_mqa_connection(self, x, training):
        return self.self_mqa(self.pre_self_mqa_norm(x, training), training)

    @log_io
    def cross_mqa_connection(self, x, c, training):
        return self.cross_mqa(self.pre_cross_mqa_x_norm(x, training), 
                              self.pre_cross_mqa_c_norm(c, training), training)
    @log_io
    def mlp_connection(self, x, training):
        return self.mlp(self.pre_mlp_norm(x, training), training)

## demonstration/debugging

### no predictive mask

In [29]:
module = Layer(config)
module.enable_logging()

### enabling printing for sub-modules
#module.pre_self_mqa_norm.enable_logging()
#module.self_mqa.enable_logging()
#module.post_self_mqa_norm.enable_logging()
#module.pre_cross_mqa_norm.enable_logging()
#module.cross_mqa.enable_logging()
#module.post_cross_mqa_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('self_mqa_connection')
#module.disable_function_logging('cross_mqa_connection')
#module.disable_function_logging('mlp_connection')

output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])


In [30]:
hold = config.predictive_mask
config.predictive_mask = False
module = Layer(config)
module.enable_logging()

### enabling printing for sub-modules
#module.pre_self_mqa_norm.enable_logging()
#module.self_mqa.enable_logging()
#module.post_self_mqa_norm.enable_logging()
#module.pre_cross_mqa_norm.enable_logging()
#module.cross_mqa.enable_logging()
#module.post_cross_mqa_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('self_mqa_connection')
#module.disable_function_logging('cross_mqa_connection')
#module.disable_function_logging('mlp_connection')

x = torch.randn(32, config.max_seq_len, config.embed_dim)
c = torch.randn(32, config.max_seq_len // config.combine_factor, config.embed_dim)
output = module(x, c)
config.predictive_mask = hold
del hold, module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Tensor 'c' shape: torch.Size([32, 32, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Tensor 'c' shape: torch.Size([32, 32, 64])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])


### prolly need to do more debugging once predictive mask works

# embedding vector combination function

In [12]:
class SumModule(LoggingModule): 
    @log_io
    def forward(self, x):
        return x.sum(dim=2)
class MeanModule(LoggingModule):    
    @log_io
    def forward(self, x):
        return x.mean(dim=2)
class MaxModule(LoggingModule):    
    @log_io
    def forward(self, x):
        return x.max(dim=2).values
class ReshapeModule(LoggingModule):
    @log_io
    def forward(self, x):
        return x.reshape(x.shape[0], x.shape[1], -1)

class CombineEmbeddings(LoggingModule):
    def __init__(self, config: Config, padding_vector: torch.Tensor):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.combine_factor = config.combine_factor
        #self.padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
        self.padding_vector = padding_vector

        # Initialize operation chain
        self.operation_chain = nn.Sequential()
        operation_map = {
            "sum": SumModule(), # cannot go before or after reshape
            "mean": MeanModule(), # cannot go before or after reshape
            "max": MaxModule(), # cannot go before or after reshape
            "norm": Norm(config), # cannot go before or after reshape
            "linear_post_reshape": nn.Linear(config.embed_dim * config.combine_factor, config.embed_dim), # must go after reshape
            "linear": nn.Linear(config.embed_dim, config.embed_dim), # cannot go after reshape
            "mlp_post_reshape": MLP(config.embed_dim * config.combine_factor, config.mlp_multiplier, config.embed_dim, config.dropout_rate), # must go after reshape
            "mlp": MLP(config.embed_dim, config.mlp_multiplier, config.embed_dim, config.dropout_rate), # cannot go after reshape
            "reshape": ReshapeModule()}
        index = 0  # Initialize a counter to create unique module names
        for op in config.combine_type.split("->"):
            if op in operation_map:
                unique_op_name = f"{op}_{index}"  # Append the index to the operation name
                self.operation_chain.add_module(unique_op_name, operation_map[op])
                index += 1  # Increment the index for the next operation
            else:
                raise ValueError(f"Pooling operation {op} is not a valid operation")

    
    @log_io
    def forward(self, tensor, combine_factor):
        # this function will apply itself recursively for higher levels
        assert combine_factor % self.combine_factor == 0, f'combine factor={combine_factor} is not a multiple of config.combine_factor={config.combine_factor}'
        if combine_factor // self.combine_factor != 1:
            tensor = self.forward(tensor, combine_factor // self.combine_factor)
            
        b, t, d = tensor.shape
            
        # Calculate the necessary amount of padding
        remainder = t % self.combine_factor
        padding_needed = 0 if remainder == 0 else self.combine_factor - remainder
        
        if padding_needed > 0:
            # Replicate the padding vector the necessary number of times
            padding = self.padding_vector.repeat(padding_needed, 1).unsqueeze(0).expand(b, -1, -1)
            tensor = torch.cat([padding, tensor], dim=1)
        
        # Update t after padding
        t_padded = t + padding_needed
        
        # Reshape the tensor to group 'combine_factor' entries along the t dimension
        reshaped_tensor = tensor.view(b, t_padded // self.combine_factor, self.combine_factor, d)
        
        # the actual combination operation
        combined_tensor = self.operation_chain(reshaped_tensor)
        assert combined_tensor.shape[0] == b, f"b={b}; pooling operation order is invalid. output shape: {combined_tensor.shape}"
        assert combined_tensor.shape[1] == t_padded // self.combine_factor, f"t={t_padded // self.combine_factor}; pooling operation order is invalid. output shape: {combined_tensor.shape}"
        assert combined_tensor.shape[2] == self.embed_dim, f"d={self.embed_dim}; pooling operation order is invalid. output shape: {combined_tensor.shape}"
        
        return combined_tensor

## demonstration/debugging

### testing different pooling types

In [25]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'mlp->sum->mlp->norm'
config.combine_factor = 4
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.mlp_0.enable_logging()
module.operation_chain.sum_1.enable_logging()
module.operation_chain.mlp_2.enable_logging()
module.operation_chain.norm_3.enable_logging()

x = torch.randn(32, config.max_seq_len, config.embed_dim)
output = module(x, config.combine_factor)
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 128])
Integer 'combine_factor': Value=4

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 4, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])


In [26]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'reshape->mlp_post_reshape'
config.combine_factor = 2
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.reshape_0.enable_logging()
module.operation_chain.mlp_post_reshape_1.enable_logging()

x = torch.randn(32, config.max_seq_len, config.embed_dim)
output = module(x, config.combine_factor)
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 128])
Integer 'combine_factor': Value=2

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])


### off-lengths that'll require padding

In [27]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'reshape->linear_post_reshape'
config.combine_factor = 2
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.reshape_0.enable_logging()
# the linear is here it's just not compatible with our printing and i'm too lazy to build a wrapper for it

x = torch.randn(32, config.max_seq_len-1, config.embed_dim)
output = module(x, config.combine_factor)
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 127, 128])
Integer 'combine_factor': Value=2

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])


In [28]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'max'
config.combine_factor = 4
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.max_0.enable_logging()

x = torch.randn(32, config.max_seq_len-3, config.embed_dim)
print(x[0,0,:8])
output = module(x, config.combine_factor)
print(output[0,0,:8])
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None

tensor([ 0.2387, -1.1476,  0.0282, -1.5950, -1.6658,  0.6434, -0.7313,  0.3435])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 125, 128])
Integer 'combine_factor': Value=4

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])
tensor([0.2387, 0.0000, 0.0282, 0.0000, 0.0000, 0.6434, 0.0000, 0.3435],
       grad_fn=<SliceBackward0>)


In [29]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'mean'
config.combine_factor = 4
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.mean_0.enable_logging()

x = torch.randn(32, config.max_seq_len-4, config.embed_dim)
output = module(x, config.combine_factor)
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 124, 128])
Integer 'combine_factor': Value=4

Inputs:
Tensor 'x' shape: torch.Size([32, 31, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 31, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 31, 128])


### it has to calculate recursively for multiple levels whenever reshape is involved

In [39]:
hold1 = config.combine_type
hold2 = config.combine_factor
hold3 = config.levels
config.combine_type = 'mlp->sum->mlp->norm'
config.combine_factor = 4
config.levels = 3
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.mlp_0.enable_logging()
module.operation_chain.sum_1.enable_logging()
module.operation_chain.mlp_2.enable_logging()
module.operation_chain.norm_3.enable_logging()

x = torch.randn(32, config.max_seq_len, config.embed_dim)
output = module(x, config.combine_factor**(config.levels-1))
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, hold3, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 128])
Integer 'combine_factor': Value=16
combine_factor: 16, combine_factor // self.combine_factor: 4

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 128])
Integer 'combine_factor': Value=4
combine_factor: 4, combine_factor // self.combine_factor: 1

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 4, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 8, 4, 128])

Outputs:

### recursive and off-length sequence

In [42]:
hold1 = config.combine_type
hold2 = config.combine_factor
hold3 = config.levels
config.combine_type = 'reshape->linear_post_reshape'
config.combine_factor = 2
config.levels = 3
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.reshape_0.enable_logging()

x = torch.randn(32, config.max_seq_len - 1, config.embed_dim)
output = module(x, config.combine_factor**(config.levels-1))
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, hold3, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 127, 128])
Integer 'combine_factor': Value=4
combine_factor: 4, combine_factor // self.combine_factor: 2

Inputs:
Tensor 'tensor' shape: torch.Size([32, 127, 128])
Integer 'combine_factor': Value=2
combine_factor: 2, combine_factor // self.combine_factor: 1

Inputs:
Tensor 'x' shape: torch.Size([32, 64, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])


In [43]:
hold1 = config.combine_type
hold2 = config.combine_factor
hold3 = config.levels
config.combine_type = 'reshape->linear_post_reshape'
config.combine_factor = 2
config.levels = 3
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.reshape_0.enable_logging()

x = torch.randn(32, config.max_seq_len - 3, config.embed_dim)
output = module(x, config.combine_factor**(config.levels-1))
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, hold3, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 125, 128])
Integer 'combine_factor': Value=4
combine_factor: 4, combine_factor // self.combine_factor: 2

Inputs:
Tensor 'tensor' shape: torch.Size([32, 125, 128])
Integer 'combine_factor': Value=2
combine_factor: 2, combine_factor // self.combine_factor: 1

Inputs:
Tensor 'x' shape: torch.Size([32, 63, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 63, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 63, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 128])


### This is how Body.concept_matchup() will use this class once it's up and working

basically it just wants a single vector and there definitely won't be any padding to be done so it can skip most of the logic & get right to the operation chain

In [16]:
hold1 = config.combine_type
hold2 = config.combine_factor
config.combine_type = 'mlp->sum->mlp->norm'
config.combine_factor = 4
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
#module.operation_chain.mlp_0.enable_logging()
#module.operation_chain.sum_1.enable_logging()
#module.operation_chain.mlp_2.enable_logging()
#module.operation_chain.norm_3.enable_logging()

x = torch.randn(32, config.combine_factor, config.embed_dim)
output = module(x, config.combine_factor)
config.combine_type = hold1
config.combine_factor = hold2
del padding_vector, module, hold1, hold2, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 4, 64])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 1, 64])


In [17]:
hold1 = config.combine_type
hold2 = config.combine_factor
hold3 = config.levels
config.combine_type = 'reshape->mlp_post_reshape'
config.combine_factor = 2
config.levels = 3
padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
module = CombineEmbeddings(config, padding_vector)

module.enable_logging()
module.operation_chain.reshape_0.enable_logging()
module.operation_chain.mlp_post_reshape_1.enable_logging()

x = torch.randn(32, config.combine_factor**(config.levels-1), config.embed_dim)
output = module(x, config.combine_factor**(config.levels-1))
config.combine_type = hold1
config.combine_factor = hold2
config.levels = hold3
del padding_vector, module, hold1, hold2, hold3, x, output
module = None


Inputs:
Tensor 'tensor' shape: torch.Size([32, 4, 64])
Integer 'combine_factor': Value=4

Inputs:
Tensor 'tensor' shape: torch.Size([32, 4, 64])
Integer 'combine_factor': Value=2

Inputs:
Tensor 'x' shape: torch.Size([32, 2, 2, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 1, 2, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 1, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 1, 64])


# Body

In [13]:
# to prevent the warning statement from printing hella times
cvec_warning = False

class Body(LoggingModule):
    def __init__(self, config: Config, embedding: torch.Tensor, embedding_combiner: LoggingModule):
        super().__init__()
        self.max_seq_len = config.max_seq_len
        self.combine_factor = config.combine_factor
        self.levels = config.levels
        self.embedding = embedding
        
        # Initialize a sequence of Layer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_layers))

        # initialize the final normalizations to stabilize residual outputs before output layer
        self.final_norms = nn.ModuleList(Norm(config) for _ in range(config.levels))

        # initialize the concept output layers as either full MLPs or simple linear layers. maybe later make an option to make them shared?
        if config.output_layer == 'mlp':
            self.concept_output_layers = nn.ModuleList(MLP(config.embed_dim,
                                                           config.mlp_multiplier,
                                                           config.embed_dim,
                                                           config.dropout_rate) for _ in range(config.levels-1))
        else: # defaults to a linear layer
            self.concept_output_layers = nn.ModuleList(nn.Linear(config.embed_dim, config.embed_dim) for _ in range(config.levels-1))

        # gets used during inference for the matchup to concept embedding vectors
        self.concept_creator = embedding_combiner
        
    @log_io
    def forward(self,
                x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                targets: Tuple[torch.Tensor] = None,
                cvec_topk: int = None,
                cvec_greedy: bool = False,
                cvec_temp: float = 1.0,
               ) -> Tuple[torch.Tensor]:
        return self.forward_training(x0s, targets) if targets is not None else self.forward_inference(x0s, cvec_topk, cvec_greedy, cvec_temp)
        
    @log_io
    def forward_training(self,
                         x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                         targets: Tuple[torch.Tensor],
                        ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level

            if i == 0:
                # if we're dealing with the highest level concepts, there's nothing to cross-attend to
                x = self.layers_loop(x, i, c = None, training = True)
            else:
                # the current level x will cross-attend to the higher level c
                x = self.layers_loop(x, i, c = targets[-i], training = True)
                
            # every level gets its own norm? sure why not. only does anything if config.norm_affine == True
            x = self.final_norms[i](x, training = True)

            if i == len(x0s)-1:
                # if we're dealing with the token residual state then we use the transposed embedding matrix as our output
                x = x @ self.embedding.t()
            else:
                # all the concept levels get their own output layer to help with their regression goal
                x = self.concept_output_layers[i](x) # not setting training=True because nn.Linear doesn't know what that is
            
            # add the final residual state of the level to our tuple
            xfs += (x,)
                
        return xfs # return all final residual states
        
    @log_io
    def forward_inference(self,
                          x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                          cvec_topk: int = None,
                          cvec_greedy: bool = False,
                          cvec_temp: float = 1.0,
                         ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level
            
            # if we're dealing with a concept level, then we need to fill in entire concept sequences for lower level to attend to. for token level we only want one pass
            if i != len(x0s)-1:

                # figuring out how many times we need to run this concept level in order for it to be usable for cross-attention to the layer below
                effective_max_seq_len = self.max_seq_len // (self.combine_factor ** (self.levels-1-i))
                assert x.shape[1] <= effective_max_seq_len, f'at level {i} a too-long sequence ({x.shape[1]} vs {effective_max_seq_len}) made it to Body'
                extra_runs = effective_max_seq_len - x.shape[1]

                # if extra_runs == 0 then this will only run once
                for k in range(extra_runs+1): 
                    
                    # run through layers. xfs[i-1] are the higher level concepts to pay attention to
                    c_regression = self.layers_loop(x, i, training=False) if i == 0 else self.layers_loop(x, i, xfs[i-1], training=False)
                    # splice out the final prediction
                    c_regression = c_regression[:,-1,:]
                    # normalize it
                    c_regression = self.final_norms[i](c_regression, training=False)
                    # our concept output layer
                    c_regression = self.concept_output_layers[i](c_regression)
                    # either select most similar concept vectors to be appended to the sequence or use the raw regression output
                    c_vec = self.concept_matchup(c_regression, cvec_topk, cvec_greedy, cvec_temp) if cvec_topk is not None else c_regression
                    # append to x
                    x = torch.concat([x, c_vec.unsqueeze(1)], dim=1)

                # this should only trim off one concept vector from the beginning of the sequence at most
                x = x[:,-effective_max_seq_len:,:]
                
            else: # if we're dealing with the token layer rather than a concept layer, we only want to run once
                # run the model
                x = self.layers_loop(x, i, xfs[i-1], training=False)
                
                # normalize
                x = self.final_norms[i](x, training=False)
                # the output layer is the transposed embedding matrix
                
                x = x @ self.embedding.t()
            
            # add the final residual state of the level to our tuple
            xfs += (x,)
            
        return xfs # we could return just xfs[-1] since it's inference but i have a feeling i'll want to analyze the concepts later

    @log_io
    def layers_loop(self, x: torch.Tensor,
                    i: int,
                    c: torch.Tensor = None,
                    training : bool = False,
                   ) -> torch.Tensor:
        
        # Iteratively process the input through each Layer of the model
        for layer in self.layers:
            
            # run through layers. at i==0 there's no higher level to attend to
            x = layer(x, training=training) if i == 0 else layer(x, c, training)
                
        return x
    
    @log_io 
    def concept_matchup(self,
                        c_regression: torch.Tensor,
                        cvec_topk: int,
                        cvec_greedy: bool,
                        cvec_temp: float,
                        ) -> torch.Tensor:
        """
        NOT CURRENTLY WORKING
        
        1. check similarity of inputted c_regression against all tokens in vocabulary
        2. select topk tokens where k = cvec_topk
        3. create every order of combinations of length combine_factor of the selected tokens
        4. run every combination through concept embedding function
        5. check similarity of c_regression against all created concept embeddings
        6. Repeat steps 2-4 until correct level is reached
        7. if cvec_greedy==True then select highest similarity vector; otherwise softmax, apply cvec_temp, and sample from distribution
        8. Return c_vec

        hold up, would this fall apart above the second level? 
        like would the concepts that get created at a given level be unlikely to create the correct next-higher concepts because there'd be a lot of repetition?
        also, wouldn't linear layer and MLP in the concept creation mess with the cosine similarity metric?
        maybe this function should only be used when the concept construction method does NOT include MLP or linear layer, like only the normed version can use it
        in that case then if we're restricted to using sum, max, and mean wouldn't these be the steps?:
        1. assert that the concept combination method is either mean, sum, or max (and optionally norm can be in there)?
        2. check similarity of inputted c_regression against all tokens in vocabulary
        3. select topk tokens where k = cvec_topk
        4. create every order of combinations of length combine_factor of the selected tokens
        5. run every combination through concept embedding function
        6. check similarity of c_regression against all created concept embeddings
        7. if cvec_greedy==True then select highest similarity vector; otherwise softmax, apply cvec_temp, and sample from distribution
        8. Return c_vec
        """
        global cvec_warning
        batch_size, d = c_regression.size()
        vocab_size = self.embedding.size(0)
    
        # Batch cosine similarity
        # Reshape c: (batch_size x 1 x embedding_dim)
        # Reshape embedding: (1 x vocab_size x embedding_dim)
        # Resulting similarity: (batch_size x vocab_size)
        token_similarities = F.cosine_similarity(c_regression.unsqueeze(1), self.embedding.unsqueeze(0), dim=-1)
        
        # how many tokens will we sample to build up our chosen concept vector?
        if cvec_topk is None:
            cvec_topk = self.combine_factor ** (self.levels-1)
            if (cvec_warning == False) or (cvec_warning is None):
                print(f"cvec_topk not defined. defaulting to highest level's minimum size: combo**(levels-1) = {cvec_topk}")
                cvec_warning = True
        assert cvec_topk >= self.combine_factor ** (self.levels-1), f'cvec_topk = {cvec_topk} needs to be >= self.combine_factor ** (self.levels-1) = {self.combine_factor ** (self.levels-1)}'
        
        # Select top-k token embeddings for each concept vector
        topk_token_indices = torch.topk(token_similarities, k=cvec_topk, dim=1).indices  # (batch_size x sample)

        ##################
        ### UNFINISHED ###
        ##################
        
        return c_regression # change to c_vec once finished
        

## demonstrations/debugging

### training

In [20]:
with torch.no_grad():
    # first let's do 2 levels training
    embedder = nn.Embedding(config.vocab_size, config.embed_dim)
    padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
    embedding_combiner = CombineEmbeddings(config, padding_vector)
    hold = config.levels
    config.levels = 2
    module = Body(config, embedder.weight, embedding_combiner)
    module.enable_logging()
    
    ### enabling logging for sub-modules
    #module.layers[0].enable_logging()
    #module.final_norms[0].enable_logging()
    
    ### disabling logging for sub-functions
    module.disable_function_logging('forward')
    #module.disable_function_logging('forward_training')
    #module.disable_function_logging('forward_inference')
    #module.disable_function_logging('layers_loop')
    #module.disable_function_logging('concept_matchup')
    #module.disable_function_logging('create_concept_embeddings')
    
    x = torch.randn(32, config.max_seq_len, config.embed_dim)
    c = torch.randn(32, config.max_seq_len // config.combine_factor, config.embed_dim)
    x0s = (x,c)
    targets = (x + torch.randn_like(x), c + torch.randn_like(c))
    output = module(x0s, targets)
    config.levels = hold
del embedder, embedding_combiner, hold, module, x, c, x0s, targets, output
embedder, embedding_combiner, module = None, None, None


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 128, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 32, 64])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 128, 64])
    Tensor 'targets[1]' shape: torch.Size([32, 32, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 64])
Integer 'i': Value=0
Other-type 'c': Type=NoneType, Value=None
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'i': Value=1
Tensor 'c' shape: torch.Size([32, 32, 64])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 32, 64])
Tensor 'output[1]' shape: torch.Size([32, 128, 128])


In [21]:
with torch.no_grad():
    # 3 levels training
    embedder = nn.Embedding(config.vocab_size, config.embed_dim)
    padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
    embedding_combiner = CombineEmbeddings(config, padding_vector)
    hold = config.levels
    config.levels = 3
    module = Body(config, embedder.weight, embedding_combiner)
    module.enable_logging()
    
    ### enabling logging for sub-modules
    #module.layers[0].enable_logging()
    #module.final_norms[0].enable_logging()
    
    ### disabling logging for sub-functions
    module.disable_function_logging('forward')
    #module.disable_function_logging('forward_training')
    #module.disable_function_logging('forward_inference')
    #module.disable_function_logging('layers_loop')
    #module.disable_function_logging('concept_matchup')
    #module.disable_function_logging('create_concept_embeddings')
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    x0s = (x0, c1, c2)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2))
    output = module(x0s, targets)
    config.levels = hold
del embedder, embedding_combiner, hold, module, x0, c1, c2, x0s, targets, output
embedder, embedding_combiner, module = None, None, None


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 128, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 32, 64])
    Tensor 'x0s[2]' shape: torch.Size([32, 8, 64])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 128, 64])
    Tensor 'targets[1]' shape: torch.Size([32, 32, 64])
    Tensor 'targets[2]' shape: torch.Size([32, 8, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 8, 64])
Integer 'i': Value=0
Other-type 'c': Type=NoneType, Value=None
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 8, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 64])
Integer 'i': Value=1
Tensor 'c' shape: torch.Size([32, 8, 64])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64])
Integer 'i': Value=2
Tensor 'c' shape: torch.Size([32, 32, 64])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 64])

Outputs:
Tensor 'output

In [22]:
with torch.no_grad():
    # 4 levels training
    embedder = nn.Embedding(config.vocab_size, config.embed_dim)
    padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
    embedding_combiner = CombineEmbeddings(config, padding_vector)
    hold = config.levels
    config.levels = 4
    module = Body(config, embedder.weight, embedding_combiner)
    module.enable_logging()
    
    ### enabling logging for sub-modules
    #module.layers[0].enable_logging()
    #module.final_norms[0].enable_logging()
    
    ### disabling logging for sub-functions
    module.disable_function_logging('forward')
    #module.disable_function_logging('forward_training')
    #module.disable_function_logging('forward_inference')
    #module.disable_function_logging('layers_loop')
    #module.disable_function_logging('concept_matchup')
    #module.disable_function_logging('create_concept_embeddings')
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    c3 = torch.randn(32, config.max_seq_len // (config.combine_factor**3), config.embed_dim)
    x0s = (x0, c1, c2, c3)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2), c3 + torch.randn_like(c3))
    output = module(x0s, targets)
    config.levels = hold
del embedder, embedding_combiner, hold, module, x0, c1, c2, c3, x0s, targets, output
embedder, embedding_combiner, module = None, None, None


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 128, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 32, 64])
    Tensor 'x0s[2]' shape: torch.Size([32, 8, 64])
    Tensor 'x0s[3]' shape: torch.Size([32, 2, 64])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 128, 64])
    Tensor 'targets[1]' shape: torch.Size([32, 32, 64])
    Tensor 'targets[2]' shape: torch.Size([32, 8, 64])
    Tensor 'targets[3]' shape: torch.Size([32, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 2, 64])
Integer 'i': Value=0
Other-type 'c': Type=NoneType, Value=None
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 8, 64])
Integer 'i': Value=1
Tensor 'c' shape: torch.Size([32, 2, 64])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 8, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 32, 64])
Integer 'i': Value=2
Tensor 'c' shape: torch.Size([32, 8, 64])
Integer 'tra

### inference w/out concept matchup

In [23]:
with torch.no_grad():
    # now 2 levels partial sequence length, so like we're doing inference
    # our partial sequence will always be a clean interval of config.combine_factor because Model.create_x0s() uses CombineEmbeddings() to make the intervals clean
    embedder = nn.Embedding(config.vocab_size, config.embed_dim)
    padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
    embedding_combiner = CombineEmbeddings(config, padding_vector)
    hold = config.levels
    config.levels = 2
    module = Body(config, embedder.weight, embedding_combiner)
    module.enable_logging()
    
    ### enabling logging for sub-modules
    #module.layers[0].enable_logging()
    #module.final_norms[0].enable_logging()
    
    ### disabling logging for sub-functions
    #module.disable_function_logging('forward')
    #module.disable_function_logging('forward_training')
    #module.disable_function_logging('forward_inference')
    #module.disable_function_logging('layers_loop')
    #module.disable_function_logging('concept_matchup')
    module.disable_function_logging('create_concept_embeddings') # rn this one is hella inefficient using a for loop over batch so not fun to print
    
    x = torch.randn(32, config.max_seq_len - (config.combine_factor**2), config.embed_dim)
    c = torch.randn(32, (config.max_seq_len // config.combine_factor) - config.combine_factor, config.embed_dim)
    x0s = (x,c)
    output = module(x0s, cvec_topk = None)
    config.levels = hold
del embedder, embedding_combiner, hold, module, x, c, x0s, output
embedder, embedding_combiner, module = None, None, None



Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 112, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 28, 64])
Other-type 'targets': Type=NoneType, Value=None

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 112, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 28, 64])
Other-type 'cvec_topk': Type=NoneType, Value=None
Integer 'cvec_greedy': Value=False
Float 'cvec_temp': Value=1.0

Inputs:
Tensor 'x' shape: torch.Size([32, 28, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 28, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 29, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 29, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 30, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 30, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 31, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor '

In [24]:
with torch.no_grad():
    # now 3 levels partial sequence length, so like we're doing inference
    # our partial sequence will always be a clean interval of config.combine_factor because Model.create_x0s() uses CombineEmbeddings() to make the intervals clean
    embedder = nn.Embedding(config.vocab_size, config.embed_dim)
    padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
    embedding_combiner = CombineEmbeddings(config, padding_vector)
    hold = config.levels
    config.levels = 3
    module = Body(config, embedder.weight, embedding_combiner)
    module.enable_logging()
    
    ### enabling logging for sub-modules
    #module.layers[0].enable_logging()
    #module.final_norms[0].enable_logging()
    
    ### disabling logging for sub-functions
    #module.disable_function_logging('forward')
    #module.disable_function_logging('forward_training')
    #module.disable_function_logging('forward_inference')
    #module.disable_function_logging('layers_loop')
    module.disable_function_logging('concept_matchup') # this one is pretty boring and our prints are already long enough
    module.disable_function_logging('create_concept_embeddings') # rn this one is hella inefficient using a for loop over batch so not fun to print
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor ** 0) - (config.combine_factor**3), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor ** 1) - (config.combine_factor**2), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor ** 2) - (config.combine_factor**1), config.embed_dim)
    x0s = (x0, c1, c2)
    output = module(x0s, cvec_topk = None)
    config.levels = hold
del embedder, embedding_combiner, hold, module, x0, c1, c2, x0s, output
embedder, embedding_combiner, module = None, None, None


Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 64, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 16, 64])
    Tensor 'x0s[2]' shape: torch.Size([32, 4, 64])
Other-type 'targets': Type=NoneType, Value=None

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 64, 64])
    Tensor 'x0s[1]' shape: torch.Size([32, 16, 64])
    Tensor 'x0s[2]' shape: torch.Size([32, 4, 64])
Other-type 'cvec_topk': Type=NoneType, Value=None
Integer 'cvec_greedy': Value=False
Float 'cvec_temp': Value=1.0

Inputs:
Tensor 'x' shape: torch.Size([32, 4, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 5, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 5, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 6, 64])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 6, 64])

Inputs:
Tensor 'x' sh

### infernce w/ matchup

# Concept Loss

In [14]:
class ConceptLoss(LoggingModule):
    def __init__(self, config: Config):
        super().__init__()
        self.combine_factor = config.combine_factor
        self.levels = config.levels
        self.level_loss_weight = config.level_loss_weight

        self.MAE_loss = nn.L1Loss() if config.mae_loss else None
        self.MSE_loss = nn.MSELoss() if config.mse_loss else None
        self.COS_loss = nn.CosineSimilarity(dim=-1, eps=1e-6) if config.cos_loss else None
    
    @log_io
    def forward(self, 
                xfs: Tuple[torch.Tensor], # xfs are ordered highest concept level -> token level
                targets: Tuple[torch.Tensor], # targets are ordered highest concept level -> token level
               ) -> torch.Tensor:
        # initialize loss value
        concept_loss = torch.tensor(0.0)
        
        # iterate through all concept-embedding layers and calculate loss
        for i in range(self.levels - 1):
            # select our relevant final residual state and target vectors
            lvl_output = xfs[i]
            lvl_targets = targets[i].detach().clone()
            
            # calculate the decay value placed on this level's total amount of loss
            lambadada = (self.level_loss_weight ** (self.levels -1 -i))
            
            # setup flattening for if we're doing MAE or MSE
            if (self.MAE_loss is not None) or (self.MSE_loss is not None):
                # Reshape output and target_vectors to combine batch and seq_len dimensions
                lvl_output_flat = lvl_output.view(-1, lvl_output.size(-1))
                lvl_targets_flat = lvl_targets.view(-1, lvl_targets.size(-1))

            # calculate loss values. notice multiple might occur or even none at all
            if self.MAE_loss is not None:
                concept_loss = concept_loss + self.MAE_loss(lvl_output_flat, lvl_targets_flat) * lambadada
            if self.MSE_loss is not None:
                concept_loss = concept_loss + self.MSE_loss(lvl_output_flat, lvl_targets_flat) * lambadada
            if self.COS_loss is not None:
                cosine_loss = (1 - self.COS_loss(lvl_output, lvl_targets)).mean()
                concept_loss = concept_loss + cosine_loss * lambadada

        return concept_loss

## Demonstration/Debugging

In [46]:
with torch.no_grad():
    hold1 = config.levels
    hold2 = config.cos_loss
    hold3 = config.mse_loss
    hold4 = config.mae_loss
    config.levels = 3
    config.cos_loss = True
    config.mse_loss = False
    config.mae_loss = False
    module = ConceptLoss(config)
    module.enable_logging()
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    xfs = (x0, c1, c2)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2))
    output = module(xfs, targets)
    print(output)
    config.levels = hold1
    config.cos_loss = hold2
    config.mse_loss = hold3
    config.mae_loss = hold4
del hold1, hold2, hold3, hold4, module, x0, c1, c2, xfs, targets, output
module = None


Inputs:
Tuple 'xfs':
    Tensor 'xfs[0]' shape: torch.Size([32, 256, 128])
    Tensor 'xfs[1]' shape: torch.Size([32, 64, 128])
    Tensor 'xfs[2]' shape: torch.Size([32, 16, 128])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 256, 128])
    Tensor 'targets[1]' shape: torch.Size([32, 64, 128])
    Tensor 'targets[2]' shape: torch.Size([32, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([])
tensor(0.5887)


In [47]:
with torch.no_grad():
    hold1 = config.levels
    hold2 = config.cos_loss
    hold3 = config.mse_loss
    hold4 = config.mae_loss
    config.levels = 3
    config.cos_loss = False
    config.mse_loss = True
    config.mae_loss = False
    module = ConceptLoss(config)
    module.enable_logging()
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    xfs = (x0, c1, c2)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2))
    output = module(xfs, targets)
    print(output)
    config.levels = hold1
    config.cos_loss = hold2
    config.mse_loss = hold3
    config.mae_loss = hold4
del hold1, hold2, hold3, hold4, module, x0, c1, c2, xfs, targets, output
module = None


Inputs:
Tuple 'xfs':
    Tensor 'xfs[0]' shape: torch.Size([32, 256, 128])
    Tensor 'xfs[1]' shape: torch.Size([32, 64, 128])
    Tensor 'xfs[2]' shape: torch.Size([32, 16, 128])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 256, 128])
    Tensor 'targets[1]' shape: torch.Size([32, 64, 128])
    Tensor 'targets[2]' shape: torch.Size([32, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([])
tensor(2.0027)


In [48]:
with torch.no_grad():
    hold1 = config.levels
    hold2 = config.cos_loss
    hold3 = config.mse_loss
    hold4 = config.mae_loss
    config.levels = 3
    config.cos_loss = False
    config.mse_loss = False
    config.mae_loss = True
    module = ConceptLoss(config)
    module.enable_logging()
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    xfs = (x0, c1, c2)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2))
    output = module(xfs, targets)
    print(output)
    config.levels = hold1
    config.cos_loss = hold2
    config.mse_loss = hold3
    config.mae_loss = hold4
del hold1, hold2, hold3, hold4, module, x0, c1, c2, xfs, targets, output
module = None


Inputs:
Tuple 'xfs':
    Tensor 'xfs[0]' shape: torch.Size([32, 256, 128])
    Tensor 'xfs[1]' shape: torch.Size([32, 64, 128])
    Tensor 'xfs[2]' shape: torch.Size([32, 16, 128])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 256, 128])
    Tensor 'targets[1]' shape: torch.Size([32, 64, 128])
    Tensor 'targets[2]' shape: torch.Size([32, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([])
tensor(1.5958)


In [49]:
with torch.no_grad():
    hold1 = config.levels
    hold2 = config.cos_loss
    hold3 = config.mse_loss
    hold4 = config.mae_loss
    config.levels = 3
    config.cos_loss = True
    config.mse_loss = True
    config.mae_loss = True
    module = ConceptLoss(config)
    module.enable_logging()
    
    x0 = torch.randn(32, config.max_seq_len // (config.combine_factor**0), config.embed_dim)
    c1 = torch.randn(32, config.max_seq_len // (config.combine_factor**1), config.embed_dim)
    c2 = torch.randn(32, config.max_seq_len // (config.combine_factor**2), config.embed_dim)
    xfs = (x0, c1, c2)
    targets = (x0 + torch.randn_like(x0), c1 + torch.randn_like(c1), c2 + torch.randn_like(c2))
    output = module(xfs, targets)
    print(output)
    config.levels = hold1
    config.cos_loss = hold2
    config.mse_loss = hold3
    config.mae_loss = hold4
del hold1, hold2, hold3, hold4, module, x0, c1, c2, xfs, targets, output
module = None


Inputs:
Tuple 'xfs':
    Tensor 'xfs[0]' shape: torch.Size([32, 256, 128])
    Tensor 'xfs[1]' shape: torch.Size([32, 64, 128])
    Tensor 'xfs[2]' shape: torch.Size([32, 16, 128])
Tuple 'targets':
    Tensor 'targets[0]' shape: torch.Size([32, 256, 128])
    Tensor 'targets[1]' shape: torch.Size([32, 64, 128])
    Tensor 'targets[2]' shape: torch.Size([32, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([])
tensor(4.1888)


# Model

In [15]:
cvec_warning = False

class Model(LoggingModule):
    def __init__(self, config: Config, tokenizer: tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        
        ### hyperparameters
        self.max_seq_len = config.max_seq_len
        self.sa_head_dim = config.sa_head_dim
        self.vocab_size = config.vocab_size
        self.embed_dim = config.embed_dim
        self.combine_factor = config.combine_factor
        self.levels = config.levels
        
        ### embedding
        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(config.vocab_size, config.embed_dim)
        # the padding vector to get used when sequence length isn't perfect
        self.padding_vector = nn.Parameter(torch.zeros(config.embed_dim), requires_grad=True)
        
        # the function that combines embeddings into higher level concept residual states
        self.embedding_combiner = CombineEmbeddings(config, self.padding_vector)

        ### the actual bulk of the model
        self.body = Body(config, self.embedder.weight, self.embedding_combiner)
        
        ### the loss functions
        # lowest-level token model
        self.ce_loss_fn = nn.CrossEntropyLoss()
        # concept models
        self.concept_loss_fn = ConceptLoss(config)
        
    @log_io
    def forward(self,
                input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids to run forward pass on
                target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len + combo ** (levels-1)) list of token ids to train on
                cvec_topk: int = None,
                cvec_greedy: bool = False,
                cvec_temp: float = 1.0,
               ) -> torch.Tensor:

        # create the tuple of initial residual states to calculate on
        x0s = self.create_x0s(input_token_ids) # x0s are ordered token level -> highest concept level
        
        if target_token_ids is None: ### if we're doing inference
            # the body of the model that iterates through the decoder & cross-attention layers
            xfs = self.body(x0s, cvec_topk=cvec_topk, cvec_greedy=cvec_greedy, cvec_temp=cvec_temp) 

            # the actual token output logits we care about
            logits = xfs[-1]
            
            # if we're not training, then we don't need to calculate loss
            loss = None
        else: ### if we're training
            assert input_token_ids.shape[1] == target_token_ids.shape[1] - (self.combine_factor ** (self.levels - 1)), f'inputs:{input_token_ids.shape[1]} and targets:{target_token_ids.shape[1]} have unexpected shapes'

            # create the tuple of target embedding vectors
            targets = self.create_targets(target_token_ids, input_token_ids.shape[1]) # targets are ordered token level -> highest concept level

            # the body of the model that iterates through the decoder & cross-attention layers
            xfs = self.body(x0s, targets) # xfs are ordered highest concept level -> token level

            ### first up is regular CE token loss
            logits = xfs[-1]
            batch_size, input_len, vocab_size = logits.shape
            # splice target tokens to exclude the ones that were only to be used by concept levels
            target_token_ids_spliced = target_token_ids[:,:input_len]
            # we reshape our logits & targets before calculating cross-entropy loss
            ce_loss = self.ce_loss_fn(logits.view(batch_size*input_len, vocab_size),
                                      target_token_ids_spliced.reshape(batch_size*input_len))

            ### the new thing, a regression loss for all our concept-embedding layers
            concept_loss = self.concept_loss_fn(xfs, tuple(reversed(targets)))
            # adding it all together
            loss = ce_loss + concept_loss
        
        return logits, loss
        
    @log_io
    def create_x0s(self, input_token_ids: torch.Tensor) -> Tuple[torch.Tensor]:
        #print(f'input_token_ids: {input_token_ids.shape}\n{input_token_ids[0,:32]}')
        
        # turn the input tokens into the first residual state using the embedding matrix
        x0 = self.embedder(input_token_ids) # (batch_size, input_len, embed_dim)
        #print(f'x0: {x0.shape}\n{x0[0,0:(self.combine_factor**(config.levels-1))*2,:4]}')

        # finding the number of padding vectos we have to use at the token level to ensure the cross-attention predictive mask will line up
        remainder = x0.shape[1] % self.combine_factor
        padding_needed = 0 if remainder == 0 else self.combine_factor - remainder

        # do the actual padding for the token level
        # once i get a more complicatead tokenizer would i replace this with a <|bos|> token? Would that token be unique to each level?
        if padding_needed > 0:
            # Replicate the padding vector the necessary number of times
            padding = self.padding_vector.repeat(padding_needed, 1).unsqueeze(0).expand(x0.shape[0], -1, -1)
            #print(f'padding: {padding.shape}')
            
            x0 = torch.cat([padding, x0], dim=1)
            #print(f'x0 after padding: {x0.shape}\n{x0[0,0:(self.combine_factor**(config.levels-1))*2,:4]}')
        
        # instantiate the tuple that'll hold all the residual states
        x0s = (x0 * (self.embed_dim ** 0.5),) 
        
        ### iterating through levels to create each higher-level concept residual state
        for i in range(self.levels-1):
            # combine into smaller tensor by adding token (or lower level concept) embeddings together
            lvl_combo = self.combine_factor ** (i+1)
            x0c = self.embedding_combiner(x0, lvl_combo) # c stands for concept
            #print(f'x0c: {x0c.shape}\n{x0c[0,0:self.combine_factor,:4]}')
            
            # finally scale & add it to the tuple of residual states
            x0s += (x0c * (self.embed_dim ** 0.5),)
        
        return x0s

    @log_io
    def create_targets(self, target_token_ids: torch.Tensor, input_len: int) -> Tuple[torch.Tensor]:
        #print(f'target_token_ids: {target_token_ids.shape}\n{target_token_ids[0,:32]}')
        
        # turn the target tokens into the first residual state using the embedding matrix
        token_lvl_target_token_ids = target_token_ids[:,1:1+input_len]
        t0 = self.embedder(token_lvl_target_token_ids) # (batch_size, input_len, embed_dim)
        #print(f'token_lvl_target_token_ids: {token_lvl_target_token_ids.shape}\n{token_lvl_target_token_ids[0,:32]}')
        #print(f't0: {t0.shape}\n{t0[0,0:(self.combine_factor**(config.levels-1))*2,:4]}')
        
        # need to account for offsets in sequence length, which includes both an offset correction & padding vectors
        remainder = t0.shape[1] % self.combine_factor
        padding_needed = 0 if remainder == 0 else self.combine_factor - remainder
        if padding_needed == 1:
            t0 = torch.cat([self.embedder(target_token_ids[:,0].unsqueeze(1)), t0], dim=1)
            #print(f't0 after padding: {t0.shape}\n{t0[0,0:(self.combine_factor**(config.levels-1))*2,:4]}')
        elif padding_needed > 1:
            padding = self.padding_vector.repeat(padding_needed - 1, 1).unsqueeze(0).expand(t0.shape[0], -1, -1)
            #print(f'padding: {padding.shape}')
            t0 = torch.cat([padding, self.embedder(target_token_ids[:,0].unsqueeze(1)), t0], dim=1)
            #print(f't0 after padding: {t0.shape}\n{t0[0,0:(self.combine_factor**(config.levels-1))*2,:4]}')
        
        # instantiate the tuple that'll hold all the residual states
        targets = (t0,) 
        
        ### iterating through levels to create each higher-level concepts
        for i in range(1, self.levels):
            # calculate the correct combo factor for this level
            lvl_combo = self.combine_factor ** i

            # my subsetting here is all messy. doesn't properly take into account off-sequences & the padding token
            # i think maybe i can fix this in the predictive mask once i make that part

            # how many tokens off are we from a perfectly sized (multiple of lvl_combo) sequence, meaning how many padding vectors do we need?
            remainder = input_len % self.combine_factor # will only ever be self.combine_factor -1 at most
            offset = 0 if remainder == 0 else self.combine_factor - remainder
            
            # adjust input_len to ceiling the size necessary for this level
            #input_len_adj = input_len + lvl_combo

            # subset the currect targets to be predicted at this level
            concept_lvl_target_token_ids = target_token_ids[:, lvl_combo - offset:lvl_combo + input_len]# - offset]
            #print(f'concept_lvl_target_token_ids: {concept_lvl_target_token_ids.shape}\n{concept_lvl_target_token_ids[0,:32]}')

            # turn them into embeddings
            t0c = self.embedder(concept_lvl_target_token_ids)

            # combine the token embeddings into concepts
            t0c = self.embedding_combiner(t0c, lvl_combo)
            #print(f't0c: {t0c.shape}\n{t0c[0,0:self.combine_factor,:4]}')
            
            # append to tuple
            targets += (t0c,)
        
        return targets
        
    @log_io
    def generate(self,
                 prompt: str,
                 output_len: int = 1, # the model will output 1 token by default
                 temperature: float = 1.0, # 1.0 would be no effect
                 top_p: float = 1.0,
                 top_k: int = config.vocab_size,
                ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        if len(tokens) + output_len > self.config.max_seq_len:
            output_len = self.max_seq_len - len(tokens)
            print("capping output at maximum sequence length")

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(logits, temperature, top_p, top_k)

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # resets this variable so that the corresponding warning in Body.concept_matchup can come up next time we perform inference
        global cvec_warning 
        cvec_warning = False

        # decode our list of tokens to an actual string
        return self.tokenizer.decode(tokens.squeeze(0).tolist())

    @torch.no_grad() # no need to keep track of gradients during inference
    @log_io
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        
        # Select the last element for each sequence & apply temperature scaling
        logits = logits[:,-1,:].div_(temperature) # -> (batch_size, vocab_size)
        
        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along
        #print('first probs: ', probs)
        
        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        #print('probs after topp & k: ', probs)
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token

## demonstration/debugging

### training w/ regular length sequences

In [40]:
hold = config.levels
config.levels = 2
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
module.embedding_combiner.enable_logging()
#module.body.enable_logging()
module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + config.combine_factor))
input_token_ids = token_ids[:,:config.max_seq_len]
target_token_ids = token_ids
output, loss = module(input_token_ids, target_token_ids)
config.levels = hold
del hold, module, token_ids, input_token_ids, target_token_ids, output, loss
module = None


Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])
Tensor 'target_token_ids' shape: torch.Size([32, 132])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 32])
Tensor 'output[1]' shape: torch.Size([32, 32, 32])

Inputs:
Tensor 'target_token_ids' shape: torch.Size([32, 132])
Integer 'input_len': Value=128
target_token_ids: torch.Size([32, 132])
tensor([399, 323,  78, 360, 263,  85, 353,  57,  82, 347, 143,  58, 158, 195,
        357, 271, 478, 187, 283, 458, 342,  11, 390, 427, 160, 114, 317, 251,
        459, 355, 210, 391])
token_lvl_target_token_ids: torch.Size([32, 128])
tensor([323,  78, 360, 263,  85, 353,  57,  82, 347, 143,  58, 158, 195, 357,
        271, 478, 187, 283, 458, 342,  11, 390, 427, 160, 114, 317, 251, 459,
        355, 210, 391, 397])
t

In [41]:
hold = config.levels
config.levels = 3
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
module.embedding_combiner.enable_logging()
#module.body.enable_logging()
module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + (config.combine_factor ** (config.levels-1))))
input_token_ids = token_ids[:,:config.max_seq_len]
target_token_ids = token_ids
output, loss = module(input_token_ids, target_token_ids)
config.levels = hold
del hold, module, token_ids, input_token_ids, target_token_ids, output, loss
module = None


Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])
Tensor 'target_token_ids' shape: torch.Size([32, 144])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=16

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 8, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 32])
Tensor 'output[1]' shape: torch.Size([32, 32, 32])
Tensor 'output[2]' shape: torch.Size([32, 8, 32])

Inputs:
Tensor 'target_token_ids' shape: torch.Size([32, 144])
Integer 'input_len': Value=128
target_token_ids: torch.Size([32, 144])
tensor([ 21, 476, 421, 280, 382, 385, 111, 199, 434, 413, 330,  94, 455

In [42]:
hold = config.levels
config.levels = 4
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
#module.embedding_combiner.enable_logging()
#module.body.enable_logging()
#module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + (config.combine_factor ** (config.levels-1))))
input_token_ids = token_ids[:,:config.max_seq_len]
target_token_ids = token_ids
output, loss = module(input_token_ids, target_token_ids)
config.levels = hold
del hold, module, token_ids, input_token_ids, target_token_ids, output, loss
module = None


Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])
Tensor 'target_token_ids' shape: torch.Size([32, 192])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 128])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 32])
Tensor 'output[1]' shape: torch.Size([32, 32, 32])
Tensor 'output[2]' shape: torch.Size([32, 8, 32])
Tensor 'output[3]' shape: torch.Size([32, 2, 32])

Inputs:
Tensor 'target_token_ids' shape: torch.Size([32, 192])
Integer 'input_len': Value=128
target_token_ids: torch.Size([32, 192])
tensor([366, 358, 297, 222, 158,  23, 247, 358, 313, 279, 504, 393, 387, 370,
        425, 215, 107, 163,  14,  78, 350, 155, 260, 115, 128, 166, 125, 385,
        235, 311,  55, 138])
token_lvl_target_token_ids: torch.Size([32, 128])
tensor([358, 297, 222, 158,  23, 247, 358, 313, 279, 504, 393, 387, 370, 425,
        215, 107, 163,  14,  78, 350, 155, 260, 115, 128, 166, 125, 385, 235,
        311,  55, 138, 156])
t0: torch.Size([32, 128, 32])
tensor([[ 1.8047e+00,

### training w/ offset length sequences

In [43]:
hold = config.levels
config.levels = 2
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
module.embedding_combiner.enable_logging()
module.body.enable_logging()
module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

offset = random.randint(1, config.combine_factor-1)
print(f'randomly chosen offset: {offset}')
token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + config.combine_factor - offset))
input_token_ids = token_ids[:,:config.max_seq_len - offset]
target_token_ids = token_ids
output, loss = module(input_token_ids, target_token_ids)
config.levels = hold
del hold, module, offset, token_ids, input_token_ids, target_token_ids, output, loss
module = None

randomly chosen offset: 3

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 125])
Tensor 'target_token_ids' shape: torch.Size([32, 129])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 125])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 32])
Tensor 'output[1]' shape: torch.Size([32, 32, 32])

Inputs:
Tensor 'target_token_ids' shape: torch.Size([32, 129])
Integer 'input_len': Value=125
target_token_ids: torch.Size([32, 129])
tensor([175, 356,  55, 126, 457, 470, 335, 151,  84, 287, 473, 126, 346, 320,
         59, 476, 340,  35, 410, 248, 244, 170, 107, 458, 130, 240, 423, 457,
        326,   1, 379, 494])
token_lvl_target_token_ids: torch.Size([32, 125])
tensor([356,  55, 126, 457, 470, 335, 151,  84, 287, 473, 126, 346, 320,  59,
        476, 340,  35, 410, 248, 244, 170, 107, 458, 130, 240, 423, 457, 326,
    

In [44]:
hold = config.levels
config.levels = 3
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
module.embedding_combiner.enable_logging()
#module.body.enable_logging()
module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

offset = random.randint(1, config.combine_factor-1)
print(f'randomly chosen offset: {offset}')
token_ids = torch.randint(config.vocab_size, (32, config.max_seq_len + (config.combine_factor ** (config.levels-1)) - offset))
input_token_ids = token_ids[:,:config.max_seq_len - offset]
target_token_ids = token_ids
output, loss = module(input_token_ids, target_token_ids)
config.levels = hold
del hold, module, offset, token_ids, input_token_ids, target_token_ids, output, loss
module = None

randomly chosen offset: 3

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 125])
Tensor 'target_token_ids' shape: torch.Size([32, 141])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 125])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=16

Inputs:
Tensor 'tensor' shape: torch.Size([32, 128, 32])
Integer 'combine_factor': Value=4

Outputs:
Tensor 'output' shape: torch.Size([32, 32, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 8, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 32])
Tensor 'output[1]' shape: torch.Size([32, 32, 32])
Tensor 'output[2]' shape: torch.Size([32, 8, 32])

Inputs:
Tensor 'target_token_ids' shape: torch.Size([32, 141])
Integer 'input_len': Value=125
target_token_ids: torch.Size([32, 141])
tensor([498, 279,  63, 196, 416,  72, 256, 35

### inference

In [45]:
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
#module.embedding_combiner.enable_logging()
module.body.enable_logging()
#module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

input_token_ids = torch.randint(config.vocab_size, 
                                (32, config.max_seq_len-11))
output, loss = module(input_token_ids)
del module, input_token_ids, output, loss


Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 117])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 117])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 120, 32])
Tensor 'output[1]' shape: torch.Size([32, 30, 32])

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 120, 32])
    Tensor 'x0s[1]' shape: torch.Size([32, 30, 32])
Other-type 'targets': Type=NoneType, Value=None
Integer 'cvec_topk': Value=False
Float 'cvec_greedy': Value=1.0

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 120, 32])
    Tensor 'x0s[1]' shape: torch.Size([32, 30, 32])
Other-type 'cvec_topk': Type=NoneType, Value=None
Integer 'cvec_greedy': Value=False
Float 'cvec_temp': Value=1.0

Inputs:
Tensor 'x' shape: torch.Size([32, 30, 32])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 30, 32])

Inputs:
Tensor 'x' shape: torch.Size([32, 31, 32])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' s

In [46]:
module = Model(config, tokenizer)
module.enable_logging()

### enabling logging for sub-modules
#module.embedding_combiner.enable_logging()
module.body.enable_logging()
#module.concept_loss_fn.enable_logging()

### disabling logging for sub-functions
#module.disable_function_logging('create_x0s')
#module.disable_function_logging('create_targets')
#module.disable_function_logging('generate')
#module.disable_function_logging('sampler')

input_token_ids = torch.randint(config.vocab_size, 
                                (32, config.max_seq_len - 15))
output, loss = module(input_token_ids)
del module, input_token_ids, output, loss


Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 113])

Inputs:
Tensor 'input_token_ids' shape: torch.Size([32, 113])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 116, 32])
Tensor 'output[1]' shape: torch.Size([32, 29, 32])

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 116, 32])
    Tensor 'x0s[1]' shape: torch.Size([32, 29, 32])
Other-type 'targets': Type=NoneType, Value=None
Integer 'cvec_topk': Value=False
Float 'cvec_greedy': Value=1.0

Inputs:
Tuple 'x0s':
    Tensor 'x0s[0]' shape: torch.Size([32, 116, 32])
    Tensor 'x0s[1]' shape: torch.Size([32, 29, 32])
Other-type 'cvec_topk': Type=NoneType, Value=None
Integer 'cvec_greedy': Value=False
Float 'cvec_temp': Value=1.0

Inputs:
Tensor 'x' shape: torch.Size([32, 29, 32])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 29, 32])

Inputs:
Tensor 'x' shape: torch.Size([32, 30, 32])
Integer 'i': Value=0
Integer 'c': Value=False

Outputs:
Tensor 'output' s

# Instantiate a brand new model

In [16]:
model = Model(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

2639.36 K parameters
Model(
  (embedder): Embedding(512, 128)
  (embedding_combiner): CombineEmbeddings(
    (operation_chain): Sequential(
      (reshape_0): ReshapeModule()
      (mlp_post_reshape_1): MLP(
        (gate_proj): Linear(in_features=512, out_features=1024, bias=True)
        (up_proj): Linear(in_features=512, out_features=1024, bias=True)
        (down_proj): Linear(in_features=1024, out_features=128, bias=True)
      )
      (norm_2): Norm()
    )
  )
  (body): Body(
    (layers): ModuleList(
      (0-7): 8 x Layer(
        (pre_self_mqa_norm): Norm()
        (self_mqa): selfMQA()
        (pre_cross_mqa_x_norm): Norm()
        (pre_cross_mqa_c_norm): Norm()
        (cross_mqa): crossMQA()
        (pre_mlp_norm): Norm()
        (mlp): MLP(
          (gate_proj): Linear(in_features=128, out_features=256, bias=True)
          (up_proj): Linear(in_features=128, out_features=256, bias=True)
          (down_proj): Linear(in_features=256, out_features=128, bias=True)
        )

# Training

In [17]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [18]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - (config.max_seq_len + (config.combine_factor ** (config.levels-1))), (batch_size,))
    # some training batches need to be offset so it learns how to use the padding vector
    offset = random.randint(0, config.combine_factor-1)
    x = torch.stack([data[i:i+config.max_seq_len - offset] for i in ix])
    ### i actually need the y tensor to be + (config.combine_factor ** (config.levels-1)) to fit the future concepts
    y = torch.stack([data[i+1:i+1+(config.max_seq_len + (config.combine_factor ** (config.levels-1))) - offset] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [19]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [20]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate_init, weight_decay=config.weight_decay)

# how long we want to train for
config.max_iters = 100

# Learning rate scheduler setup
lr_start = config.learning_rate_init
lr_end = config.learning_rate_end  # Final learning rate after decay
lr_lambda = lambda iter: (lr_end + (lr_start - lr_end) * (1 - iter / config.max_iters))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# how often we want to check & see how our loss is doing
eval_interval = 10

# batch size to use
config.batch_size = 32

# ---- BOOKMARK -----
so it looks like the wrong number of target concept vectors are getting created during training. once i've fixed that i'd like to confirm training & inference are working and then switch to the new dataset

In [21]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(config.max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', config.batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    # Update the learning rate
    scheduler.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == config.max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, config.batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 112.2529, val loss 112.0210, time elapsed: 0.97 seconds
step 10: train loss 16.2049, val loss 16.3779, time elapsed: 14.41 seconds
step 20: train loss 9.7793, val loss 9.7198, time elapsed: 27.87 seconds
step 30: train loss 7.8275, val loss 7.8874, time elapsed: 41.20 seconds
step 40: train loss 6.8737, val loss 6.8584, time elapsed: 54.32 seconds
step 50: train loss 6.1504, val loss 6.2550, time elapsed: 67.58 seconds
step 60: train loss 5.7682, val loss 5.7188, time elapsed: 80.71 seconds
step 70: train loss 5.3119, val loss 5.4951, time elapsed: 93.80 seconds
step 80: train loss 5.2514, val loss 5.2149, time elapsed: 106.94 seconds
step 90: train loss 5.1099, val loss 5.1737, time elapsed: 120.10 seconds
step 99: train loss 5.0733, val loss 5.0451, time elapsed: 132.55 seconds


# Saving your model

In [22]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
config_dict = asdict(config)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(config_dict, f)

# Load a Pretrained Model

In [None]:
name = '?'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    config_dict = json.load(f)

# Convert the dictionary back to a dataclass object
config = Config(**config_dict)

# Initialize a blank model
model = Model(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

# Inference

In [22]:
model.eval()
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?\n" # the classic line
max_useable_output_len = config.max_seq_len - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len, temperature=1.0)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?
ds,  ye   es nERoron bs     n  br ningNs lth A weing
AObird ,
 saidceheer  ararmydbca oncee t    r-look t  
