## Task List

* Create Graph based version
    - Each module is a node
    - Each contribution is an edge
    - Update Node with for ecach incomming (and outgoing) edge
* Change to bi-directional batching
    - Start with last token -> load parts -> calculate contribution for each layer -> repeat for previous token until no open nodes are left.
    - Less memory intensive (Only load cache for one token)

In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import json
import sys
import gc
from typing import Dict, List, Union

project_root = os.path.abspath(os.path.join(os.getcwd()))
if project_root not in sys.path:
    sys.path.append(project_root)

from transformers import AutoModel, AutoTokenizer, PretrainedConfig, AutoConfig
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from nnsight import LanguageModel
import pandas as pd


from data.utils import get_entity_idx, get_prompts
from mi_toolbox.utils.data_types import DataDict
from mi_toolbox.utils.collate import TensorCollator
from mi_toolbox.transformer_caching import caching_wrapper, decompose_attention_to_neuron, decompose_glu_to_neuron, TransformerCache
from mi_toolbox.causal_tracing import get_mass_mean_vectors
from mi_toolbox.contribution_tracing import get_dot_prod_contribution, get_top_x_contribution_values


model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_ids[0])

### Get Concept Vectors

### Get direct contributions

In [2]:
def get_module_lookup(model_config: PretrainedConfig, target_layer: int, num_tokens: int) -> List[str]:
    num_attn_heads, attn_head_dim, num_mlp_heads = model_config.num_attention_heads, model_config.head_dim, model_config.intermediate_size
    module_list = ['emb']

    for layer_id in range(target_layer):
        module_list.extend([f"{layer_id}.attn.{token_pos}.{head_pos}" for token_pos in range(num_tokens) for head_pos in range(num_attn_heads * attn_head_dim)])
        module_list.extend([f"{layer_id}.mlp.-1.{head_pos}" for head_pos in range(num_mlp_heads)])
    
    return module_list

def get_first_module_group_id_lookup(module_lookup: List[str]) -> Dict[str, int]:
    first_module_group_id_lookup = {}
    prev_layer, prev_module_type = 0, ''
    for i, module in enumerate(module_lookup):
        if module == 'emb': continue
        
        curr_layer, curr_module_type, _, _ = module.split('.')
        if not first_module_group_id_lookup or \
            curr_layer != prev_layer or curr_module_type != prev_module_type:
            first_module_group_id_lookup[f"{curr_layer}.{curr_module_type}"] = i
        
        prev_layer, prev_module_type = curr_layer, curr_module_type
        
    return first_module_group_id_lookup

def get_initial_tasks(concept_vectors: torch.Tensor, target_module: str, num_examples: int, target_token_pos: List[int]) -> DataDict:
    tasks = DataDict(length=num_examples)
        
    target_vectors = concept_vectors[target_module]
    target_vectors = torch.flatten(target_vectors[:, None].repeat(1, 10, 1) * torch.tensor([[1]] * 5 + [[-1]] * 5), 0, 1)
    tasks.attach('target_idx', [None] * num_examples)
    tasks.attach('target_vectors', target_vectors, force= True)
    tasks.attach('target_contribution_factor', torch.tensor([1] * num_examples), force= True)
    tasks.attach('target_sample_ids', range(num_examples), force= True)
    tasks.attach('target_token_pos', target_token_pos, force= True)
    tasks.attach('depth', [[0]] * num_examples, force= True)

    return tasks

def mean_shift(tensor: torch.Tensor) -> torch.Tensor:
    return tensor - tensor.mean(dim=-1, keepdim=True)

In [3]:
from contextlib import contextmanager
from typing import Generator, Tuple, Iterator, Union, List, Dict
from transformers import PreTrainedModel, PretrainedConfig
from torch.utils.data import DataLoader
import time

@contextmanager
def load_language_model(
    model_id: Union[str, List[str]],
    **kwargs
) -> Generator[Tuple[PreTrainedModel, PretrainedConfig], None, None]:
    llm = AutoModel.from_pretrained(
        model_id,
        **kwargs
    )
    try:
        yield llm, llm.config
    finally:
        llm.cpu()
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

def batch_iterator(dl: DataLoader) -> Iterator[Dict, ]:
    for batch_id, batch in enumerate(dl):
        start_time = time.perf_counter()
        try:
            yield batch
        finally:
            end_time = time.perf_counter()
            print(f"Batch {batch_id + 1}/{len(dl)}: {(end_time - start_time):.2f} seconds")

In [4]:
def get_contribution_ceiling_id(target_module: str, module_lookup: List[str]):
    module_type = '.'.join(target_module.split('.')[:2])
    first_module_id = next(idx for idx, module in enumerate(module_lookup) if module.startswith(module_type))
    return first_module_id

def get_batch_parts(batch: Dict, parts_cache: torch.Tensor, contribution_ceiling_id: int):
    target_sample_ids = batch['target_sample_ids']
    target_token_pos = batch['target_token_pos']

    batch_parts = parts_cache[target_sample_ids, target_token_pos, :contribution_ceiling_id].transpose(0, 1)
    return batch_parts

def get_contributions(batch: Dict, batch_parts: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    target_vectors = batch['target_vectors']
    target_contribution_factor = batch['target_contribution_factor']

    contributions = get_dot_prod_contribution(parts=batch_parts, whole=target_vectors).transpose(0, 1)
    scaled_contributions = contributions * target_contribution_factor[:, None]

    top_x_contributions = get_top_x_contribution_values(scaled_contributions, 0.9) 

    contribution_idx = [el.nonzero(as_tuple=True)[0] for el in top_x_contributions]
    contribution_values = [top_x_contributions[i][idx] for i, idx in enumerate(contribution_idx)]

    return contribution_idx, contribution_values

def get_direct_contributions(direct_parts: torch.Tensor, concept_vectors: torch.Tensor, module_lookup: List[str], target_module: str, target_token_pos: List[int], batch_size: int = 16) -> DataDict:
    num_examples = len(target_token_pos)
    active_tasks = get_initial_tasks(concept_vectors, target_module, num_examples, target_token_pos)

    dl = DataLoader(active_tasks, batch_size=batch_size, collate_fn=TensorCollator(), shuffle=False)
    direct_contributions = DataDict()

    for batch in batch_iterator(dl):
        contribution_idx, contribution_values = get_contributions(batch, direct_parts)
        contribution_modules = [[module_lookup[id] for id in idx] for idx in contribution_idx]

        direct_contributions.extend({
            'contribution_idx': contribution_idx,
            'contribution_values':contribution_values,
            'contribution_modules': contribution_modules,
            'target_sample_ids': batch['target_sample_ids'],
            'target_token_pos': batch['target_token_pos'],
            'target_idx': batch['target_idx'],
            'depth': batch['depth'],
            })
            
    return direct_contributions


In [5]:
from tqdm import tqdm
def update_task_queue(task_queue: DataDict, new_contributions: DataDict) -> DataDict:
    for cache_item in new_contributions.to_list():
        contribution_idx = cache_item['contribution_idx']

        non_emb_contribution_ids = [i for i, idx in enumerate(contribution_idx) if idx != 0]
        num_new_tasks = len(non_emb_contribution_ids)

        if not num_new_tasks: continue

        non_emb_contribution_idx = [contribution_idx[i] for i in non_emb_contribution_ids]
        non_emb_contribution_values = [cache_item['contribution_values'][i] for i in non_emb_contribution_ids]
        non_emb_target_modules = [cache_item['contribution_modules'][i] for i in non_emb_contribution_ids]

        task_queue.extend({
            'target_idx': non_emb_contribution_idx,
            'target_values': non_emb_contribution_values,
            'target_modules': non_emb_target_modules,
            'target_sample_ids': [cache_item['target_sample_ids']] * num_new_tasks,
            'source_token_pos': [cache_item['target_token_pos']] * num_new_tasks,
            'depth': [[depth + 1 for depth in cache_item['depth']]] * num_new_tasks
        })
    task_queue = merge_task_queue_items(task_queue)

    return task_queue

def merge_task_queue_items(task_queue: DataDict) -> DataDict:
    task_queue.sort(by='target_sample_ids')
    task_queue.sort(by='target_idx', descending=True)

    prev_task = task_queue[0]
    merge_idx = []
    value_accumulator = [[]]
    acc_depth_list = [[]]

    for i, task in enumerate(task_queue.to_list()):
        if task['target_sample_ids'] == prev_task['target_sample_ids'] \
            and task['target_idx'] == prev_task['target_idx']:
            
            value_accumulator[-1].append(task['target_values'])
            acc_depth_list[-1].append(task['depth'][0])
            continue
        
        merge_idx.append(i - 1)
        prev_task = task
        value_accumulator.append([task['target_values']])
        acc_depth_list.append([task['depth'][0]])
    merge_idx.append(i - 1)

    merged_tasks = task_queue[merge_idx]
    merged_tasks['target_values'] = [sum(task_values) for task_values in value_accumulator]
    merged_tasks['depth'] = [sorted(acc_depth) for acc_depth in acc_depth_list]

    return merged_tasks

@torch.no_grad()
def get_active_tasks(llm: PreTrainedModel, task_queue: DataDict, projection_cache: Dict) -> Tuple[DataDict, DataDict]:
    task_queue.sort(by='target_modules', descending=True)
    next_module_type = '.'.join(task_queue[0]['target_modules'].split('.')[:2])
    end_idx = next((i for i, module in enumerate(task_queue['target_modules']) if not module.startswith(next_module_type)), None)
    
    if not end_idx:
        new_active_tasks = task_queue
        new_task_queue = DataDict.from_dict({k:[None] for k in task_queue})
    else:
        new_active_tasks = task_queue[:end_idx]
        new_task_queue = task_queue[end_idx:]
    
    target_token_pos = get_target_token_pos(new_active_tasks)
    new_active_tasks.attach('target_token_pos', target_token_pos)

    target_vectors = get_target_vectors(llm, new_active_tasks, projection_cache)
    target_contribution_factor = get_contribution_factor(llm, new_active_tasks, projection_cache)

    new_active_tasks.attach('target_vectors', target_vectors)
    new_active_tasks.attach('target_contribution_factor', target_contribution_factor)

    return new_active_tasks, new_task_queue

def get_target_vectors(llm: PreTrainedModel, tasks: DataDict, projection_cache: Dict) -> List[torch.Tensor]:

    target_layer, module_type = tasks[0]['target_modules'].split('.')[:2]
    target_heads = [int(module.split('.')[-1]) for module in tasks['target_modules']]

    sample_ids = tasks['target_sample_ids']
    token_pos = tasks['target_token_pos']
    if module_type == 'attn':
        target_vectors = get_attn_target_vectors(llm, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return target_vectors
    if module_type == 'mlp':
        target_vectors = get_mlp_target_vectors(llm, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return target_vectors
    raise ValueError(f"Unexpected module type: {module_type}")

def get_attn_target_vectors(llm: PreTrainedModel, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    hidden_dim = llm.config.hidden_size

    k_v_head_pos = torch.tensor(target_heads) // head_dim // num_head_groups
    neuron_pos = torch.tensor(target_heads) % head_dim

    v_proj_W = llm.layers[target_layer].self_attn.v_proj.weight.data.view(num_k_v_heads, head_dim, hidden_dim)[(k_v_head_pos, neuron_pos)]
    norm_W = llm.layers[target_layer].input_layernorm.weight.data
    norm_var = projection_cache[f"{target_layer}.attn_norm_var"][(sample_ids, token_pos)]

    target_vectors = trace_through_layer_norm(v_proj_W, norm_W, norm_var)

    return target_vectors.cpu()

def get_mlp_target_vectors(llm: PreTrainedModel, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    up_proj_W = llm.layers[target_layer].mlp.up_proj.weight.data[target_heads]
    norm_W = llm.layers[target_layer].post_attention_layernorm.weight.data
    norm_var = projection_cache[f"{target_layer}.mlp_norm_var"][sample_ids, token_pos]

    target_vectors = trace_through_layer_norm(up_proj_W, norm_W, norm_var)

    return target_vectors.cpu()

@torch.no_grad()
def trace_through_layer_norm(target: torch.Tensor, norm_W: torch.Tensor, norm_var: torch.Tensor, device: Union[str, torch.device] = 'cuda') -> torch.Tensor:
    input_device = target.device
    if not device:
        device = input_device
    
    var_scaled_target = target.to(device) * norm_var[:, None].to(device)
    norm_w_scaled_target = var_scaled_target / norm_W.to(device)

    return norm_w_scaled_target.to(input_device)

def get_contribution_factor(llm: PreTrainedModel, tasks: DataDict, projection_cache: Dict):
    target_layer, module_type = tasks[0]['target_modules'].split('.')[:2]
    target_heads = [int(module.split('.')[-1]) for module in tasks['target_modules']]

    sample_ids = tasks['target_sample_ids']
    target_values = torch.tensor(tasks['target_values'])
    token_pos = tasks['target_token_pos']
    if module_type == 'attn':
        contribution_factor = get_attn_contribution_factor(llm, target_values, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return contribution_factor
    if module_type == 'mlp':
        contribution_factor = get_mlp_contribution_factor(target_values, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return contribution_factor
    raise ValueError(f"Unexpected module type: {module_type}")

def get_attn_contribution_factor(llm: PreTrainedModel, target_values: torch.Tensor, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    num_targets = len(sample_ids)

    k_v_head_pos = torch.tensor(target_heads) // head_dim // num_head_groups
    neuron_pos = torch.tensor(target_heads) % head_dim

    v_proj = projection_cache[f"{target_layer}.v_proj"][sample_ids, token_pos].view(num_targets, num_k_v_heads, head_dim)[range(num_targets), k_v_head_pos, neuron_pos]
    contribution_factor = target_values / v_proj
    print(contribution_factor, target_values, v_proj)

    return contribution_factor

def get_mlp_contribution_factor(target_values: torch.Tensor, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    up_proj =  projection_cache[f"{target_layer}.up_proj"][sample_ids, token_pos, target_heads]
    contribution_factor = target_values / up_proj
    print(contribution_factor, target_values, up_proj)

    return contribution_factor

def get_target_token_pos(tasks: DataDict):
    module_token_pos = torch.tensor([int(mod.split('.')[-2]) for mod in tasks['target_modules']])
    source_token_pos = torch.tensor(tasks['source_token_pos'])
    target_token_pos = torch.where(module_token_pos != -1, module_token_pos, source_token_pos).tolist()

    return target_token_pos

In [6]:

def flatten_contributions(contributions: DataDict) -> DataDict:
    flat_contributions = DataDict()
    if not contributions.length: return flat_contributions
    for cache_item in contributions.to_list():
        num_contributions = len(cache_item['contribution_idx'])

        flat_contributions.extend({
            'contribution_idx': cache_item['contribution_idx'].tolist(),
            'contribution_values': cache_item['contribution_values'].tolist(),
            'contribution_modules': cache_item['contribution_modules'],
            'sample_ids': [cache_item['target_sample_ids']] * num_contributions,
            'source_idx': [cache_item['target_idx'].item() if cache_item['target_idx'] else None] * num_contributions,
            'source_token_pos': [cache_item['target_token_pos']] * num_contributions,
            'depth': [cache_item['depth']] * num_contributions
        })
    return flat_contributions

In [7]:
def main(sample_id):
    no_layers = 10
    data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'

    data_keys = ['ent_pos_idx', '9.post']
    target_token_res_states = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
    target_token_res_states.stack_tensors()

    num_homographs = 1
    num_examples = 5

    concept_vectors = {}
    post_pairs = target_token_res_states[f'9.post'].view(num_homographs, 2, num_examples, -1)
    concept_vectors[f'9.post'] = get_mass_mean_vectors(post_pairs)

    target_token_pos = target_token_res_states['ent_pos_idx']

    del target_token_res_states
    gc.collect()

    target_module = '9.post'
    target_layer = 10
    batch_size = 32

    config = AutoConfig.from_pretrained(model_ids[0])
    module_lookup = get_module_lookup(config, target_layer, num_tokens)

    direct_contributions = get_direct_contributions(direct_parts, concept_vectors, module_lookup, target_module, target_token_pos)

    del direct_parts
    gc.collect()


    # get projection cache
    data_keys = [f"{layer}.{module}" for layer in range(num_layers) for module in ['v_proj', 'attn_norm_var', 'up_proj', 'mlp_norm_var']]
    projection_cache = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
    projection_cache.stack_tensors(padding=True)


    # free memory 
    del model_cache
    gc.collect()

    task_queue = update_task_queue(DataDict(), direct_contributions)
    active_tasks, task_queue = get_active_tasks(llm, task_queue, projection_cache)


    with load_language_model(
        model_id=model_ids[0],
        trust_remote_code=True,
        device_map='auto',
        dtype=torch.bfloat16,
        attn_implementation = 'eager'
    ) as (llm, config):
        
        target_module = '9.post'
        target_layer = 10
        batch_size = 32

        indirect_contributions = []
        while active_tasks.length:
            print(active_tasks['target_modules'][0])
            indirect_contributions.append(DataDict())
            dl = DataLoader(active_tasks, batch_size=batch_size, collate_fn=TensorCollator(), shuffle=False)
            
            for batch in batch_iterator(dl):
                target_module = batch['target_modules'][0]
                contribution_ceiling_id = get_contribution_ceiling_id(target_module, module_lookup)
                batch_parts = get_batch_parts(batch, indirect_parts_cache, contribution_ceiling_id)
        
                contribution_idx, contribution_values = get_contributions(batch, batch_parts)
                contribution_modules = [[module_lookup[id] for id in idx] for idx in contribution_idx]
                indirect_contributions[-1].extend({
                    'contribution_idx': contribution_idx,
                    'contribution_values':contribution_values,
                    'contribution_modules': contribution_modules,
                    'target_sample_ids': batch['target_sample_ids'],
                    'target_token_pos': batch['target_token_pos'],
                    'target_modules': batch['target_modules'],
                    'target_idx': batch['target_idx'],
                    'depth': batch['depth'],
                    })
            if target_module.startswith('0.attn'):
                break
            task_queue = update_task_queue(task_queue, indirect_contributions[-1])
            active_tasks, task_queue = get_active_tasks(llm, task_queue, projection_cache)

            direct_flat_contributions = flatten_contributions(direct_contributions)

    raise
    df = pd.DataFrame(data=direct_flat_contributions.to_dict())

    for i, contribution_cache in enumerate(indirect_contributions):
        flat_contributions = flatten_contributions(contribution_cache)
        if not flat_contributions.length: continue
        for batch_start in range(0, len(flat_contributions), 2000):
            contribution_df = pd.DataFrame(data=flat_contributions[batch_start: batch_start + 2000].to_dict())
            df = pd.concat([df, contribution_df])

    df['layer'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[0]) if x != 'emb' else 0)
    df['module_type'] = pd.Categorical(df['contribution_modules'].apply(lambda x: x.split('.')[1] if x != 'emb' else 'emb'), categories=['emb', 'attn', 'mlp'], ordered=True)
    df['token_pos'] = df.apply(lambda x: int(x['contribution_modules'].split('.')[2]) if x['contribution_modules'] != 'emb' and int(x['contribution_modules'].split('.')[2]) != -1  else x['source_token_pos'], axis=1)
    df['head_id'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[3]) if x != 'emb' else 0)
    df = df.drop('contribution_modules', axis=1)
        
    data_path = os.path.join(project_root, f'data/contribution_cache/cache_9_post_with_norm_hom_{sample_id}.parquet')
    df.to_parquet(data_path, index=False)
    df.head()

In [8]:
%%script false --no-raise-error

sample_id = 0
num_layers = 10
data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'
out_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/.cache/{sample_id}/'

# load model cache
data_keys = ['ent_pos_idx', 'full_emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
model_cache = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
model_cache.stack_tensors(padding=True)

num_examples = model_cache.length
num_tokens = model_cache['0.d_attn'].size(1)
target_token_pos = model_cache['ent_pos_idx']
batch_ent_idx = (range(model_cache.length), target_token_pos)

# get parts for direct contributions
direct_parts = [model_cache['full_emb'][batch_ent_idx][:, None]]

for layer in range(num_layers):
    direct_parts.append(model_cache[f'{layer}.d_attn'][batch_ent_idx].flatten(1, -2))
    direct_parts.append(model_cache[f"{layer}.d_mlp"][batch_ent_idx])
direct_parts = torch.concat(direct_parts, dim=1).transpose(0,1)




In [None]:
sample_id = 0
num_layers = 10
data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'
out_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/.cache/{sample_id}/'

# load model cache
data_keys = ['full_emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
model_cache = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
model_cache.stack_tensors(padding=True)
print('done loading')


done loading


: 

In [None]:
num_emb_cont = model_cache['full_emb'][:, :, None].shape[-2]
num_attn_cont = model_cache[f'0.d_attn'].flatten(2, -2).shape[-2]
num_mlp_cont = model_cache[f"0.d_mlp"].shape[-2]
total_cont = num_emb_cont + num_mlp_cont * num_layers + num_attn_cont * num_layers

num_example, num_token, _, model_dim = model_cache['full_emb'][:, :, None].shape

indirect_parts_cache = torch.empty((num_example, num_token, total_cont, model_dim))

curr_start = 0
# get indirect parts cache
indirect_parts_cache[:, :, curr_start:num_emb_cont] = mean_shift(model_cache['full_emb'])[:, :, None]
curr_start += num_emb_cont
del model_cache.data[f'full_emb']
gc.collect()

for layer in range(num_layers):
    indirect_parts_cache[:, :, curr_start: curr_start + num_attn_cont] = mean_shift(model_cache[f'{layer}.d_attn']).flatten(2, -2)
    curr_start += num_attn_cont
    del model_cache.data[f'{layer}.d_attn']
    gc.collect()

    indirect_parts_cache[:, :, curr_start: curr_start + num_mlp_cont] = mean_shift(model_cache[f"{layer}.d_mlp"])
    curr_start += num_mlp_cont
    del model_cache.data[f'{layer}.d_mlp']
    gc.collect()

print('done adding')

os.makedirs(out_dir, exist_ok=True)
torch.save(indirect_parts_cache, os.path.join(out_dir, 'indirect_parts.safetensor'))

RuntimeError: No active exception to reraise

In [None]:
raise 
main(0)

Batch 1/1: 2.27 seconds


In [None]:
# for sample_id in range(4, 5):
#     main(sample_id)
#     gc.collect()
#     torch.cuda.empty_cache()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/1: 1.63 seconds
9.mlp.-1.9705
Batch 1/5: 7.25 seconds
Batch 2/5: 9.65 seconds
Batch 3/5: 9.09 seconds
Batch 4/5: 9.47 seconds
Batch 5/5: 5.69 seconds
9.attn.3.2710
Batch 1/1: 5.35 seconds
8.mlp.-1.9668
Batch 1/6: 8.60 seconds
Batch 2/6: 8.60 seconds
Batch 3/6: 9.00 seconds
Batch 4/6: 9.06 seconds
Batch 5/6: 8.42 seconds
Batch 6/6: 2.71 seconds
8.attn.4.2033
Batch 1/1: 2.52 seconds
7.mlp.-1.9679
Batch 1/7: 6.39 seconds
Batch 2/7: 7.23 seconds
Batch 3/7: 7.41 seconds
Batch 4/7: 7.21 seconds
Batch 5/7: 7.29 seconds
Batch 6/7: 7.24 seconds
Batch 7/7: 4.82 seconds
7.attn.4.1487
Batch 1/1: 5.71 seconds
6.mlp.-1.9695
Batch 1/9: 6.14 seconds
Batch 2/9: 6.03 seconds
Batch 3/9: 6.40 seconds
Batch 4/9: 6.39 seconds
Batch 5/9: 6.26 seconds
Batch 6/9: 6.21 seconds
Batch 7/9: 6.00 seconds
Batch 8/9: 6.24 seconds
Batch 9/9: 3.98 seconds
6.attn.4.1278
Batch 1/10: 4.96 seconds
Batch 2/10: 5.79 seconds
Batch 3/10: 5.49 seconds
Batch 4/10: 5.50 seconds
Batch 5/10: 5.49 seconds
Batch 6/10: 7.21 se