In [1]:
"""
Main script for collecting and analyzing CrossCoder layer activations, 
logits, and decoder directions.

This script:
1. Collects layer activations and logits from both base and IT models
2. Computes metrics for decoder directions (norms, cosine similarities)
3. Calculates KL divergence between model logits
4. Analyzes and visualizes the results
5. Identifies interesting decoder directions for further study
"""

import os
import torch as th
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import gc


In [2]:

# Import your modules
from dictionary_learning.dictionary import BatchTopKCrossCoder
from nnsight import LanguageModel

# Import the utility functions we created
# You'll need to save these as separate Python modules
import importlib
import data_collection
importlib.reload(data_collection)

from data_collection import collect_data

from analysis_utils import (
    find_interesting_directions, 
    analyze_feature_occurrence, 
    plot_decoder_stats, 
    plot_kl_divergence_analysis,
    generate_feature_report
)


In [3]:
from huggingface_hub import login
login("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")


In [4]:

# Create directories for outputs
os.makedirs('saved_data', exist_ok=True)
os.makedirs('plots', exist_ok=True)
os.makedirs('reports', exist_ok=True)

# Set up models (assuming you've already loaded them as in your notebook)
print("Loading models...")
crosscoder = BatchTopKCrossCoder.from_pretrained(
    "science-of-finetuning/gemma-2-2b-L13-k100-lr1e-04-local-shuffling-CCLoss", 
    from_hub=True, 
    device="cuda:0"
)

gemma_2 = LanguageModel("google/gemma-2-2b", device_map="cuda:0")
gemma_2_it = LanguageModel("google/gemma-2-2b-it", device_map="cuda:1")


Loading models...


In [5]:
# # Clear CUDA memory and print stats for both GPUs
# import gc
# gc.collect()  # Run garbage collection first
# import torch
# print("CUDA Memory Stats for GPU 0:")
# print(f"Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
# print(f"Reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")
# print(f"Max Allocated: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB")

# print("\nCUDA Memory Stats for GPU 1:")
# print(f"Allocated: {torch.cuda.memory_allocated(1) / 1024**2:.2f} MB")
# print(f"Reserved: {torch.cuda.memory_reserved(1) / 1024**2:.2f} MB")
# print(f"Max Allocated: {torch.cuda.max_memory_allocated(1) / 1024**2:.2f} MB")
# th.cuda.empty_cache()
# th.cuda.synchronize(0)  # Force clear by waiting for all CUDA operations to finish
# th.cuda.empty_cache()
# th.cuda.empty_cache()
# th.cuda.synchronize(1)  # Force clear by waiting for all CUDA operations to finish
# th.cuda.empty_cache()


In [7]:
# Load your jokes dataset
print("Loading jokes data...")
jokes_df = pd.read_csv('shortjokes_500.csv')

# Get the joke texts
jokes = jokes_df["Joke"].tolist()

# Set the token index to analyze (default from your notebook: -5)
token_index = -5
token_index_range = list(range(-10,0,1))


Loading jokes data...


In [8]:

# Collect data and get CPU-based DataFrames
jokes_df, features_df, global_df = collect_data(
    jokes=jokes,
    gemma_2=gemma_2,
    gemma_2_it=gemma_2_it,
    crosscoder=crosscoder,
    token_index_range=list(range(-10,0,1)),
    save_dir='saved_data'
)


Processing jokes:   0%|          | 0/500 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Processing jokes:   0%|          | 1/500 [00:05<48:47,  5.87s/it]

Error processing joke 0: Accessing value before it's been set.
Error processing joke 1: Accessing value before it's been set.


Processing jokes:   1%|          | 3/500 [00:06<12:27,  1.50s/it]

Error processing joke 2: Accessing value before it's been set.
Error processing joke 3: Accessing value before it's been set.


Processing jokes:   1%|          | 5/500 [00:06<05:58,  1.38it/s]

Error processing joke 4: Accessing value before it's been set.
Error processing joke 5: Accessing value before it's been set.


Processing jokes:   1%|▏         | 7/500 [00:07<03:40,  2.24it/s]

Error processing joke 6: Accessing value before it's been set.
Error processing joke 7: Accessing value before it's been set.


Processing jokes:   2%|▏         | 9/500 [00:07<02:42,  3.02it/s]

Error processing joke 8: Accessing value before it's been set.
Error processing joke 9: Accessing value before it's been set.


Processing jokes:   2%|▏         | 11/500 [00:08<02:15,  3.61it/s]

Error processing joke 10: Accessing value before it's been set.
Error processing joke 11: Accessing value before it's been set.


Processing jokes:   3%|▎         | 13/500 [00:08<02:03,  3.95it/s]

Error processing joke 12: Accessing value before it's been set.
Error processing joke 13: Accessing value before it's been set.


Processing jokes:   3%|▎         | 15/500 [00:09<01:55,  4.18it/s]

Error processing joke 14: Accessing value before it's been set.
Error processing joke 15: Accessing value before it's been set.


Processing jokes:   3%|▎         | 17/500 [00:09<01:51,  4.33it/s]

Error processing joke 16: Accessing value before it's been set.
Error processing joke 17: Accessing value before it's been set.


Processing jokes:   4%|▍         | 19/500 [00:10<01:51,  4.31it/s]

Error processing joke 18: Accessing value before it's been set.
Error processing joke 19: Accessing value before it's been set.


Processing jokes:   4%|▍         | 21/500 [00:10<01:49,  4.37it/s]

Error processing joke 20: Accessing value before it's been set.
Error processing joke 21: Accessing value before it's been set.


Processing jokes:   5%|▍         | 23/500 [00:10<01:47,  4.43it/s]

Error processing joke 22: Accessing value before it's been set.
Error processing joke 23: Accessing value before it's been set.


Processing jokes:   5%|▌         | 25/500 [00:11<01:47,  4.42it/s]

Error processing joke 24: Accessing value before it's been set.
Error processing joke 25: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: [1,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: [2,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: [3,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: [4,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [4,0,0], thread: 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


ative/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [10,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [11,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [12,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [13,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [14,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [3,0,0], thread: [15,0,0] Assertion 

da/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [46,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [47,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [48,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [49,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [50,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:94: operator(): block: [2,0,0], thread: [51,0,0] Assertion `-sizes[

In [None]:

# Find the most interesting directions
interesting_features = find_interesting_directions(
    global_df, 
    cosine_threshold=0.8,
    norm_threshold=0.1
)


In [None]:

# Analyze which of those interesting features appear most often
feature_counts = features_df[features_df['feature_index'].isin(interesting_features['feature_index'])]
feature_occurrence = feature_counts['feature_index'].value_counts().reset_index()
feature_occurrence.columns = ['feature_index', 'occurrence_count']

# Plot the results
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.scatter(global_df['cosine_similarity'], global_df['l2_norm_base'], alpha=0.5)
plt.xlabel('Cosine Similarity')
plt.ylabel('L2 Norm (Base Model)')
plt.title('Relationship between Cosine Similarity and Norm')
plt.show()

In [5]:


# Collect data and compute metrics
print("Collecting model data and computing metrics...")
df, raw_data = collect_model_data(
    jokes, 
    gemma_2, 
    gemma_2_it, 
    crosscoder, 
    token_index=token_index,
    save_dir='saved_data',
    batch_size=1
)


Loading jokes data...
Collecting model data and computing metrics...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Error processing joke 0: CUDA out of memory. Tried to allocate 82.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 79.19 MiB is free. Process 1879730 has 59.33 GiB memory in use. Process 2134231 has 19.83 GiB memory in use. Of the allocated memory 19.31 GiB is allocated by PyTorch, and 31.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


NameError: name 'gc' is not defined

In [None]:

# At this point, you should have the following files:
# - saved_data/crosscoder_metrics.csv (DataFrame with metrics)
# - saved_data/raw_activations_logits.pkl (Raw model outputs)
# - saved_data/global_decoder_stats.csv (Stats for all decoder directions)

# If you want to load the data later instead of collecting it again:
# df = pd.read_csv('saved_data/crosscoder_metrics.csv')
# with open('saved_data/raw_activations_logits.pkl', 'rb') as f:
#     raw_data = pickle.load(f)
# global_df = pd.read_csv('saved_data/global_decoder_stats.csv')

# Load the global decoder stats
global_df = pd.read_csv('saved_data/global_decoder_stats.csv')

# Find interesting decoder directions (low cosine similarity, non-small norms)
print("Finding interesting decoder directions...")
interesting_df = find_interesting_directions(
    global_df, 
    cosine_threshold=0.3,  # Adjust as needed
    norm_threshold=0.1     # Adjust as needed
)

# Get the list of interesting feature indices
interesting_features = interesting_df['feature_index'].tolist()
print(f"Found {len(interesting_features)} interesting features.")

# Analyze how often these features occur in the dataset
print("Analyzing feature occurrence...")
feature_df = analyze_feature_occurrence(df, interesting_features)

# Create visualizations
print("Generating visualizations...")
plot_decoder_stats(global_df, save_dir='plots')
plot_kl_divergence_analysis(df, save_dir='plots')

# Generate a report of the most interesting features
print("Generating feature reports...")
feature_report = generate_feature_report(global_df, feature_df, top_n=50, save_dir='reports')

# Print summary of the most interesting features
print("\nTop 10 most interesting features:")
interesting_active = feature_report.head(10)
for _, row in interesting_active.iterrows():
    feat_idx = int(row['feature_index'])
    cos_sim = row['cosine_similarity']
    l2_base = row['l2_norm_base']
    l2_it = row['l2_norm_it']
    occurrences = row['occurrence_count']
    
    print(f"Feature {feat_idx}: cos_sim={cos_sim:.4f}, l2_base={l2_base:.4f}, "
          f"l2_it={l2_it:.4f}, occurrences={occurrences}")

print("\nAnalysis complete! Results saved to 'saved_data', 'plots', and 'reports' directories.")