In [None]:
"""
Decompose Phi-3 into constituent parts. Runs each part peice by piece. Test logit lens, position loss, and feature clustering via the MLP layer.
"""
None

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
import plotly.express as px 
import os
import wandb 
from dotenv import load_dotenv
from datetime import datetime
import pytz
from torch.utils.data import DataLoader
import importlib
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from helpers.phi3.phi3 import Phi3Config, Phi3ForCausalLM, _prepare_4d_causal_attention_mask
from helpers.phi3.parse import parse_phi
from helpers.memory import check_memory

load_dotenv('secrets.env')
device = 'cuda'

## Initialize Model

In [None]:
attn_implementation = None # None/flash_attention_2

# Load Model
# Padding side not important EXCEPT for flash attention, needs to be left
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'left') 

# Load the usual model from HF transformers - attn_implementation = None to disable flash attention
my_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/Phi-3-mini-4k-instruct',
    device_map = device,
    trust_remote_code = True, 
    torch_dtype = torch.bfloat16, 
    attn_implementation = attn_implementation
    ).to(device).eval()

## Initialize FENCE Parameters

In [None]:
# Pass indices starting at 1
fence_dict = {
    'animals': (3061, 3072),
    'cats': (3061, 3064),
    'dogs': (3065, 3068),
    'food': (3020, 3028),
    'programming': (2961, 2972)
}

Kfstart = 1
Kfend = 32
Kf_target_values = {
    'hkrs': {Kfstart + j - 1: (j - 1) * .25 + .25/2 for j in range(Kfstart, Kfend + 1)},
    'hks': {Kfstart + j - 1: (j - 1) * .25 + .25 for j in range(Kfstart, Kfend + 1)},
}

print(fence_dict)
Kf_target_values

## Test Inference & Visualizations with Base Model

In [None]:
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import generate_fence

# Test
dog_prompt = parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True)
dog_gens = generate_fence(my_model, tokenizer, prompt = dog_prompt, max_tokens = 12)

animal_prompt = '<s>Animals are,'
animal_gens = generate_fence(my_model, tokenizer, prompt = animal_prompt, max_tokens = 12)

In [1]:
importlib.reload(importlib.import_module('helpers.fence.visualize'))
from helpers.fence.visualize import visualize_fence

for l in [1]:
    visualize_fence(
        dog_gens['text'],
        dog_gens['hks'],
        [l],
        fence_dict,
        start_dim = 2900, end_dim = 3072,
        min_range = 0, max_range = Kf_target_values['hks'][l]
    ).update_layout(title = 'H<sub>' + str(l) + '</sub>', height = 300).show('colab')

NameError: name 'importlib' is not defined

## Test Component-by-Component Inference

In [None]:
from helpers.misc import is_notebook
from helpers.phi3.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb
import math

@torch.no_grad()
def generate_fence_with_force(model, tokenizer, prompt, echo_output = True, max_tokens = 128, device = 'cuda'):
    """
    Runs a forward pass and stores FENCE-relevant intermediate hidden states. Allows for forced-FENCE (see foward pass code).
    Also calculates the modularity loss. Position loss is NOT calculated as there are no true-FENCE states.

    Returns a dictionary with keys:
        - `text`: The decoded output text, as a list
        - `hk1s`: The first residual stream output
        - `hk2s`: The final residual stream output
        - `hksas`: The hidden state outputs of the SA component
        - `hkmlps`: The hidden state outputs of the MLP component
    """
    model.eval()
    generated_tokens = 0
    
    input_ids_0 = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
    input_ids = input_ids_0

    while True:
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        
        # Flash attention = use default attention mask 2d
        if model.model._attn_implementation == 'flash_attention_2':
            attention_mask = None
        # Non flash-attention: Make a triangular attention mask to hide right context
        else:
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) 

        saved_sa_outputs = []
        saved_hkrs = []
        saved_mlp_outputs = []
        saved_hks = []
        ### Transformer Blocks ###
        for i, layer in enumerate(model.model.layers):            

            residual = hidden_state
            sa_input = layer.input_layernorm(hidden_state)
            
            ### SA ###
            sa_module = layer.self_attn
            # sa_output = sa_module(sa_input, attention_mask, position_ids)[0]
            qkv = sa_module.qkv_proj(sa_input)
            queries = qkv[:, :, :D].view(B, N, H, Dh).transpose(1, 2)
            keys = qkv[:, :, D:2*D].view(B, N, H, Dh).transpose(1, 2)
            values = qkv[:, :, 2*D:].view(B, N, H, Dh).transpose(1, 2)

            if model.model._attn_implementation == 'flash_attention_2':     
                # Flash attention requires the input to have the shape B x N x Dh x D           
                # Because the input can be padded, the absolute sequence length depends on the max position id.
                rotary_seq_len = max(N, position_ids[:, -1].max().item()) + 1
                cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = rotary_seq_len)
                queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
                ################## # Reshape to the expected shape for Flash Attention
                queries = queries.transpose(1, 2)
                keys = keys.transpose(1, 2)
                values = values.transpose(1, 2)
                ###################
                sa_output = sa_module._flash_attention_forward(queries, keys, values, attention_mask, N)
                sa_output = sa_output.reshape(B, N, D).contiguous()
            else:    
                cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = N)
                queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
                attn_weights = torch.matmul(queries, keys.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
                attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity
                attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(values.dtype)
                sa_output = torch.matmul(attn_weights, values) # B x H x N x D/H
                sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
                sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
    
            # Finally post-concatenation linear layer
            sa_output = sa_module.o_proj(sa_output)

            saved_sa_outputs.append(sa_output[0, :, :].detach())
            
            ### add residual -> store residual -> layernorm -> mlp -> add residual
            hidden_state = residual + sa_output

            # FENCE
            if l >= Kfstart - 1 and l <= Kfend - 1:
                # To extract the right layer from hkr_feature_targets:
                # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                # - => k = l + 1 - Kfstart
                # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hkrs'][l+1] # B x N x Df                
                pass
            
            residual = hidden_state
            saved_hkrs.append(hidden_state[0, :, :].detach())

            hidden_state = layer.post_attention_layernorm(hidden_state)
            ## MLP            
            up_state = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
            gate, up_state = up_state.chunk(2, dim = -1) # B x N x I
            up_state = up_state * layer.mlp.activation_fn(gate)  # Elementwise
            hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
            ## End MLP
            
            saved_mlp_outputs.append(hidden_state[0, :, :].detach())

            hidden_state = residual + hidden_state

            if l >= Kfstart - 1 and l <= Kfend - 1:
                # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hks'][l+1] # B x N x Df
                pass

            saved_hks.append(hidden_state[0, :, :].detach())
                
        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    # Use it on the last pass only
    all_hksas = [h.cpu().to(torch.float16).numpy() for h in saved_sa_outputs]
    all_hkrs = [h.cpu().to(torch.float16).numpy() for h in saved_hkrs]
    all_hkmlps = [h.cpu().to(torch.float16).numpy() for h in saved_mlp_outputs]
    all_hks = [h.cpu().to(torch.float16).numpy() for h in saved_hks]

    final_output = input_ids.squeeze()
    decoded_text = tokenizer.batch_decode(final_output)
    decoded_output = tokenizer.decode(final_output[input_ids_0.size()[1]:])

    if echo_output:
        if is_notebook():
            display(HTML(
                '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
                    '<h4>Modified model output</h4>' + 
                    '<span style="color:green">' + tokenizer.batch_decode(input_ids_0)[0][3:] + '</span> ' + 
                    '<span style="color:red">' + decoded_output + '</span>' +
                '</div>'
            ))
        else:
            print(colored(tokenizer.batch_decode(input_ids_0)[0][3:], 'green'), colored(tokenizer.decode(final_output[input_ids_0.size()[1]:]), 'red'))

generate_fence(my_model, tokenizer, prompt = dog_prompt, max_tokens = 12)

In [None]:
# TEST - SINGLE PASS
"""
This does not collect any position loss, but it does collect modularity loss
"""
from helpers.phi3.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb
import math

model = my_model
prompt = dog_prompt

model.eval()
generated_tokens = 0

input_ids_0 = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
input_ids = input_ids_0

embeds_output = model.model.embed_tokens(input_ids)
hidden_state = embeds_output

B, N, D = embeds_output.shape
H = 32
Dh = int(D/H)

position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs

# Flash attention = use default attention mask 2d
if model.model._attn_implementation == 'flash_attention_2':
    attention_mask = None
# Non flash-attention: Make a triangular attention mask to hide right context
else:
    attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) 

saved_sa_outputs = []
saved_hkrs = []
saved_mlp_outputs = []
saved_hks = []
### Transformer Blocks ###
for i, layer in enumerate(model.model.layers):            

    residual = hidden_state
    sa_input = layer.input_layernorm(hidden_state)
    
    ### SA ###
    sa_module = layer.self_attn
    # sa_output = sa_module(sa_input, attention_mask, position_ids)[0]
    qkv = sa_module.qkv_proj(sa_input)
    queries = qkv[:, :, :D].view(B, N, H, Dh).transpose(1, 2)
    keys = qkv[:, :, D:2*D].view(B, N, H, Dh).transpose(1, 2)
    values = qkv[:, :, 2*D:].view(B, N, H, Dh).transpose(1, 2)

    if model.model._attn_implementation == 'flash_attention_2':     
        # Flash attention requires the input to have the shape B x N x Dh x D           
        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(N, position_ids[:, -1].max().item()) + 1
        cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = rotary_seq_len)
        queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
        ################## # Reshape to the expected shape for Flash Attention
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        ###################
        sa_output = sa_module._flash_attention_forward(queries, keys, values, attention_mask, N)
        sa_output = sa_output.reshape(B, N, D).contiguous()
    else:    
        cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = N)
        queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
        attn_weights = torch.matmul(queries, keys.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
        attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity
        attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(values.dtype)
        sa_output = torch.matmul(attn_weights, values) # B x H x N x D/H
        sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
        sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D

    # Finally post-concatenation linear layer
    sa_output = sa_module.o_proj(sa_output)

    saved_sa_outputs.append(sa_output[0, :, :].detach())
    
    ### add residual -> store residual -> layernorm -> mlp -> add residual
    hidden_state = residual + sa_output

    # FENCE
    if l >= Kfstart - 1 and l <= Kfend - 1:
        # To extract the right layer from hkr_feature_targets:
        # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
        # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
        # - => k = l + 1 - Kfstart
        # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hkrs'][l+1] # B x N x Df                
        pass
    
    residual = hidden_state
    saved_hkrs.append(hidden_state[0, :, :].detach())

    hidden_state = layer.post_attention_layernorm(hidden_state)
    hidden_state_pre_mlp = hidden_state
    
    ## MLP       
    # hidden_state is of size B x N x D

    # Original
    # gate_plus_vmat = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
    # gate, vmat = gate_plus_fvals.chunk(2, dim = -1) # B x N x I

    # Alternative - split gate and vmat
    weight_matrix = layer.mlp.gate_up_proj.weight
    # Split the weight matrix into two parts along the output dimension (dim=0)
    gate_weight, vmat_weight = weight_matrix.chunk(2, dim=0)
    # Manually perform the linear transformation: gate = hidden_state @ gate_weight^T; vmat = hidden_state @ vmat_weight^T
    gate = torch.matmul(hidden_state, gate_weight.T)
    hv = torch.matmul(hidden_state, vmat_weight.T)

    # At this point the up_state = values (see Geva et al), and the gate is the keys
    up_state = hv * layer.mlp.activation_fn(gate)  # Elementwise
    hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
    ## End MLP
    
    saved_mlp_outputs.append(hidden_state[0, :, :].detach())

    hidden_state = residual + hidden_state

    if l >= Kfstart - 1 and l <= Kfend - 1:
        # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hks'][l+1] # B x N x Df
        pass

    saved_hks.append(hidden_state[0, :, :].detach())
        
hidden_state = model.model.norm(hidden_state)
logits = model.lm_head(hidden_state)

# Get argmax tokens + concatenate onto previous tokens
output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)


tokenizer.batch_decode([output_token])

In [None]:
from helpers.memory import check_memory
check_memory()

In [None]:
vmat

In [None]:
vmat_weight.T.shape

In [None]:
import torch

def normalized_L1_loss(H, V, dim_range=None, device='cuda'):
    """
    Computes the normalized L1 distance-weighted interaction loss for the given input tensors H and V.

    Parameters:
    - H (torch.Tensor): The input tensor of shape (B, N, D), where B is the batch size, N is the token length, and D is the dimension.
    - V (torch.Tensor): The learnable weight matrix of shape (D, I).
    
    Returns:
    - loss (torch.Tensor): The computed normalized L1 loss.
    """
    
    B, N, D = H.shape  # B: batch size, N: token length, D: dimension
    I = V.shape[1]     # I: the output dimension (same as p_j length)
    
    # If dim_range is None, use the full range of dimensions
    if dim_range is None:
        dim_range = (0, D)
        
    start_dim, end_dim = dim_range
    assert 0 <= start_dim < end_dim <= D, "Invalid dimension range."

    # Calculate P = H @ V (for each batch and token)
    # H has shape (B, N, D) and V has shape (D, I), so the result P will have shape (B, N, I)
    P = torch.matmul(H, V)  # This computes P efficiently in batch mode

    # Initialize loss
    loss = torch.tensor(0.0, device=device)

    # Compute the normalized L1 loss
    # We'll need to iterate over the dimension D for the interactions between elements of H.
    for i in range(start_dim, end_dim):
        print(i)
        for j in range(i + 1, end_dim):  # Only consider upper triangular part to avoid double-counting
            # Interaction term |h_i * h_j|
            interaction = torch.abs(H[:, :, i] * H[:, :, j])

            # Distance-weighted interaction: Multiply by |i - j| (the distance between indices)
            weighted_interaction = interaction * abs(i - j)

            # Sum over batches (B) and tokens (N) for the total weighted loss
            loss += weighted_interaction.sum()

    # Normalization factor: Sum of all interaction terms (i.e., without the distance weighting)
    sum_interactions = 0
    for i in range(start_dim, end_dim):
        for j in range(i + 1, end_dim):  # Again, only consider the upper triangular part
            interaction = torch.abs(H[:, :, i] * H[:, :, j])
            sum_interactions += interaction.sum()

    # Normalize the loss
    if sum_interactions > 0:
        loss = loss / sum_interactions

    return loss

normalized_L1_loss(hidden_state_pre_mlp, vmat_weight.T, dim_range = (2900, 3072))

In [None]:
hidden_state_pre_mlp.shape

In [None]:
def normalized_L1_loss_memory_efficient(H, V):
    """
    Calculate the normalized L1 loss for a batch of input matrices H (B x N x D)
    and a learnable weight matrix V (D x I) using a more memory-efficient approach.
    
    Arguments:
    H -- Input matrix of shape (B x N x D)
    V -- Learnable weight matrix of shape (D x I)
    
    Returns:
    Loss value -- Scalar representing the normalized L1 loss
    """
    B, N, D = H.shape  # B = batch size, N = token length, D = dimension
    I = V.shape[1]     # I = number of outputs (same as number of columns in V)

    # Precompute the distance matrix for dimensions D x D
    distance_matrix = torch.abs(torch.arange(D).unsqueeze(0) - torch.arange(D).unsqueeze(1)).to(H.device)  # Shape (D, D)

    # (1) Compute the matrix product H[b, n, d] * V[d, i] for each b, n, and i
    # Shape of intermediate result: (B, N, D, I)
    HV = torch.einsum('bnd,di->bndi', H, V)  # Shape (B, N, D, I)
    
    # (2) Compute pairwise interactions between different dimensions d1 and d2:
    # HV[b, n, d1, i] * HV[b, n, d2, i] (for all d1, d2 pairs)
    # Shape of result: (B, N, D, D, I)
    interaction_matrix = HV.unsqueeze(3) * HV.unsqueeze(2)  # Shape (B, N, D, D, I)
    
    # (3) Apply the distance matrix to weigh interactions
    distance_weighted_interactions = interaction_matrix * distance_matrix.unsqueeze(-1)  # Shape (B, N, D, D, I)
    
    # (4) Sum over the D x D interaction matrix to compute the raw penalty
    raw_penalty = torch.sum(distance_weighted_interactions, dim=(2, 3))  # Shape (B, N, I)
    
    # (5) Sum over interactions (without distance weighting) to compute sum of interactions
    sum_interactions = torch.sum(interaction_matrix, dim=(2, 3))  # Shape (B, N, I)
    
    # (6) Avoid division by zero by clamping sum_interactions
    sum_interactions = torch.clamp(sum_interactions, min=1e-8)
    
    # (7) Normalize the raw penalty by the sum of interactions
    normalized_penalty = raw_penalty / sum_interactions  # Shape (B, N, I)
    
    # (8) Compute the total loss: sum over batches, tokens, and output dimensions
    loss = torch.mean(normalized_penalty)
    
    return loss

    
normalized_L1_loss_memory_efficient(hidden_state_pre_mlp, vmat_weight.T)

In [None]:
def calculate_memory_usage(H, V):
    """
    Calculate the memory usage for tensors in the normalized L1 loss computation.
    """
    B, N, D = H.shape  # B = batch size, N = token length, D = dimension
    I = V.shape[1]     # I = number of outputs

    # Memory for H and V
    H_mem = H.numel() * H.element_size()  # Number of elements * bytes per element
    V_mem = V.numel() * V.element_size()

    # Memory for HV: (B, N, D, I)
    HV_mem = B * N * D * I * H.element_size()

    # Memory for interaction_matrix: (B, N, D, D, I)
    interaction_matrix_mem = B * N * D * D * I * H.element_size()

    # Memory for distance_weighted_interactions: (B, N, D, D, I)
    distance_weighted_mem = B * N * D * D * I * H.element_size()

    # Memory for raw_penalty and sum_interactions: both are (B, N, I)
    raw_penalty_mem = B * N * I * H.element_size()
    sum_interactions_mem = B * N * I * H.element_size()

    # Total memory used
    total_memory = (H_mem + V_mem + HV_mem + interaction_matrix_mem +
                    distance_weighted_mem + raw_penalty_mem + sum_interactions_mem)

    # Convert to megabytes (MB)
    total_memory_MB = total_memory / (1024 ** 2)  # Convert from bytes to MB

    return total_memory_MB

calculate_memory_usage(hidden_state_pre_mlp, vmat_weight.T)

In [None]:
def normalized_L1_loss_chunked(H, V, chunk_size=64):
    """
    Calculate the normalized L1 loss for a batch of input matrices H (B x N x D)
    and a learnable weight matrix V (D x I), using chunking to reduce memory usage.
    
    Arguments:
    H -- Input matrix of shape (B x N x D)
    V -- Learnable weight matrix of shape (D x I)
    chunk_size -- The number of dimensions to process at once to save memory.
    
    Returns:
    Loss value -- Scalar representing the normalized L1 loss
    """
    B, N, D = H.shape  # B = batch size, N = token length, D = dimension
    I = V.shape[1]     # I = number of outputs (same as number of columns in V)

    # Precompute the distance matrix for dimensions D x D
    distance_matrix = torch.abs(torch.arange(D).unsqueeze(0) - torch.arange(D).unsqueeze(1)).to(H.device)  # Shape (D, D)

    total_loss = 0.0  # Accumulate loss across batches and tokens

    # Chunking the dimension D
    for d_start in range(0, D, chunk_size):
        print(d_start)
        d_end = min(d_start + chunk_size, D)

        # Extract the chunk of H and V
        H_chunk = H[:, :, d_start:d_end]  # Shape (B, N, chunk_size)
        V_chunk = V[d_start:d_end, :]     # Shape (chunk_size, I)
        distance_chunk = distance_matrix[d_start:d_end, d_start:d_end]  # Shape (chunk_size, chunk_size)

        # (1) Compute the matrix product H[b, n, d] * V[d, i] for each b, n, and i
        HV_chunk = torch.einsum('bnd,di->bndi', H_chunk, V_chunk)  # Shape (B, N, chunk_size, I)

        # (2) Compute pairwise interactions between different dimensions within the chunk
        interaction_matrix = HV_chunk.unsqueeze(3) * HV_chunk.unsqueeze(2)  # Shape (B, N, chunk_size, chunk_size, I)

        # (3) Apply the distance matrix to weigh interactions
        distance_weighted_interactions = interaction_matrix * distance_chunk.unsqueeze(-1)  # Shape (B, N, chunk_size, chunk_size, I)

        # (4) Sum over the chunked interaction matrix
        raw_penalty = torch.sum(distance_weighted_interactions, dim=(2, 3))  # Shape (B, N, I)
        sum_interactions = torch.sum(interaction_matrix, dim=(2, 3))  # Shape (B, N, I)

        # Avoid division by zero
        sum_interactions = torch.clamp(sum_interactions, min=1e-8)
        
        # Compute the normalized penalty
        normalized_penalty = raw_penalty / sum_interactions

        # Accumulate loss
        total_loss += normalized_penalty.sum()  # Sum over all output dimensions I

    # Return the averaged loss across batches, tokens, and outputs
    return total_loss / (B * N * I)

normalized_L1_loss_chunked(hidden_state_pre_mlp, vmat_weight.T)

In [None]:
check_memory_usage()