In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import json
import sys
import gc
from typing import Dict, List, Union

project_root = os.path.abspath(os.path.join(os.getcwd()))
if project_root not in sys.path:
    sys.path.append(project_root)

from transformers import AutoModel, AutoTokenizer, PretrainedConfig, AutoConfig
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from nnsight import LanguageModel

from data.utils import get_entity_idx, get_prompts
from mi_toolbox.utils.data_types import DataDict
from mi_toolbox.utils.collate import TensorCollator
from mi_toolbox.transformer_caching import caching_wrapper, decompose_attention_to_neuron, decompose_glu_to_neuron, TransformerCache
from mi_toolbox.causal_tracing import get_mass_mean_vectors
from mi_toolbox.contribution_tracing import get_dot_prod_contribution, get_top_x_contribution_values


model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_ids[0])
config = AutoConfig.from_pretrained(model_ids[0])

### Get Concept Vectors

In [2]:
sample_id = 0
num_layers = 10
data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'

#TODO incorporate into TransformerCache -> also init from pretrained (AutoConfig)
data_keys = ['emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['mid', 'post']]
data_dir_files = os.listdir(data_dir)
target_token_res_states = {}
for key in data_keys:
    if f"{key}.json" in data_dir_files:
        file_path = os.path.join(data_dir, f"{key}.json")
        with open(file_path, 'r') as f:
            target_token_res_states[key] = json.load(f)
    if f"{key}.safetensors" in data_dir_files:
        file_path  = os.path.join(data_dir, f"{key}.safetensors")
        target_token_res_states[key] = torch.stack(torch.load(file_path))

target_token_res_states = TransformerCache.from_dict(target_token_res_states, model_config=config)


In [None]:
concept_vectors = {}
no_layers = 10
num_homographs = 1
num_examples = 5

concept_vectors = {}

emb_pairs = target_token_res_states[f'emb'].view(num_homographs, 2, num_examples, -1)
concept_vectors['emb'] = get_mass_mean_vectors(emb_pairs)

for layer in range(no_layers):
    mid_pairs = target_token_res_states[f'{layer}.mid'].view(num_homographs, 2, num_examples, -1)
    post_pairs = target_token_res_states[f'{layer}.post'].view(num_homographs, 2, num_examples, -1)
    
    concept_vectors[f'{layer}.mid'] = get_mass_mean_vectors(mid_pairs)
    concept_vectors[f'{layer}.post'] = get_mass_mean_vectors(post_pairs)

del target_token_res_states
gc.collect()

207

: 

### Get direct contributions

In [None]:
data_keys = ['ent_pos_idx', 'emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
data_dir_files = os.listdir(data_dir)
model_cache = {}
for key in data_keys:
    if f"{key}.json" in data_dir_files:
        file_path = os.path.join(data_dir, f"{key}.json")
        with open(file_path, 'r') as f:
            model_cache[key] = json.load(f)
    if f"{key}.safetensors" in data_dir_files:
        file_path  = os.path.join(data_dir, f"{key}.safetensors")
        model_cache[key] = torch.load(file_path)

model_cache = TransformerCache.from_dict(model_cache, model_config=config)
data_keys = [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
model_cache.map(
    fn = lambda row_id, row: {key:row[key][row['ent_pos_idx']] for key in data_keys}
)
model_cache.stack_tensors(padding=True)

num_examples = model_cache.length
num_tokens = model_cache['0.d_attn'].size(1)
target_token_pos = model_cache['ent_pos_idx']
parts = [model_cache['emb'][:, None]]

for layer in range(no_layers):
    parts.append(model_cache[f'{layer}.d_attn'].flatten(1, -2))
    parts.append(model_cache[f"{layer}.d_mlp"])
parts = torch.concat(parts, dim=1).transpose(0,1)

del model_cache
gc.collect()

In [None]:
def get_module_list(llm, target_layer, num_tokens):
    num_attn_heads, attn_head_dim, num_mlp_heads = llm.config.num_attention_heads, llm.config.head_dim, llm.config.intermediate_size
    module_list = ['emb']
    for layer_id in range(target_layer + 1):
        module_list.extend([f"{layer_id}.attn.{token_pos}.{head_pos}" for token_pos in range(num_tokens) for head_pos in range(num_attn_heads * attn_head_dim)])
        module_list.extend([f"{layer_id}.mlp.-1.{head_pos}" for head_pos in range(num_mlp_heads)])
    
    return module_list

def get_first_module_group_id_lookup(module_list):
    first_module_group_id_lookup = {}
    prev_layer, prev_module_type = 0, ''
    for i, module in enumerate(module_list):
        if module == 'emb': continue
        
        curr_layer, curr_module_type, _, _ = module.split('.')
        if not first_module_group_id_lookup or \
            curr_layer != prev_layer or curr_module_type != prev_module_type:
            first_module_group_id_lookup[f"{curr_layer}.{curr_module_type}"] = i
        
        prev_layer, prev_module_type = curr_layer, curr_module_type
        
    return first_module_group_id_lookup

def get_initial_tasks(concept_vectors, target_module, target_layer, num_examples):
    tasks = DataDict(length=num_examples)
        
    target_vectors = concept_vectors[target_module]
    target_vectors = torch.flatten(target_vectors[:, None].repeat(1, 10, 1) * torch.tensor([[1]] * 5 + [[-1]] * 5), 0, 1)
    tasks.attach('target_vectors', target_vectors, force= True)
    
    tasks.attach('target_sample_id', range(num_examples), force= True)
    tasks.attach('target_contribution_factor', torch.tensor([1] * num_examples), force= True)
    tasks.attach('target_layer', [target_layer] * num_examples, force= True)
    tasks.attach('target_module', [target_module] * num_examples, force= True)
    tasks.attach('depth', [[0]] * num_examples, force= True)

    return tasks

In [None]:
target_module = '9.post'
target_layer = 10
contribution_cache = []
batch_size = 16

llm = AutoModel.from_pretrained(
    model_ids[0],
    trust_remote_code=True,
    device_map='auto',
    dtype=torch.bfloat16,
    attn_implementation = 'eager'
)
llm.layers[target_layer:].to('meta')
gc.collect()
torch.cuda.empty_cache()

module_list = get_module_list(llm, target_layer, num_tokens)
first_module_group_id_lookup = get_first_module_group_id_lookup(module_list)
curr_tasks = get_initial_tasks(concept_vectors, target_module, target_layer, num_examples)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
collate_fn = TensorCollator()
dl = DataLoader(curr_tasks, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
num_batches = len(dl)
direct_contributions = DataDict()

for batch in dl:
    target_vectors = batch['target_vectors']
    sample_id = batch['target_sample_id']
    target_layer = batch['target_layer']
    
    contributions = get_dot_prod_contribution(parts = parts, whole= target_vectors).transpose(0, 1)
    top_x_contributions = get_top_x_contribution_values(contributions, 0.9) 

    contribution_idx = [el.nonzero(as_tuple=True)[0] for el in top_x_contributions]
    contribution_values = [top_x_contributions[i][idx] for i, idx in enumerate(contribution_idx)]
    contribution_modules = [[module_list[i] for i in idx] for idx in contribution_idx]
    
    direct_contributions.extend({
        'contribution_idx': contribution_idx,
        'contribution_values':contribution_values,
        'contribution_modules':contribution_modules,
        'target_sample_id': sample_id,
        'target_token_pos': target_token_pos,
        'target_module': batch['target_module'],
        'depth': batch['depth'],
        })
    
del parts
gc.collect()

1361

### Get indirect contributions

In [None]:
@torch.no_grad()
def get_target_vectors_and_postions(llm, cache_item, model_cache):
    contribution_modules = cache_item['contribution_modules']
    contribution_values = cache_item['contribution_values']
    sample_id = cache_item['target_sample_id']
    source_token_pos = cache_item['target_token_pos']
    eps = 1e-7

    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    hidden_dim = llm.config.hidden_size

    new_target_vectors = []
    new_target_token_pos = []
    new_target_module = []
    new_target_layer = []
    new_target_contribution_factor = []
    valid_contribution_idx = []

    for i,( module_desc, cont_value) in enumerate(zip(contribution_modules, contribution_values)):
        if module_desc == 'emb':continue
        
        layer_id, module_type, target_token_pos, head_pos = module_desc.split('.')
        layer_id, head_pos, target_token_pos = int(layer_id), int(head_pos), int(target_token_pos)
        
        if target_token_pos == -1:
            target_token_pos = source_token_pos
        
        new_target_token_pos.append(target_token_pos)
        new_target_module.append(f"{layer_id}.{module_type}")
        new_target_layer.append(layer_id)
        valid_contribution_idx.append(i)
        
        if module_type == 'attn':
            v_proj_W = llm.layers[layer_id].self_attn.v_proj.weight.data.view(num_k_v_heads, head_dim, hidden_dim)[head_pos//head_dim//num_head_groups, head_pos % head_dim]
            norm = llm.layers[layer_id].input_layernorm.weight.data
            
            v_proj = model_cache[f"{layer_id}.v_proj"][sample_id, target_token_pos].view(num_k_v_heads, head_dim)[head_pos//head_dim//num_head_groups, head_pos % head_dim]
            norm_var = model_cache[f"{layer_id}.attn_norm_var"][sample_id, source_token_pos] 
            
            target_vector = v_proj_W * norm / norm_var.cuda()
            target_contribution_factor = (cont_value / v_proj.cuda())
            new_target_contribution_factor.append(target_contribution_factor.cpu())
            
        elif module_type == 'mlp':
            up_proj_W = llm.layers[layer_id].mlp.up_proj.weight.data[head_pos]
            norm = llm.layers[layer_id].post_attention_layernorm.weight.data
        
            up_proj =  model_cache[f"{layer_id}.up_proj"][sample_id, target_token_pos][head_pos]
            norm_var = model_cache[f"{layer_id}.mlp_norm_var"][sample_id, source_token_pos]
            
            target_vector = up_proj_W * norm / norm_var.cuda()
            target_contribution_factor = (cont_value / up_proj.cuda())
            new_target_contribution_factor.append(target_contribution_factor.cpu())
        
        new_target_vectors.append(target_vector.cpu())
        
    new_target_vectors = torch.stack(new_target_vectors)

    return new_target_vectors, new_target_token_pos, new_target_layer, new_target_module, new_target_contribution_factor, valid_contribution_idx

@torch.no_grad()
def trache_through_layer_norm(target: torch.Tensor, norm_W troch.Tensor, norm_var: torch.Tensor):
    

def get_current_tasks(tasks, contribution_cache):

    for cache_item in contribution_cache.to_list():
        if not cache_item['contribution_modules'] or torch.all(cache_item['contribution_idx'] == 0): continue
        new_target_vectors, new_target_token_pos, new_target_layer, new_target_module, new_target_contribution_factor, valid_contribution_idx = get_target_vectors_and_postions(llm, cache_item, model_cache)

        num_tasks = len(new_target_vectors)
        source_idx = [cache_item['contribution_idx'][i] for i in valid_contribution_idx]
        depth = [depth_id+1 for depth_id in cache_item['depth']] # increase all depth ids of previous contribution for one
        
        tasks.extend({
            'source_idx': source_idx,
            'target_module': new_target_module,
            'target_vectors':new_target_vectors,
            'target_contribution_factor': new_target_contribution_factor,
            'target_layer': new_target_layer,
            'target_sample_id': [cache_item['target_sample_id']] * num_tasks,
            'target_token_pos':new_target_token_pos,
            'depth': [depth] * num_tasks
        })
    tasks = merge_tasks(tasks) # create depth list with all depth of querying tokens
    curr_tasks, tasks = get_next_curr_tasks(tasks)
    

    return curr_tasks, tasks

def merge_tasks(tasks):
    tasks.sort(by='target_sample_id')
    tasks.sort(by='target_token_pos')
    tasks.sort(by='source_idx', descending=True)

    merged_tasks = DataDict()

    prev_task = tasks[0]
    contribution_factor_sum = torch.tensor(0, dtype=prev_task['target_contribution_factor'].dtype)
    acc_depth_list = []

    for task in tasks.to_list():
        if task['target_sample_id'] == prev_task['target_sample_id'] \
            and task['source_idx'] == prev_task['source_idx'] \
            and task['target_token_pos'] == prev_task['target_token_pos']:
            
            contribution_factor_sum += task['target_contribution_factor']
            acc_depth_list += task['depth']
            continue
        
        merged_tasks.append(
            prev_task | {'target_contribution_factor': contribution_factor_sum, 'depth':sorted(acc_depth_list)}
        )
        prev_task = task
        contribution_factor_sum = task['target_contribution_factor']

    merged_tasks.append(
        prev_task | {'target_contribution_factor': contribution_factor_sum, 'depth':sorted(acc_depth_list)}
    )      

    return merged_tasks

def get_next_curr_tasks(tasks):
    tasks.sort(by='target_module', descending=True)
    next_module = tasks[0]['target_module']
    end_idx = next((i for i, module in enumerate(tasks['target_module']) if module != next_module), None)

    if not end_idx:
        return tasks, DataDict()
    
    curr_tasks = tasks[:end_idx]
    tasks = tasks[end_idx:]
    
    return curr_tasks, tasks

torch.Size([343041, 10, 2560])