In [None]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import torch
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
    
tool_root = '/home/dacslab/lassejantsch/master_thesis/mi_toolbox/'
if tool_root not in sys.path:
    sys.path.append(tool_root)

import gc
import pandas as pd
from torch.utils.data import DataLoader
import torch.nn.functional as F
from nnsight import LanguageModel
import json
import time
from typing import *
from transformers import AutoTokenizer, AutoModel

from data.utils import get_entity_indices, get_prompts,

# from tools.mixin_measures import get_dot_prod_contribution
# from tools.mixin_measures.decomposition import decompose_attention_to_neuron, decompose_glu_mlp
# from tools.mixin_measures.contribution import get_top_x_contribution_values
# from tools.mixin_measures.utils import distributed
# from tools.concept_detection import get_mass_mean_vectors
# from lib.custom_types import DataDict
# from lib.utils import max_pad_sequence

data_path = os.path.join(project_root, 'data/homograph_data/homograph_small.json')
with open(data_path) as f:
    data = json.load(f)  
    

ImportError: cannot import name 'get_entity_indices' from 'data.utils' (/home/dacslab/lasse_jantsch/concept_formation/data/utils.py)

In [2]:
data = data[:1]

## Extract residual contribution

In [3]:
model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_ids[0])

batch_size = 16

In [4]:
prompts = get_prompts(data, context_type='minimal_context')
ent_pos_idx = get_entity_indices(tokenizer, prompts)

def extract_collate_fn(batch):
    prompts, entity_pos_idx = zip(*batch)
    inputs = tokenizer(
        prompts, 
        return_tensors='pt', 
        padding=True,
    )
    batch_ent_pos_idx = (list(range(len(entity_pos_idx))), entity_pos_idx)
    return inputs | {'batch_ent_pos_idx': batch_ent_pos_idx}

extract_dl = DataLoader(list(zip(prompts, ent_pos_idx)), batch_size=batch_size, shuffle= False, collate_fn=extract_collate_fn)

In [5]:
big_cache = {}
layer_slice = slice(0,10)

for model_id in model_ids:
    try:
        llm = LanguageModel(
            model_id,
            trust_remote_code=True,
            device_map='auto',
            dtype=torch.bfloat16,
            attn_implementation = 'eager',
            dispatch=True
        )
        head_dim, num_attn_heads, num_k_v_heads = llm.config.head_dim, llm.config.num_attention_heads, llm.config.num_key_value_heads
        model_cache = DataDict() # TODO: create a special tensor cache dictionary that also holds config etc.
        num_batches = len(extract_dl)
        
        for batch_id, batch in enumerate(extract_dl):
            batch_ent_pos_idx = batch['batch_ent_pos_idx']
            batch_cache = {}
            batch_cache['attention_mask'] = batch['attention_mask']
            batch_cache['ent_pos_idx'] = batch_ent_pos_idx[1]
            start = time.process_time()
            try:
                with torch.no_grad(), llm.trace(batch) as tracer:
                    
                    emb = llm.model.embed_tokens.output
                    batch_cache['emb'] = emb[batch_ent_pos_idx].cpu().save()
                    batch_cache['full_emb'] = emb.cpu().save()
                    
                    for i, layer in enumerate(llm.model.layers[layer_slice]):
                        attn_norm_var = torch.var(layer.input, dim=-1)
                        
                        # decompose attention out
                        v_proj = layer.self_attn.v_proj.output
                        _, attn_weight = layer.self_attn.output
                        o_proj_WT = layer.self_attn.o_proj.weight.T
                        d_attn = decompose_attention_to_neuron(
                            attn_weight, 
                            v_proj, 
                            o_proj_WT,
                            num_attn_heads,
                            num_k_v_heads,
                            head_dim
                        ) 
                        
                        # extract mid residual state
                        mid = layer.post_attention_layernorm.input[batch_ent_pos_idx]
                        mlp_norm_var = torch.var(layer.post_attention_layernorm.input, dim=-1)

                        # decomposed mlp out    
                        up_proj = layer.mlp.up_proj.output
                        z = layer.mlp.down_proj.input
                        down_proj_WT = layer.mlp.down_proj.weight.T
                        d_mlp = decompose_glu_mlp(z=z, down_proj_WT=down_proj_WT)

                        # extract post residual state
                        post = layer.output[batch_ent_pos_idx]
                        
                        batch_cache[f'{i}.d_attn'] = d_attn.cpu().save()
                        batch_cache[f'{i}.v_proj'] = v_proj.cpu().save()
                        batch_cache[f'{i}.attn_norm_var'] = attn_norm_var.cpu().save()
                        batch_cache[f'{i}.mid'] = mid.cpu().save()
                        batch_cache[f'{i}.d_mlp'] = d_mlp.cpu().save()
                        batch_cache[f'{i}.up_proj'] = up_proj.cpu().save()
                        batch_cache[f'{i}.mlp_norm_var'] = mlp_norm_var.cpu().save()
                        batch_cache[f'{i}.post'] = post.cpu().save()
                    
                    model_cache.extend(batch_cache)
            finally:
                del tracer
                del batch_cache
                gc.collect()
                torch.cuda.empty_cache()
                end = time.process_time()
                print(f"Batch {batch_id + 1}/{num_batches}: {(end - start):.2f} seconds")
                
        for key, value in model_cache.items():
            if isinstance(value[0], torch.Tensor):
                model_cache.replace(key, max_pad_sequence(value))
        big_cache[model_id] = model_cache
    finally:
        llm.cpu()
        del llm
        del model_cache
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/1: 22.99 seconds
Memory allocated: 0.03 GB
Memory reserved: 6.22 GB


In [6]:
def get_concept_vectors(big_cache, num_homographs, num_examples):
    concept_vectors = {}
    no_layers = 10

    for model in big_cache:
        model_cache = big_cache[model]
        concept_vectors[model] = {}
        
        emb_pairs = model_cache[f'emb'].view(num_homographs, 2, num_examples, -1)
        concept_vectors[model]['emb'] = get_mass_mean_vectors(emb_pairs)

        for layer in range(no_layers):
            mid_pairs = model_cache[f'{layer}.mid'].view(num_homographs, 2, num_examples, -1)
            post_pairs = model_cache[f'{layer}.post'].view(num_homographs, 2, num_examples, -1)
            
            concept_vectors[model][f'{layer}.mid'] = get_mass_mean_vectors(mid_pairs)
            concept_vectors[model][f'{layer}.post'] = get_mass_mean_vectors(post_pairs)
    
    return concept_vectors

concept_vectors = get_concept_vectors(big_cache, 1, 4)

### Trace Contributions

In [44]:
@torch.no_grad()
def get_contribution_parts(cache: dict, stop_layer: int, batch_token_pos_idx: List[Tuple]):
    """Returns all contributing parts until the stop layer

    Args:
        cache (dict): batch cache dictionary
        stop_layer (int): index of last included layer (zero indexing)
        batch_token_position_idx (List[Tuple]): Batch and token position indexes (sample_idx, token_pos_idx)

    Returns:
        Tensor: all contributing parts until stop layer (batch, flat_pats)
    """
    
    parts = []
    parts.append(cache['full_emb'][batch_token_pos_idx][:,None])
    
    for layer in range(stop_layer):
        parts.append(cache[f'{layer}.d_attn'][batch_token_pos_idx].flatten(1, -2))
        parts.append(cache[f"{layer}.d_mlp"][batch_token_pos_idx])
    parts = torch.concat(parts, dim=1).transpose(0,1)
    
    return parts


@torch.no_grad()
def get_target_vectors_and_postions(llm, cache_item, model_cache):
    contribution_modules = cache_item['contribution_modules']
    contribution_values = cache_item['contribution_values']
    sample_id = cache_item['target_sample_id']
    source_token_pos = cache_item['target_token_pos']
    eps = 1e-7

    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    hidden_dim = llm.config.hidden_size

    new_target_vectors = []
    new_target_token_pos = []
    new_target_module = []
    new_target_layer = []
    new_target_contribution_factor = []
    valid_contribution_idx = []

    for i,( module_desc, cont_value) in enumerate(zip(contribution_modules, contribution_values)):
        if module_desc == 'emb':continue
        
        layer_id, module_type, target_token_pos, head_pos = module_desc.split('.')
        layer_id, head_pos, target_token_pos = int(layer_id), int(head_pos), int(target_token_pos)
        
        if target_token_pos == -1:
            target_token_pos = source_token_pos
        
        new_target_token_pos.append(target_token_pos)
        new_target_module.append(f"{layer_id}.{module_type}")
        new_target_layer.append(layer_id)
        valid_contribution_idx.append(i)
        
        if module_type == 'attn':
            v_proj_W = llm.layers[layer_id].self_attn.v_proj.weight.data.view(num_k_v_heads, head_dim, hidden_dim)[head_pos//head_dim//num_head_groups, head_pos % head_dim]
            norm = llm.layers[layer_id].input_layernorm.weight.data
            
            v_proj = model_cache[f"{layer_id}.v_proj"][sample_id, target_token_pos].view(num_k_v_heads, head_dim)[head_pos//head_dim//num_head_groups, head_pos % head_dim]
            norm_var = model_cache[f"{layer_id}.attn_norm_var"][sample_id, source_token_pos] 
            
            target_vector = v_proj_W * norm / norm_var.cuda()
            target_contribution_factor = (cont_value / v_proj.cuda())
            new_target_contribution_factor.append(target_contribution_factor.cpu())
            
        elif module_type == 'mlp':
            up_proj_W = llm.layers[layer_id].mlp.up_proj.weight.data[head_pos]
            norm = llm.layers[layer_id].post_attention_layernorm.weight.data
        
            up_proj =  model_cache[f"{layer_id}.up_proj"][sample_id, target_token_pos][head_pos]
            norm_var = model_cache[f"{layer_id}.mlp_norm_var"][sample_id, source_token_pos]
            
            target_vector = up_proj_W * norm / norm_var.cuda()
            target_contribution_factor = (cont_value / up_proj.cuda())
            new_target_contribution_factor.append(target_contribution_factor.cpu())
        
        new_target_vectors.append(target_vector.cpu())
        
    new_target_vectors = torch.stack(new_target_vectors)

    return new_target_vectors, new_target_token_pos, new_target_layer, new_target_module, new_target_contribution_factor, valid_contribution_idx

def apply_target_mask(contributions, first_target_module_group_idx):
    
    batch_size, contribution_size = contributions.shape
    
    contribution_mask = torch.arange(0, contribution_size)[None, :].repeat(batch_size, 1)
    contribution_mask = contribution_mask < torch.tensor(first_target_module_group_idx)[:, None]
    
    contributions[~contribution_mask] = 0
    
    return contributions

def get_module_list(llm, target_layer, num_tokens):
    num_attn_heads, attn_head_dim, num_mlp_heads = llm.config.num_attention_heads, llm.config.head_dim, llm.config.intermediate_size
    module_list = ['emb']
    for layer_id in range(target_layer + 1):
        module_list.extend([f"{layer_id}.attn.{token_pos}.{head_pos}" for token_pos in range(num_tokens) for head_pos in range(num_attn_heads * attn_head_dim)])
        module_list.extend([f"{layer_id}.mlp.-1.{head_pos}" for head_pos in range(num_mlp_heads)])
    
    return module_list

def get_first_module_group_id_lookup(module_list):
    first_module_group_id_lookup = {}
    prev_layer, prev_module_type = 0, ''
    for i, module in enumerate(module_list):
        if module == 'emb': continue
        
        curr_layer, curr_module_type, _, _ = module.split('.')
        if not first_module_group_id_lookup or \
            curr_layer != prev_layer or curr_module_type != prev_module_type:
            first_module_group_id_lookup[f"{curr_layer}.{curr_module_type}"] = i
        
        prev_layer, prev_module_type = curr_layer, curr_module_type
        
    return first_module_group_id_lookup

def get_initial_tasks(concept_vectors, target_module, target_layer, num_examples):
    tasks = DataDict(length=num_examples)
        
    target_vectors = concept_vectors[model_id][target_module]
    target_vectors = torch.flatten(target_vectors[:, None].repeat(1, 8, 1) * torch.tensor([[1]] * 4 + [[-1]] * 4), 0, 1)
    tasks.attach('target_vectors', target_vectors, force= True)
    
    tasks.attach('target_sample_id', range(num_examples), force= True)
    tasks.attach('target_contribution_factor', torch.tensor([1] * num_examples), force= True)
    tasks.attach('target_token_pos', model_cache['ent_pos_idx'], force= True)
    tasks.attach('target_layer', [target_layer] * num_examples, force= True)
    tasks.attach('target_module', [target_module] * num_examples, force= True)
    tasks.attach('depth', [[0]] * num_examples, force= True)
    
    return tasks

def merge_tasks(tasks):
    tasks.sort(by='target_sample_id')
    tasks.sort(by='target_token_pos')
    tasks.sort(by='source_idx', descending=True)

    merged_tasks = DataDict()

    prev_task = tasks[0]
    contribution_factor_sum = torch.tensor(0, dtype=prev_task['target_contribution_factor'].dtype)
    acc_depth_list = []

    for task in tasks.to_list():
        if task['target_sample_id'] == prev_task['target_sample_id'] \
            and task['source_idx'] == prev_task['source_idx'] \
            and task['target_token_pos'] == prev_task['target_token_pos']:
            
            contribution_factor_sum += task['target_contribution_factor']
            acc_depth_list += task['depth']
            continue
        
        merged_tasks.append(
            prev_task | {'target_contribution_factor': contribution_factor_sum, 'depth':sorted(acc_depth_list)}
        )
        prev_task = task
        contribution_factor_sum = task['target_contribution_factor']

    merged_tasks.append(
        prev_task | {'target_contribution_factor': contribution_factor_sum, 'depth':sorted(acc_depth_list)}
    )      

    return merged_tasks

def get_next_curr_tasks(tasks):
    tasks.sort(by='target_module', descending=True)
    next_module = tasks[0]['target_module']
    end_idx = next((i for i, module in enumerate(tasks['target_module']) if module != next_module), None)

    if not end_idx:
        return tasks, DataDict()
    
    curr_tasks = tasks[:end_idx]
    tasks = tasks[end_idx:]
    
    return curr_tasks, tasks

def get_contribution_pars_cache(model_cache, stop_layer):
    stop_layer = 9
    parts = []
    parts.append(model_cache['full_emb'][:,:,None])

    for layer in range(stop_layer + 1):
        parts.append(model_cache[f'{layer}.d_attn'].flatten(2, -2))
        parts.append(model_cache[f"{layer}.d_mlp"])
    parts = torch.concat(parts, dim=2)

    parts -= parts.mean(dim=-1, keepdim=True)
    
    return parts

In [8]:
def collate_fn(batch):
    batch_dict = {}
    
    for key, exampel_val in batch[0].items():
        if isinstance(exampel_val, torch.Tensor):
            batch_dict[key] = max_pad_sequence([sample[key] for sample in batch])
            continue
        batch_dict[key] = [sample[key] for sample in batch]
    
    return batch_dict | {'length': len(batch_dict[key])}


In [59]:
target_module = '9.post'
target_layer = 10
contribution_cache = []
batch_size = 16
max_depth = 20

for model_id in big_cache:
    try:
        llm = AutoModel.from_pretrained(
            model_id,
            trust_remote_code=True,
            device_map='auto',
            dtype=torch.bfloat16,
            attn_implementation = 'eager'
        )
        llm.layers[target_layer:].to('meta')
        model_cache = big_cache[model_id]
        num_examples = model_cache.length
        num_tokens = model_cache['0.d_attn'].size(1)
        
        module_list = get_module_list(llm, target_layer, num_tokens)
        first_module_group_id_lookup = get_first_module_group_id_lookup(module_list)
        tasks = DataDict()
        curr_tasks = get_initial_tasks(concept_vectors, target_module, target_layer, num_examples)
        
        parts_cache = get_contribution_pars_cache(model_cache, target_layer)
        
        for depth in range(max_depth):
            print(f"\nStep {depth}/{max_depth}:")
            contribution_cache.append(DataDict())
            
            dl = DataLoader(curr_tasks, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
            num_batches = len(dl)
            
            for batch_id, batch in enumerate(dl):
                start = time.time()
                curr_batch_size = batch['length']
                batch_start_time = time.time()
                try:
                    target_vectors = batch['target_vectors']
                    sample_id = batch['target_sample_id']
                    token_pos = batch['target_token_pos']
                    target_layer = batch['target_layer']
                    
                    max_target_layer = max(target_layer)
                    
                    if depth==0:
                        parts = get_contribution_parts(model_cache, max_target_layer, batch_token_pos_idx=(sample_id, token_pos))
                    else: 
                        contribution_ceiling = first_module_group_id_lookup[f"{max_target_layer}.attn"]
                        parts = parts_cache[sample_id, token_pos, :contribution_ceiling].transpose(0, 1)
                        
                    
                    print(f"Parts time: ({(time.time()-start):.2f}s)")
                    contributions = get_dot_prod_contribution(parts = parts, whole= target_vectors).transpose(0, 1)
                    # contributions = distributed(
                    #     mixin_measure = get_dot_prod_contribution,
                    #     parts = parts, whole= target_vectors, device='cuda', chunk_size=50000).transpose(0, 1)
                    
                    if depth > 0:
                        first_target_module_group_idx = [first_module_group_id_lookup[module] for module in batch['target_module']]
                        contributions = apply_target_mask(contributions, first_target_module_group_idx)
                    
                    scaled_contributions = contributions * batch['target_contribution_factor'][:, None]
                    
                    top_x_contributions = get_top_x_contribution_values(scaled_contributions, 0.9) # does this make sense anymore?

                    contribution_idx = [el.nonzero(as_tuple=True)[0] for el in top_x_contributions]
                    contribution_values = [top_x_contributions[i][idx] for i, idx in enumerate(contribution_idx)]
                    contribution_modules = [[module_list[i] for i in idx] for idx in contribution_idx]
                    
                    contribution_cache[-1].extend({
                        'contribution_idx': contribution_idx,
                        'contribution_values':contribution_values,
                        'contribution_modules':contribution_modules,
                        'target_sample_id': sample_id,
                        'target_token_pos': token_pos,
                        'target_module': batch['target_module'],
                        'depth': batch['depth'],
                        })
                finally:
                    batch_end_time = time.time()
                    print(f"\tBatch {batch_id + 1}/{num_batches}: {(batch_end_time - batch_start_time):.2f} seconds")
            
            for cache_item in contribution_cache[-1].to_list():
                if not cache_item['contribution_modules'] or torch.all(cache_item['contribution_idx'] == 0): continue
                new_target_vectors, new_target_token_pos, new_target_layer, new_target_module, new_target_contribution_factor, valid_contribution_idx = get_target_vectors_and_postions(llm, cache_item, model_cache)

                num_tasks = len(new_target_vectors)
                source_idx = [cache_item['contribution_idx'][i] for i in valid_contribution_idx]
                depth = [depth_id+1 for depth_id in cache_item['depth']] # increase all depth ids of previous contribution for one
                
                tasks.extend({
                    'source_idx': source_idx,
                    'target_module': new_target_module,
                    'target_vectors':new_target_vectors,
                    'target_contribution_factor': new_target_contribution_factor,
                    'target_layer': new_target_layer,
                    'target_sample_id': [cache_item['target_sample_id']] * num_tasks,
                    'target_token_pos':new_target_token_pos,
                    'depth': [depth] * num_tasks
                })
            if tasks.length == 0: break
            tasks = merge_tasks(tasks) # create depth list with all depth of querying tokens
            curr_tasks, tasks = get_next_curr_tasks(tasks)
    finally:
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Step 0/20:
Parts time: (1.82s)
	Batch 1/1: 2.88 seconds

Step 1/20:
Parts time: (1.30s)
	Batch 1/7: 3.14 seconds
Parts time: (1.74s)
	Batch 2/7: 3.64 seconds
Parts time: (1.69s)
	Batch 3/7: 3.53 seconds
Parts time: (1.74s)
	Batch 4/7: 3.63 seconds
Parts time: (1.69s)
	Batch 5/7: 3.53 seconds
Parts time: (1.74s)
	Batch 6/7: 3.65 seconds
Parts time: (1.68s)
	Batch 7/7: 3.41 seconds

Step 2/20:
Parts time: (1.19s)
	Batch 1/1: 1.91 seconds

Step 3/20:
Parts time: (1.07s)
	Batch 1/9: 2.72 seconds
Parts time: (1.56s)
	Batch 2/9: 3.26 seconds
Parts time: (1.51s)
	Batch 3/9: 3.17 seconds
Parts time: (1.57s)
	Batch 4/9: 3.26 seconds
Parts time: (1.51s)
	Batch 5/9: 3.19 seconds
Parts time: (1.56s)
	Batch 6/9: 3.24 seconds
Parts time: (1.50s)
	Batch 7/9: 3.12 seconds
Parts time: (1.56s)
	Batch 8/9: 3.24 seconds
Parts time: (1.18s)
	Batch 9/9: 2.02 seconds

Step 4/20:
Parts time: (0.47s)
	Batch 1/1: 0.58 seconds

Step 5/20:
Parts time: (0.73s)
	Batch 1/13: 2.20 seconds
Parts time: (1.38s)
	Batch 

In [60]:
df = pd.DataFrame()

for cache in contribution_cache:
    for contribution in cache.to_list():
        if not contribution['contribution_modules']: continue
        num_contributions = len(contribution['contribution_modules'])
        contribution_df = pd.DataFrame(data={
            'contribution_idx': contribution['contribution_idx'].tolist(),
            'contribution_values':contribution['contribution_values'].tolist(),
            'contribution_modules':contribution['contribution_modules'],
            'target_sample_id':[contribution['target_sample_id']] * num_contributions,
            'target_token_pos':[contribution['target_token_pos']] * num_contributions,
            'depth':[contribution['depth']] * num_contributions,
            })
        df = pd.concat([df, contribution_df]).reset_index(drop=True)
        
df['layer'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[0]) if x != 'emb' else 0)
df['module_type'] = pd.Categorical(df['contribution_modules'].apply(lambda x: x.split('.')[1] if x != 'emb' else 'emb'), categories=['emb', 'attn', 'mlp'], ordered=True)
df['token_pos'] = df.apply(lambda x: int(x['contribution_modules'].split('.')[2]) if x['contribution_modules'] != 'emb' and int(x['contribution_modules'].split('.')[2]) != -1  else x['target_token_pos'], axis=1)
df['head_id'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[3]) if x != 'emb' else 0)
df = df.drop('contribution_modules', axis=1)

data_path = os.path.join(project_root, 'data/contribution_cache/cache_9_post_with_norm_hom_1.csv')
df.to_csv(data_path)
df.head()

Unnamed: 0,contribution_idx,contribution_values,target_sample_id,target_token_pos,depth,layer,module_type,token_pos,head_id
0,113378,1.359375,0,4,[0],3,mlp,4,2273
1,142508,0.261719,0,4,[0],4,mlp,4,1195
2,180371,3.125,0,4,[0],5,mlp,4,8850
3,203391,0.12207,0,4,[0],6,mlp,4,1662
4,207013,0.464844,0,4,[0],6,mlp,4,5284
