In [1]:
import os
import json
import glob
import torch
import pickle
import re
import plotly.graph_objects as go
import einops
import pandas as pd
from functools import partial
from torch import Tensor
from torchtyping import TensorType as TT
from utils.visualization import imshow_p

In [85]:
MODEL_TO_VIEW = "pythia-160m"

In [86]:
filename = f"backup_research/FINAL_figure_files/sr_over_time/data/{MODEL_TO_VIEW}/sample_results_dict.pt"
data = torch.load(filename)

In [87]:

# sort the dictionary by the key
data = dict(sorted(data.items(), key=lambda item: int(item[0])))
data.keys()

dict_keys([256, 512, 1000, 2000, 3000, 5000, 10000, 30000, 60000, 90000, 143000])

### Single Checkpoint

In [88]:
CHECKPOINT = 143000

In [89]:
repair_score = 1 - (data[CHECKPOINT]['thresholded_cil'][0] / -data[CHECKPOINT]['thresholded_de'][0])
#weighted_repair_score = (data[CHECKPOINT]['thresholded_de'][0] / -data[CHECKPOINT]['thresholded_cil'][0]) * data[CHECKPOINT]['thresholded_de'][0]

In [90]:
layer = 11
for head in range(12):
    print((data[143000]['thresholded_de'][0][layer, head], data[143000]['thresholded_cil'][0][layer, head], repair_score[layer, head]))

(tensor(0.3570), tensor(-0.0820), tensor(0.7702))
(tensor(0.1418), tensor(-0.0197), tensor(0.8609))
(tensor(0.1189), tensor(-0.0263), tensor(0.7787))
(tensor(0.0934), tensor(-0.0224), tensor(0.7601))
(tensor(0.0256), tensor(-0.0113), tensor(0.5582))
(tensor(0.2075), tensor(-0.0144), tensor(0.9304))
(tensor(0.1735), tensor(-0.0452), tensor(0.7394))
(tensor(0.0588), tensor(-0.0193), tensor(0.6724))
(tensor(0.0525), tensor(-0.0232), tensor(0.5584))
(tensor(0.1587), tensor(-0.0405), tensor(0.7451))
(tensor(0.0821), tensor(-0.0182), tensor(0.7782))
(tensor(0.1963), tensor(-0.0666), tensor(0.6608))


In [91]:
from typing import Union
from jaxtyping import Float
import numpy as np
import itertools

def create_layered_scatter(
    heads_x: Float[Tensor, "layer head"],
    heads_y: Float[Tensor, "layer head"], 
    x_title: str, 
    y_title: str, 
    plot_title: str,
    mlp_x: Union[Float[Tensor, "layer"], None] = None,
    mlp_y: Union[Float[Tensor, "layer"], None] = None,
    x_range: Union[list, None] = None,  # New parameter for x-range
    y_range: Union[list, None] = None   # New parameter for y-range
):
    """
    This function now also accepts x_data and y_data for MLP layers and manual x- and y-ranges. 
    It plots properties of transformer heads and MLP layers with layered coloring and annotations.
    Additionally, it plots a dotted line where the negative value of the y-axis equals the positive value of the x-axis.
    """
    num_layers = 12
    num_heads = 12
    layer_colors = np.linspace(0, num_layers, num_layers, endpoint=False)
    
    # Annotations and colors for transformer heads
    head_annotations = [f"Layer {layer}, Head {head}" for layer, head in itertools.product(range(num_layers), range(num_heads))]
    head_marker_colors = [layer_colors[layer] for layer in range(num_layers) for _ in range(num_heads)]

    # Prepare MLP data if provided
    mlp_annotations = []
    mlp_marker_colors = []
    if mlp_x is not None and mlp_y is not None:
        mlp_annotations = [f"MLP Layer {layer}" for layer in range(num_layers)]
        mlp_marker_colors = [layer_colors[layer] for layer in range(num_layers)]

    # Flatten data
    heads_x = heads_x.flatten().cpu().numpy() if heads_x.ndim > 1 else heads_x.cpu().numpy()
    heads_y = heads_y.flatten().cpu().numpy() if heads_y.ndim > 1 else heads_y.cpu().numpy()
    if mlp_x is not None and mlp_y is not None:
        mlp_x = mlp_x.flatten().cpu().numpy() if mlp_x.ndim > 1 else mlp_x.cpu().numpy()
        mlp_y = mlp_y.flatten().cpu().numpy() if mlp_y.ndim > 1 else mlp_y.cpu().numpy()

    # Create scatter plots
    scatter_heads = go.Scatter(
        x=heads_x,
        y=heads_y,
        text=head_annotations,
        mode='markers',
        marker=dict(
            size=8,
            opacity=0.8,
            color=head_marker_colors,
            colorscale='Viridis',
            colorbar=dict(
                title='Layer',
                #tickvals=[0, num_layers - 1],
                #ticktext=[0, 1,2,1,1,1,1,1,1,1,1,1,1,11,1,1,3,4,5,5,num_layers - 1],
                orientation="h"
            ),
            line=dict(width=0.5, color='DarkSlateGrey')
        ),
        name="Attention Heads"
    )

    scatter_mlp = go.Scatter(
        x=mlp_x,
        y=mlp_y,
        text=mlp_annotations,
        mode='markers',
        name='MLP Layers',
        marker=dict(
            size=10,
            opacity=0.6,
            color=mlp_marker_colors,
            colorscale='Viridis',
            symbol='diamond',
            line=dict(width=1, color='Black')
        )
    ) if mlp_x is not None and mlp_y is not None else None

    # Create the figure and add the traces
    fig = go.Figure()
    fig.add_trace(scatter_heads)
    if scatter_mlp:
        fig.add_trace(scatter_mlp)

    # Add a dotted line where the negative y-value equals the positive x-value
    if x_range and y_range:
        # Ensuring the line covers the entire visible range by finding the min and max
        line_range = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]
        fig.add_trace(go.Scatter(x=line_range, y=[-x for x in line_range], mode='lines', line=dict(color='grey', dash='dot'), name='Y=-X'))

    # Update the layout with the manual x- and y-range
    fig.update_layout(
        title=f"{plot_title}",
        title_x=0.5,
        xaxis_title=x_title,
        yaxis_title=y_title,
        legend_title="Component",
        # do not show legend
        showlegend=False,
        width=500,
        height=500,
        #xaxis_range=x_range,
        #yaxis_range=y_range
    )

    return fig
fig = create_layered_scatter(
    data[CHECKPOINT]['thresholded_de'][0], 
    data[CHECKPOINT]['thresholded_cil'][0], 
    "Direct Effect of Component", 
    "Change in Logits Upon Ablation", 
    f"Self-Repair in {MODEL_TO_VIEW} at Checkpoint {CHECKPOINT}",
    x_range=[-0.1, 0.2],
    y_range=[-0.4, 0.1]    
)
fig.show()

In [92]:
imshow_p(
    repair_score * 100,
    title="Self-Repair Score",
    labels={"x": "Head", "y": "Layer", "color": "Self-Repair Score"},
    coloraxis=dict(colorbar_ticksuffix = "%", cmin=-500, cmax=500),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

In [93]:
data[CHECKPOINT]['thresholded_cil'][0].sum(), data[CHECKPOINT]['thresholded_de'][0].sum(), repair_score.sum()

(tensor(-15.2078), tensor(4.4502), tensor(-341.2779))

### All Checkpoints

In [94]:
subset_checkpoints = [512, 3000, 60000, 143000]
subset_data = {checkpoint: data[checkpoint] for checkpoint in subset_checkpoints}

In [95]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import torch

from utils.visualization import convert_title_to_filename


# Specify your desired x and y axis range
x_range = [-0.1, 0.4]
y_range = [-0.4, 0.1]

# Determine the grid size (rows x columns)
total_checkpoints = len(subset_data.keys())
num_columns = 4  # Example: 3 columns in your grid
num_rows = (total_checkpoints + num_columns - 1) // num_columns  # Calculate rows needed

# Initialize the subplot figure with specified rows and columns
fig = make_subplots(
    rows=num_rows, cols=num_columns, 
    subplot_titles=[f"Checkpoint {checkpoint}" for checkpoint in data.keys()],
    horizontal_spacing=0.05,  # Reduce horizontal spacing
    vertical_spacing=0.05     # Reduce vertical spacing
)

subplot_index = 1  # Initialize subplot index

for checkpoint in subset_data.keys():
    # Calculate the current row and column position
    row = (subplot_index - 1) // num_columns + 1
    col = (subplot_index - 1) % num_columns + 1

    # Generate the plot for the current checkpoint
    current_fig = create_layered_scatter(
        subset_data[checkpoint]['thresholded_de'][0], 
        subset_data[checkpoint]['thresholded_cil'][0], 
        "Direct Effect of Component", 
        "Change in Logits Upon Ablation", 
        f"Self-Repair in {MODEL_TO_VIEW} at Checkpoint {checkpoint}",
    )

    # Add each trace from the current figure to the subplot
    for trace in current_fig.data:
        fig.add_trace(trace, row=row, col=col)

    # Explicitly set x and y ranges for this subplot
    fig.update_xaxes(range=x_range, row=row, col=col)
    fig.update_yaxes(range=y_range, row=row, col=col)
    
    # Add the dotted line where the negative value of the y-axis equals the positive value of the x-axis
    line_range = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]
    fig.add_trace(
        go.Scatter(x=line_range, y=[-x for x in line_range], mode='lines',
                   line=dict(color='grey', dash='dot'), name='Y=-X'),
        row=row, col=col
    )

    subplot_index += 1

title = f"Self-Repair for {MODEL_TO_VIEW}"

# Update layout with centered axis titles, increased legend-subplot spacing, and specified label size
fig.update_layout(
    height=400*num_rows+100, 
    width=350*num_columns, 
    title_text=title, 
    showlegend=False,
    xaxis_title="Original Direct Effect",  # Centered x-axis title
    yaxis_title="Change in Logits",  # Centered y-axis title
    margin=dict(t=75),  # Increase top margin to provide space for the legend
    xaxis_title_font=dict(size=16),  # Set x-axis label size
    yaxis_title_font=dict(size=16)   # Set y-axis label size
)

filename = "results/plots/" + convert_title_to_filename(title) + ".pdf"
fig.write_image(filename, format='pdf', width=350*num_columns, height=400*num_rows+100, engine="kaleido")

fig.show()

In [53]:
def convert_title_to_filename(title: str):
    # replace spaces with dashes, remove parentheses, and make lowercase
    return title.replace(' ', '-').replace('(', '').replace(')', '').lower()

import numpy as np
import pandas as pd
import plotly.express as px
from typing import Dict

def plot_all_heads(
        model_name: str,
        checkpoint_dict: Dict[int, Dict[str, np.ndarray]], 
        plot_everything: bool = False, 
        top_k_per_checkpoint: int = 5, 
        top_k: int = 5
    ) -> pd.DataFrame:
    """
    Plot the head attributions across checkpoints, with the option to plot all heads or only the top ones.

    Args:
        model_name (str): Name of the model for title display.
        checkpoint_dict (Dict[int, Dict[str, np.ndarray]]): A dictionary mapping checkpoints to a dictionary
            that includes a key "self_repair_score" pointing to numpy arrays of head attributions.
        plot_everything (bool, optional): If True, plots all heads without applying top_k filters. Defaults to False.
        top_k_per_checkpoint (int, optional): The number of top heads to consider per checkpoint. Effective only if plot_everything is False.
        top_k (int, optional): The number of overall top heads to plot. Effective only if plot_everything is False.

    Returns:
        pd.DataFrame: A DataFrame containing the plot data.
    """
    plot_data = []

    for checkpoint, data in checkpoint_dict.items():
        array = data['self_repair_score'][0].numpy()
        print(array.shape)
        
        if plot_everything:
            indices = np.indices(array.shape)
            selected_heads = [(layer, head) for layer, head in zip(indices[0].flatten(), indices[1].flatten())]
        else:
            # Use argpartition to get the indices of the top heads in the entire array
            flat_indices = np.argpartition(array.flatten(), -top_k_per_checkpoint)[-top_k_per_checkpoint:]
            # Convert flat indices to 2D indices
            indices = np.unravel_index(flat_indices, array.shape)
            selected_heads = [(layer, head) for layer, head in zip(indices[0], indices[1])]

        for layer, head in selected_heads:
            plot_data.append(
                {
                    'Checkpoint': checkpoint,
                    'Layer-Head': f'Layer {layer}-Head {head}',
                    'Layer': layer,
                    'Head': head,
                    'Value': float(array[layer, head])  # Ensure conversion to float
                }
            )

    # Convert to DataFrame
    df = pd.DataFrame(plot_data)

    if not plot_everything:
        # Ensure 'Value' is numeric for aggregation functions
        df['Value'] = pd.to_numeric(df['Value'], errors='coerce')  # Converts non-numeric to NaN, can handle errors

        # Calculate sum of values over all checkpoints for each head
        summary_df = df.groupby(['Layer-Head', 'Layer', 'Head']).sum().reset_index()

        # Label the top_k items in summary_df based on their sum
        summary_df['Top K'] = summary_df['Layer-Head'].isin(df.groupby('Layer-Head').mean().nlargest(top_k, 'Value').index)

        # Filter the DataFrame to include only the top_k heads across all checkpoints
        df = df.merge(summary_df, on=['Layer-Head', 'Layer', 'Head'], how='inner').query('`Top K`')

    # Step 3: Plot the data
    fig = px.line(
        df, 
        x='Checkpoint',  # Corrected column name for Checkpoint
        y='Value',       # Assuming Value_x is the correct column for self-repair scores
        color='Layer-Head', 
        # specify y_range
        range_y=[-300, 300],
        title=f'Self Repair Across Checkpoints (DE+CIL/DE) ({model_name})', 
        height=500,
        labels={'Checkpoint': 'Checkpoint', 'Value': 'Self Repair Score'}  # Correct labels for axes
    )
    fig.show()

    return df

In [145]:
for checkpoint in data.keys():
    data[checkpoint]['self_repair_score'] = (data[checkpoint]['thresholded_de'] + data[checkpoint]['thresholded_cil']) / data[checkpoint]['thresholded_de']

In [133]:
df = plot_all_heads(MODEL_TO_VIEW, data, plot_everything=True, top_k_per_checkpoint=5, top_k=5)

(12, 12)
(12, 12)
(12, 12)


In [40]:
data[143000]['self_repair_score'].shape

torch.Size([1, 12, 12])

In [48]:
df.head(50)

Unnamed: 0,Checkpoint_x,Layer-Head,Layer,Head,Value_x,Checkpoint_y,Value_y,Top K
1,512,Layer 3-Head 8,3,8,25.7409,30512,173.673319,True
23,5000,Layer 1-Head 10,1,10,89.960876,5000,89.960876,True
26,10000,Layer 1-Head 6,1,6,15.675889,40000,226.239076,True
29,10000,Layer 3-Head 5,3,5,116.355156,10000,116.355156,True
33,30000,Layer 3-Head 8,3,8,147.932419,30512,173.673319,True
34,30000,Layer 1-Head 6,1,6,210.563187,40000,226.239076,True
35,90000,Layer 0-Head 8,0,8,21.291231,233000,207.862093,True
43,143000,Layer 0-Head 8,0,8,186.570862,233000,207.862093,True
