In [None]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import json
import sys
import gc
from typing import Dict, List

project_root = os.path.abspath(os.path.join(os.getcwd()))
if project_root not in sys.path:
    sys.path.append(project_root)

from transformers import AutoModel, AutoTokenizer, PretrainedConfig
import torch
from torch.utils.data import DataLoader
from nnsight import LanguageModel

from data.utils import get_entity_idx, get_prompts
from mi_toolbox.utils.data_types import DataDict
from mi_toolbox.utils.collate import TokenizeCollator
from mi_toolbox.transformer_caching import caching_wrapper, decompose_attention_to_neuron, decompose_glu_to_neuron


model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_ids[0])

data_path = os.path.join(project_root, 'data/homograph_data/homograph_small.json')
with open(data_path) as f:
    data = json.load(f) 

In [None]:
for sample_id in range(5):
    sample_data = [data[sample_id]]

    prompts = get_prompts(sample_data, context_type='minimal_context')
    ent_idx = get_entity_idx(tokenizer, prompts)

    extract_collate_fn = TokenizeCollator(tokenizer, collate_fn={
        'ent_idx': lambda key, value: {'batch_ent_pos_idx': (list(range(len(value))), value)}
    })

    batch_size = 10
    extract_dd = DataDict.from_dict({'prompts': prompts, "ent_idx": ent_idx})
    extract_dl = DataLoader(extract_dd, batch_size=batch_size, shuffle= False, collate_fn=extract_collate_fn)


    layer_target = slice(0, 10)
    def caching_function(llm: LanguageModel, config: PretrainedConfig, batch: Dict[str, List]) -> Dict:
        try:
            batch_cache = {}
            batch_ent_pos_idx = batch['batch_ent_pos_idx']

            batch_cache['attention_mask'] = batch['attention_mask']
            batch_cache['ent_pos_idx'] = batch_ent_pos_idx[1]

            with llm.trace(batch) as tracer:         
                emb = llm.model.embed_tokens.output
                batch_cache['emb'] = emb[batch_ent_pos_idx].cpu().save()
                batch_cache['full_emb'] = emb.cpu().save()
                
                for i, layer in enumerate(llm.model.layers[layer_target]):
                    attn_norm_var = torch.var(layer.input, dim=-1)
                    
                    # decompose attention out
                    v_proj = layer.self_attn.v_proj.output
                    _, attn_weight = layer.self_attn.output
                    o_proj_WT = layer.self_attn.o_proj.weight.T
                    d_attn = decompose_attention_to_neuron(
                        attn_weight, 
                        v_proj, 
                        o_proj_WT,
                        config.num_attention_heads,
                        config.num_key_value_heads,
                        config.head_dim
                    ) 
                    
                    # extract mid residual state
                    mid = layer.post_attention_layernorm.input[batch_ent_pos_idx]
                    mlp_norm_var = torch.var(layer.post_attention_layernorm.input, dim=-1)

                    # decomposed mlp out    
                    up_proj = layer.mlp.up_proj.output
                    act_prod = layer.mlp.down_proj.input
                    down_proj_WT = layer.mlp.down_proj.weight.T
                    d_mlp = decompose_glu_to_neuron(act_prod=act_prod, down_proj_WT=down_proj_WT)

                    # extract post residual state
                    post = layer.output[batch_ent_pos_idx]
                    
                    # save cache
                    batch_cache[f'{i}.d_attn'] = d_attn.cpu().save()
                    batch_cache[f'{i}.v_proj'] = v_proj.cpu().save()
                    batch_cache[f'{i}.attn_norm_var'] = attn_norm_var.cpu().save()
                    batch_cache[f'{i}.mid'] = mid.cpu().save()
                    batch_cache[f'{i}.d_mlp'] = d_mlp.cpu().save()
                    batch_cache[f'{i}.up_proj'] = up_proj.cpu().save()
                    batch_cache[f'{i}.mlp_norm_var'] = mlp_norm_var.cpu().save()
                    batch_cache[f'{i}.post'] = post.cpu().save()
        finally:
            del tracer
        return batch_cache

    big_cache = caching_wrapper(model_ids, extract_dl, caching_function)


    #TODO: add saving mechanism to transformer chache object
    output_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'
    os.makedirs(output_dir, exist_ok=True)
    model_cache = big_cache['Qwen/Qwen3-4B']

    for key, default_value in model_cache.default_entry.items():

        if isinstance(default_value, torch.Tensor):
            output_path = os.path.join(output_dir, f"{key}.safetensors")
            torch.save(model_cache[key], output_path)
        else:
            output_path = os.path.join(output_dir, f"{key}.json")
            with open(output_path, 'w') as f:
                json.dump(model_cache[key], f) 

    del big_cache
    gc.collect()


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Batch 1/3: 10.51 seconds
Batch 2/3: 13.54 seconds
Batch 3/3: 3.81 seconds
Memory allocated: 0.04 GB
Memory reserved: 0.05 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/3: 22.27 seconds
Batch 2/3: 10.00 seconds
Batch 3/3: 2.36 seconds
Memory allocated: 0.08 GB
Memory reserved: 8.11 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/3: 6.88 seconds
Batch 2/3: 10.08 seconds
Batch 3/3: 6.83 seconds
Memory allocated: 0.08 GB
Memory reserved: 8.12 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/3: 10.26 seconds
Batch 2/3: 7.08 seconds
Batch 3/3: 5.06 seconds
Memory allocated: 0.06 GB
Memory reserved: 0.08 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Batch 1/3: 7.04 seconds
Batch 2/3: 10.00 seconds
Batch 3/3: 2.27 seconds
Memory allocated: 0.06 GB
Memory reserved: 0.08 GB
