In [1]:
%load_ext autoreload
%autoreload
from representations_across_sizes.gride import calculate_gride_id, get_sequences
import torch

sequences_pile = get_sequences(dataset_name="pile")

In [None]:
len(sequences_pile[0])

In [3]:
sequences_pile_debug = sequences_pile[0][:10]

In [None]:
sequences_pile_debug

In [None]:
len(lm.model.layers)


In [None]:
%autoreload

from nnsight import LanguageModel

from representations_across_sizes.utils import get_activation_cache

model = "meta-llama/Llama-3.2-1B"
lm = LanguageModel(model, device_map="auto")
remote = False

activations = get_activation_cache(lm, layer_idxs=list(range(len(lm.model.layers))), dataset=sequences_pile_debug, llm_batch_size=64)


In [None]:
from lovely_tensors import monkey_patch
monkey_patch()
activations


In [None]:
from typing import List
from torch import Tensor

# get only the last sequence, and cat into one tensor
for layer, acts in activations.items():
    acts: List[Tensor] = [act[:, -1, :] for act in acts]
    activations[layer] = torch.cat(acts, dim=0)

activations

In [None]:
%load_ext autoreload
%autoreload

from representations_across_sizes.gride import calculate_gride_id

ids = [float(calculate_gride_id(activations[layer].to('cpu'))) for layer in activations.keys()]
ids

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_context("talk")

# Create the plot
plt.figure(figsize=(8, 6))
plt.plot(ids, marker='o', label='llama', color='#1f77b4')  # Using the same blue as in reference

# Customize the plot
plt.title("Average ID over Layers")
plt.xlabel("layer")
plt.ylabel("intrinsic dimension")

# Add grid
plt.grid(True, alpha=0.3)

# Set y-axis limits similar to reference plot
plt.ylim(5, 45)

# Add legend
plt.legend()

# Show plot
plt.tight_layout()
plt.show()

# for many models over all partitions

In [None]:
import numpy as np

# use all partitions
sequences_pile = get_sequences(dataset_name="pile")

# Store IDs for each partition
all_partition_ids = []

for partition in sequences_pile:

    print(f"Processing partition {len(all_partition_ids)+1}/5")
    
    # Get activations for this partition
    activations = get_activation_cache(
        lm, 
        layer_idxs=list(range(len(lm.model.layers))), 
        dataset=partition, 
        llm_batch_size=64
    )
    
    # Get only the last sequence token and calculate IDs
    for layer, acts in activations.items():
        acts = [act[:, -1, :] for act in acts]
        activations[layer] = torch.cat(acts, dim=0)
    
    # Calculate IDs for this partition
    partition_ids = [float(calculate_gride_id(activations[layer].to('cpu'))) 
                    for layer in activations.keys()]
    all_partition_ids.append(partition_ids)
    print(partition_ids)

# Convert to numpy for easier calculations
all_partition_ids = np.array(all_partition_ids)

# Calculate mean and std across partitions
mean_ids = np.mean(all_partition_ids, axis=0)
std_ids = np.std(all_partition_ids, axis=0)


In [None]:
all_partition_ids[0]

In [None]:
all_partition_ids[1]

In [None]:
all_partition_ids[2]

In [None]:
all_partition_ids[3]

In [None]:
all_partition_ids[4]


In [None]:

# Plot with error bars
plt.figure(figsize=(8, 6))
plt.errorbar(
    range(len(mean_ids)), 
    mean_ids,
    yerr=2*std_ids,  # 2 standard deviations like in paper
    fmt='o-',
    label='llama',
    color='#1f77b4',
    capsize=3,
    markersize=4,
    linewidth=1,
    elinewidth=1
)

# Customize the plot
plt.title("Average ID over Layers")
plt.xlabel("layer")
plt.ylabel("intrinsic dimension")

# Add grid
plt.grid(True, alpha=0.3)

# Set y-axis limits similar to reference plot
plt.ylim(5, 45)

# Add legend
plt.legend()

# Show plot
plt.tight_layout()
plt.show()