# Projections across different datasets

In [None]:
import torch
import os
import json
import sys
from collections import defaultdict
import pandas as pd
import numpy as np

sys.path.append('.')
sys.path.append('..')

from utils.pca_utils import L2MeanScaler, MeanScaler
from utils.inference_utils import *
from utils.probing_utils import *
from utils.steering_utils import *

torch.set_float32_matmul_precision('high')

In [None]:
model_name = "llama-3.3-70b"
layer = 40
total_layers = 80
base_dir = f"/workspace/{model_name}"

output_file = f"{base_dir}/capped/configs/multi_contrast_vectors.pt"

os.makedirs(os.path.dirname(output_file), exist_ok=True)

plot_dir = f"/root/git/plots/{model_name}/capped/projections"
os.makedirs(plot_dir, exist_ok=True)

In [None]:
# load and project default onto the vectors
default_vectors = torch.load(f"{base_dir}/roles_240/default_vectors.pt", weights_only=False)
contrast_vectors = torch.load(output_file, weights_only=False)

# Get default activations (shape: [64, 5120] - one per layer)
assistant_layer_activation = default_vectors['activations']['default_1'].float()

# Stack all contrast vectors into a single tensor (shape: [64, 5120])
contrast_vecs_stacked = torch.stack([cv['vector'].float() for cv in contrast_vectors])

# L2 normalize all contrast vectors at once
contrast_vecs_normalized = contrast_vecs_stacked / contrast_vecs_stacked.norm(p=2, dim=1, keepdim=True)

# Vectorized projection: element-wise multiply and sum along feature dimension
# This computes all 64 dot products at once (shape: [64])
default_projections = (assistant_layer_activation * contrast_vecs_normalized).sum(dim=1)

## Role/trait rollout projections

In [None]:
role_file = f"{base_dir}/capped/projections/roles_projections.jsonl"
trait_file = f"{base_dir}/capped/projections/traits_projections.jsonl"

# Load each file directly into pandas
df_role = pd.read_json(role_file, lines=True)
df_trait = pd.read_json(trait_file, lines=True)

# Add source_file column to track origin
df_role['source_file'] = 'role'
df_trait['source_file'] = 'trait'

# Concatenate the dataframes
df_rt = pd.concat([df_role, df_trait], ignore_index=True)

# Add type column based on source file, name format, and prompt_label
def determine_type(row):
    # Check if prompt_label is default
    if row.get('prompt_label') == 'default':
        return 'default'
    
    # Check if name matches {integer}_default format
    if isinstance(row['role'], str) and '_default' in row['role']:
        parts = row['role'].split('_')
        if len(parts) == 2 and parts[0].isdigit() and parts[1] == 'default':
            return 'default'
    
    # Otherwise use source file
    return row['source_file']

df_rt['type'] = df_rt.apply(determine_type, axis=1)

# Remove the temporary source_file column
df_rt = df_rt.drop('source_file', axis=1)

# Display the type distribution
print("Type distribution:")
print(df_rt['type'].value_counts())


In [None]:
# Expand projections into separate columns
projections_df = pd.json_normalize(df_rt['projections'])
df_rt = pd.concat([df_rt.drop('projections', axis=1), projections_df], axis=1)

print(f"Expanded shape: {df_rt.shape}")
print(f"\nSample projection columns: {[col for col in df_rt.columns if col.startswith('layer_')][:5]}")
print(f"\nDataFrame info:")
print(f"  - Types: {df_rt['type'].value_counts().to_dict()}")
print(f"  - Score distribution: {df_rt['score'].value_counts(dropna=False).to_dict()}")

In [None]:
def bin_scores(df):
    """
    Add a score_bin column to the dataframe that bins scores consistently:
    - Bin 0: Role score 0, Trait 0-25, Default (all)
    - Bin 1: Role score 1, Trait 25-50
    - Bin 2: Role score 2, Trait 50-75
    - Bin 3: Role score 3, Trait 75-100
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing 'type' and 'score' columns
        
    Returns:
    --------
    pd.DataFrame
        DataFrame with added 'score_bin' column
    """
    df = df.copy()
    
    # Handle REFUSAL and null scores
    df['score'] = df['score'].replace('REFUSAL', 0)
    df['score'] = df['score'].fillna(-1)
    df['score'] = df['score'].astype(int)
    
    # Create binned score column
    def bin_score(row):
        if row['type'] == 'role':
            # Role scores are already 0-3
            return row['score']
        elif row['type'] == 'trait':
            # Bin trait scores: 0-25 -> 0, 25-50 -> 1, 50-75 -> 2, 75-100 -> 3
            if row['score'] < 25:
                return 0
            elif row['score'] < 50:
                return 1
            elif row['score'] < 75:
                return 2
            else:
                return 3
        else:  # default
            # All defaults go to bin 0
            return 0
    
    df['score_bin'] = df.apply(bin_score, axis=1)
    
    return df

In [None]:
# Bin the scores
df_rt_binned = bin_scores(df_rt)

print(f"Score bin distribution:")
print(df_rt_binned['score_bin'].value_counts().sort_index())

In [None]:
def compute_projection_stats(df, percentiles=[1, 5, 10, 25, 50, 75, 90, 95, 99]):
    """
    Compute statistics for all projection columns in the role/trait dataframe.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing projection columns (columns starting with 'layer_')
    percentiles : list, optional
        List of percentile values to compute (default: [1, 5, 10, 25, 50, 75, 90, 95, 99])
        
    Returns:
    --------
    dict
        Dictionary with structure matching chat dataset JSON format:
        {
            'per_vector_stats': {
                'layer_X/vector_name': {
                    'count': int,
                    'mean': float,
                    'std': float,
                    'min': float,
                    'max': float,
                    'percentiles': {
                        '1': float,
                        '5': float,
                        ...
                    }
                },
                ...
            }
        }
    """
    # Get all projection columns (columns starting with 'layer_')
    projection_cols = [col for col in df.columns if col.startswith('layer_')]
    
    if not projection_cols:
        raise ValueError("No projection columns found (columns starting with 'layer_')")
    
    per_vector_stats = {}
    
    for col in projection_cols:
        # Get non-null values for this column
        values = df[col].dropna()
        
        if len(values) == 0:
            continue
        
        # Compute statistics
        stats = {
            'count': int(len(values)),
            'mean': float(values.mean()),
            'std': float(values.std()),
            'min': float(values.min()),
            'max': float(values.max()),
            'percentiles': {}
        }
        
        # Compute percentiles
        for p in percentiles:
            stats['percentiles'][str(p)] = float(np.percentile(values, p))
        
        per_vector_stats[col] = stats
    
    return {'per_vector_stats': per_vector_stats}

In [None]:
# Compute statistics for role/trait projections
rt_stats = compute_projection_stats(df_rt_binned)


In [None]:

# Display sample stats for one layer
# layer = 32
vector_name = f'layer_{layer}/contrast_role_pos3_default1'
sample_stats = rt_stats['per_vector_stats'][vector_name]

print(f"Statistics for {vector_name}:")
print(f"  Count: {sample_stats['count']:,}")
print(f"  Mean: {sample_stats['mean']:.6f}")
print(f"  Std: {sample_stats['std']:.6f}")
print(f"  Min: {sample_stats['min']:.6f}")
print(f"  Max: {sample_stats['max']:.6f}")
print(f"  Percentiles:")
for p, val in sample_stats['percentiles'].items():
    print(f"    p{p}: {val:.6f}")

In [None]:
import plotly.graph_objects as go

def plot_score_histogram(df, vector_name, title, subtitle, stats_data, nbinsx=50, default_projection=None):
    """
    Create a histogram of projections colored by score bins, with bars overlaid by score.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing projection column and 'score_bin' column
    projection_col : str
        Name of the projection column to plot
    title : str, optional
        Title for the plot
    nbinsx : int, optional
        Number of bins for histogram (default: 50)
    default_projection : float, optional
        Default projection value to show as a vertical line
        
    Returns:
    --------
    plotly.graph_objects.Figure
    """
    if vector_name not in df.columns:
        raise ValueError(f"Projection column {vector_name} not found in dataframe")
    
    if 'score_bin' not in df.columns:
        raise ValueError("DataFrame must have 'score_bin' column. Call bin_scores() first.")
    
    # Bin labels and colors
    bin_labels = {
        0: 'None (0-25)',
        1: 'Slightly (25-50)',
        2: 'Somewhat (50-75)',
        3: 'Fully (75-100)'
    }
    
    bin_colors = {
        0: '#e74c3c',  # red
        1: '#3498db',  # blue
        2: '#f39c12',  # orange
        3: '#2ecc71'   # green
    }
    stats = stats_data['per_vector_stats'][vector_name]
    
    # Create figure
    fig = go.Figure()
    
    # Add traces for each bin
    for bin_val in sorted(df['score_bin'].unique()):
        if bin_val == -1:
            continue
            
        data = df[df['score_bin'] == bin_val][vector_name]
        if len(data) == 0:
            continue
        
        fig.add_trace(
            go.Histogram(
                x=data,
                name=bin_labels.get(bin_val, str(bin_val)),
                marker_color=bin_colors.get(bin_val, '#95a5a6'),
                opacity=0.6,
                nbinsx=nbinsx
            )
        )

    # Add percentile lines
    percentiles = stats['percentiles']
    # Add percentile lines with red annotations at top-left
    for pct, value in percentiles.items():
        fig.add_vline(
            x=value,
            line_width=1,
            line_dash="dash",
            line_color="red",
            opacity=0.5,
            annotation=dict(
                text=f"p{pct}:<br>{value:.1f}",
                align="right",
                font=dict(color="red", size=9),
                xanchor="right",   # position text to the left of the line
                yanchor="top", # anchor text to bottom so it sits within the plot area
                y=1,              # normalized coordinate (top of plot)
                yref="paper",     # relative to the full plot height
                showarrow=False,
                opacity=0.5
            )
        )
    
    # Add default projection line if provided
    if default_projection is not None:
        fig.add_vline(
            x=default_projection,
            line_width=2,
            line_dash="solid",
            line_color="black",
            opacity=0.8,
            annotation=dict(
                text=f"Default:<br>{default_projection:.1f}",
                align="left",
                font=dict(color="black", size=10),
                xanchor="right",
                yanchor="top",
                y=0.9,
                yref="paper",
                showarrow=False,
                opacity=0.8
            )
        )
    
    # Update layout
    fig.update_layout(
        title={
            'text': title,
            'subtitle': {
                'text': subtitle,
            }
        },
        xaxis_title="Projection",
        yaxis_title="Count",
        width=1000,
        height=600,
        barmode='overlay',
        bargap=0,
        showlegend=True,
        legend=dict(
            title="Expression",
            orientation="h",
            yanchor="bottom",
            y=1.01,
            xanchor="right",
            x=1.01
        )
    )

    stats_text = f"n={stats['count']:,}<br>μ={stats['mean']:.4f}<br>σ={stats['std']:.4f}"
    fig.add_annotation(
        text=stats_text,
        xref="paper", yref="paper",
        x=0.98, y=0.98,
        xanchor="right", yanchor="top",
        showarrow=False,
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )
    
    return fig

In [None]:
# Plot projection histogram for layer 32
fig = plot_score_histogram(
    df_rt_binned, 
    vector_name=f'layer_{layer}/contrast_role_pos3_default1',
    title='Role and Trait Responses Projected on Assistant Contrast Vector',
    subtitle=f'{model_name.replace("-", " ").title()}, Layer {layer}',
    stats_data=rt_stats,
    default_projection=default_projections[layer].item()
)
fig.show()
#fig.write_html(f'{plot_dir}/layer{layer}_role_trait.html')

## Chat dataset

In [None]:
name = "lmsys_10000"
readable_name = "LMSYS-Chat-1M"
chat_file = f"{base_dir}/capped/projections/{name}.json"

# Load each file directly into pandas
df_chat = pd.read_json(chat_file)


In [None]:
def plot_chat_histogram(data, vector_name, title, subtitle):
    """
    Plot a histogram from pre-computed chat projection data.
    
    Parameters:
    -----------
    data : dict
        Dictionary containing 'per_vector_stats' with histogram bins and counts
    vector_name : str
        Name of the vector to plot (e.g., 'layer_32/contrast_role_pos3_default1')
    title : str
        Title for the plot
    subtitle : str
        Subtitle for the plot
        
    Returns:
    --------
    plotly.graph_objects.Figure
    """
    if vector_name not in data['per_vector_stats']:
        raise ValueError(f"Vector {vector_name} not found in data")
    
    stats = data['per_vector_stats'][vector_name]
    bins = stats['histogram']['bins']
    counts = stats['histogram']['counts']
    
    # Reconstruct sample data from histogram bins and counts
    # For each bin, create samples at the bin center repeated by count
    reconstructed_data = []
    for i in range(len(counts)):
        bin_center = (bins[i] + bins[i+1]) / 2
        # Add bin_center repeated counts[i] times
        reconstructed_data.extend([bin_center] * int(counts[i]))
    
    # Create figure
    fig = go.Figure()
    
    # Calculate number of bins to match original histogram
    nbins = len(counts)
    
    fig.add_trace(
        go.Histogram(
            x=reconstructed_data,
            marker_color='#3498db',  # blue
            opacity=0.8,
            name='Chat Dataset',
            nbinsx=nbins,
            xbins=dict(
                start=bins[0],
                end=bins[-1],
                size=(bins[-1] - bins[0]) / nbins
            )
        )
    )
    
    # Add percentile lines
    percentiles = stats['percentiles']
    # Add percentile lines with red annotations at top-left
    for pct, value in percentiles.items():
        fig.add_vline(
            x=value,
            line_width=1,
            line_dash="dash",
            line_color="red",
            opacity=0.5,
            annotation=dict(
                text=f"p{pct}:<br>{value:.1f}",
                align="right",
                font=dict(color="red", size=9),
                xanchor="right",   # position text to the left of the line
                yanchor="top", # anchor text to bottom so it sits within the plot area
                y=1,              # normalized coordinate (top of plot)
                yref="paper",     # relative to the full plot height
                showarrow=False,
                opacity=0.5
            )
        )
    
    # Update layout
    fig.update_layout(
        title={
            'text': title,
            'subtitle': {
                'text': subtitle,
            }
        },
        xaxis_title="Projection",
        yaxis_title="Count",
        width=1000,
        height=600,
        bargap=0,  # No gap between bars
    )
    
    # Add text annotation with statistics
    stats_text = f"n={stats['count']:,}<br>μ={stats['mean']:.4f}<br>σ={stats['std']:.4f}"
    fig.add_annotation(
        text=stats_text,
        xref="paper", yref="paper",
        x=0.98, y=0.98,
        xanchor="right", yanchor="top",
        showarrow=False,
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )
    
    return fig

In [None]:
# Plot chat histogram for layer 32
layer = 40
fig = plot_chat_histogram(
    df_chat,
    vector_name=f'layer_{layer}/contrast_role_pos3_default1',
    title='Chat Dataset Responses Projected on Assistant Contrast Vector',
    subtitle=f'{model_name.replace("-", " ").title()}, Layer {layer} - 10000 Conversations Sampled from {readable_name}'
)
fig.show()
fig.write_html(f'{plot_dir}/layer{layer}_{name}.html')

In [None]:
name = "wildchat_10000"
readable_name = "WildChat"
chat_file = f"{base_dir}/capped/projections/{name}.json"

# Load each file directly into pandas
df_chat = pd.read_json(chat_file)

fig = plot_chat_histogram(
    df_chat,
    vector_name=f'layer_{layer}/contrast_role_pos3_default1',
    title='Chat Dataset Responses Projected on Assistant Contrast Vector',
    subtitle=f'{model_name.replace("-", " ").title()}, Layer {layer} - 10000 Conversations Sampled from {readable_name}'
)
fig.show()
fig.write_html(f'{plot_dir}/layer{layer}_{name}.html')

## Comparing caps with datasets

In [None]:
# Load existing caps config if it exists
cfg_file = f"{base_dir}/capped/configs/multi_contrast_layers_config.pt"

# Experiments 4-7 are all layers (layers_0:64) with different thresholds:
# - layers_0:64-harm_0.01: Block 99% of harmful responses
# - layers_0:64-harm_0.25: Block 75% of harmful responses  
# - layers_0:64-safe_0.01: Block 99% of safe-default responses
# - layers_0:64-safe_0.50: Block 50% of safe-default responses

if os.path.exists(cfg_file):
    cfg = torch.load(cfg_file, weights_only=False)
    
    # Extract the 4 all-layer experiments (indices 4-7)
    experiments_all_layers = cfg['experiments'][4:8]
    
    # Create a dataframe mapping vector names to cap values
    caps_data = []
    
    # Get all vector names from the first experiment
    vector_names = [interv['vector'] for interv in experiments_all_layers[0]['interventions']]
    
    for vec_name in vector_names:
        # Extract layer number from vector name (e.g., "layer_32/contrast_role_pos3_default1" -> 32)
        layer_num = int(vec_name.split('/')[0].replace('layer_', ''))
        
        # Extract cap values from each of the 4 experiments
        caps = {}
        for exp in experiments_all_layers:
            exp_id = exp['id'].split('-')[1]  # e.g., "layers_0:64-harm_0.01" -> "harm_0.01"
            
            # Find this vector's cap in this experiment
            for interv in exp['interventions']:
                if interv['vector'] == vec_name:
                    caps[exp_id] = interv['cap']
                    break
        
        caps_data.append({
            'layer': layer_num,
            'vector_name': vec_name,
            'harm_0.01': caps.get('harm_0.01'),
            'harm_0.25': caps.get('harm_0.25'),
            'safe_0.01': caps.get('safe_0.01'),
            'safe_0.50': caps.get('safe_0.50'),
        })
    
    df_caps = pd.DataFrame(caps_data).sort_values('layer')
    
    print(f"Loaded caps config from {cfg_file}")
    print(f"  - {len(cfg['vectors'])} vectors")
    print(f"  - {len(cfg['experiments'])} experiments")
    print(f"  - Extracted caps for {len(df_caps)} vectors across 4 threshold levels")
else:
    print(f"Config file not found: {cfg_file}")
    df_caps = None

In [None]:
def print_caps_for_layer(layer_num, caps_df=df_caps):
    """
    Pretty print the cap values for a given layer.
    
    Parameters:
    -----------
    layer_num : int
        The layer number to display caps for
    caps_df : pd.DataFrame, optional
        The caps dataframe (defaults to df_caps)
    """
    if caps_df is None:
        print("No caps data loaded. Run the config loading cell first.")
        return
    
    # Filter to the specified layer
    layer_data = caps_df[caps_df['layer'] == layer_num]
    
    if len(layer_data) == 0:
        print(f"No cap data found for layer {layer_num}")
        return
    
    row = layer_data.iloc[0]
    
    print(f"{'='*70}")
    print(f"Cap Values for Layer {layer_num}")
    print(f"Vector: {row['vector_name']}")
    print(f"{'='*70}")
    print()
    print(f"{'Threshold':<20} {'Cap Value':>15} {'Description':<30}")
    print(f"{'-'*70}")
    print(f"{'safe_0.01':<20} {row['safe_0.01']:>15.4f} {'Block 99% of safe-default':<30}")
    print(f"{'safe_0.50':<20} {row['safe_0.50']:>15.4f} {'Block 50% of safe-default':<30}")
    print(f"{'harm_0.01':<20} {row['harm_0.01']:>15.4f} {'Block 99% of harmful':<30}")
    print(f"{'harm_0.25':<20} {row['harm_0.25']:>15.4f} {'Block 75% of harmful':<30}")
    print(f"{'='*70}")

# Display caps for the default layer (32)
if df_caps is not None:
    print_caps_for_layer(layer)

In [None]:
# Load WildChat dataset
wildchat_file = f"{base_dir}/capped/projections/wildchat_10000.json"
df_wildchat = pd.read_json(wildchat_file)

print(f"Loaded chat datasets:")
print(f"  - LMSYS-Chat-1M: {df_chat['metadata']['n_assistant_turns']:,} assistant turns")
print(f"  - WildChat: {df_wildchat['metadata']['n_assistant_turns']:,} assistant turns")

In [None]:
def compute_percentile_from_data(value, data):
    """
    Compute the percentile rank of a value in a dataset.
    
    Parameters:
    -----------
    value : float
        The value to find the percentile for
    data : array-like
        The dataset
        
    Returns:
    --------
    float
        Percentile rank (0-100)
    """
    return (data < value).sum() / len(data) * 100


def interpolate_percentile_from_stats(value, percentiles_dict):
    """
    Interpolate the percentile rank of a value from pre-computed percentiles.
    
    Parameters:
    -----------
    value : float
        The value to find the percentile for
    percentiles_dict : dict
        Dictionary mapping percentile strings (e.g., '1', '5', '10') to values
        
    Returns:
    --------
    float
        Estimated percentile rank (0-100)
    """
    # Convert to sorted lists
    pcts = sorted([int(k) for k in percentiles_dict.keys()])
    vals = [percentiles_dict[str(p)] for p in pcts]
    
    # Handle edge cases
    if value <= vals[0]:
        return pcts[0]
    if value >= vals[-1]:
        return pcts[-1]
    
    # Linear interpolation
    for i in range(len(vals) - 1):
        if vals[i] <= value <= vals[i+1]:
            # Interpolate between pcts[i] and pcts[i+1]
            t = (value - vals[i]) / (vals[i+1] - vals[i])
            return pcts[i] + t * (pcts[i+1] - pcts[i])
    
    return 50.0  # Fallback


def print_cap_percentiles_for_layer(layer_num, caps_df=df_caps, rt_df=df_rt_binned, 
                                     lmsys_data=df_chat, wildchat_data=df_wildchat):
    """
    Print the percentile rank of each cap value across different datasets.
    
    Parameters:
    -----------
    layer_num : int
        The layer number to analyze
    caps_df : pd.DataFrame
        DataFrame with cap values
    rt_df : pd.DataFrame
        Role/trait dataframe with projection columns
    lmsys_data : dict
        LMSYS chat dataset with per_vector_stats
    wildchat_data : dict
        WildChat dataset with per_vector_stats
    """
    if caps_df is None:
        print("No caps data loaded. Run the config loading cell first.")
        return
    
    # Get cap values for this layer
    layer_data = caps_df[caps_df['layer'] == layer_num]
    if len(layer_data) == 0:
        print(f"No cap data found for layer {layer_num}")
        return
    
    row = layer_data.iloc[0]
    vector_name = row['vector_name']
    
    print(f"{'='*90}")
    print(f"Cap Percentiles for Layer {layer_num}")
    print(f"Vector: {vector_name}")
    print(f"{'='*90}")
    print()
    
    # Table header
    print(f"{'Threshold':<15} {'Cap Value':>12} {'Role/Trait':>15} {'LMSYS-Chat':>15} {'WildChat':>15}")
    print(f"{'-'*90}")
    
    # Get role/trait data for this vector
    rt_values = rt_df[vector_name].dropna().values
    
    # Get chat dataset percentiles
    lmsys_percentiles = lmsys_data['per_vector_stats'][vector_name]['percentiles']
    wildchat_percentiles = wildchat_data['per_vector_stats'][vector_name]['percentiles']
    
    # Compute percentiles for each cap threshold
    thresholds = [
        ('safe_0.01', 'Block 99% safe'),
        ('safe_0.50', 'Block 50% safe'),
        ('harm_0.01', 'Block 99% harm'),
        ('harm_0.25', 'Block 75% harm'),
    ]
    
    for threshold_name, description in thresholds:
        cap_value = row[threshold_name]
        
        # Compute percentiles in each dataset
        rt_pct = compute_percentile_from_data(cap_value, rt_values)
        lmsys_pct = interpolate_percentile_from_stats(cap_value, lmsys_percentiles)
        wildchat_pct = interpolate_percentile_from_stats(cap_value, wildchat_percentiles)
        
        print(f"{threshold_name:<15} {cap_value:>12.4f} {rt_pct:>14.1f}% {lmsys_pct:>14.1f}% {wildchat_pct:>14.1f}%")
    
    print(f"{'='*90}")
    print()
    print("Interpretation:")
    print("  - Percentile shows what % of responses fall BELOW the cap threshold")
    print("  - Higher percentile = more responses blocked by the cap")
    print(f"  - Role/Trait dataset: {len(rt_values):,} samples")
    print(f"  - LMSYS-Chat dataset: {lmsys_data['per_vector_stats'][vector_name]['count']:,} samples")
    print(f"  - WildChat dataset: {wildchat_data['per_vector_stats'][vector_name]['count']:,} samples")


In [None]:
# Compute percentiles for all layers across all datasets and thresholds
def compute_all_layer_percentiles(caps_df=df_caps, rt_df=df_rt_binned, 
                                   lmsys_data=df_chat, wildchat_data=df_wildchat):
    """
    Compute percentiles for all layers, datasets, and cap thresholds.
    
    Returns:
    --------
    pd.DataFrame
        DataFrame with columns: layer, dataset, threshold, percentile
    """
    results = []
    
    thresholds = ['safe_0.01', 'safe_0.50', 'harm_0.01', 'harm_0.25']
    
    for _, row in caps_df.iterrows():
        layer_num = row['layer']
        vector_name = row['vector_name']
        
        # Get data for this layer
        rt_values = rt_df[vector_name].dropna().values
        lmsys_percentiles = lmsys_data['per_vector_stats'][vector_name]['percentiles']
        wildchat_percentiles = wildchat_data['per_vector_stats'][vector_name]['percentiles']
        
        for threshold in thresholds:
            cap_value = row[threshold]
            
            # Compute percentile for each dataset
            rt_pct = compute_percentile_from_data(cap_value, rt_values)
            lmsys_pct = interpolate_percentile_from_stats(cap_value, lmsys_percentiles)
            wildchat_pct = interpolate_percentile_from_stats(cap_value, wildchat_percentiles)
            
            results.append({
                'layer': layer_num,
                'dataset': 'Role/Trait',
                'threshold': threshold,
                'percentile': rt_pct
            })
            results.append({
                'layer': layer_num,
                'dataset': 'LMSYS-Chat',
                'threshold': threshold,
                'percentile': lmsys_pct
            })
            results.append({
                'layer': layer_num,
                'dataset': 'WildChat',
                'threshold': threshold,
                'percentile': wildchat_pct
            })
    
    return pd.DataFrame(results)

# Compute all percentiles
print("Computing percentiles for all layers...")
df_percentiles = compute_all_layer_percentiles()
print(f"Computed {len(df_percentiles)} data points")

In [None]:
# Create line plot showing percentiles across all layers
fig = go.Figure()

# Define colors for datasets
dataset_colors = {
    'Role/Trait': '#3498db',    # blue
    'LMSYS-Chat': '#2ecc71',    # green
    'WildChat': '#e67e22'       # orange
}

# Define markers for thresholds
threshold_markers = {
    'safe_0.01': 'circle',
    'safe_0.50': 'square',
    'harm_0.01': 'diamond',
    'harm_0.25': 'triangle-up'
}

# Define threshold labels
threshold_labels = {
    'safe_0.01': 'Block 99% safe',
    'safe_0.50': 'Block 50% safe',
    'harm_0.01': 'Block 99% harm',
    'harm_0.25': 'Block 75% harm'
}

# Add a trace for each dataset-threshold combination
for dataset in ['Role/Trait', 'LMSYS-Chat', 'WildChat']:
    for threshold in ['safe_0.01', 'safe_0.50', 'harm_0.01', 'harm_0.25']:
        # Filter data for this combination
        mask = (df_percentiles['dataset'] == dataset) & (df_percentiles['threshold'] == threshold)
        subset = df_percentiles[mask].sort_values('layer')
        
        # Create trace name
        trace_name = f"{dataset} - {threshold_labels[threshold]}"
        
        fig.add_trace(
            go.Scatter(
                x=subset['layer'],
                y=subset['percentile'],
                mode='lines+markers',
                name=trace_name,
                line=dict(color=dataset_colors[dataset], width=2),
                marker=dict(
                    symbol=threshold_markers[threshold],
                    size=6,
                    color=dataset_colors[dataset],
                    line=dict(width=1, color='white')
                ),
                legendgroup=dataset,
                showlegend=True
            )
        )

# Update layout
fig.update_layout(
    title={
        'text': 'Percentage of Responses Blocked by Caps',
        'subtitle': {
            'text': f'{model_name.replace("-", " ").title()} - Caps from Jailbreak Rollouts',
        }
    },
    xaxis_title="Layer",
    yaxis_title="Percentile (%)",
    width=1200,
    height=600,
    hovermode='closest',
    legend=dict(
        title="Dataset - Threshold",
        orientation="v",
        yanchor="top",
        y=1.0,
        xanchor="left",
        x=1.01,
        font=dict(size=10)
    ),
    xaxis=dict(
        tickmode='linear',
        tick0=0,
        dtick=4
    ),
    yaxis=dict(
        range=[-2, 102],
        ticksuffix='%'
    )
)

fig.show()
fig.write_html(f'{plot_dir}/cap_percentile_comparison.html')
print(f"Saved plot to {plot_dir}/cap_percentile_comparison.html")