## Task List

* Create Graph based version
    - Each module is a node
    - Each contribution is an edge
    - Update Node with for ecach incomming (and outgoing) edge
* Change to bi-directional batching
    - Start with last token -> load parts -> calculate contribution for each layer -> repeat for previous token until no open nodes are left.
    - Less memory intensive (Only load cache for one token)

In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import json
import sys
import gc
from typing import Dict, List, Union, Tuple, Iterator, Generator
from contextlib import contextmanager
from typing import Generator, Tuple, Iterator, Union, List, Dict
from transformers import PreTrainedModel, PretrainedConfig
from torch.utils.data import DataLoader
import time

project_root = os.path.abspath(os.path.join(os.getcwd()))
if project_root not in sys.path:
    sys.path.append(project_root)

from transformers import AutoModel, AutoTokenizer, PretrainedConfig, AutoConfig
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from nnsight import LanguageModel
import pandas as pd
import networkx as nx

from data.utils import get_entity_idx, get_prompts
from mi_toolbox.utils.data_types import DataDict
from mi_toolbox.utils.collate import TensorCollator
from mi_toolbox.transformer_caching import caching_wrapper, decompose_attention_to_neuron, decompose_glu_to_neuron, TransformerCache
from mi_toolbox.causal_tracing import get_mass_mean_vectors
from mi_toolbox.contribution_tracing import get_dot_prod_contribution, get_top_x_contribution_values
from mi_toolbox.contribution_tracing.utils import mean_center_tensor, trace_through_layer_norm

DATA_DIR = "/raid/dacslab/CONCEPT_FORMATION/homograph_small/"
model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"

### Create parts cache

In [2]:
sample_id = 0
num_layers = 10
data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'
out_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/.cache/{sample_id}/'

In [3]:
%%script false --no-raise-error
def cache_direct_parts(in_path, out_path, num_layers):

    # load model cache
    data_keys = ['ent_pos_idx', 'full_emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
    model_cache = TransformerCache.load(in_path, data_keys, model_id=model_ids[0])

    for sample_id in range(model_cache.length):
        

    target_token_pos = model_cache['ent_pos_idx']
    batch_ent_idx = (range(model_cache.length), target_token_pos)

    # get parts for direct contributions
    direct_parts = [model_cache['full_emb'][batch_ent_idx][:, None]]
    del model_cache.data['full_emb']
    gc.collect()

    for layer in range(num_layers):
        direct_parts.append(model_cache[f'{layer}.d_attn'][batch_ent_idx].flatten(1, -2))
        del model_cache.data[f'{layer}.d_attn']
        gc.collect()
        
        direct_parts.append(model_cache[f"{layer}.d_mlp"][batch_ent_idx])
        del model_cache.data[f'{layer}.d_mlp']
        gc.collect()
        
    direct_parts = torch.concat(direct_parts, dim=1)
    print(direct_parts.shape)

    # for i, sample_parts in enumerate(direct_parts):
    #     os.makedirs(os.path.join(out_path, f'{i}'), exist_ok= True)
    #     file_path = os.path.join(out_path, f'{i}/direct_parts.safetensor')
    #     torch.save(sample_parts, file_path)

cache_direct_parts(data_dir, out_dir, num_layers)

In [4]:
%%script false --no-raise-error
def cache_indirect_parts(in_path, out_path, num_layers):

    # load model cache
    data_keys = ['full_emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
    model_cache = TransformerCache.load(in_path, data_keys, model_id=model_ids[0])
    model_cache.stack_tensors(padding=True)

    num_samples = model_cache['full_emb'].size(0)
    num_tokens = model_cache['full_emb'].size(1)

    for sample_id in range(num_samples):
        indirect_parts_cache = []
        indirect_parts_cache.append(mean_center_tensor(model_cache['full_emb'])[sample_id, :, None])
        gc.collect()

        for layer in range(num_layers):
            indirect_parts_cache.append(mean_center_tensor(model_cache[f'{layer}.d_attn'][sample_id]).flatten(1, -2))
            gc.collect()

            indirect_parts_cache.append(mean_center_tensor(model_cache[f"{layer}.d_mlp"][sample_id]))
            gc.collect()

        indirect_parts_cache = torch.concat(indirect_parts_cache, dim=1)

        for token_pos, parts in enumerate(indirect_parts_cache):
            os.makedirs(os.path.join(out_path, f'{sample_id}'), exist_ok= True)
            file_path = os.path.join(out_path, f'{sample_id}/indirect_parts__pos_{token_pos}.safetensor')
            torch.save(parts, file_path)

        del indirect_parts_cache#
        gc.collect()

cache_indirect_parts(data_dir, out_dir, num_layers)

In [5]:
def sanitize_model_id(model_id):
    model_id = model_id.replace('/','_')
    return model_id


@contextmanager
def load_language_model(
    model_id: Union[str, List[str]],
    **kwargs
) -> Generator[Tuple[PreTrainedModel, PretrainedConfig], None, None]:
    llm = AutoModel.from_pretrained(
        model_id,
        **kwargs
    )
    try:
        yield llm, llm.config
    finally:
        llm.cpu()
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
        
def batch_iterator(dl: DataLoader, silent=False) -> Iterator[Dict, ]:
    for batch_id, batch in enumerate(dl):
        start_time = time.perf_counter()
        try:
            yield batch
        finally:
            end_time = time.perf_counter()
            if not silent:
                print(f"Batch {batch_id + 1}/{len(dl)}: {(end_time - start_time):.2f} seconds")

### Get direct contributions

In [6]:


class BiDirectionalContributionTracer:

    def __init__(self, homograph_id: int, sample_id: int, model_id: str, target_pos: str,
                 batch_size = 16, device='cuda'):
        self.homograph_id = homograph_id
        self.sample_id = sample_id
        self.model_id = model_id
        self.target_pos = target_pos
        self.batch_size = batch_size
        self.device = device

        self.num_layers = int(target_pos.split('.')[0]) + 1
        self.num_tokens = self.load_num_token_pos()

        # load model config 
        self.model_config = AutoConfig.from_pretrained(model_id)

        # get ent token pos
        self.ent_token_pos = self.load_ent_token_pos()
        self.parts_token_pos = None

        self.curr_token_pos = self.ent_token_pos
        self.curr_layer = self.num_layers - 1
        self.curr_module = None

        # module lookup
        self.module_lookup = self.get_module_lookup()

    def __enter__(self):
        # perform caching logic
        pass

    def __exit__(self, exc_type, exc_value, traceback):
        # perform caching cleanup
        pass

    def load_ent_token_pos(self) -> int:
        cache_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/{self.homograph_id}"
        model_cache = DataDict.load(cache_path, ['ent_pos_idx'])

        ent_token_pos = model_cache['ent_pos_idx'][self.sample_id]
        return ent_token_pos
    
    def load_num_token_pos(self) -> int:
        cache_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/{self.homograph_id}"
        model_cache = TransformerCache.load(cache_path, ['0.d_attn'])
        model_cache.stack_tensors(padding=True) # TODO due to the padding we store bigger tensors.
        num_tokens = model_cache['0.d_attn'][self.sample_id].size(0)
        return num_tokens
    
    def get_module_lookup(self) -> List[str]:
        num_attn_heads, attn_head_dim, num_mlp_heads = self.model_config.num_attention_heads, self.model_config.head_dim, self.model_config.intermediate_size
        module_lookup = ['0.emb.0.0']

        for layer_id in range(self.num_layers):
            module_lookup.extend([f"{layer_id}.attn.{token_pos}.{head_pos}" for token_pos in range(self.num_tokens) for head_pos in range(num_attn_heads * attn_head_dim)])
            module_lookup.extend([f"{layer_id}.mlp.0.{head_pos}" for head_pos in range(num_mlp_heads)])
        
        return module_lookup
    
    def trace(self):
        G, unvisited_nodes = self.init_contribution_graph()

        with load_language_model(
            model_id=model_ids[0],
            trust_remote_code=True,
            device_map=self.device,
            dtype=torch.bfloat16,
            attn_implementation = 'eager'
        ) as (llm, config):
            
            projection_cache = self.load_projection_cache()

            while unvisited_nodes:
                next_module_nodes, unvisited_nodes = self.select_next_module_nodes(unvisited_nodes)
                next_module_tasks = self.get_next_module_tasks(G, llm, projection_cache, next_module_nodes)


                if self.parts_token_pos != self.curr_token_pos:
                    print(f'Loading token {self.curr_token_pos} cache.')
                    parts_cache = self.load_token_pos_parts(next_module_nodes)

                contribution_end_idx = self.get_contribution_end_idx()

                dl = DataLoader(next_module_tasks, batch_size=self.batch_size, collate_fn=TensorCollator(), shuffle=False)

                for batch in batch_iterator(dl, silent=True):
                    
                    contributions = self.get_contributions(batch, parts_cache, contribution_end_idx)

                    unvisited_nodes = self.update_contribution_graph(G, batch, contributions, unvisited_nodes)

        return G
    
    def init_contribution_graph(self) -> Tuple[nx.DiGraph, List[str]]:
        layer_id, module_type = self.target_pos.split('.')

        G = nx.DiGraph()
        G.add_node('target', token_pos=self.ent_token_pos, layer_id=int(layer_id), module_type=module_type)

        return G, list(G.nodes())
    
    def load_projection_cache(self):
        cache_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/{self.homograph_id}" # TODO create cache_path in __init__
        data_keys = [f"{layer}.{module}" for layer in range(num_layers) for module in ['v_proj', 'attn_norm_var', 'up_proj', 'mlp_norm_var']]

        projection_cache = TransformerCache.load(cache_path, data_keys, model_id=self.model_id)
        projection_cache.stack_tensors(padding=True)

        return projection_cache
    
    def select_next_module_nodes(self, unvisited_nodes) -> Tuple[List[str], List[str]]:

        if unvisited_nodes == ['target']:
            return unvisited_nodes, []

        sorted_unvisited_nodes = sorted(unvisited_nodes, key=self.unvisited_nodes_sorter) # TODO skipp 'emb' modules
        next_module_type = '.'.join(sorted_unvisited_nodes[-1].split('.')[:3])

        next_module_nodes = []
        while sorted_unvisited_nodes:
            if not sorted_unvisited_nodes[-1].startswith(next_module_type): break
            next_module_nodes.append(sorted_unvisited_nodes.pop())

        return next_module_nodes, sorted_unvisited_nodes

    @staticmethod
    def unvisited_nodes_sorter(node: str) -> int:
        layer_id, module, token_pos, _ = node.split('.')
        return int(token_pos) * 10000 + int(layer_id) * 100 + int(module == 'mlp')
    
    def get_next_module_tasks(self, G: nx.DiGraph, llm: PreTrainedModel, projection_cache: TransformerCache, next_module_nodes: List[str]) -> DataDict:

        if next_module_nodes == ['target']:

            target_vector = self.get_target_vector()

            return DataDict.from_dict({
                "nodes":['target'], 
                "target_vectors": [target_vector],
                "contribution_factors": torch.tensor([1])
                })
        
        example_node = next_module_nodes[0]
        self.update_current_module(G, example_node)

        projection_vectors = self.get_projection_vectors(G, llm, projection_cache, next_module_nodes)
        conrtibution_factors = self.get_contribution_factors(G, projection_cache, next_module_nodes)

        return DataDict.from_dict({
            "nodes": next_module_nodes,
            "target_vectors": projection_vectors,
            "contribution_factors": conrtibution_factors
        })

    def get_target_vector(self) -> torch.Tensor:
        cache_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/{self.homograph_id}"
        model_cache = TransformerCache.load(cache_path, [self.target_pos], model_id=self.model_id)
        model_cache.stack_tensors()

        num_samples, model_dim = model_cache[self.target_pos].shape
        target_residual_states = model_cache[self.target_pos].view(2, num_samples // 2, model_dim).unsqueeze(0)
        target_vector = get_mass_mean_vectors(target_residual_states)[0]

        if sample_id >= num_samples / 2:
            target_vector = -target_vector

        return target_vector
    
    def update_current_module(self, G: nx.DiGraph, example_node: str) -> None:
        self.curr_token_pos = G.nodes[example_node]['token_pos']
        self.curr_layer = G.nodes[example_node]['layer_id']
        self.curr_module = G.nodes[example_node]['module_type']

    def get_projection_vectors(self, G: nx.DiGraph, llm: PreTrainedModel, projection_cache: TransformerCache, next_module_nodes: List[str]) -> torch.Tensor:

        projection_heads = [G.nodes[node]['head_id'] for node in next_module_nodes]

        if self.curr_module == 'attn':
            projection_vectors = self.get_attn_projection_vectors(llm, projection_cache, projection_heads)
            return projection_vectors
        elif self.curr_module == 'mlp':
            projection_vectors = self.get_mlp_projection_vectors(llm, projection_cache, projection_heads)
            return projection_vectors
        raise ValueError(f"Unexpected module type: {self.curr_module}")

    def get_attn_projection_vectors(self, llm: PreTrainedModel, projection_cache: TransformerCache, projection_heads: List[int]) -> torch.Tensor:
        num_attn_heads = self.model_config.num_attention_heads
        num_k_v_heads = self.model_config.num_key_value_heads
        num_head_groups = num_attn_heads // num_k_v_heads
        head_dim = self.model_config.head_dim
        hidden_dim = self.model_config.hidden_size

        k_v_head_pos = torch.tensor(projection_heads) // head_dim // num_head_groups
        neuron_pos = torch.tensor(projection_heads) % head_dim

        v_proj_W = llm.layers[self.curr_layer].self_attn.v_proj.weight.data.view(num_k_v_heads, head_dim, hidden_dim)[(k_v_head_pos, neuron_pos)]
        norm_W = llm.layers[self.curr_layer].input_layernorm.weight.data
        norm_var = projection_cache[f"{self.curr_layer}.attn_norm_var"][(self.sample_id, self.curr_token_pos)].unsqueeze(0)

        projection_vectors = trace_through_layer_norm(v_proj_W, norm_W, norm_var, device=self.device)

        return projection_vectors.cpu()

    def get_mlp_projection_vectors(self, llm: PreTrainedModel, projection_cache: TransformerCache, projection_heads: List[int]) -> torch.Tensor:
        up_proj_W = llm.layers[self.curr_layer].mlp.up_proj.weight.data[projection_heads]
        norm_W = llm.layers[self.curr_layer].post_attention_layernorm.weight.data
        norm_var = projection_cache[f"{self.curr_layer}.mlp_norm_var"][(self.sample_id, self.curr_token_pos)].unsqueeze(0)

        projection_vectors = trace_through_layer_norm(up_proj_W, norm_W, norm_var, device=self.device)

        return projection_vectors.cpu()

    def get_contribution_factors(self, G: nx.DiGraph, projection_cache: TransformerCache, next_module_nodes: List[str]) -> torch.Tensor:
        
        projection_heads = [G.nodes[node]['head_id'] for node in next_module_nodes]
        contribution_values = self.get_node_contribution_values(G, next_module_nodes)

        if self.curr_module == 'attn':
            contribution_factors = self.get_attn_contribution_factors(G, projection_cache, contribution_values, projection_heads)
            return contribution_factors
        if self.curr_module == 'mlp':
            contribution_factors = self.get_mlp_contribution_factors(G, projection_cache, contribution_values, projection_heads)
            return contribution_factors
        raise ValueError(f"Unexpected module type: {self.curr_module}")

    def get_node_contribution_values(self, G: nx.DiGraph, next_module_nodes: List[str]) -> torch.Tensor:
        
        contribution_values = []
        for projection_node in next_module_nodes:
            incoming_edges_with_data = G.in_edges(projection_node, data=True)
            contribution_value = sum([data['weight'] for _, _, data in incoming_edges_with_data])

            G.nodes[projection_node]['contribution_value'] = contribution_value
            contribution_values.append(contribution_value)

        return torch.tensor(contribution_values)

    def get_attn_contribution_factors(self, G: nx.DiGraph, projection_cache: Dict, contribution_values: torch.Tensor, projection_heads: List[int]) -> torch.Tensor:
        num_attn_heads = self.model_config.num_attention_heads
        num_k_v_heads = self.model_config.num_key_value_heads
        num_head_groups = num_attn_heads // num_k_v_heads
        head_dim = self.model_config.head_dim

        k_v_head_pos = torch.tensor(projection_heads) // head_dim // num_head_groups
        neuron_pos = torch.tensor(projection_heads) % head_dim

        v_proj = projection_cache[f"{self.curr_layer}.v_proj"][self.sample_id, self.curr_token_pos].view(num_k_v_heads, head_dim)[k_v_head_pos, neuron_pos]
        contribution_factors = contribution_values / v_proj

        # print(sorted(list(zip(contribution_factors.tolist(), contribution_values.tolist(), v_proj.tolist())), key= lambda x: x[0], reverse=True)[:5])

        return contribution_factors

    def get_mlp_contribution_factors(self, G: nx.DiGraph, projection_cache: Dict, contribution_values: torch.Tensor, projection_heads: List[int]) -> torch.Tensor:
        up_proj =  projection_cache[f"{self.curr_layer}.up_proj"][self.sample_id, self.curr_token_pos, projection_heads]
        contribution_factors = contribution_values / up_proj

        # print(sorted(list(zip(contribution_factors.tolist(), contribution_values.tolist(), up_proj.tolist())), key= lambda x: x[0], reverse=True)[:5])

        return contribution_factors

    def load_token_pos_parts(self, next_module_nodes: List[str]) -> torch.Tensor:

        if next_module_nodes == ['target']:
            parts_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/.cache/{self.homograph_id}/{self.sample_id}/direct_parts.safetensor"
        else:
            parts_path = f"{DATA_DIR}/{sanitize_model_id(self.model_id)}/.cache/{self.homograph_id}/{self.sample_id}/indirect_parts__pos_{self.curr_token_pos}.safetensor"

        parts = torch.load(parts_path).unsqueeze(1)
        self.parts_token_pos = self.curr_token_pos
        return parts.to(self.device)
    
    def get_contribution_end_idx(self):
        if self.curr_module == None:
            return None
        first_contribution_of_type = f"{self.curr_layer}.{self.curr_module}.0.0"
        end_idx = self.module_lookup.index(first_contribution_of_type)
        return end_idx

    def get_contributions(self, batch: Dict, batch_parts: torch.Tensor, contribution_end_idx: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        target_vectors = batch['target_vectors'].to(self.device)
        contribution_factors = batch['contribution_factors'].to(self.device)

        contributions = get_dot_prod_contribution(parts=batch_parts[:contribution_end_idx], whole=target_vectors).transpose(0, 1)
        
        scaled_contributions = contributions * contribution_factors[:, None]

        top_x_contributions = get_top_x_contribution_values(scaled_contributions, 0.9) 

        contribution_idx = [el.nonzero(as_tuple=True)[0] for el in top_x_contributions]
        contribution_values = [top_x_contributions[i][idx] for i, idx in enumerate(contribution_idx)]

        return contribution_idx, contribution_values
    
    def update_contribution_graph(self, G:nx.DiGraph, batch: Dict, contributions: Tuple[torch.Tensor, torch.Tensor], unvisited_nodes: List[str]) -> List[str]:
        source_nodes = batch['nodes']
        unvisited_nodes = set(unvisited_nodes)

        flat_contributions = self.flatten_contributions(source_nodes, contributions)
        for source_node, idx, value in flat_contributions:
            out_of_context_contribution_node = self.module_lookup[idx]

            contextualized_contribution_node = self.contextualize_contribution_node(out_of_context_contribution_node)

            if contextualized_contribution_node not in unvisited_nodes:
                unvisited_nodes = self.add_new_contribution_node(G, unvisited_nodes, contextualized_contribution_node)

            G.add_edge(source_node, contextualized_contribution_node, weight=value.item())

        return list(unvisited_nodes)
    
    def flatten_contributions(self, source_nodes: List[str], contributions: Tuple[torch.Tensor, torch.Tensor]) -> List[Tuple[str, torch.Tensor, torch.Tensor]]:
        flat_contributions = []

        for source_node, contribution_idx, contribution_values in zip(source_nodes, *contributions):
            for idx, value in zip(contribution_idx, contribution_values):
                flat_contributions.append((source_node, idx, value))
        
        return flat_contributions
    
    def contextualize_contribution_node(self, contribution_node: str) -> str:
        contribution_layer, contribution_module, contribution_pos, contribution_head = contribution_node.split('.')

        if contribution_module == 'mlp' or contribution_module == 'emb':
            contribution_pos = self.curr_token_pos
        
        contextualized_contribution_node = f"{contribution_layer}.{contribution_module}.{contribution_pos}.{contribution_head}"

        return contextualized_contribution_node

    
    def add_new_contribution_node(self, G: nx.DiGraph, unvisited_nodes: List[str], contribution_node: str) -> List[str]:
        contribution_layer, contribution_module, contribution_pos, contribution_head = contribution_node.split('.')

        G.add_node(contribution_node, token_pos=int(contribution_pos), layer_id=int(contribution_layer), module_type=contribution_module, head_id=int(contribution_head))

        if contribution_module != 'emb':
            unvisited_nodes.add(contribution_node)

        return unvisited_nodes

In [7]:
tracer = BiDirectionalContributionTracer(0, 0, model_ids[0], '9.post')

G = tracer.trace()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading token 3 cache.
Loading token 2 cache.
Loading token 1 cache.
Loading token 0 cache.
Memory allocated: 1.76 GB
Memory reserved: 1.76 GB


In [10]:
df = nx.to_pandas_edgelist(G)

'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/.cache/0/'

In [12]:
data_path = os.path.join(project_root, f'data/contribution_cache/cache_9_post_bidir_{sample_id}.parquet')
df.to_parquet(data_path, index=False)