In [2]:
from torch.utils.data import DataLoader, Dataset
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import gc

In [13]:
from einops import einsum

In [17]:
probes_weights = torch.load('/workspace/llm-progress-monitor/qwen3_4b_weight_tensor.pt').to(dtype=torch.float32)

In [81]:
# Get PCA components 1, 2, 4 using sklearn PCA
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

pca = PCA(n_components=9)
pca.fit(probes_weights.detach().cpu().numpy())
# Select components 1, 2, 4 (0-indexed: 0, 1, 3)
selected_components = torch.tensor(pca.components_, dtype=torch.float32).to('cuda', dtype=torch.bfloat16)  # Shape: [3, 2560]
print(f"Selected PCA components shape: {selected_components.shape}")

def get_ema_coords(coords, alpha=0.99):
    given_alpha = alpha
    coords_list = coords.tolist()
    
    ema_coords = []
    ema_coords_list = []
    cur_ema = None
    for i, coord in enumerate(coords_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = coord
        else:
            cur_ema = [alpha*(cur_ema[j]) + (1-alpha)*coord[j] for j in range(len(coord))]
        ema_coords.append(cur_ema)
    return torch.tensor(ema_coords)

def process_activations(activations):
    seq_coords = einsum(activations, selected_components, 's h, n h -> s n')
    ema_seq_coords = get_ema_coords(seq_coords)
    np_coords = ema_seq_coords.detach().cpu().numpy()
    return np_coords

# Process first 10 activation files
all_coords = []
for i in range(0,30):
    activations = torch.load(f'/workspace/llm-progress-monitor/rollouts/activations/{i}.pt')[15]
    np_coords = process_activations(activations)
    all_coords.append(np_coords)
    print(f"Processed file {i}, shape: {np_coords.shape}")


Selected PCA components shape: torch.Size([9, 2560])
Processed file 0, shape: (733, 9)
Processed file 1, shape: (2536, 9)
Processed file 2, shape: (490, 9)
Processed file 3, shape: (425, 9)
Processed file 4, shape: (674, 9)
Processed file 5, shape: (1193, 9)
Processed file 6, shape: (1599, 9)
Processed file 7, shape: (466, 9)
Processed file 8, shape: (1476, 9)
Processed file 9, shape: (1277, 9)
Processed file 10, shape: (383, 9)
Processed file 11, shape: (981, 9)
Processed file 12, shape: (1563, 9)
Processed file 13, shape: (1751, 9)
Processed file 14, shape: (853, 9)
Processed file 15, shape: (938, 9)
Processed file 16, shape: (1353, 9)
Processed file 17, shape: (642, 9)
Processed file 18, shape: (1453, 9)
Processed file 19, shape: (1380, 9)
Processed file 20, shape: (496, 9)
Processed file 21, shape: (797, 9)
Processed file 22, shape: (979, 9)
Processed file 23, shape: (1330, 9)
Processed file 24, shape: (1405, 9)
Processed file 25, shape: (1252, 9)
Processed file 26, shape: (707, 9)

In [71]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Create 2x5 subplots for 3D trajectories
fig = make_subplots(
    rows=2, cols=5,
    specs=[[{'type': 'scatter3d'} for _ in range(5)] for _ in range(2)],
    subplot_titles=[f'Sequence {i}' for i in range(10)]
)

# Add each trajectory to its subplot
for i in range(10):
    coords = all_coords[i]
    row = (i // 5) + 1
    col = (i % 5) + 1
    
    # Create color values as token position
    color_values = list(range(len(coords)))
    
    fig.add_trace(
        go.Scatter3d(
            x=coords[:, 0],
            y=coords[:, 1], 
            z=coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=3,
                color=color_values,
                colorscale='viridis',
                showscale=True if i == 0 else False,  # Show colorbar only for first subplot
                colorbar=dict(
                    title="Token Position",
                    x=1.02,
                    len=0.4
                ) if i == 0 else None
            ),
            line=dict(
                color='rgba(0,0,0,0.3)',
                width=2
            ),
            showlegend=False
        ),
        row=row, col=col
    )

fig.update_layout(
    title='3D Trajectories for All Sequences',
    width=1600,
    height=800
)

# Update scene properties for all subplots
for i in range(1, 11):
    row = ((i-1) // 5) + 1
    col = ((i-1) % 5) + 1
    fig.update_scenes(
        xaxis_title='PC1',
        yaxis_title='PC2',
        zaxis_title='PC3',
        row=row, col=col
    )

fig.show()

In [84]:
# Plot how each PCA component varies with token sequence position, averaged across all sequences
import numpy as np

# Find the maximum sequence length to determine how many positions to plot
max_seq_len = max(len(coords) for coords in all_coords)

# Initialize arrays to store sums and counts for averaging
component_sums = np.zeros((max_seq_len, 9))  # 9 PCA components
position_counts = np.zeros(max_seq_len)

# Accumulate values for each position across all sequences
for coords in all_coords:
    seq_len = len(coords)
    for pos in range(seq_len):
        component_sums[pos] += coords[pos]
        position_counts[pos] += 1

# Calculate averages (only for positions that have data)
component_averages = np.zeros((max_seq_len, 9))
valid_positions = position_counts > 0
component_averages[valid_positions] = component_sums[valid_positions] / position_counts[valid_positions, np.newaxis]

fig = go.Figure()

# Plot each PCA component
for i in range(8):
    # Only plot positions that have data
    valid_pos_indices = np.where(valid_positions)[0]
    fig.add_trace(
        go.Scatter(
            x=valid_pos_indices,
            y=component_averages[valid_positions, i],
            mode='lines+markers',
            name=f'PC{i+1}',
            line=dict(width=2),
            marker=dict(size=4)
        )
    )

fig.update_layout(
    title='All PCA Components vs Token Position (Averaged Across All Sequences)',
    xaxis_title='Token Position',
    yaxis_title='Average PCA Component Value',
    width=1000,
    height=600,
    showlegend=True
)

fig.show()
