In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from tqdm import tqdm
from collections import defaultdict
import json
import pickle
from typing import List, Dict, Any

from buffer import AllActivationBuffer
from trainers.scae import SCAESuite
from utils import load_model_with_folded_ln2, load_iterable_dataset

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
t.set_grad_enabled(False)
t.manual_seed(42)

model = load_model_with_folded_ln2("gpt2", device=device, torch_dtype=DTYPE)
data = load_iterable_dataset('Skylion007/openwebtext')

In [2]:
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [3]:
suite = SCAESuite.from_pretrained(
    'jacobcd52/gpt2_suite_folded_ln',
    device=device,
    dtype=DTYPE,
    )

  checkpoint = t.load(checkpoint_path, map_location='cpu')


In [4]:
initial_submodule = model.transformer.h[0]
layernorm_submodules = {}
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")

    layernorm_submodules[f"mlp_{layer}"] = model.transformer.h[layer].ln_2

buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    initial_submodule=initial_submodule,
    layernorm_submodules=layernorm_submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 32,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

  t.cuda.amp.autocast(dtype=self.dtype)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
from tqdm import tqdm
import psutil
import sys

def get_top_c_indices(top_connections_dict: Dict[str, t.Tensor], c: int, chunk_size: int = 100, 
                      memory_threshold_gb: float = 32) -> Dict[str, t.Tensor]:
    """
    Args:
        top_connections_dict: Dictionary mapping strings to sparse COO tensors, each of shape [M, N]
        c: Number of top indices to return per row
        chunk_size: Number of rows to process at once to manage memory
        memory_threshold_gb: Maximum allowed CPU memory usage in gigabytes
        
    Returns:
        Dictionary mapping strings to tensors of shape [M, c] containing indices that correspond
        to values that rank in the top c by magnitude across all dictionary entries combined
        
    Raises:
        MemoryError: If CPU memory usage exceeds memory_threshold_gb
    """
    def get_memory_usage_gb():
        """Get current memory usage in GB"""
        process = psutil.Process()
        return process.memory_info().rss / (1024 ** 3)
    
    def check_memory_usage():
        """Check if memory usage exceeds threshold"""
        current_usage = get_memory_usage_gb()
        if current_usage > memory_threshold_gb:
            raise MemoryError(f"Memory usage ({current_usage:.2f}GB) exceeded threshold ({memory_threshold_gb}GB)")
    
    # Initial memory check
    check_memory_usage()
    
    # Convert all tensors to dense and get shapes
    print("Converting sparse tensors to dense...")
    dense_dict = {key: tensor.to_dense() for key, tensor in top_connections_dict.items()}
    check_memory_usage()
    
    M, N = next(iter(dense_dict.values())).shape
    device = next(iter(dense_dict.values())).device
    num_dicts = len(dense_dict)
    dict_keys = list(dense_dict.keys())
    
    print(f"Processing {M} rows in chunks of {chunk_size}")
    print(f"Current memory usage: {get_memory_usage_gb():.2f}GB")
    
    # Initialize result dictionary with -1s on CPU
    result_dict = {key: t.full((M, c), -1, dtype=t.long) for key in dict_keys}
    check_memory_usage()
    
    # Process chunks
    chunk_pbar = tqdm(range(0, M, chunk_size), desc="Processing chunks")
    for start_idx in chunk_pbar:
        end_idx = min(start_idx + chunk_size, M)
        chunk_pbar.set_postfix({'mem_usage': f'{get_memory_usage_gb():.2f}GB'})
        
        # Stack chunk of all tensors
        chunk_values = t.stack([dense[start_idx:end_idx] for dense in dense_dict.values()], dim=1).cuda()
        chunk_size_actual = end_idx - start_idx
        
        # Get absolute values
        abs_values = chunk_values.abs()
        
        # Create indices tensors
        batch_idx = t.arange(chunk_size_actual, device='cuda')[:, None, None].expand(-1, num_dicts, N)
        dict_idx = t.arange(num_dicts, device='cuda')[None, :, None].expand(chunk_size_actual, -1, N)
        col_idx = t.arange(N, device='cuda')[None, None, :].expand(chunk_size_actual, num_dicts, -1)
        
        # Mask for nonzero values
        nonzero_mask = chunk_values != 0
        
        # Get values and indices where values are nonzero
        values_flat = abs_values[nonzero_mask]
        batch_flat = batch_idx[nonzero_mask]
        dict_flat = dict_idx[nonzero_mask]
        col_flat = col_idx[nonzero_mask]
        
        # Group by batch within chunk
        batch_sizes = nonzero_mask.sum(dim=(1,2))
        batch_groups = t.split(t.arange(values_flat.size(0), device='cuda'), batch_sizes.tolist())
        
        # Sort values within each batch group and get top c
        batch_pbar = tqdm(enumerate(batch_groups), 
                         total=len(batch_groups), 
                         desc="Processing batches",
                         leave=False)
        
        for b, group in batch_pbar:
            if len(group) > 0:
                # Sort this batch's values
                sorted_vals, sort_idx = values_flat[group].sort(descending=True)
                top_c_idx = group[sort_idx[:c]]
                
                # Get corresponding dictionary indices and column indices
                top_dict_indices = dict_flat[top_c_idx]
                top_col_indices = col_flat[top_c_idx]
                
                # For each dictionary
                for d, key in enumerate(dict_keys):
                    # Get indices where this dictionary appears
                    dict_mask = top_dict_indices == d
                    if dict_mask.any():
                        # Get columns for this dictionary and place them in result
                        dict_cols = top_col_indices[dict_mask]
                        num_cols = dict_cols.size(0)
                        result_dict[key][start_idx + b, :num_cols] = dict_cols
            
            check_memory_usage()
        
        # Clear GPU memory
        del chunk_values, abs_values, batch_idx, dict_idx, col_idx
        del values_flat, batch_flat, dict_flat, col_flat
        t.cuda.empty_cache()
    
    return {k : v.cuda() for k, v in result_dict.items()}

In [24]:
c=100

connections = {}

for down_layer in range(12):
    down_name = f"mlp_{down_layer}"
    conns = {}
    
    for up_name in [f"mlp_{i}" for i in range(down_layer)] + [f"attn_{i}" for i in range(down_layer+1)]:
        with open(f"/root/dictionary_learning/notebooks/importance_scores/importance_{up_name}_to_{down_name}.pkl", "rb") as f:
            conns[up_name] = pickle.load(f)
    
    connections[down_name] = get_top_c_indices(conns, c)

Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 2.88GB


Processing chunks: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s, mem_usage=2.88GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 4.01GB


Processing chunks: 100%|██████████| 123/123 [00:14<00:00,  8.30it/s, mem_usage=4.01GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 5.13GB


Processing chunks: 100%|██████████| 123/123 [00:14<00:00,  8.60it/s, mem_usage=5.13GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 6.43GB


Processing chunks: 100%|██████████| 123/123 [00:17<00:00,  7.01it/s, mem_usage=6.43GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 7.65GB


Processing chunks: 100%|██████████| 123/123 [00:22<00:00,  5.55it/s, mem_usage=7.66GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 8.88GB


Processing chunks: 100%|██████████| 123/123 [00:25<00:00,  4.85it/s, mem_usage=8.87GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 10.07GB


Processing chunks: 100%|██████████| 123/123 [00:24<00:00,  5.01it/s, mem_usage=10.07GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 11.33GB


Processing chunks: 100%|██████████| 123/123 [00:25<00:00,  4.82it/s, mem_usage=11.33GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 12.62GB


Processing chunks: 100%|██████████| 123/123 [00:27<00:00,  4.50it/s, mem_usage=12.62GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 13.86GB


Processing chunks: 100%|██████████| 123/123 [00:30<00:00,  4.03it/s, mem_usage=13.87GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 14.96GB


Processing chunks: 100%|██████████| 123/123 [00:31<00:00,  3.90it/s, mem_usage=15.03GB]


Converting sparse tensors to dense...
Processing 12288 rows in chunks of 100
Current memory usage: 15.95GB


Processing chunks: 100%|██████████| 123/123 [00:31<00:00,  3.96it/s, mem_usage=16.06GB]


In [25]:
# Save connections as pickle file
with open(f"/root/dictionary_learning/notebooks/connections_100.pkl", "wb") as f:
    pickle.dump(connections, f)

In [None]:
suite.conncetions = connections

In [20]:
def run_evaluation(
        suite, 
        buffer, 
        n_batches=10, 
        ce_batch_size=32,
        use_sparse_connections=False
        ):
    '''Simple function to run evaluation on several batches, and return the average metrics'''
    
    varexp_metrics = {name : {} for name in buffer.submodules.keys()}
    ce_metrics = {name : {} for name in buffer.submodules.keys()}

    for i in tqdm(range(n_batches)):
        # get varexp metrics
        initial_acts, input_acts, output_acts, layernorm_scales = next(buffer)
        batch_varexp_metrics = suite.evaluate_varexp_batch(
            initial_acts,
            input_acts, 
            output_acts,
            layernorm_scales,
            use_sparse_connections=use_sparse_connections
            )

        # # get CE metrics
        # b = buffer.refresh_batch_size
        # buffer.refresh_batch_size = ce_batch_size
        # tokens = buffer.token_batch()
        # batch_ce_metrics = suite.evaluate_ce_batch(
        #     model, 
        #     tokens, 
        #     initial_submodule,
        #     submodules,
        #     layernorm_submodules,
        #     use_sparse_connections=use_sparse_connections
        #     )
        # buffer.refresh_batch_size = b

        for name in ce_metrics.keys():
            # for metric in batch_ce_metrics[name].keys():
            #     ce_metrics[name][metric] = ce_metrics[name].get(metric, 0) + batch_ce_metrics[name].get(metric, 0) / n_batches
            for metric in batch_varexp_metrics[name].keys():
                varexp_metrics[name][metric] = varexp_metrics[name].get(metric, 0) + batch_varexp_metrics[name].get(metric, 0) / n_batches
           
    return varexp_metrics, ce_metrics

In [21]:
varexp_metrics, ce_metrics = run_evaluation(
    suite, 
    buffer, 
    n_batches=2, 
    ce_batch_size=1,
    use_sparse_connections=True
    )

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:02<00:00,  1.49s/it]


In [22]:
for name in ce_metrics.keys():
    ce_metrics[name]['loss_reconstructed'] = 1
    ce_metrics[name]['loss_original'] = 1
    ce_metrics[name]['frac_recovered'] = 1

In [23]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 1.000

Module  CE increase  CE expl FVU
mlp_0   0.000        100%     46%
mlp_1   0.000        100%     232%
mlp_2   0.000        100%     193%
mlp_3   0.000        100%     175%
mlp_4   0.000        100%     186%
mlp_5   0.000        100%     216%
mlp_6   0.000        100%     150%
mlp_7   0.000        100%     126%
mlp_8   0.000        100%     118%
mlp_9   0.000        100%     116%
mlp_10   0.000        100%     142%
mlp_11   0.000        100%     90%

attn_0   0.000        100%     1%
attn_1   0.000        100%     3%
attn_2   0.000        100%     4%
attn_3   0.000        100%     6%
attn_4   0.000        100%     8%
attn_5   0.000        100%     6%
attn_6   0.000        100%     7%
attn_7   0.000        100%     7%
attn_8   0.000        100%     8%
attn_9   0.000        100%     6%
attn_10   0.000        100%     5%
attn_11   0.000        100%     1%


In [36]:
print(f"Clean loss = {ce_metrics['mlp_0']['loss_original']:.3f}\n")

print("Module  CE increase  CE expl FVU")
for name in [k for k in ce_metrics.keys() if 'mlp' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

print()

for name in [k for k in ce_metrics.keys() if 'attn' in k]:
    print(f"{name}   {ce_metrics[name]['loss_reconstructed'] - ce_metrics[name]['loss_original']:.3f}        {ce_metrics[name]['frac_recovered']*100:.0f}%     {varexp_metrics[name]['FVU']*100:.0f}%")

Clean loss = 1.000

Module  CE increase  CE expl FVU
mlp_0   0.000        100%     4%
mlp_1   0.000        100%     11%
mlp_2   0.000        100%     19%
mlp_3   0.000        100%     12%
mlp_4   0.000        100%     15%
mlp_5   0.000        100%     16%
mlp_6   0.000        100%     17%
mlp_7   0.000        100%     17%
mlp_8   0.000        100%     17%
mlp_9   0.000        100%     15%
mlp_10   0.000        100%     11%
mlp_11   0.000        100%     8%

attn_0   0.000        100%     1%
attn_1   0.000        100%     3%
attn_2   0.000        100%     4%
attn_3   0.000        100%     6%
attn_4   0.000        100%     8%
attn_5   0.000        100%     7%
attn_6   0.000        100%     8%
attn_7   0.000        100%     7%
attn_8   0.000        100%     8%
attn_9   0.000        100%     7%
attn_10   0.000        100%     6%
attn_11   0.000        100%     1%


In [38]:
top_connections['mlp_10']

{'mlp_0': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [  250,    -1, 11139,  ...,    -1,  7932,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
         [  250,    -1, 11139,  ..., 12242,    -1,    -1],
         [  250,    -1,    -1,  ..., 11132,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0'),
 'mlp_1': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,  3208,    -1,  ...,    -1,    -1,  2056],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
         [   -1,  3208,    -1,  ...,    -1,  3180,    -1],
         [   -1,    -1, 11580,  ...,    -1,    -1,    -1],
         [ 3208,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0'),
 'attn_0': tensor([[   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         [   -1,    -1,    -1,  ...,    -1,    -1,    -1],
         ...,
       