In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import json
import sys
import gc
from typing import Dict, List, Union

project_root = os.path.abspath(os.path.join(os.getcwd()))
if project_root not in sys.path:
    sys.path.append(project_root)

from transformers import AutoModel, AutoTokenizer, PretrainedConfig, AutoConfig
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from nnsight import LanguageModel

from data.utils import get_entity_idx, get_prompts
from mi_toolbox.utils.data_types import DataDict
from mi_toolbox.utils.collate import TensorCollator
from mi_toolbox.transformer_caching import caching_wrapper, decompose_attention_to_neuron, decompose_glu_to_neuron, TransformerCache
from mi_toolbox.causal_tracing import get_mass_mean_vectors
from mi_toolbox.contribution_tracing import get_dot_prod_contribution, get_top_x_contribution_values


model_ids = ["Qwen/Qwen3-4B"]#, "Qwen/Qwen3-8B", "Qwen/Qwen3-14B"]  # "Qwen/Qwen3-0.6B","Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_ids[0])

### Get Concept Vectors

In [2]:
sample_id = 0
num_layers = 10
data_dir = f'/raid/dacslab/CONCEPT_FORMATION/homograph_small/Qwen_Qwen3-4B/{sample_id}/'

data_keys = ['emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['mid', 'post']]
target_token_res_states = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
target_token_res_states.stack_tensors()

concept_vectors = {}
no_layers = 10
num_homographs = 1
num_examples = 5

concept_vectors = {}

emb_pairs = target_token_res_states[f'emb'].view(num_homographs, 2, num_examples, -1)
concept_vectors['emb'] = get_mass_mean_vectors(emb_pairs)

for layer in range(no_layers):
    mid_pairs = target_token_res_states[f'{layer}.mid'].view(num_homographs, 2, num_examples, -1)
    post_pairs = target_token_res_states[f'{layer}.post'].view(num_homographs, 2, num_examples, -1)
    
    concept_vectors[f'{layer}.mid'] = get_mass_mean_vectors(mid_pairs)
    concept_vectors[f'{layer}.post'] = get_mass_mean_vectors(post_pairs)

del target_token_res_states
gc.collect()

237

### Get direct contributions

In [3]:
def get_module_lookup(model_config: PretrainedConfig, target_layer: int, num_tokens: int) -> List[str]:
    num_attn_heads, attn_head_dim, num_mlp_heads = model_config.num_attention_heads, model_config.head_dim, model_config.intermediate_size
    module_list = ['emb']

    for layer_id in range(target_layer):
        module_list.extend([f"{layer_id}.attn.{token_pos}.{head_pos}" for token_pos in range(num_tokens) for head_pos in range(num_attn_heads * attn_head_dim)])
        module_list.extend([f"{layer_id}.mlp.-1.{head_pos}" for head_pos in range(num_mlp_heads)])
    
    return module_list

def get_first_module_group_id_lookup(module_lookup: List[str]) -> Dict[str, int]:
    first_module_group_id_lookup = {}
    prev_layer, prev_module_type = 0, ''
    for i, module in enumerate(module_lookup):
        if module == 'emb': continue
        
        curr_layer, curr_module_type, _, _ = module.split('.')
        if not first_module_group_id_lookup or \
            curr_layer != prev_layer or curr_module_type != prev_module_type:
            first_module_group_id_lookup[f"{curr_layer}.{curr_module_type}"] = i
        
        prev_layer, prev_module_type = curr_layer, curr_module_type
        
    return first_module_group_id_lookup

def get_initial_tasks(concept_vectors: torch.Tensor, target_module: str, num_examples: int, target_token_pos: List[int]) -> DataDict:
    tasks = DataDict(length=num_examples)
        
    target_vectors = concept_vectors[target_module]
    target_vectors = torch.flatten(target_vectors[:, None].repeat(1, 10, 1) * torch.tensor([[1]] * 5 + [[-1]] * 5), 0, 1)
    tasks.attach('target_idx', [None] * num_examples)
    tasks.attach('target_vectors', target_vectors, force= True)
    tasks.attach('target_contribution_factor', torch.tensor([1] * num_examples), force= True)
    tasks.attach('target_sample_ids', range(num_examples), force= True)
    tasks.attach('target_token_pos', target_token_pos, force= True)
    tasks.attach('depth', [[0]] * num_examples, force= True)

    return tasks

def mean_shift(tensor: torch.Tensor) -> torch.Tensor:
    return tensor - tensor.mean(dim=-1, keepdim=True)

In [4]:
# load model cache
data_keys = ['ent_pos_idx', 'full_emb'] + [f"{layer}.{module}" for layer in range(num_layers) for module in ['d_attn', 'd_mlp']]
model_cache = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
model_cache.stack_tensors(padding=True)

num_examples = model_cache.length
num_tokens = model_cache['0.d_attn'].size(1)
target_token_pos = model_cache['ent_pos_idx']
batch_ent_idx = (range(model_cache.length), target_token_pos)

# get parts for direct contributions
direct_parts = [model_cache['full_emb'][batch_ent_idx][:, None]]

for layer in range(no_layers):
    direct_parts.append(model_cache[f'{layer}.d_attn'][batch_ent_idx].flatten(1, -2))
    direct_parts.append(model_cache[f"{layer}.d_mlp"][batch_ent_idx])
direct_parts = torch.concat(direct_parts, dim=1).transpose(0,1)

# get projection cache
data_keys = [f"{layer}.{module}" for layer in range(num_layers) for module in ['v_proj', 'attn_norm_var', 'up_proj', 'mlp_norm_var']]
projection_cache = TransformerCache.load(data_dir, data_keys, model_id=model_ids[0])
projection_cache.stack_tensors(padding=True)

# get indirect parts cache
indirect_parts_cache = [mean_shift(model_cache['full_emb'])[:, :, None]]

for layer in range(no_layers):
    indirect_parts_cache.append(mean_shift(model_cache[f'{layer}.d_attn']).flatten(2, -2))
    indirect_parts_cache.append(mean_shift(model_cache[f"{layer}.d_mlp"]))
indirect_parts_cache = torch.concat(indirect_parts_cache, dim=2)

# free memory 
del model_cache
gc.collect()

184

In [5]:
from contextlib import contextmanager
from typing import Generator, Tuple, Iterator, Union, List, Dict
from transformers import PreTrainedModel, PretrainedConfig
from torch.utils.data import DataLoader
import time

@contextmanager
def load_language_model(
    model_id: Union[str, List[str]],
    **kwargs
) -> Generator[Tuple[PreTrainedModel, PretrainedConfig], None, None]:
    llm = AutoModel.from_pretrained(
        model_id,
        **kwargs
    )
    try:
        yield llm, llm.config
    finally:
        llm.cpu()
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

def batch_iterator(dl: DataLoader) -> Iterator[Dict, ]:
    for batch_id, batch in enumerate(dl):
        start_time = time.perf_counter()
        try:
            yield batch
        finally:
            end_time = time.perf_counter()
            print(f"Batch {batch_id + 1}/{len(dl)}: {(end_time - start_time):.2f} seconds")

In [6]:
def get_contribution_ceiling_id(target_module: str, module_lookup: List[str]):
    module_type = '.'.join(target_module.split('.')[:2])
    first_module_id = next(idx for idx, module in enumerate(module_lookup) if module.startswith(module_type))
    return first_module_id

def get_batch_parts(batch: Dict, parts_cache: torch.Tensor, contribution_ceiling_id: int):
    target_sample_ids = batch['target_sample_ids']
    target_token_pos = batch['target_token_pos']

    batch_parts = parts_cache[target_sample_ids, target_token_pos, :contribution_ceiling_id].transpose(0, 1)
    return batch_parts

def get_contributions(batch: Dict, batch_parts: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    target_vectors = batch['target_vectors']
    target_contribution_factor = batch['target_contribution_factor']

    contributions = get_dot_prod_contribution(parts=batch_parts, whole=target_vectors).transpose(0, 1)
    scaled_contributions = contributions * target_contribution_factor[:, None]

    print(scaled_contributions)

    top_x_contributions = get_top_x_contribution_values(scaled_contributions, 0.9) 

    contribution_idx = [el.nonzero(as_tuple=True)[0] for el in top_x_contributions]
    contribution_values = [top_x_contributions[i][idx] for i, idx in enumerate(contribution_idx)]

    return contribution_idx, contribution_values

def get_direct_contributions(direct_parts: torch.Tensor, concept_vectors: torch.Tensor, module_lookup: List[str], target_module: str, target_token_pos: List[int], batch_size: int = 16) -> DataDict:
    active_tasks = get_initial_tasks(concept_vectors, target_module, num_examples, target_token_pos)

    dl = DataLoader(active_tasks, batch_size=batch_size, collate_fn=TensorCollator(), shuffle=False)
    direct_contributions = DataDict()

    for batch in batch_iterator(dl):
        contribution_idx, contribution_values = get_contributions(batch, direct_parts)
        contribution_modules = [[module_lookup[id] for id in idx] for idx in contribution_idx]

        direct_contributions.extend({
            'contribution_idx': contribution_idx,
            'contribution_values':contribution_values,
            'contribution_modules': contribution_modules,
            'target_sample_ids': batch['target_sample_ids'],
            'target_token_pos': batch['target_token_pos'],
            'target_idx': batch['target_idx'],
            'depth': batch['depth'],
            })
            
    return direct_contributions


In [7]:
from tqdm import tqdm
def update_task_queue(task_queue: DataDict, new_contributions: DataDict) -> DataDict:
    for cache_item in tqdm(new_contributions.to_list()):
        contribution_idx = cache_item['contribution_idx']

        non_emb_contribution_ids = [i for i, idx in enumerate(contribution_idx) if idx != 0]
        num_new_tasks = len(non_emb_contribution_ids)

        if not num_new_tasks: continue

        non_emb_contribution_idx = [contribution_idx[i] for i in non_emb_contribution_ids]
        non_emb_contribution_values = [cache_item['contribution_values'][i] for i in non_emb_contribution_ids]
        non_emb_target_modules = [cache_item['contribution_modules'][i] for i in non_emb_contribution_ids]

        task_queue.extend({
            'target_idx': non_emb_contribution_idx,
            'target_values': non_emb_contribution_values,
            'target_modules': non_emb_target_modules,
            'target_sample_ids': [cache_item['target_sample_ids']] * num_new_tasks,
            'source_token_pos': [cache_item['target_token_pos']] * num_new_tasks,
            'depth': [cache_item['depth']] * num_new_tasks
        })
        
    print(task_queue.length)
    task_queue = merge_task_queue_items(task_queue)

    return task_queue

def merge_task_queue_items(task_queue: DataDict) -> DataDict:
    task_queue.sort(by='target_sample_ids')
    task_queue.sort(by='target_idx', descending=True)

    prev_task = task_queue[0]
    merge_idx = []
    value_accumulator = [[]]
    acc_depth_list = [[]]

    for i, task in tqdm(enumerate(task_queue.to_list())):
        if task['target_sample_ids'] == prev_task['target_sample_ids'] \
            and task['target_idx'] == prev_task['target_idx']:
            
            value_accumulator[-1].append(task['target_values'])
            acc_depth_list[-1].append(task['depth'])
            continue
        
        merge_idx.append(i - 1)
        prev_task = task
        value_accumulator.append([task['target_values']])
        acc_depth_list.append([task['depth']])

    merge_idx.append(i - 1)

    merged_tasks = task_queue[merge_idx]
    merged_tasks['target_values'] = [sum(task_values) for task_values in value_accumulator]
    merged_tasks['depth'] = acc_depth_list 

    return merged_tasks

@torch.no_grad()
def get_active_tasks(llm: PreTrainedModel, task_queue: DataDict, projection_cache: Dict) -> Tuple[DataDict, DataDict]:
    task_queue.sort(by='target_modules', descending=True)
    next_module_type = '.'.join(task_queue[0]['target_modules'].split('.')[:2])
    end_idx = next((i for i, module in enumerate(task_queue['target_modules']) if not module.startswith(next_module_type)), None)

    if not end_idx:
        new_active_tasks = task_queue
        new_task_queue = DataDict.from_dict({k:[] for k in task_queue})
    else:
        new_active_tasks = task_queue[:end_idx]
        new_task_queue = task_queue[end_idx:]
    
    target_token_pos = get_target_token_pos(new_active_tasks)
    new_active_tasks.attach('target_token_pos', target_token_pos)

    target_vectors = get_target_vectors(llm, new_active_tasks, projection_cache)
    target_contribution_factor = get_contribution_factor(llm, new_active_tasks, projection_cache)

    new_active_tasks.attach('target_vectors', target_vectors)
    new_active_tasks.attach('target_contribution_factor', target_contribution_factor)

    return new_active_tasks, new_task_queue

def get_target_vectors(llm: PreTrainedModel, tasks: DataDict, projection_cache: Dict) -> List[torch.Tensor]:

    target_layer, module_type = tasks[0]['target_modules'].split('.')[:2]
    target_heads = [int(module.split('.')[-1]) for module in tasks['target_modules']]

    sample_ids = tasks['target_sample_ids']
    token_pos = tasks['target_token_pos']
    if module_type == 'attn':
        target_vectors = get_attn_target_vectors(llm, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return target_vectors
    if module_type == 'mlp':
        target_vectors = get_mlp_target_vectors(llm, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return target_vectors
    raise ValueError(f"Unexpected module type: {module_type}")

def get_attn_target_vectors(llm: PreTrainedModel, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    hidden_dim = llm.config.hidden_size

    k_v_head_pos = torch.tensor(target_heads) // head_dim // num_head_groups
    neuron_pos = torch.tensor(target_heads) % head_dim

    v_proj_W = llm.layers[target_layer].self_attn.v_proj.weight.data.view(num_k_v_heads, head_dim, hidden_dim)[(k_v_head_pos, neuron_pos)]
    norm_W = llm.layers[target_layer].input_layernorm.weight.data
    norm_var = projection_cache[f"{target_layer}.attn_norm_var"][(sample_ids, token_pos)]

    target_vectors = trace_through_layer_norm(v_proj_W, norm_W, norm_var)

    return target_vectors.cpu()

def get_mlp_target_vectors(llm: PreTrainedModel, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    up_proj_W = llm.layers[target_layer].mlp.up_proj.weight.data[target_heads]
    norm_W = llm.layers[target_layer].post_attention_layernorm.weight.data
    norm_var = projection_cache[f"{target_layer}.mlp_norm_var"][sample_ids, token_pos]

    target_vectors = trace_through_layer_norm(up_proj_W, norm_W, norm_var)

    return target_vectors.cpu()

@torch.no_grad()
def trace_through_layer_norm(target: torch.Tensor, norm_W: torch.Tensor, norm_var: torch.Tensor, device: Union[str, torch.device] = 'cuda') -> torch.Tensor:
    input_device = target.device
    if not device:
        device = input_device
    
    var_scaled_target = target.to(device) * norm_var[:, None].to(device)
    norm_w_scaled_target = var_scaled_target / norm_W.to(device)

    return norm_w_scaled_target.to(input_device)

def get_contribution_factor(llm: PreTrainedModel, tasks: DataDict, projection_cache: Dict):
    target_layer, module_type = tasks[0]['target_modules'].split('.')[:2]
    target_heads = [int(module.split('.')[-1]) for module in tasks['target_modules']]

    sample_ids = tasks['target_sample_ids']
    target_values = torch.tensor(tasks['target_values'])
    token_pos = tasks['target_token_pos']
    if module_type == 'attn':
        contribution_factor = get_attn_contribution_factor(llm, target_values, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return contribution_factor
    if module_type == 'mlp':
        contribution_factor = get_mlp_contribution_factor(target_values, projection_cache, int(target_layer), sample_ids, token_pos, target_heads)
        return contribution_factor
    raise ValueError(f"Unexpected module type: {module_type}")

def get_attn_contribution_factor(llm: PreTrainedModel, target_values: torch.Tensor, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    num_attn_heads = llm.config.num_attention_heads
    num_k_v_heads = llm.config.num_key_value_heads
    num_head_groups = num_attn_heads // num_k_v_heads
    head_dim = llm.config.head_dim
    num_targets = len(sample_ids)

    k_v_head_pos = torch.tensor(target_heads) // head_dim // num_head_groups
    neuron_pos = torch.tensor(target_heads) % head_dim

    v_proj = projection_cache[f"{target_layer}.v_proj"][sample_ids, token_pos].view(num_targets, num_k_v_heads, head_dim)[range(num_targets), k_v_head_pos, neuron_pos]
    contribution_factor = target_values / v_proj

    return contribution_factor

def get_mlp_contribution_factor(target_values: torch.Tensor, projection_cache: Dict, target_layer: int, sample_ids: List[int], token_pos: List[int], target_heads: List[int]) -> torch.Tensor:
    up_proj =  projection_cache[f"{target_layer}.up_proj"][sample_ids, token_pos, target_heads]
    contribution_factor = target_values / up_proj

    return contribution_factor

def get_target_token_pos(tasks: DataDict):
    module_token_pos = torch.tensor([int(mod.split('.')[-2]) for mod in tasks['target_modules']])
    source_token_pos = torch.tensor(tasks['source_token_pos'])
    target_token_pos = torch.where(module_token_pos != -1, module_token_pos, source_token_pos).tolist()

    return target_token_pos

In [8]:
target_module = '9.post'
target_layer = 10
batch_size = 32

with load_language_model(
    model_id=model_ids[0],
    trust_remote_code=True,
    device_map='auto',
    dtype=torch.bfloat16,
    attn_implementation = 'eager'
) as (llm, config):
    module_lookup = get_module_lookup(config, target_layer, num_tokens)

    direct_contributions = get_direct_contributions(direct_parts, concept_vectors, module_lookup, target_module, target_token_pos)

    task_queue = update_task_queue(DataDict(), direct_contributions)
    active_tasks, task_queue = get_active_tasks(llm, task_queue, projection_cache)

    indirect_contributions = []
    while active_tasks.length:
        indirect_contributions.append(DataDict())
        dl = DataLoader(active_tasks, batch_size=batch_size, collate_fn=TensorCollator(), shuffle=False)
        
        for batch in batch_iterator(dl):
            target_module = batch['target_modules'][0]
            print(target_module)
            contribution_ceiling_id = get_contribution_ceiling_id(target_module, module_lookup)
            batch_parts = get_batch_parts(batch, indirect_parts_cache, contribution_ceiling_id)
    
            contribution_idx, contribution_values = get_contributions(batch, batch_parts)
            contribution_modules = [[module_lookup[id] for id in idx] for idx in contribution_idx]

            indirect_contributions[-1].extend({
                'contribution_idx': contribution_idx,
                'contribution_values':contribution_values,
                'contribution_modules': contribution_modules,
                'target_sample_ids': batch['target_sample_ids'],
                'target_token_pos': batch['target_token_pos'],
                'target_idx': batch['target_idx'],
                'depth': batch['depth'],
                })
        print('end cont')
        prev = time.time()
        task_queue = update_task_queue(task_queue, indirect_contributions[-1])
        now = time.time()
        print(f'end queue update: {now - prev}')
        prev = time.time()
        active_tasks, task_queue = get_active_tasks(llm, task_queue, projection_cache)
        now = time.time()
        print(f'end active tasks selection: {now - prev}')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

tensor([[-9.9487e-03, -5.7220e-04,  4.9210e-04,  ...,  4.2915e-04,
          7.0333e-06, -4.0245e-04],
        [-9.9487e-03, -3.7384e-04,  3.2425e-04,  ...,  1.2779e-04,
         -1.2293e-06, -1.5182e-03],
        [-9.9487e-03, -4.6730e-04,  4.0054e-04,  ..., -2.0905e-03,
          9.1642e-07, -6.0654e-04],
        ...,
        [ 9.9487e-03,  5.8365e-04, -5.0354e-04,  ...,  1.1368e-03,
         -5.2452e-06,  6.0654e-04],
        [ 9.9487e-03,  5.0354e-04, -4.3678e-04,  ..., -1.3580e-03,
          8.5682e-07,  9.5367e-04],
        [ 9.9487e-03,  4.4250e-04, -3.8147e-04,  ...,  1.7319e-03,
          7.5698e-06,  9.5367e-04]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/1: 3.02 seconds


100%|██████████| 10/10 [00:00<00:00, 1052.29it/s]


532


532it [00:00, 111024.47it/s]

9.mlp.-1.975





tensor([[ 3.7842e-02,  3.0212e-03,  2.0905e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.9609e-02,  5.1880e-03,  3.6316e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.3965e-01,  1.0132e-02,  7.0190e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 2.3906e+00,  4.9072e-02,  2.2217e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.0000e+00,  5.1270e-02,  2.3193e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.8828e+00,  2.4048e-02,  1.1108e-02,  ...,  1.0864e-02,
         -2.0020e-02,  6.0120e-03]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/6: 8.05 seconds
9.mlp.-1.8759
tensor([[-7.5391e-01, -2.1362e-02, -9.7656e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 3.4180e-01,  1.0925e-02,  5.4016e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.9629e-01,  1.0376e-02,  6.4087e-03,  ..., -0.0000e+00,
    

100%|██████████| 170/170 [00:00<00:00, 12545.64it/s]

1250



1250it [00:00, 363105.48it/s]


end queue update: 0.06739354133605957
end active tasks selection: 0.015015363693237305
9.attn.3.2710
tensor([[-3.8250e+01, -3.7812e+00, -2.7031e+00,  ..., -1.8375e+01,
          3.2812e+00, -7.1250e+00],
        [-3.4000e+01, -3.6875e+00, -2.6719e+00,  ..., -1.8250e+01,
          7.3438e-01, -7.9688e+00],
        [-6.9500e+01, -8.5000e+00, -6.0938e+00,  ..., -2.7500e+01,
         -1.0312e+01, -1.1500e+01],
        ...,
        [-1.5155e+06,  3.4406e+05,  2.4883e+05,  ...,  3.7888e+05,
         -4.6500e+01, -9.7280e+04],
        [-5.3248e+05,  1.2083e+05,  8.7552e+04,  ...,  1.3312e+05,
         -1.7250e+01, -3.4304e+04],
        [-1.2042e+06,  2.7238e+05,  1.9661e+05,  ...,  3.0106e+05,
         -3.8750e+01, -7.7312e+04]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/1: 5.76 seconds
end cont


100%|██████████| 17/17 [00:00<00:00, 27466.55it/s]


530


530it [00:00, 684204.72it/s]

end queue update: 0.019420146942138672
end active tasks selection: 0.0017070770263671875
8.mlp.-1.9698





tensor([[-3.2715e-02, -3.6011e-03, -2.7008e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-4.3701e-02, -3.0365e-03, -2.0905e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.8652e-01,  1.9897e-02,  1.4648e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 1.0781e+00, -1.6797e-01, -1.5039e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.9531e+00, -5.5859e-01, -4.9219e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 6.7871e-02, -6.8359e-03, -6.0730e-03,  ..., -1.3046e-03,
         -4.3945e-03,  4.1260e-02]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/5: 8.61 seconds
8.mlp.-1.8268
tensor([[ 7.6904e-03, -1.8616e-03, -1.6556e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 3.7891e-01,  3.2715e-02,  2.3682e-02,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 3.5938e-01,  3.1006e-02,  2.2461e-02,  ..., -0.0000e+00,
    

100%|██████████| 147/147 [00:00<00:00, 7112.19it/s]


1814


1814it [00:00, 323626.86it/s]

end queue update: 0.09186625480651855
end active tasks selection: 0.007964611053466797
8.attn.5.413





tensor([[ 1.9653e-02,  8.6784e-05,  2.8038e-04,  ..., -2.4319e-05,
         -2.5635e-03, -9.7275e-05],
        [ 3.3906e+00,  3.3447e-02,  8.0078e-02,  ..., -3.2227e-02,
         -7.1484e-01, -1.8066e-02],
        [ 1.0750e+02,  1.7188e+00,  4.0938e+00,  ..., -4.8438e-01,
         -1.6500e+01,  4.6250e+00],
        ...,
        [-5.5312e+07, -4.2598e+06, -6.5536e+06,  ..., -2.5805e+05,
         -1.6794e+06, -1.9497e+06],
        [-1.7777e+06, -5.6576e+04, -1.3619e+05,  ..., -1.1520e+03,
         -2.6880e+04, -4.1984e+04],
        [-1.0547e+05, -8.9600e+03, -1.0944e+04,  ..., -4.2800e+02,
         -3.1200e+03, -3.2320e+03]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/1: 4.58 seconds
end cont


100%|██████████| 13/13 [00:00<00:00, 14942.71it/s]


609


609it [00:00, 647321.63it/s]

end queue update: 0.02876424789428711
end active tasks selection: 0.001725912094116211
7.mlp.-1.9679





tensor([[ 1.7900e+02,  2.5000e+00,  7.3125e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.5940e-03,  1.1139e-03,  5.4169e-04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-2.3926e-02,  2.2278e-03, -7.6675e-04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        ...,
        [-7.4707e-02, -4.0245e-04, -3.0670e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 4.1920e+03,  1.1100e+02,  2.1400e+02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.2891e-01,  2.7466e-03, -4.6082e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/5: 6.19 seconds
7.mlp.-1.8265
tensor([[-2.4292e-02, -1.2131e-03, -6.5613e-04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-1.4709e-02, -1.4038e-03, -4.4632e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.5015e-02, -1.7700e-03, -1.2512e-03,  ..., -0.0000e+00,
    

100%|██████████| 155/155 [00:00<00:00, 5262.70it/s]


3018


3018it [00:00, 267506.54it/s]

end queue update: 0.13957428932189941
end active tasks selection: 0.0024411678314208984
7.attn.5.3381





tensor([[ 3.7695e-01, -1.0254e-02, -4.8828e-03,  ...,  6.8359e-03,
          4.5776e-03, -2.0599e-04],
        [ 8.3203e-01, -1.4099e-02, -5.6458e-03,  ..., -1.6602e-02,
          1.5198e-02, -7.6294e-04],
        [ 6.5600e+02,  2.5000e+01,  1.2875e+01,  ..., -3.0312e+00,
          3.3203e-02, -3.2031e-01],
        ...,
        [ 3.2512e+04, -2.8320e+03, -1.1920e+03,  ...,  5.8000e+01,
          1.8800e+02,  1.5200e+02],
        [ 6.7994e+05, -5.9136e+04, -2.4832e+04,  ...,  1.2160e+03,
          3.9200e+03,  3.1840e+03],
        [ 2.7392e+04, -2.3840e+03, -1.0000e+03,  ...,  4.9000e+01,
          1.5800e+02,  1.2800e+02]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/1: 6.00 seconds
end cont


100%|██████████| 19/19 [00:00<00:00, 13963.86it/s]


1088


1088it [00:00, 698515.65it/s]

end queue update: 0.039216041564941406
end active tasks selection: 0.003609895706176758
6.mlp.-1.9716





tensor([[ 8.7402e-02,  3.4943e-03,  3.7689e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.7334e-02, -6.2866e-03, -3.1281e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 2.9541e-02, -1.8539e-03,  2.6131e-04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        ...,
        [ 5.2000e+02,  9.1250e+00,  9.3750e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 3.6875e+00,  2.2656e-01,  2.1484e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-7.2754e-02,  3.6163e-03,  1.0147e-03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/8: 6.60 seconds
6.mlp.-1.8178
tensor([[-3.6865e-02,  4.5586e-04, -2.4796e-04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 2.8521e+09,  2.1810e+09,  1.5854e+09,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 3.9649e+06,  3.0310e+06,  2.2118e+06,  ..., -0.0000e+00,
    

100%|██████████| 250/250 [00:00<00:00, 4576.64it/s]


5129


5129it [00:00, 305199.33it/s]


end queue update: 0.22641229629516602
end active tasks selection: 0.0058858394622802734
6.attn.5.1273
tensor([[-8.7040e+03, -8.3750e+00, -2.8750e+01,  ..., -1.6700e+02,
         -4.6200e+02, -1.1000e+02],
        [ 6.5200e+02, -1.3000e+01, -7.9375e+00,  ..., -4.4750e+01,
         -7.5500e+01, -1.5875e+01],
        [ 1.2544e+04,  5.6500e+01,  4.0500e+01,  ...,  3.6800e+02,
         -7.0000e+02, -2.1300e+02],
        ...,
        [ 1.1264e+05,  7.8720e+03,  5.8240e+03,  ..., -1.1840e+04,
         -2.7520e+03, -4.6400e+02],
        [-7.2188e+00, -2.0020e-01, -1.6992e-01,  ...,  1.5938e+00,
         -3.1641e-01,  1.2891e+00],
        [-1.1264e+04,  5.8800e+02,  3.6600e+02,  ...,  3.0400e+02,
         -2.7840e+03, -2.0469e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/10: 6.02 seconds
6.attn.3.3641
tensor([[-2.3101e+06,  9.8304e+04,  6.0160e+04,  ...,  6.1184e+04,
         -7.5366e+05, -1.7000e+02],
        [ 2.6953e-01, -5.6152e-03, -4.0283e-03,  ..., -4.5898e-02,
    

100%|██████████| 289/289 [00:00<00:00, 1540.34it/s]


18554


18554it [00:00, 67953.40it/s]


end queue update: 1.0291364192962646
end active tasks selection: 0.01345205307006836
5.mlp.-1.995
tensor([[ 3.6659e+05, -1.7664e+04, -1.3376e+04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-2.3069e+07, -2.1299e+06, -1.5319e+06,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-2.9327e+06, -2.9901e+05, -2.1811e+05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [-1.4008e+06, -2.5600e+05, -1.3619e+05,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 7.0451e+05, -9.1136e+04, -6.0928e+04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 6.3281e-01,  8.3984e-02,  2.9419e-02,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/39: 6.15 seconds
5.mlp.-1.9530
tensor([[-1.9923e+07,  1.2370e+06,  8.6835e+05,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-3.4560e+03, -3.2800e+02, -2.3200e+02,  ..., -3.0400e+02,
        

100%|██████████| 1234/1234 [00:01<00:00, 1101.37it/s]


74931


74931it [00:00, 251008.43it/s]


end queue update: 3.588798761367798
end active tasks selection: 0.03189444541931152
5.attn.5.995
tensor([[ 2.1100e+02, -1.3828e+00, -5.1562e-01,  ..., -7.3750e+00,
          8.3984e-01,  3.2227e-01],
        [-4.5056e+04,  6.6400e+02,  1.5400e+02,  ..., -2.1200e+02,
         -2.7800e+02,  2.8400e+02],
        [ 2.8467e+05, -3.6864e+04, -1.4592e+04,  ...,  9.1136e+04,
          1.1648e+04, -1.2736e+04],
        ...,
        [ 3.1248e+08, -6.2915e+06, -1.2452e+06,  ..., -4.2189e+05,
          2.8672e+06, -4.7841e+06],
        [-2.5559e+07,  9.9942e+05,  3.7274e+05,  ..., -5.1814e+05,
         -4.8128e+05,  5.8573e+05],
        [ 1.2499e+09, -3.9059e+07, -1.0748e+07,  ..., -6.5274e+07,
          1.1796e+07, -1.3697e+07]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/21: 5.04 seconds
5.attn.4.131
tensor([[-1.3025e+06,  7.3728e+04, -2.3808e+04,  ...,  7.2499e+05,
         -9.8816e+04, -2.4678e+05],
        [-8.9702e+05,  5.5552e+04,  2.8800e+04,  ...,  1.3312e+05,
         -

100%|██████████| 644/644 [00:00<00:00, 1463.55it/s]


25973


25973it [00:00, 285479.41it/s]


end queue update: 1.4729723930358887
end active tasks selection: 0.02614116668701172
4.mlp.-1.994
tensor([[ 3.1293e+06,  7.3216e+04,  4.2752e+04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 4.6557e+08, -6.7633e+07, -9.2275e+07,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 2.1135e+06,  9.4208e+04,  6.7584e+04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [-4.9280e+03, -1.2450e+02, -3.1500e+01,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 3.0467e+10, -2.9360e+08, -7.1303e+07,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-4.1779e+06, -8.5504e+04, -2.1248e+04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/53: 3.83 seconds
4.mlp.-1.941
tensor([[ 4.0550e+05, -1.7152e+04, -3.9360e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 4.6976e+08, -1.4785e+08, -3.5127e+07,  ..., -0.0000e+00,
         

100%|██████████| 1695/1695 [00:00<00:00, 2150.46it/s]


52809


52809it [00:00, 258906.99it/s]


end queue update: 2.694352388381958
end active tasks selection: 0.02860403060913086
4.attn.5.2942
tensor([[-3.2200e+02,  3.4844e+00,  1.2250e+01,  ..., -7.3853e-03,
         -1.1094e+00, -7.6172e-01],
        [ 3.2000e+02, -2.1750e+01, -2.5750e+01,  ...,  8.1787e-03,
          9.1797e-01,  6.0547e-01],
        [-2.3680e+04,  4.3000e+02,  1.9900e+02,  ...,  6.3705e-04,
          4.2188e+00,  3.2812e+00],
        ...,
        [-5.0594e+07, -8.6835e+05, -7.8643e+05,  ..., -8.7200e+02,
          1.5552e+04,  3.0880e+03],
        [-6.3439e+07, -1.5073e+06, -1.3599e+06,  ..., -1.3300e+02,
          3.3024e+04,  5.5680e+03],
        [ 6.4960e+03, -1.6700e+02, -1.5200e+02,  ...,  4.2725e-02,
          1.0312e+00,  2.0605e-01]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/13: 4.03 seconds
4.attn.3.3759
tensor([[-3.9059e+07, -1.0240e+06, -9.3798e+05,  ..., -3.2400e+02,
          1.6000e+04,  2.8000e+03],
        [-3.5127e+07, -6.0211e+05, -5.4886e+05,  ..., -6.0800e+02,
        

100%|██████████| 404/404 [00:00<00:00, 3416.79it/s]


14203


14203it [00:00, 368030.96it/s]


end queue update: 0.9299931526184082
end active tasks selection: 0.01660633087158203
3.mlp.-1.994
tensor([[ 5.3268e+08, -1.4877e+07, -1.4025e+07,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-7.0656e+04,  8.1920e+03,  7.1680e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.2778e+11,  3.0031e+09,  2.6172e+09,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        ...,
        [-6.5536e+05, -9.2160e+04, -7.4752e+04,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-3.0540e+07, -1.1338e+07, -9.1750e+06,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-7.5497e+08, -9.4372e+07, -7.6546e+07,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/41: 3.29 seconds
3.mlp.-1.9564
tensor([[ 6.2976e+04, -3.6640e+03, -2.9760e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 3.5021e+05, -1.6320e+04, -1.3248e+04,  ..., -0.0000e+00,
        

100%|██████████| 1300/1300 [00:00<00:00, 3802.71it/s]


33146


33146it [00:00, 235072.40it/s]


end queue update: 1.6805026531219482
end active tasks selection: 0.03036785125732422
3.attn.5.1486
tensor([[-1.3120e+04,  3.6200e+02,  6.0000e+02,  ..., -9.0000e+00,
          4.5625e+00, -2.9688e+00],
        [-2.2200e+02,  1.8500e+01, -4.8000e+01,  ...,  7.8516e-01,
         -8.3594e-01, -4.1260e-02],
        [ 2.3962e+05, -2.4000e+03, -8.7200e+02,  ...,  1.7750e+01,
          2.6562e+00,  1.4312e+01],
        ...,
        [-5.8368e+04,  2.6720e+03,  1.4320e+03,  ..., -1.2266e+00,
          1.3047e+00, -5.8750e+00],
        [-3.5635e+05, -4.5440e+03,  1.3632e+04,  ..., -3.3500e+01,
         -1.3438e+01,  1.5391e+00],
        [ 6.0211e+05,  7.2000e+03, -2.1632e+04,  ...,  6.5000e+01,
          2.4500e+01,  2.9531e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/12: 3.02 seconds
3.attn.3.3282
tensor([[ 1.5811e+06,  4.9664e+04,  1.0560e+04,  ..., -2.2750e+01,
          2.2750e+01, -9.6500e+01],
        [ 2.9000e+06,  7.4752e+04, -7.9680e+03,  ..., -1.0391e+00,
       

100%|██████████| 370/370 [00:00<00:00, 3756.22it/s]


11066


11066it [00:00, 44722.80it/s]


end queue update: 0.7920353412628174
end active tasks selection: 0.013776063919067383
2.mlp.-1.996
tensor([[ 7.0080e+03,  3.8000e+02,  1.3200e+02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.1674e+05, -9.7280e+03, -3.3120e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.9608e+08,  2.0185e+07,  7.4383e+06,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        ...,
        [ 2.9120e+03,  1.2300e+02,  5.0750e+01,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-2.7853e+05,  8.4480e+03,  3.5040e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 6.9206e+07,  6.6765e+05,  2.0787e+05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/44: 2.71 seconds
2.mlp.-1.964
tensor([[ 8.9339e+08,  1.2386e+07,  4.4892e+06,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.3770e-01, -7.3242e-03, -2.6093e-03,  ..., -0.0000e+00,
        

100%|██████████| 1379/1379 [00:00<00:00, 5486.07it/s]


24614


24614it [00:00, 276112.12it/s]


end queue update: 1.2344813346862793
end active tasks selection: 0.016874074935913086
2.attn.5.2551
tensor([[-2.0320e+03,  2.3200e+02,  1.6600e+02,  ...,  4.5938e+00,
          8.5449e-03,  3.0312e+00],
        [-6.1200e+02,  4.7000e+01,  7.0938e+00,  ...,  1.2891e+00,
          3.0859e-01,  2.7344e-01],
        [ 6.1932e+06,  5.8368e+04,  1.0048e+04,  ..., -4.5120e+03,
         -1.8400e+03, -7.8800e+02],
        ...,
        [ 6.6500e+01, -3.1375e+01, -6.5000e+01,  ...,  1.0498e-02,
          7.4219e-02, -4.0771e-02],
        [-6.9200e+02, -9.8125e+00,  4.8250e+01,  ..., -2.4316e-01,
         -1.7285e-01,  1.2634e-02],
        [ 9.9219e-01,  3.2806e-04, -2.4170e-02,  ..., -6.3705e-04,
         -1.4400e-04, -6.6683e-07]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/9: 1.64 seconds
2.attn.3.3334
tensor([[ 2.9883e-01, -6.6833e-03, -1.5076e-02,  ..., -3.4094e-05,
          4.8876e-05, -1.8835e-05],
        [ 8.6528e+04,  6.7840e+03,  3.0720e+03,  ...,  3.1375e+01,
       

100%|██████████| 266/266 [00:00<00:00, 2580.61it/s]

12796



12796it [00:00, 345022.46it/s]


end queue update: 0.8687710762023926
end active tasks selection: 0.013762474060058594
1.mlp.-1.995
tensor([[ 2.3757e+06,  8.8576e+04,  1.0086e+05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.9657e+10,  6.8719e+10,  7.8383e+10,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.7613e+05,  2.0352e+04,  2.3424e+04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 1.5974e+16, -9.0600e+14, -1.5305e+15,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 7.3183e+16, -4.1166e+15, -6.9665e+15,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 5.8203e-01, -3.2959e-02, -5.5664e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/41: 1.61 seconds
1.mlp.-1.945
tensor([[ 2.6457e+12, -1.4925e+11, -2.5340e+11,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 5.4888e+15, -3.1006e+14, -5.2557e+14,  ...,  0.0000e+00,
        

100%|██████████| 1286/1286 [00:00<00:00, 6332.28it/s]


21395


21395it [00:00, 289109.97it/s]


end queue update: 1.0369188785552979
end active tasks selection: 0.02296137809753418
1.attn.5.3931
tensor([[ 2.9983e+06,  4.6400e+03, -1.9456e+04,  ...,  1.4746e+05,
         -7.3728e+05,  6.2976e+04],
        [-5.2800e+02,  9.0000e+00,  6.3125e+00,  ..., -4.7188e+00,
         -3.4500e+01, -8.6875e+00],
        [-6.0621e+06, -1.9072e+04, -9.4720e+04,  ...,  1.3005e+05,
          5.2838e+05,  1.1930e+05],
        ...,
        [-7.5000e+01,  4.1250e+00,  2.5750e+01,  ..., -4.0312e+00,
          6.1035e-02, -5.3125e-01],
        [-1.3200e+02, -2.5156e+00, -4.4062e+00,  ...,  7.7148e-02,
         -5.4932e-02,  5.1270e-02],
        [-3.7849e+10,  2.2230e+08,  5.3687e+08,  ..., -2.0709e+07,
         -4.7312e+09, -2.2754e+08]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/23: 1.27 seconds
1.attn.4.3931
tensor([[ 1.6643e+10,  4.2729e+07, -1.7092e+08,  ...,  1.4221e+07,
         -6.2747e+09,  1.1220e+08],
        [ 5.8787e+10,  4.5928e+08,  1.1241e+09,  ..., -1.0813e+07,
       

100%|██████████| 709/709 [00:00<00:00, 11011.57it/s]

9613



9613it [00:00, 412998.91it/s]


end queue update: 0.4742410182952881
end active tasks selection: 0.025357484817504883
0.mlp.-1.996
tensor([[ 1.4848e+05,  1.3696e+04,  1.4720e+04,  ..., -1.6500e+01,
          5.2750e+01,  4.4750e+01],
        [-1.4234e+05,  6.2400e+03,  6.7200e+03,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [ 1.0630e+11, -7.0129e+09, -2.9863e+09,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 2.4371e+05,  1.4016e+04,  1.4208e+04,  ..., -1.2812e+01,
          7.7500e+01,  8.1500e+01],
        [-6.7584e+05,  2.1094e+05,  2.1402e+05,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-9.0000e+02,  2.7375e+01,  2.7750e+01,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]], dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)
Batch 1/101: 0.68 seconds
0.mlp.-1.9602
tensor([[ 3.6962e+07,  4.0468e+06,  2.1955e+06,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0905e+08,  6.9468e+06,  1.6896e+05,  ...,  0.0000e+00,
      

100%|██████████| 3206/3206 [00:00<00:00, 5121.39it/s]


32879


32879it [00:00, 246313.93it/s]


end queue update: 1.9145746231079102
Memory allocated: 0.00 GB
Memory reserved: 0.00 GB


IndexError: list index out of range

In [12]:
import pandas as pd

def flatten_contributions(contributions: DataDict) -> DataDict:
    flat_contributions = DataDict()
    if not contributions.length: return flat_contributions
    for cache_item in tqdm(contributions.to_list(), total=len(contributions)):
        num_contributions = len(cache_item['contribution_idx'])

        flat_contributions.extend({
            'contribution_idx': cache_item['contribution_idx'].tolist(),
            'contribution_values': cache_item['contribution_values'].tolist(),
            'contribution_modules': cache_item['contribution_modules'],
            'sample_ids': [cache_item['target_sample_ids']] * num_contributions,
            'source_idx': [cache_item['target_idx'].item() if cache_item['target_idx'] else None] * num_contributions,
            'source_token_pos': [cache_item['target_token_pos']] * num_contributions,
            'depth': [cache_item['depth']] * num_contributions
        })
    return flat_contributions

In [13]:
direct_flat_contributions = flatten_contributions(direct_contributions)
df = pd.DataFrame(data=direct_flat_contributions.to_dict())

for i, contribution_cache in enumerate(indirect_contributions):
    flat_contributions = flatten_contributions(contribution_cache)
    if not flat_contributions.length: continue
    for batch_start in range(0, len(flat_contributions), 2000):
        contribution_df = pd.DataFrame(data=flat_contributions[batch_start: batch_start + 2000].to_dict())
        df = pd.concat([df, contribution_df])

df['layer'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[0]) if x != 'emb' else 0)
df['module_type'] = pd.Categorical(df['contribution_modules'].apply(lambda x: x.split('.')[1] if x != 'emb' else 'emb'), categories=['emb', 'attn', 'mlp'], ordered=True)
df['token_pos'] = df.apply(lambda x: int(x['contribution_modules'].split('.')[2]) if x['contribution_modules'] != 'emb' and int(x['contribution_modules'].split('.')[2]) != -1  else x['source_token_pos'], axis=1)
df['head_id'] = df['contribution_modules'].apply(lambda x: int(x.split('.')[3]) if x != 'emb' else 0)
df = df.drop('contribution_modules', axis=1)
     
df.head()

100%|██████████| 10/10 [00:00<00:00, 36986.81it/s]


100%|██████████| 170/170 [00:00<00:00, 98839.99it/s]
100%|██████████| 17/17 [00:00<00:00, 59468.86it/s]
100%|██████████| 147/147 [00:00<00:00, 94261.23it/s]
100%|██████████| 13/13 [00:00<00:00, 69905.07it/s]
100%|██████████| 155/155 [00:00<00:00, 91282.94it/s]
100%|██████████| 19/19 [00:00<00:00, 76553.10it/s]
100%|██████████| 250/250 [00:00<00:00, 105247.01it/s]
100%|██████████| 289/289 [00:00<00:00, 67631.19it/s]
100%|██████████| 1234/1234 [00:00<00:00, 67300.84it/s]
100%|██████████| 644/644 [00:00<00:00, 79583.15it/s]
100%|██████████| 1695/1695 [00:00<00:00, 83966.33it/s]
100%|██████████| 404/404 [00:00<00:00, 92940.92it/s]
100%|██████████| 1300/1300 [00:00<00:00, 86226.12it/s]
100%|██████████| 370/370 [00:00<00:00, 91309.28it/s]
100%|██████████| 1379/1379 [00:00<00:00, 92999.94it/s]
100%|██████████| 266/266 [00:00<00:00, 83142.18it/s]
100%|██████████| 1286/1286 [00:00<00:00, 90787.63it/s]
100%|██████████| 709/709 [00:00<00:00, 111052.41it/s]
100%|██████████| 3206/3206 [00:00<00:00

Unnamed: 0,contribution_idx,contribution_values,sample_ids,source_idx,source_token_pos,depth,layer,module_type,token_pos,head_id
0,29315,0.070312,0,,3,[0],0,mlp,3,4738
1,129762,0.527344,0,,3,[0],3,mlp,3,2273
2,162988,0.09668,0,,3,[0],4,mlp,3,1195
3,199601,0.076172,0,,3,[0],5,mlp,3,3504
4,201607,0.104004,0,,3,[0],5,mlp,3,5510


In [14]:
data_path = os.path.join(project_root, 'data/contribution_cache/cache_9_post_with_norm_hom_1_no_depth.parquet')
df.to_parquet(data_path, index=False)
df.head()

ArrowInvalid: ('cannot mix list and non-list, non-null values', 'Conversion failed for column depth with type object')