# v11.3

primary todo's:
- [ ] copy & paste stuff that doesn't need to be changed
- [ ] copy, paste, & remove references to matryoshka embeddings for modules w/ em
- [ ] bring over todo items that haven't been completed from v11.2 and from notes app
- [ ] setup single attention mechanism module for both self & cross? i feel like separate is prolly still easier
- [ ] setup single decoder layer body
- [ ] mess with model.forward() to work with new setup
- [ ] mess with model.inference() to work with new setup

other todo's:
- [ ] setup concept loss to use MSE and COS simultaneously
- [ ] get predictive attention mask working in crossMQA

# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the debugging/demonstration setup
import functools
import inspect

# imports for the tokenizer
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

# Imports used for the config
from dataclasses import dataclass
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used for training
import random
import time

# used to save & load models
import json
from dataclasses import asdict

# Demonstration/Debugging wrapper

In [3]:
# this function will be used throughout for debugging/demonstration purposes
# using this is way cleaner than cluttering up our code with print statements
def log_io(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        # Check if logging is enabled globally and for the specific function
        if not self.logging_enabled or func.__name__ in self.disabled_logging_functions:
            return func(self, *args, **kwargs)
        #if not self.logging_enabled:
            #return func(self, *args, **kwargs)

        def log_item(item, name, level=0, is_root=False):
            indent = "    " * level
            if isinstance(item, torch.Tensor):
                print(f"{indent}Tensor '{name}' shape: {item.shape}")
            elif isinstance(item, tuple):
                if is_root and level == 0:
                    # Root level tuple, don't print it as a tuple unless it's a "true" tuple
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level)
                else:
                    print(f"{indent}Tuple '{name}':")
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level + 1)
            elif isinstance(item, int):
                print(f"{indent}Integer '{name}': Value={item}")
            elif isinstance(item, float):
                print(f"{indent}Float '{name}': Value={item}")
            else:
                print(f"{indent}Other-type '{name}': Type={type(item).__name__}, Value={item}")

        print(f"\n{'='*10}Entering {self.__class__.__name__}.{func.__name__}{'='*10}")
        print("Inputs:")
        arg_names = inspect.getfullargspec(func).args[1:]  # Excluding 'self'
        arg_values = args + tuple(kwargs.values())
        for name, value in zip(arg_names, arg_values):
            log_item(value, name)

        result = func(self, *args, **kwargs)
        print("\nOutputs:")
        if isinstance(result, tuple):
            log_item(result, "output", is_root=True)
        else:
            log_item(result, "output")

        print(f"{'='*10}Exiting {self.__class__.__name__}.{func.__name__}{'='*10}")
        return result
    return wrapper

# Config & tokenizer

In [4]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# and the tokenizer
tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)

In [90]:
@dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for my next-concept predictor
    """
    ### boring hyperparameters
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 256
    num_layers: int = 6
    sa_q_heads: int = 2
    sa_kv_heads: int = 1
    attn_bias: bool = False
    embed_dim: int = 256
    mlp_multiplier: int = 4
    sa_head_dim: int = 64
    theta: float = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate: float = 0.05
    eps = 1e-6
    norm_affine: bool = True # whether norms should have a linear & bias after them
    norm_type = "RMSNorm"  # Options are RMSNorm, CosineNorm and LayerNorm

    ### Concept embedding vectors
    levels: int = 3
    combo: int = 4 # how many lower-level tokens/concepts to combine into the next level's concept
    concept_loss: str = "cos" # options are 'mae', 'mse', and 'cos'(default)
    @property
    def seq_len_list(self):
        return [(self.max_seq_len // (self.combo ** (i-1))) for i in range(1, self.levels + 1)]
    level_loss_weight: float = 1.0 # how much to discount each higher level in the loss function compared to the last

    ### Dualcoder cross-attention
    # how many times to do the cross-attention
    ca_q_heads: int = sa_q_heads
    ca_kv_heads: int = sa_kv_heads
    ca_head_dim: int = sa_head_dim
    ca_use_RoPE: bool = False # True: expands out k & v tensors to be usable with rope. False: leaves k & v same size but no positional encodings
    predictive_mask: bool = True # True: upper-triangular predictive mask to focus model's attention. False: no mask like a regular encoder

    ### assertions
    assert sa_q_heads % sa_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in self-attention'
    assert ca_q_heads % ca_kv_heads == 0, 'the number of query heads must be divisible by the number of key-value heads in cross-attention'     
        
config = Config()
print(config)
print(f"sequence length of each model: {config.seq_len_list}")
print(f"loss discounts starting from lowest level: {[config.level_loss_weight**i for i in range(config.levels)]}")

Config(vocab_size=128, max_seq_len=256, num_layers=6, sa_q_heads=2, sa_kv_heads=1, attn_bias=False, embed_dim=256, mlp_multiplier=4, sa_head_dim=64, theta=100.0, dropout_rate=0.05, norm_affine=True, levels=3, combo=4, concept_loss='cos', level_loss_weight=1.0, ca_q_heads=2, ca_kv_heads=1, ca_head_dim=64, ca_use_RoPE=False, predictive_mask=True)
sequence length of each model: [256, 64, 16]
loss discounts starting from lowest level: [1.0, 1.0, 1.0]


# RoPE

i like the idea of pre-computing RoPE embeddings but at this point i don't think it's worth the effort bc i'd have to not only use this code but also pipe the two instantiations of this class through from `Body` all the way to `selfMQA` and `crossMQA` and I'm not even sure if it matters. I really should learn more about RoPE

```Python
# this class has not been implemented nor even tested. on my todo list
class RoPE(nn.Module):
    def __init__(self, 
                 dim: int, 
                 max_seq_len:int = config.max_seq_len, 
                 device: str = config.device):
        super().__init__()
        # Validate that dim is even since we split it by 2 for real and imaginary parts
        if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
        # Precompute frequencies based on configuration
        theta = config.theta if hasattr(config, 'theta') else 10000.0
        
        freqs = 1.0 / (theta ** (torch.arange(0, config.dim, 2, device=config.device).float() / config.dim))
        t = torch.arange(config.max_seq_len, device=config.device)
        freqs = torch.outer(t, freqs).to(config.device).float()
        
        # Register as buffer to prevent gradient tracking
        self.register_buffer('freqs_cis', torch.polar(torch.ones_like(freqs), freqs)) # complex64

    def forward(self, x):
        # Apply rotary embeddings to the input tensor
        x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
        x_out = torch.view_as_real(x_ * self.freqs_cis.unsqueeze(0)).type_as(x)
        x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
        x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)
        return x_out
```

```Python
# demonstration/debugging
module = RoPE(dim=10)
module.enable_logging()
output = module(torch.randn(, 5))
del module, output
```

In [91]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Validate that dim is even since we split it by 2 for real and imaginary parts
    if dim % 2 != 0: raise ValueError("Dimension 'dim' must be an even number.")
            
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # it's important to train on a wide variety of sequence lengths within your context length so that the model learns to generalize

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# selfMQA

In [92]:
class selfMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.sa_q_heads = config.sa_q_heads
        self.sa_kv_heads = config.sa_kv_heads
        assert self.sa_q_heads % self.sa_kv_heads == 0
        self.num_queries_per_kv = self.sa_q_heads // self.sa_kv_heads

        self.embed_dim = config.embed_dim
        self.sa_head_dim = config.sa_head_dim
        self.theta = config.theta
        self.dropout_rate = config.dropout_rate

        #self.Wqkv = nn.Linear(self.embed_dim, 
        #                      (self.sa_q_heads + 2 * self.sq_kv_heads) * self.sa_head_dim, 
        #                      bias = config.attn_bias)
        #self.Wo = nn.Linear(self.sa_q_heads * self.sa_head_dim,
        #                    self.embed_dim,
        #                    bias = config.attn_bias)
        
        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.sa_q_heads + 2 * self.sa_kv_heads) * self.sa_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)

        self.Wo = nn.Parameter(torch.Tensor(self.sa_q_heads * self.sa_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5), (1 / (self.sa_q_heads * self.sa_head_dim)) ** 0.5)

        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = x.shape

        # splicing our primary projection to get the correct sub-matrices
        Wq, Wk, Wv = self.weight_splicing(self.Wqkv)
        # technically self.weight_splicing has access to self.Wqkv & Wo but this way our debugger can see them

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also dropout if we're training
        xk = F.dropout(x @ Wk, p=self.dropout_rate, training=training)
        xv = F.dropout(x @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.sa_q_heads, self.sa_head_dim)
        xk = xk.view(batch_size, -1, self.sa_kv_heads, self.sa_head_dim)
        xv = xv.view(batch_size, -1, self.sa_kv_heads, self.sa_head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq, xk = self.RoPE(xq, xk)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.sa_kv_heads != self.sa_q_heads:
            xk, xv = self.match_headcount(xk, xv) # [batch_size, input_len, n_local_heads, sa_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, sa_head_dim]
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, xk) # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = self.apply_mask(logits, input_len)

        # applies values to get final output
        output = self.calc_output(logits, xv, batch_size, input_len, training) 

        # Applies the final linear projection to the attention output, mapping it back to self.embed_dim
        return F.dropout(output @ self.Wo, p=self.dropout_rate, training=training) # also dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv):
        Wq, Wk, Wv = Wqkv.split([self.sa_q_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim,
                                 self.sa_kv_heads * self.sa_head_dim],dim = -1)
        return Wq, Wk, Wv

    @log_io
    def RoPE(self, xq, xk):
        xq = RoPE(xq, self.sa_head_dim, self.theta)
        xk = RoPE(xk, self.sa_head_dim, self.theta)
        return xq, xk

    @log_io
    def match_headcount(self, xk, xv):
        xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
        xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
        return xk, xv

    @log_io
    def attend(self, xq, xk):
        return torch.matmul(xq, xk.transpose(2, 3)) * (self.sa_head_dim ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len):
        return torch.where(self.mask[..., :input_len, :input_len].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, xv, batch_size, input_len, training):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # also applies dropout if we're training
        scores = F.dropout(scores, p=self.dropout_rate, training=training)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ xv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

### demonstration/debugging
I've setup these little snippets after each nn.Module to help you see what's happening and for my own debugging

In [93]:
# Create an instance of selfMQA
module = selfMQA(config)

# Initially, logging is disabled
# Enable logging
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

# Call the forward method - logging will occur
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))#, training=True)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 256, 2, 64])
Tensor 'xk' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 256, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 256, 1, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 256, 1, 64])
Tensor 'xv' shape: torch.Size([32, 256, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 256, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 256, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'xk' shape: torch.Size([32, 2, 256, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 256])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 256])
Integer 'input_len': Value=256

Outputs:
Tensor 'output' s

# MLP

In [94]:
class MLP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 mlp_multiplier: int,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier
        self.dropout_rate = dropout_rate

        # the gate, up and down projections
        self.gate_proj = nn.Linear(embed_dim, self.hidden_size)
        self.up_proj = nn.Linear(embed_dim, self.hidden_size)
        self.down_proj = nn.Linear(self.hidden_size, embed_dim)
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)
        
    @log_io
    def forward(self, x: torch.Tensor, training: bool = False ) -> torch.Tensor:
        output = self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))
        return F.dropout(output, p=self.dropout_rate, training=training)

### demonstration/debugging

In [95]:
module = MLP(config.embed_dim, config.mlp_multiplier, config.dropout_rate)
module.enable_logging()
output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


# Norms

In [96]:
class Norm(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps
        self.affine = config.norm_affine
        self.dropout_rate = config.dropout_rate
        self.type = config.norm_type

        # Initialize weight and bias parameters for affine transformation
        # We start with ones for weight to keep the original scale initially, and zeros for bias.
        self.w = nn.Parameter(torch.ones(config.embed_dim))
        self.b = nn.Parameter(torch.zeros(config.embed_dim))
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

    @log_io
    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        # Normalize the input tensor
        if self.type == "CosineNorm":
            x = self.CosineNorm(x)
        elif self.type == "LayerNorm":
            x = self.LayerNorm(x)
        else: # defaults to RMSNorm bc that's the most commonly used nowadays
            x = self.RMSNorm(x)

        if self.affine: # Optionally apply the affine transformation
            x = x * self.w + self.b
            
        return F.dropout(x, p=self.dropout_rate, training=training) # and dropout if we're training

    @log_io
    def CosineNorm(self, x):
        # normalize x by dividing by its L2 norm along the last dimension.
        # this places x on the unit hypersphere centered at the origin
        # Add a small constant to the denominator to avoid division by zero.
        return x / torch.norm(x, p=2, dim=-1, keepdim=True).clamp(min=self.eps)

    @log_io
    def LayerNorm(self, x): # nn.LayerNorm() exists but might as well make it from scratch if we have to do the other two
        # normalize x by subtracting by its mean then dividing by its variance
        # this places x on a hypersphere of radius sqrt(dimension) centered at the origin
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps)

    @log_io
    def RMSNorm(self, x):
        # normalize x by dividing by its root-mean-square along the last dimension
        # this places x on a hypersphere of radius sqrt(dimension) with no certain center
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

### demonstration/debugging

In [97]:
module = Norm(config)
module.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('RMSNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('CosineNorm')

output = module(torch.randn(32, config.max_seq_len, config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


# crossMQA

In [68]:
class crossMQA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.ca_q_heads = config.ca_q_heads
        self.ca_kv_heads = config.ca_kv_heads
        assert self.ca_q_heads % self.ca_kv_heads == 0
        self.num_queries_per_kv = self.ca_q_heads // self.ca_kv_heads
        self.embed_dim = config.embed_dim
        self.ca_head_dim = config.ca_head_dim
        self.sa_head_dim = config.sa_head_dim # used only for an assertion to make sure sizes will fit
        self.combo = config.combo # used only for an assertion to make sure sizes will fit
        self.theta = config.theta
        self.use_RoPE = config.ca_use_RoPE
        self.dropout_rate = config.dropout_rate
        self.predictive_mask = config.predictive_mask

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.ca_q_heads + 2 * self.ca_kv_heads) * self.ca_head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)
        
        self.Wo = nn.Parameter(torch.Tensor(self.ca_q_heads * self.ca_head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5), (1 / (self.ca_q_heads * self.ca_head_dim)) ** 0.5)
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

    @log_io
    def forward(self, 
                x: torch.Tensor, # the current level tensor, sometimes a resid state full of tokens & sometimes concepts
                c: torch.Tensor, # the upper level tensor, always a resid state full of concept vecs
                training: bool = False
               ) -> torch.Tensor:
        
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len_x, _ = x.shape
        batch_size_c, input_len_c, _ = c.shape
        assert batch_size == batch_size_c

        # splicing our projection to get the correct sub-matrices
        Wq, Wk, Wv = self.weight_splicing(self.Wqkv)

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq = F.dropout(x @ Wq, p=self.dropout_rate, training=training) # also applies dropout if we're training
        ck = F.dropout(c @ Wk, p=self.dropout_rate, training=training)
        cv = F.dropout(c @ Wv, p=self.dropout_rate, training=training)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.ca_q_heads, self.ca_head_dim)
        ck = ck.view(batch_size, -1, self.ca_kv_heads, self.ca_head_dim)
        cv = cv.view(batch_size, -1, self.ca_kv_heads, self.ca_head_dim)

        # IF we want to use RoPE (doesn't fully make sense to)
        if self.use_RoPE:
            expand = input_len_x // input_len_c
            ck = ck.repeat_interleave(expand, dim=1) 
            cv = cv.repeat_interleave(expand, dim=1) # values need to be expanded for their use later on if we do this

            # Applies rotary positional embeddings to queries and keys to incorporate positional information.
            xq, ck = self.RoPE(xq, ck)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.ca_kv_heads != self.ca_q_heads:
            ck, cv = self.match_headcount(ck, cv) # [batch_size, input_len, n_local_heads, sa_head_dim]

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        xq = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, ca_head_dim]
        ck = ck.transpose(1, 2)
        cv = cv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = self.attend(xq, ck) # [batch_size, n_local_heads, input_len, input_len]
        
        # Optionally applies the upper-triangular mask to the attention logits
        if self.predictive_mask:
            logits = self.apply_mask(logits, input_len_x, input_len_c)

        # applies values to get final output
        output = self.calc_output(logits, cv, batch_size, input_len_x, training)

        # Applies the final linear projection to the attention output, mapping it back to d_i. 
        return F.dropout(output @ self.Wo, p=self.dropout_rate, training=training) # and dropout if we're training

    @log_io
    def weight_splicing(self, Wqkv):
        Wq, Wk, Wv = Wqkv.split([self.ca_q_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim,
                                 self.ca_kv_heads * self.ca_head_dim],dim = -1)
        return Wq, Wk, Wv
        
    @log_io
    def RoPE(self, xq, xk):
        xq = RoPE(xq, self.ca_head_dim, self.theta)
        xk = RoPE(xk, self.ca_head_dim, self.theta)
        return xq, xk

    @log_io
    def match_headcount(self, xk, xv):
        xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
        xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
        return xk, xv

    @log_io
    def attend(self, xq, ck):
        return torch.matmul(xq, ck.transpose(2, 3)) * (self.ca_head_dim ** -0.5)
        
    @log_io
    def apply_mask(self, logits, input_len_x, input_len_c):
        self.mask = torch.triu(torch.ones((config.max_seq_len, input_len_c), dtype=torch.uint8)
                                  ).view(1, 1, config.max_seq_len, input_len_c).to(dtype=torch.bool)
        
        return torch.where(self.mask[..., :input_len_x, :input_len_c].expand_as(logits),
                           logits,
                           torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))
    
    @log_io
    def calc_output(self, logits, cv, batch_size, input_len_x, training):
        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # also applies dropout if we're training
        scores = F.dropout(scores, p=self.dropout_rate, training=training)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ cv # [batch_size, n_local_heads, input_len, sa_head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        return output.transpose(1, 2).contiguous().view(batch_size, input_len_x, -1)

### demonstration/debugging

at one point this one had been giving me trouble so here's multiple different config.levels setups for ya

In [69]:
hold1 = config.levels
config.levels = 2
hold2 = config.predictive_mask
config.predictive_mask = False
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

x0 = torch.randn(32, config.max_seq_len, config.embed_dim)
c1 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = module(x0, c1)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, x0, c1, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 64, 1, 64])
Tensor 'xv' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 64, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'ck' shape: torch.Size([32, 2, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Tensor 'cv' shape: torch.Size([32, 2, 64, 64])
Integer 'batch_size': Value=32
Integer 'input_len_x': Value=256
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 128])

Outputs:
Tensor 'output' shape: torch.Size

In [70]:
hold1 = config.levels
config.levels = 3
hold2 = config.predictive_mask
config.predictive_mask = False
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

c1 = torch.randn(32, config.max_seq_len // (config.combo**1), config.embed_dim)
c2 = torch.randn(32, config.max_seq_len // (config.combo**2), config.embed_dim)
output = module(c1, c2)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, c1, c2, output


Inputs:
Tensor 'x' shape: torch.Size([32, 64, 256])
Tensor 'c' shape: torch.Size([32, 16, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xk' shape: torch.Size([32, 16, 1, 64])
Tensor 'xv' shape: torch.Size([32, 16, 1, 64])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 16, 2, 64])
Tensor 'output[1]' shape: torch.Size([32, 16, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 64, 64])
Tensor 'ck' shape: torch.Size([32, 2, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 64, 16])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 64, 16])
Tensor 'cv' shape: torch.Size([32, 2, 16, 64])
Integer 'batch_size': Value=32
Integer 'input_len_x': Value=64
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 

### working on crossMQA predictive mask

In [50]:
# Original tensor using torch.triu on a 3x3 tensor of ones
original_tensor = torch.triu(torch.ones(3, 3))

input_len = 8
# Expand the tensor by duplicating each row twice more
expanded_tensor = original_tensor.repeat_interleave(3, dim=0)

out = expanded_tensor[:input_len,:]
out

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 1., 1.],
        [0., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])

In [58]:
hold1 = config.levels
config.levels = 2
hold2 = config.predictive_mask
config.predictive_mask = True
module = crossMQA(config)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('weight_splicing')
#module.disable_function_logging('RoPE')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('apply_mask')
#module.disable_function_logging('calc_output')

x0 = torch.randn(32, config.max_seq_len, config.embed_dim)
c1 = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = module(x0, c1)
config.levels = hold1
config.predictive_mask = hold2
del hold1, hold2, module, x0, c1, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'Wqkv' shape: torch.Size([256, 256])

Outputs:
Tensor 'output[0]' shape: torch.Size([256, 128])
Tensor 'output[1]' shape: torch.Size([256, 64])
Tensor 'output[2]' shape: torch.Size([256, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xn' shape: torch.Size([32, 64, 1, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 64, 2, 64])

Inputs:
Tensor 'xq' shape: torch.Size([32, 2, 256, 64])
Tensor 'ck' shape: torch.Size([32, 2, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 2, 256, 64])

Inputs:
Tensor 'logits' shape: torch.Size([32, 2, 256, 64])
Integer 'input_len': Value=256


RuntimeError: The expanded size of the tensor (64) must match the existing size (256) at non-singleton dimension 3.  Target sizes: [32, 2, 256, 64].  Tensor sizes: [1, 1, 256, 256]

# Layer

we implement cross-attention inbetween self-attention and MLP like is done in Attention is All You Need

In [98]:
class Layer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.pre_self_mqa_norm = Norm(config)
        self.self_mqa = selfMQA(config)
        self.post_self_mqa_norm = Norm(config)
        
        self.pre_cross_mqa_x_norm = Norm(config)
        self.pre_cross_mqa_c_norm = Norm(config)
        self.cross_mqa = crossMQA(config)
        self.post_cross_mqa_norm = Norm(config)
        
        self.pre_mlp_norm = Norm(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier, config.dropout_rate)
        self.post_mlp_norm = Norm(config)
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

    @log_io
    def forward(self, 
                x: torch.Tensor,
                c: torch.Tensor = None,
                training: bool = False,
               ) -> torch.Tensor:
        x = x + self.self_mqa_connection(x, training)
        if c is not None:
            x = x + self.cross_mqa_connection(x, c, training)
        x = x + self.mlp_connection(x, training)
        return x

    @log_io
    def self_mqa_connection(self, x, training):
        return self.post_self_mqa_norm(self.self_mqa(self.pre_self_mqa_norm(x, training), training), training)

    @log_io
    def cross_mqa_connection(self, x, c, training):
        return self.post_cross_mqa_norm(self.cross_mqa(self.pre_cross_mqa_x_norm(x, training), 
                                                       self.pre_cross_mqa_c_norm(c, training), training), training)

    @log_io
    def mlp_connection(self, x, training):
        return self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x, training), training), training)

### demonstration/debugging

In [84]:
module = Layer(config)
module.enable_logging()

### enabling printing for sub-modules
#module.pre_self_mqa_norm.enable_logging()
#module.self_mqa.enable_logging()
#module.post_self_mqa_norm.enable_logging()
#module.pre_cross_mqa_norm.enable_logging()
#module.cross_mqa.enable_logging()
#module.post_cross_mqa_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('self_mqa_connection')
#module.disable_function_logging('cross_mqa_connection')
#module.disable_function_logging('mlp_connection')

output = module(torch.randn(32,config.max_seq_len,config.embed_dim))
del module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


In [86]:
hold = config.predictive_mask
config.predictive_mask = False
module = Layer(config)
module.enable_logging()

### enabling printing for sub-modules
#module.pre_self_mqa_norm.enable_logging()
#module.self_mqa.enable_logging()
#module.post_self_mqa_norm.enable_logging()
#module.pre_cross_mqa_norm.enable_logging()
#module.cross_mqa.enable_logging()
#module.post_cross_mqa_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

### disabling printing for sub-functions
#module.disable_function_logging('self_mqa_connection')
#module.disable_function_logging('cross_mqa_connection')
#module.disable_function_logging('mlp_connection')

x = torch.randn(32, config.max_seq_len, config.embed_dim)
c = torch.randn(32, config.max_seq_len // config.combo, config.embed_dim)
output = module(x, c)
config.predictive_mask = hold
del hold, module, output


Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Tensor 'c' shape: torch.Size([32, 64, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Inputs:
Tensor 'x' shape: torch.Size([32, 256, 256])
Integer 'training': Value=False

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])

Outputs:
Tensor 'output' shape: torch.Size([32, 256, 256])


### need to do more debugging once predictive mask works

# Body

In [None]:
# to prevent the warning statement from printing hella times
cvec_warning = False

class Body(nn.Module):
    def __init__(self, config: Config, embedder: torch.Tensor):
        super().__init__()
        self.ca_interval = config.ca_interval
        self.max_seq_len = config.max_seq_len
        self.combo = config.combo
        self.levels = config.levels
        self.embedding = embedder.weight
        
        # Initialize a sequence of Layer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_layers))
        
        # Initialize normalization layer to be applied after the last decoder layers, stabilizing the output
        self.final_norm = Norm(config)
        
        self.logging_enabled = False
        self.disabled_logging_functions = set()
    def enable_logging(self):
        self.logging_enabled = True
    def disable_logging(self):
        self.logging_enabled = False
    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)
    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)
        
    @log_io
    def forward(self,
                x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                targets: Tuple[torch.Tensor] = None,
                cvec_samples: int = None,
                cvec_greedy: bool = False,
                cvec_temp: float = 1.0,
               ) -> Tuple[torch.Tensor]:
        return self.forward_training(x0s, targets) if targets is not None else self.forward_inference(x0s, cvec_samples, cvec_greedy, cvec_temp)
        
    @log_io
    def forward_training(self,
                         x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                         training: bool = False,
                        ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level

        return
        
    @log_io
    def forward_inference(self,
                        x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                        cvec_samples: int = None,
                        cvec_greedy: bool = False,
                        cvec_temp: float = 1.0,
                        training: bool = False,
                       ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level

            # if we're dealing with a concept level, then we need to create entire concept sequences. for token level we only want one pass
            if i != len(x0s)-1:
                effective_max_seq_len = self.max_seq_len // (self.combo ** (self.levels-1-i))
                assert x.shape[1] <= effective_max_seq_len, f'somehow at level {i} a too-long sequence ({x.shape[1]} vs {effective_max_seq_len}) made it all the way to Body'
                extra_runs = effective_max_seq_len - x.shape[1]
                for k in range(extra_runs): # if extra_runs == 0 then this just won't do anything
                    # run
                    x_ = self.layers_loop(x, i, xfs[i-1], training) # xfs[i-1] is the concept embedding to pay attention to
                    # splice out the final prediction
                    x_ = x_[:,-1,:]
                    # select most similar concept vectors to be appended to the sequence
                    c = self.concept_matchup(x_, cvec_samples, cvec_greedy, cvec_temp)
                    # append to x
                    x = torch.concat([x, c.unsqueeze(1)], dim=1)

            # the final run. this one will actually be logits to get used with crossMQA
            x = self.layers_loop(x, i, xfs[i-1], training)
            
            # add the final residual state of the level to our tuple, normed
            xfs += (self.final_norm(x, training),) # should i be using separate final norms? nah
            
        return xfs[-1] # for inference, we only need to return the resid state of the token level
        
    @log_io
    def forward(self, 
                x0s: Tuple[torch.Tensor], # ordered from tokens -> highest concepts
                cvec_samples: int = None,
                cvec_greedy: bool = False,
                cvec_temp: float = 1.0,
                training: bool = False,
               ) -> Tuple[torch.Tensor]:
        # initiate tuple to hold final residual states
        xfs = ()
        # iterate through model levels, starting from highest level concepts & ending at lowest level tokens
        for i, x in enumerate(reversed(x0s)): # reversed() makes us start at highest level

            # if we're dealing with a concept level, then we need to create entire concept sequences. for tokens we only want one pass
            if i != len(x0s)-1:
                effective_max_seq_len = self.max_seq_len // (self.combo ** (self.levels-1-i))
                assert x.shape[1] <= effective_max_seq_len, f'somehow a too-long sequence ({x.shape[1]} vs {effective_max_seq_len}) made it all the way to Body'
                extra_runs = effective_max_seq_len - x.shape[1]
                for k in range(extra_runs): # if extra_runs == 0 then this just won't do anything
                    # run
                    x_ = self.layers_loop(x, i, x0s, training)
                    # splice out the final prediction
                    x_ = x_[:,-1,:]
                    # norm? no, i'm worried about affine transformations messing with it since this process wasn't present during training
                    # select most similar concept vectors to be appended to the sequence
                    c = self.concept_matchup(x_, cvec_samples, cvec_greedy, cvec_temp)
                    # append to x
                    x = torch.concat([x, c.unsqueeze(1)], dim=1)

            # the final run. this one will actually be logits to get used with crossMQA
            x = self.layers_loop(x, i, x0s, training)
            
            # add the final residual state of the level to our tuple, normed
            xfs += (self.final_norm(x, training),) # should i be using separate final norms? nah

        return xfs # now it's ordered from highest concepts -> token

    @log_io
    def layers_loop(self, x, i, x0s, training):
        
        # Iteratively process the input through each Layer of the model
        for j in range(len(self.layers)):
            layer = self.layers[j]

            # i can't equal zero bc there'd be no higher level model to pay attention to
            if (i != 0): 
                # run through layer while cross-attending to held outputs of upper level
                x = layer(x, x0s[len(x0s)-i], training)
            else:
                # run through layer when there is no higher level to attend to
                x = layer(x, training=training)
        return x

    @log_io
    def concept_matchup(self,
                        c: torch.Tensor,
                        cvec_samples: int,
                        cvec_greedy: bool,
                        cvec_temp: float,
                        ) -> torch.Tensor:
        global cvec_warning
        batch_size, d = c.size()
        vocab_size = self.embedding.size(0)
        embedding = self.embedding[:,:d]
    
        # Batch cosine similarity
        # Reshape c: (batch_size x 1 x embedding_dim)
        # Reshape embedding: (1 x vocab_size x embedding_dim)
        # Resulting similarity: (batch_size x vocab_size)
        token_similarities = F.cosine_similarity(c.unsqueeze(1), embedding.unsqueeze(0), dim=-1)
        
        # how many tokens will we sample to build up our chosen concept vector?
        if cvec_samples is None:
            cvec_samples = self.combo ** (self.levels-1)
            if (cvec_warning == False) or (cvec_warning is None):
                print(f"cvec_samples not defined. defaulting to highest level's minimum size: combo**(levels-1) = {cvec_samples}")
                cvec_warning = True
        assert cvec_samples >= self.combo ** (self.levels-1), f'cvec_samples = {cvec_samples} needs to be >= self.combo ** (self.levels-1) = {self.combo ** (self.levels-1)}'
        
        # Select top-k token embeddings for each concept vector
        topk_token_indices = torch.topk(token_similarities, k=cvec_samples, dim=1).indices  # (batch_size x sample)
    
        # Generate concept embeddings for each set of top-k token embeddings
        concept_embeddings_batch = []
        X_sizes_batch = []
        for i in range(batch_size):
            # Pass the list of indices for each concept
            concept_embeddings, X_sizes = self.create_concept_embeddings(embedding, 
                                                                         [topk_token_indices[i].tolist()])
            concept_embeddings_batch.append(concept_embeddings.squeeze(0))  # Remove the extra batch dimension
            X_sizes_batch.append(X_sizes)
    
        # Convert list of tensors to a tensor
        concept_embeddings_batch = torch.stack(concept_embeddings_batch)  # (batch_size x max_X_size x d)
    
        # Calculate concept similarities for each concept in the batch
        concept_similarities_batch = F.cosine_similarity(c.unsqueeze(1), concept_embeddings_batch, dim=-1)
    
        # Select the best matching concept embedding for each concept vector in the batch
        if cvec_greedy:
            best_concept_indices = concept_similarities_batch.argmax(dim=1)
            matched_concepts = concept_embeddings_batch[torch.arange(batch_size), best_concept_indices]
        else:
            # Apply softmax with temperature and sample
            topk_concept_probs = F.softmax(concept_similarities_batch / cvec_temp, dim=1)
            concept_topk_idx = torch.multinomial(topk_concept_probs, num_samples=1).squeeze(1)
            matched_concepts = concept_embeddings_batch[torch.arange(batch_size), concept_topk_idx]
    
        return matched_concepts

    @log_io
    def create_concept_embeddings(self, E: torch.Tensor, indices: torch.Tensor):
        """
        Create concept embeddings for a batch of indices.
    
        E: Embedding matrix (vocab_size x embedding_dim)
        indices: A list of lists of indices (batch_size x num_indices)
        """
        batch_size = len(indices)
        d = E.size(1)
        X_sizes = [(len(ind) - 1) * len(ind) // 2 for ind in indices]
        max_X_size = max(X_sizes)
        X = torch.empty((batch_size, max_X_size, d), dtype=E.dtype)
        
        # this could prolly be done way more efficiently with tensor operations
        for b in range(batch_size):
            count = 0
            for i in range(len(indices[b])):
                for j in range(i + 1, len(indices[b])):
                    X[b, count] = E[indices[b][i]] + E[indices[b][j]]
                    count += 1
            # Padding the rest if necessary
            if count < max_X_size:
                X[b, count:] = torch.zeros((max_X_size - count, d))
        
        # X_sizes is not useful rn but i think it may be later when we switch away from TinyShakespeare
        # and over to data that actually has variable sequence lengths
        return X, X_sizes