In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import fancy_einsum
from tqdm import tqdm
import re
from sklearn.metrics import roc_curve, auc
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import transformer_lens.utils as utils

from sparse_autoencoder import SparseAutoencoder

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [21]:
def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                print(f"Pos: {unique_positives}")
                negatives = negative_matrix[h1, h2, :, :] > 0
                print(f"Neg: {negatives}")
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                print(f"Unique: {unique}")
                print()
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

In [103]:
# Random integers for all indices
np.random.seed(42)
n_examples = 100
n_heads = 6
n_feat = 5
all_indices = np.random.randint(0, n_feat, (n_examples, n_heads))
all_indices

array([[3, 4, 2, 4, 4, 1],
       [2, 2, 2, 4, 3, 2],
       [4, 1, 3, 1, 3, 4],
       [0, 3, 1, 4, 3, 0],
       [0, 2, 2, 1, 3, 3],
       [2, 3, 3, 0, 2, 4],
       [2, 4, 0, 1, 3, 0],
       [3, 1, 1, 0, 1, 4],
       [1, 3, 3, 3, 3, 4],
       [2, 0, 3, 1, 3, 1],
       [1, 3, 4, 1, 1, 3],
       [1, 1, 3, 3, 0, 4],
       [4, 1, 4, 1, 0, 3],
       [3, 3, 4, 0, 4, 4],
       [0, 0, 0, 0, 3, 2],
       [2, 0, 2, 2, 0, 2],
       [4, 1, 1, 0, 3, 0],
       [3, 1, 0, 4, 2, 3],
       [2, 2, 0, 2, 4, 2],
       [0, 4, 1, 2, 0, 1],
       [1, 3, 4, 2, 0, 3],
       [4, 3, 4, 4, 2, 4],
       [3, 4, 2, 2, 3, 1],
       [1, 4, 0, 4, 3, 3],
       [3, 3, 3, 2, 1, 3],
       [0, 0, 0, 0, 2, 0],
       [3, 4, 0, 2, 2, 0],
       [4, 0, 2, 1, 3, 2],
       [0, 3, 0, 0, 1, 3],
       [3, 1, 2, 0, 4, 0],
       [0, 2, 0, 1, 1, 3],
       [4, 0, 0, 2, 1, 4],
       [3, 1, 3, 2, 2, 0],
       [4, 3, 1, 2, 0, 0],
       [3, 2, 4, 2, 3, 3],
       [2, 3, 2, 1, 2, 2],
       [3, 3, 0, 0, 1, 0],
 

In [104]:
fig = go.Figure(data=go.Heatmap(
                   z=all_indices,
                   text=[[str(y) for y in x] for x in all_indices],
                   texttemplate="%{text}",
                   textfont={"size": 30, "family": "Palatino"},
                   colorscale='Blues'))

# Change font to Palatino for whole plot
fig.update_layout(font={"family": "Palatino"})

fig.update_coloraxes(showscale=False)

fig.update_layout(
    xaxis=dict(title='Head', tickvals=list(range(6))),
    yaxis=dict(title='Examples', tickvals=list(range(6)), autorange="reversed"),
    width=600, height=600,
    title_x=0.5,
    # Hide colorbar
    coloraxis_showscale=False
)


# Change fontsize for axis titles
fig.update_layout(xaxis_title_font_size=30, yaxis_title_font_size=30)

# Change fontsize for axis ticks
fig.update_layout(xaxis_tickfont=dict(size=24), yaxis_tickfont=dict(size=24))

fig.show()

In [61]:
C_plus = gen_co_occurrence_matrix(all_indices[:3], n_heads, n_feat)
C_minus = gen_co_occurrence_matrix(all_indices[3:], n_heads, n_feat)
print(C_plus)
print()
print(C_minus)

[[[[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 1. 0. 0.]
   [0. 0. 0. 0. 1.]
   [0. 1. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 1. 0. 0.]
   [0. 0. 1. 0. 0.]
   [0. 0. 0. 1. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 1.]
   [0. 0. 0. 0. 1.]
   [0. 1. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 1. 0.]
   [0. 0. 0. 0. 1.]
   [0. 0. 0. 1. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 1. 0. 0.]
   [0. 1. 0. 0. 0.]
   [0. 0. 0. 0. 1.]]]


 [[[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 1.]
   [0. 0. 1. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 1. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 1. 0.]
   [0. 0. 1. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 1. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 1. 0. 0. 0.]
   [0. 0. 0. 0. 1.]
   [0. 0. 0. 0. 0.]


In [70]:
for i in range(n_feat):
    for j in range(n_feat):
        px.imshow(
            utils.to_numpy(C_minus[:, :, i, j]),
            color_continuous_midpoint=0.0,
            color_continuous_scale="Reds",
            title=f"C_minus[{i},{j}]",
        ).show()

In [68]:
for i in range(n_feat):
    for j in range(n_feat):
        px.imshow(
            utils.to_numpy(C_plus[:, :, i, j]),
            color_continuous_midpoint=0.0,
            color_continuous_scale="Greens",
            title=f"C_plus[{i},{j}]",
        ).show()

In [71]:
U = unique_co_occurrences(C_plus, C_minus, normalise=False)

Pos: [[False False False False False]
 [False False False False False]
 [False False  True False False]
 [False False False False  True]
 [False  True False False False]]
Neg: [[False False  True  True False]
 [False False False False False]
 [False False False  True False]
 [False False False False False]
 [False False False False False]]
Unique: [[False False False False False]
 [False False False False False]
 [False False  True False False]
 [False False False False  True]
 [False  True False False False]]

Pos: [[False False False False False]
 [False False False False False]
 [False False  True False False]
 [False False  True False False]
 [False False False  True False]]
Neg: [[False  True  True False False]
 [False False False False False]
 [False False False  True False]
 [False False False False False]
 [False False False False False]]
Unique: [[False False False False False]
 [False False False False False]
 [False False  True False False]
 [False False  True False False]
 

In [100]:
fig = go.Figure(data=go.Heatmap(
                   z=U,
                   text=U,
                   texttemplate="%{text}",
                   textfont={"size": 30, "family": "Palatino"},
                   colorscale='Blues'))

# Change font to Palatino for whole plot
fig.update_layout(font={"family": "Palatino"})

fig.update_coloraxes(showscale=False)

fig.update_layout(
    xaxis=dict(title='Head', tickvals=list(range(6))),
    yaxis=dict(title='Head', tickvals=list(range(6)), autorange="reversed"),
    width=600, height=600,
    title_x=0.5,
    # Hide colorbar
    coloraxis_showscale=False
)


# Change fontsize for axis titles
fig.update_layout(xaxis_title_font_size=30, yaxis_title_font_size=30)

# Change fontsize for axis ticks
fig.update_layout(xaxis_tickfont=dict(size=24), yaxis_tickfont=dict(size=24))

fig.show()

In [80]:
# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(U.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, U.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))
print(sorted_head_pairs)

[(1, 0), (5, 4), (3, 1), (0, 5), (0, 2), (1, 5), (3, 5), (4, 0), (0, 3), (1, 3), (2, 0), (0, 1), (4, 5), (5, 0), (5, 1), (0, 4), (5, 3), (3, 0), (1, 4), (1, 2), (2, 1), (2, 5), (2, 3), (2, 4), (4, 1), (4, 2), (5, 2), (3, 2), (3, 4), (4, 3), (5, 5), (1, 1), (2, 2), (3, 3), (4, 4), (0, 0)]


In [82]:
circuit_components = []
for i, (h1, h2) in enumerate(sorted_head_pairs):
    circuit_components.append(h1)
    circuit_components.append(h2)

circuit_components

[1,
 0,
 5,
 4,
 3,
 1,
 0,
 5,
 0,
 2,
 1,
 5,
 3,
 5,
 4,
 0,
 0,
 3,
 1,
 3,
 2,
 0,
 0,
 1,
 4,
 5,
 5,
 0,
 5,
 1,
 0,
 4,
 5,
 3,
 3,
 0,
 1,
 4,
 1,
 2,
 2,
 1,
 2,
 5,
 2,
 3,
 2,
 4,
 4,
 1,
 4,
 2,
 5,
 2,
 3,
 2,
 3,
 4,
 4,
 3,
 5,
 5,
 1,
 1,
 2,
 2,
 3,
 3,
 4,
 4,
 0,
 0]

In [92]:
y_pred = np.zeros(n_heads)
k = len(circuit_components) // 2
for h in circuit_components[:k]:
    y_pred[h] += 1

y_pred

array([10.,  6.,  2.,  6.,  4.,  8.])

In [94]:
fig = go.Figure(data=go.Heatmap(
                   z=y_pred.reshape(1, -1).T,
                   text=y_pred.reshape(1, -1).T,
                   texttemplate="%{text}",
                   textfont={"size": 30, "family": "Palatino"},
                   colorscale='Blues'))

# Change font to Palatino for whole plot
fig.update_layout(font={"family": "Palatino"})

fig.update_coloraxes(showscale=False)

fig.update_layout(
    xaxis=dict(title='Head', tickvals=list(range(6))),
    yaxis=dict(title='Examples', tickvals=list(range(6)), autorange="reversed"),
    width=250, height=600,
    title_x=0.5,
    # Hide colorbar
    coloraxis_showscale=False
)


# Change fontsize for axis titles
fig.update_layout(xaxis_title_font_size=30, yaxis_title_font_size=30)

# Change fontsize for axis ticks
fig.update_layout(xaxis_tickfont=dict(size=24), yaxis_tickfont=dict(size=24))

fig.show()

In [99]:
# Softmax y_pred
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

y_pred_softmax = softmax(y_pred)

fig = go.Figure(data=go.Heatmap(
                     z=y_pred_softmax.reshape(1, -1).T,
                     text=y_pred_softmax.reshape(1, -1).T,
                     texttemplate="%{text:.2f}",
                     textfont={"size": 30, "family": "Palatino"},
                     colorscale='Blues'))

# Change font to Palatino for whole plot
fig.update_layout(font={"family": "Palatino"})

fig.update_coloraxes(showscale=False)

fig.update_layout(
    xaxis=dict(title='Head', tickvals=list(range(6))),
    yaxis=dict(title='Examples', tickvals=list(range(6)), autorange="reversed"),
    width=250, height=600,
    title_x=0.5,
    # Hide colorbar
    coloraxis_showscale=False
)

fig.show()