In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio

pio.templates["custom"] = pio.templates["plotly"]
pio.templates["custom"]["layout"]["colorway"] = px.colors.sequential.RdBu
pio.templates.default = "custom"


from matplotlib.pyplot import ylabel

In [None]:
results = torch.load('output.pt')
plt.plot(results['train_losses'])
plt.plot(results['test_losses'])

plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'test_loss'])

In [None]:
results.keys()

In [None]:
results = torch.load('output_plus_and_minus.pt', map_location=torch.device('cpu'))
plt.plot(results['train_losses'])
plt.plot(results['test_losses'])

plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'test_loss',])


In [None]:
results.keys()
plt.plot(results['train_precision_scores'])
plt.plot(results['test_precision_scores'])
plt.ylabel('precision')
plt.xlabel('epoch')
plt.legend(['train', 'test'])


In [None]:
for operation in results['operations_losses']:
    plt.plot(operation['train_losses'])
for operation in results['operations_losses']:
    plt.plot(operation['test_losses'])
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['+ train', '- train', '+ test', '- test'])


In [None]:
for operation in results['operations_losses']:
    plt.plot(operation['train_precisions'])
for operation in results['operations_losses']:
    plt.plot(operation['test_precisions'])
plt.ylabel('precision')
plt.xlabel('epoch')
plt.legend(['+ train', '- train', '+ test', '- test'])
plt.yscale('log')

In [None]:
from transformer_lens import HookedTransformer, HookedTransformerConfig, HookedEncoderDecoder

In [None]:
results['config'].device = 'cpu'

In [None]:
model = HookedTransformer(results['config'])
model.load_state_dict(results['model'])

## Show model works

In [None]:
# (1 + 15) % 113
model(torch.tensor([1, 114, 15, 113]))[0,3].argmax().item()

In [None]:
# (1 - 15) % 113
model(torch.tensor([1, 115, 15, 113]))[0, 3].argmax().item()

## Run model on full dataset to look at activations/attention patterns

In [None]:
import einops
from modular_addition import ModularOperationsDataset
operations = (lambda x, y: x + y, lambda x, y: x - y)
dataset = ModularOperationsDataset(
    base=113,
    train_fraction=0.25,
    operations=operations,
)
full_dataset = einops.rearrange(dataset.data, "i j k -> (i k) j")
plus_dataset = dataset.data[:,:,0]
minus_dataset = dataset.data[:,:,1]
print(f"plus: {plus_dataset.shape}, minus: {minus_dataset.shape}, full: {full_dataset.shape}")

In [None]:
output, cache = model.run_with_cache(full_dataset)

In [None]:
av_attention = cache["pattern", 0].mean(dim=0).detach().cpu()

In [None]:
labels = ['a', 'operation', 'b', '=']
# Create a figure to hold the 4 attention head plots
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle("Attention Patterns for 4 Heads", fontsize=16)

# Loop through each head and plot its attention pattern
for i, ax in enumerate(axes):
    sns.heatmap(
        av_attention[i],
        annot=True,            # Annotate each cell with its value
        xticklabels=labels,    # Set x-axis labels
        yticklabels=labels,    # Set y-axis labels
        cmap="viridis",        # Use a color map
        cbar=False,            # Disable color bar to reduce clutter
        ax=ax,                  # Plot on the current axis
    )
    ax.set_title(f'Head {i + 1}', fontsize=12)
    ax.set_ylabel('destination token', fontsize=12)
    ax.set_xlabel('source token', fontsize=12)

# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

In [None]:
# plot just attn where = is the destination (all we care about)
labels = ['a', 'operation', 'b', '=']
# Create a figure to hold the 4 attention head plots
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle("Attention Patterns for 4 Heads", fontsize=16)

# Loop through each head and plot its attention pattern
for i, ax in enumerate(axes):
    sns.heatmap(
        av_attention[i, -1:],
        annot=True,            # Annotate each cell with its value
        xticklabels=labels,    # Set x-axis labels
        yticklabels=["="],    # Set y-axis labels
        cmap="viridis",        # Use a color map
        cbar=False,            # Disable color bar to reduce clutter
        ax=ax                  # Plot on the current axis
    )
    ax.set_title(f'Head {i + 1}', fontsize=12)
    ax.set_ylabel('destination token', fontsize=12)
    ax.set_xlabel('source token', fontsize=12)

# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

### Attention patterns for +

In [None]:
output_plus, cache_plus = model.run_with_cache(plus_dataset)
av_attention = cache_plus["pattern", 0].mean(dim=0).detach().cpu()
labels = ['a', 'operation', 'b', '=']
# Create a figure to hold the 4 attention head plots
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle("Attention Patterns for 4 Heads", fontsize=16)

# Loop through each head and plot its attention pattern
for i, ax in enumerate(axes):
    sns.heatmap(
        av_attention[i],
        annot=True,  # Annotate each cell with its value
        xticklabels=labels,  # Set x-axis labels
        yticklabels=labels,  # Set y-axis labels
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'Head {i + 1}', fontsize=12)

# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

### Attention patterns for -

In [None]:
output_minus, cache_minus = model.run_with_cache(minus_dataset)
av_attention = cache_minus["pattern", 0].mean(dim=0).detach().cpu()
labels = ['a', 'operation', 'b', '=']
# Create a figure to hold the 4 attention head plots
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle("Attention Patterns for 4 Heads", fontsize=16)

# Loop through each head and plot its attention pattern
for i, ax in enumerate(axes):
    sns.heatmap(
        av_attention[i],
        annot=True,  # Annotate each cell with its value
        xticklabels=labels,  # Set x-axis labels
        yticklabels=labels,  # Set y-axis labels
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'Head {i + 1}', fontsize=12)

# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

Next we plot the attention patterns for all combinations of a and b. Here we expect this to be very similar to what was found in the original paper with some potential interesting stuff on the openeration as the source attention.

In [None]:
#Original model form paper
import einops
original_results = torch.load('output.pt')
original_model = HookedTransformer(original_results['config'])
original_model.load_state_dict(original_results['model'])
a_s = einops.repeat(torch.arange(113), "i -> (i j)", j=113)
b_s = einops.repeat(torch.arange(113), "j -> (i j)", i=113)
equals = einops.repeat(
    torch.tensor(113), " -> (i j)", i=113, j=113
)
original_dataset_correctly_ordered = torch.stack([a_s, b_s, equals], dim=1)
original_output, original_cache = original_model.run_with_cache(original_dataset_correctly_ordered)
a_to_equals_attn_patterns = original_cache["pattern", 0][:,:,-1,0].detach().cpu()
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle("attention Patterns for a -> = for original model", fontsize=16)

# Loop through each head and plot its attention pattern

for i, ax in enumerate(axes):
    sns.heatmap(
        a_to_equals_attn_patterns[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'Head {i + 1}', fontsize=12)
# Loop through each head and plot its attention pattern

# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

In [None]:
a_to_equals_attn_patterns_plus = cache_plus["pattern", 0][:,:,-1,0].detach().cpu()
a_to_equals_attn_patterns_minus = cache_minus["pattern", 0][:,:,-1,0].detach().cpu()
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Attention Patterns a -> =", fontsize=16)

# Loop through each head and plot its attention pattern

for i, ax in enumerate(axes[0]):
    sns.heatmap(
        a_to_equals_attn_patterns_plus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'+: Head {i + 1}', fontsize=12)

for i, ax in enumerate(axes[1]):

    sns.heatmap(
        a_to_equals_attn_patterns_minus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'-: Head {i + 1}', fontsize=12)
# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

In [None]:
operation_to_equals_attn_patterns_plus = cache_plus["pattern", 0][:,:,-1,1].detach().cpu()
operation_to_equals_attn_patterns_minus = cache_minus["pattern", 0][:,:,-1,1].detach().cpu()
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Attention Patterns operation -> =", fontsize=16)

# Loop through each head and plot its attention pattern

for i, ax in enumerate(axes[0]):
    sns.heatmap(
        operation_to_equals_attn_patterns_plus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'+: Head {i + 1}', fontsize=12)

for i, ax in enumerate(axes[1]):

    sns.heatmap(
        operation_to_equals_attn_patterns_minus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'-: Head {i + 1}', fontsize=12)
# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

In [None]:
b_to_equals_attn_patterns_plus = cache_plus["pattern", 0][:,:,-1,2].detach().cpu()
b_to_equals_attn_patterns_minus = cache_minus["pattern", 0][:,:,-1,2].detach().cpu()
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Attention Patterns b -> =", fontsize=16)

# Loop through each head and plot its attention pattern

for i, ax in enumerate(axes[0]):
    sns.heatmap(
        b_to_equals_attn_patterns_plus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'+: Head {i + 1}', fontsize=12)

for i, ax in enumerate(axes[1]):

    sns.heatmap(
        b_to_equals_attn_patterns_minus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'-: Head {i + 1}', fontsize=12)
# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

In [None]:
equals_to_equals_attn_patterns_plus = cache_plus["pattern", 0][:,:,-1,3].detach().cpu()
equals_to_equals_attn_patterns_minus = cache_minus["pattern", 0][:,:,-1,3].detach().cpu()
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Attention Patterns = -> =", fontsize=16)

# Loop through each head and plot its attention pattern

for i, ax in enumerate(axes[0]):
    sns.heatmap(
        equals_to_equals_attn_patterns_plus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'+: Head {i + 1}', fontsize=12)

for i, ax in enumerate(axes[1]):

    sns.heatmap(
        equals_to_equals_attn_patterns_minus[:, i].reshape(113, 113),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
        ax=ax  # Plot on the current axis
    )
    ax.set_title(f'-: Head {i + 1}', fontsize=12)
# Show the plot
plt.tight_layout()
plt.subplots_adjust(top=0.85)  # Adjust to make space for the title
plt.show()

Definite periodicity but some interesting stuff happening. Maybe the model is learning a similar alg to the original but with more frequency components to accurately calculate both operations

## Check the embedding matrix

In [None]:
W_E = model.embed.W_E.cpu().detach()[:113]
sns.heatmap(
        W_E.numpy(),
        cmap="viridis",  # Use a color map
        cbar=False,  # Disable color bar to reduce clutter
    )

In [None]:
U, S, Vh = torch.svd(W_E)
plt.plot(S)
plt.title('W_E singular values')

In [None]:


# Create heatmap
px.imshow(U)

In [None]:
px.line(U[:, :8].T, title="Principle Components of Embedding").update_layout(
    xaxis_title="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(113))
fourier_basis_names.append('Constant')
for freq in range(1, 113//2 + 1):
    fourier_basis.append(torch.sin(torch.arange(113)*2 * torch.pi*freq /113))
    fourier_basis_names.append(f'Sin {freq}')
    fourier_basis.append(torch.cos(torch.arange(113)*2 * torch.pi*freq /113))
    fourier_basis_names.append(f'Cos {freq}')
fourier_basis = torch.stack(fourier_basis, dim=0)
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1)
px.imshow(fourier_basis, y=fourier_basis_names,color_continuous_scale='RdBu',).update_layout(xaxis_title="Input", yaxis_title="Cofourier_basis_namesmponent")

In [None]:
px.imshow(fourier_basis @ fourier_basis.T)

In [None]:
px.imshow(fourier_basis @ W_E, color_continuous_scale='RdBu', y =fourier_basis_names, title='embedding_in_fourier_basis').update_layout(xaxis_title="Residual Stream", yaxis_title="fourier_component")

In [None]:
px.line(y=(fourier_basis @ W_E).norm(dim=-1), x=fourier_basis_names, title='embedding_in_fourier_basis').update_layout(xaxis_title="Residual Stream", yaxis_title="fourier_component")

In [None]:
(fourier_basis @ W_E).norm(dim=-1)

In [None]:
px.line(fourier_basis[[32, 48, 100]].mean(0))


In [None]:
key_freq_indicies = [31, 32, 47, 48, 99, 100]
key_fourier_embed = (fourier_basis @ W_E)[key_freq_indicies]
px.imshow(key_fourier_embed@key_fourier_embed.T,color_continuous_scale='RdBu',color_continuous_midpoint=0)

# Look at frequencies in the mlp hidden activations

In [None]:
neuron_acts_plus = cache_plus["post", 0, "mlp"][:, -1, :]
neuron_pre_acts_plus = cache_plus["pre", 0, "mlp"][:, -1, :]
neuron_acts_minus = cache_minus["post", 0, "mlp"][:, -1, :]
neuron_pre_acts_minus = cache_minus["pre", 0, "mlp"][:, -1, :]

In [None]:
neuron_acts_plus.shape

In [None]:
px.imshow(neuron_acts_plus[:, 1].reshape(113, 113), color_continuous_scale='RdBu', color_continuous_midpoint=0).update_layout(xaxis_title="b", yaxis_title="a")

In [None]:
px.imshow(neuron_acts_minus[:, 1].reshape(113, 113), color_continuous_scale='RdBu', color_continuous_midpoint=0).update_layout(xaxis_title="b", yaxis_title="a")

In [None]:
px.imshow(fourier_basis @ neuron_acts_plus[:, 1].reshape(113, 113) @ fourier_basis.T, color_continuous_scale='RdBu', color_continuous_midpoint=0, title='2d transform of neuron 1 for plus op', x=fourier_basis_names, y=fourier_basis_names).update_layout(xaxis_title="b", yaxis_title="a")

In [None]:
px.imshow(fourier_basis @ neuron_acts_minus[:, 1].reshape(113, 113) @ fourier_basis.T, color_continuous_scale='RdBu', title='2d transform of neuron 1 for minus op', x=fourier_basis_names, y=fourier_basis_names, color_continuous_midpoint=0).update_layout(xaxis_title="b", yaxis_title="a")


### neuron clusters

In [None]:
fourier_neuron_acts_plus = fourier_basis @ einops.rearrange(neuron_acts_plus, "(a b) neuron -> neuron a b", a=113, b=113) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts_plus[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts_plus.shape)

In [None]:
neuron_freq_norm_plus = torch.zeros(113//2, model.cfg.d_mlp)
for freq in range(0, 113//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm_plus[freq] += fourier_neuron_acts_plus[:, x, y]**2
neuron_freq_norm_plus = neuron_freq_norm_plus / fourier_neuron_acts_plus.pow(2).sum(dim=[-1, -2])[None, :]
px.imshow(neuron_freq_norm_plus, y=torch.arange(1, 113//2+1), title="Neuron Frac Explained by Freq plus", color_continuous_scale='RdBu', color_continuous_midpoint=0, aspect="auto",).update_layout(xaxis_title="Neuron", yaxis_title="Freq")

In [None]:
px.line(neuron_freq_norm_plus.max(dim=0).values.sort().values, title="Max Neuron Frac Explained over Freqs plus").update_layout(xaxis_title="Neuron")

In [None]:
fourier_neuron_acts_minus = fourier_basis @ einops.rearrange(neuron_acts_minus, "(a b) neuron -> neuron a b", a=113, b=113) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts_minus[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts_minus.shape)

In [None]:
neuron_freq_norm_minus = torch.zeros(113//2, model.cfg.d_mlp)
for freq in range(0, 113//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm_minus[freq] += fourier_neuron_acts_minus[:, x, y]**2
neuron_freq_norm_minus = neuron_freq_norm_minus / fourier_neuron_acts_minus.pow(2).sum(dim=[-1, -2])[None, :]
px.imshow(neuron_freq_norm_minus, y=torch.arange(1, 113//2+1), title="Neuron Frac Explained by Freq plus", color_continuous_scale='RdBu', color_continuous_midpoint=0, aspect="auto",).update_layout(xaxis_title="Neuron", yaxis_title="Freq")

In [None]:
px.line(neuron_freq_norm_minus.max(dim=0).values.sort().values, title="Max Neuron Frac Explained over Freqs plus").update_layout(xaxis_title="Neuron")


### TODO: Maybe try to combine the above?

## Look at Unembedding

In [None]:
W_U = model.unembed.W_U.cpu().detach()
W_O = model.blocks[0].mlp.W_out.cpu().detach()

In [None]:
W_OU = W_O @ W_U

In [None]:
px.imshow(W_OU, color_continuous_scale='RdBu', color_continuous_midpoint=0, aspect="auto", height=800, width=800)

In [None]:
px.imshow(fourier_basis@W_OU.T, y=fourier_basis_names,color_continuous_scale='RdBu', color_continuous_midpoint=0, title="frequency componenets of W_OU").update_layout(xaxis_title="output", yaxis_title="Freq")

In [None]:
px.line(y=(fourier_basis@W_OU.T).T.norm(dim=0), x=fourier_basis_names,title="normed frequency componenets of W_OU").update_layout(xaxis_title="output", yaxis_title="Freq")

In [None]:
## show that corresponding the frequency component of W_OU output for the neuron activations with a frequency 16 is also 16 (this is for the final cos(w) output in the the trig addition)

In [None]:
neurons_16 = neuron_freq_norm_plus[16-1] > 0.7
px.line(y=(fourier_basis@W_OU[neurons_16].T).T.norm(dim=0), x=fourier_basis_names,title="normed frequency componenets of W_OU").update_layout(xaxis_title="output", yaxis_title="Freq")


## Looking at plus and minus

In [None]:
px.imshow(neuron_acts_plus[:, 2].reshape(113, 113), color_continuous_scale='RdBu', color_continuous_midpoint=0, title = 'plus operation neuron activations [neuron 2]').update_layout(xaxis_title="b", yaxis_title="a")

In [None]:
px.imshow(neuron_acts_minus[:, 2].reshape(113, 113), color_continuous_scale='RdBu', color_continuous_midpoint=0, title = 'minus operation neuron activations [neuron 2]').update_layout(xaxis_title="b", yaxis_title="a")

In [None]:
import plotly.subplots as sp
import numpy as np
max_plots = 25

num_plots = min(max_plots, neuron_acts_plus.shape[1])
grid_size = int(2)

fig = sp.make_subplots(rows=num_plots, cols=2, subplot_titles=[j for i in range(num_plots) for j in [f"Plus Neuron {i}", f"Minus Neuron {i}"]])

for i in range(num_plots):
    reshaped_data_plus = neuron_acts_plus[:, i].reshape(113, 113)
    max_val_plus = reshaped_data_plus.max()
    reshaped_data_minus = neuron_acts_minus[:, i].reshape(113, 113)
    max_val_minus = reshaped_data_minus.max()
    overall_max = max([max_val_plus, max_val_minus])
    normalized_data_plus = reshaped_data_plus / (overall_max + 1e-8)
    normalized_data_minus = reshaped_data_minus / (overall_max + 1e-8)
    heatmap = px.imshow(
        normalized_data_plus,
    )
    for trace in heatmap.data:
        fig.add_trace(trace, row=i+1, col=1)


    heatmap = px.imshow(
        normalized_data_minus,
    )
    for trace in heatmap.data:
        fig.add_trace(trace, row=i+1, col=2)

fig.update_layout(
    height=500 * max_plots, width=grid_size * 500,
    title_text="Neuron Activations",
    coloraxis=dict(cmid=0, colorscale='RdBu')
)
fig.show()



In [None]:
plus_not_minus = []
minus_not_plus = []
for i in range(512):
    if neuron_acts_minus[:, i].sum()/neuron_acts_plus[:, i].sum() < 0.1:
        plus_not_minus.append(i)
    if neuron_acts_plus[:, i].sum()/neuron_acts_minus[:, i].sum() < 0.1:
        minus_not_plus.append(i)


all_ = plus_not_minus + minus_not_plus
num_plots = len(all_)
grid_size = int(2)

fig = sp.make_subplots(rows=num_plots, cols=2, subplot_titles=[j for i in range(num_plots) for j in [f"Plus Neuron {i}", f"Minus Neuron {i}"]])

for i, neuron_index in enumerate(all_):
    reshaped_data_plus = neuron_acts_plus[:, neuron_index].reshape(113, 113)
    max_val_plus = reshaped_data_plus.max()
    reshaped_data_minus = neuron_acts_minus[:, neuron_index].reshape(113, 113)
    max_val_minus = reshaped_data_minus.max()
    overall_max = max([max_val_plus, max_val_minus])
    normalized_data_plus = reshaped_data_plus / (overall_max + 1e-8)
    normalized_data_minus = reshaped_data_minus / (overall_max + 1e-8)
    heatmap = px.imshow(
        normalized_data_plus,
    )
    for trace in heatmap.data:
        fig.add_trace(trace, row=i+1, col=1)


    heatmap = px.imshow(
        normalized_data_minus,
    )
    for trace in heatmap.data:
        fig.add_trace(trace, row=i+1, col=2)

fig.update_layout(
    height=500 * max_plots, width=grid_size * 500,
    title_text="Neuron Activations",
    coloraxis=dict(cmid=0, colorscale='RdBu')
)
fig.show()

## ablating neurons that only fire for one operation

In [67]:
from transformer_lens import utils
def plus_mlp_neuron_ablation_hook(
    value,
    hook
):
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, plus_not_minus] = 0.
    return value

plus_neurons_ablated_ablated_plus_results = model.run_with_hooks(
    plus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        plus_mlp_neuron_ablation_hook
        )]
    )
plus_neurons_ablated_ablated_minus_results = model.run_with_hooks(
    minus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        plus_mlp_neuron_ablation_hook
        )]
    )

def minus_mlp_neuron_ablation_hook(
    value,
    hook
):
    value[:, :, minus_not_plus] = 0.
    return value

minus_neurons_ablated_ablated_plus_results = model.run_with_hooks(
    plus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        minus_mlp_neuron_ablation_hook
        )]
    )

minus_neurons_ablated_ablated_minus_results = model.run_with_hooks(
    minus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        minus_mlp_neuron_ablation_hook
        )]
    )

Shape of the value tensor: torch.Size([12769, 4, 512])
Shape of the value tensor: torch.Size([12769, 4, 512])


In [68]:
from sklearn.metrics import precision_score
plus_dataset_labels = (plus_dataset[:,0] + plus_dataset[:,2]) % 113
minus_dataset_labels = (minus_dataset[:,0] - minus_dataset[:,2]) % 113

plus_neurons_ablated_ablated_plus_predictions = plus_neurons_ablated_ablated_plus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
plus_neurons_ablated_ablated_plus_precision = precision_score(plus_dataset_labels, plus_neurons_ablated_ablated_plus_predictions, average='macro')

plus_neurons_ablated_ablated_minus_predictions = plus_neurons_ablated_ablated_minus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
plus_neurons_ablated_ablated_minus_precision = precision_score(minus_dataset_labels, plus_neurons_ablated_ablated_minus_predictions, average='macro')

minus_neurons_ablated_ablated_plus_predictions = minus_neurons_ablated_ablated_plus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
minus_neurons_ablated_ablated_plus_precision = precision_score(plus_dataset_labels, minus_neurons_ablated_ablated_plus_predictions, average='macro')

minus_neurons_ablated_ablated_minus_predictions = minus_neurons_ablated_ablated_minus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
minus_neurons_ablated_ablated_minus_precision = precision_score(minus_dataset_labels, minus_neurons_ablated_ablated_minus_predictions, average='macro')


print(f'after ablating all but neurons only active in plus: plus precision {plus_neurons_ablated_ablated_plus_precision} minus precision {plus_neurons_ablated_ablated_minus_precision}')
print(f'after ablating all but neurons only active in minus: plus precision {minus_neurons_ablated_ablated_plus_precision} minus precision {minus_neurons_ablated_ablated_minus_precision}')

after ablating all but neurons only active in plus: plus precision 0.22997348393722042 minus precision 0.9917927362577618
after ablating all but neurons only active in minus: plus precision 0.9990725177699925 minus precision 0.3939233817539317


In [69]:
def all_but_plus_mlp_neuron_ablation_hook(
    value,
    hook
):
    value[:, :, [i for i in range(512) if i not in plus_not_minus]] = 0.
    return value

all_but_plus_neurons_ablated_ablated_plus_results = model.run_with_hooks(
    plus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        all_but_plus_mlp_neuron_ablation_hook
        )]
    )
all_but_plus_neurons_ablated_ablated_minus_results = model.run_with_hooks(
    minus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        all_but_plus_mlp_neuron_ablation_hook
        )]
    )

def all_but_minus_mlp_neuron_ablation_hook(
    value,
    hook
):
    value[:, :, [i for i in range(512) if i not in minus_not_plus]] = 0.
    return value

all_but_minus_neurons_ablated_ablated_plus_results = model.run_with_hooks(
    plus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        all_but_minus_mlp_neuron_ablation_hook
        )]
    )

all_but_minus_neurons_ablated_ablated_minus_results = model.run_with_hooks(
    minus_dataset,
    fwd_hooks=[(
        utils.get_act_name("post", 0, "mlp"),
        all_but_minus_mlp_neuron_ablation_hook
        )]
    )



In [71]:
all_but_plus_neurons_ablated_ablated_plus_predictions = all_but_plus_neurons_ablated_ablated_plus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
all_but_plus_neurons_ablated_ablated_plus_precision = precision_score(plus_dataset_labels, all_but_plus_neurons_ablated_ablated_plus_predictions, average='macro')

all_but_plus_neurons_ablated_ablated_minus_predictions = all_but_plus_neurons_ablated_ablated_minus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
all_but_plus_neurons_ablated_ablated_minus_precision = precision_score(minus_dataset_labels, all_but_plus_neurons_ablated_ablated_minus_predictions, average='macro')

all_but_minus_neurons_ablated_ablated_plus_predictions = all_but_minus_neurons_ablated_ablated_plus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
all_but_minus_neurons_ablated_ablated_plus_precision = precision_score(plus_dataset_labels, all_but_minus_neurons_ablated_ablated_plus_predictions, average='macro')

all_but_minus_neurons_ablated_ablated_minus_predictions = all_but_minus_neurons_ablated_ablated_minus_results[:,-1,:].argmax(dim=-1).cpu().numpy()
all_but_minus_neurons_ablated_ablated_minus_precision = precision_score(minus_dataset_labels, all_but_minus_neurons_ablated_ablated_minus_predictions, average='macro')


print(f'after ablating all but neurons only active in plus: plus precision {all_but_plus_neurons_ablated_ablated_plus_precision} minus precision {all_but_plus_neurons_ablated_ablated_minus_precision}')
print(f'after ablating all but neurons only active in minus: plus precision {all_but_minus_neurons_ablated_ablated_plus_precision} minus precision {all_but_minus_neurons_ablated_ablated_minus_precision}')

after ablating all but neurons only active in plus: plus precision 0.00806736360449042 minus precision 0.0005453249533054084
after ablating all but neurons only active in minus: plus precision 0.0011633032147423442 minus precision 0.014751908804845129



Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



In [None]:
# do some recursive albation studies

# Staged Training

In [None]:
results = torch.load('output_plus_and_minus_staged.pt')
plt.plot(results['test_losses'])
plt.plot(results['train_losses'])
plt.yscale('log')



In [None]:
for operation in results['operations_losses']:
    plt.plot(operation['test_losses'])
    plt.plot(operation['train_losses'])
plt.yscale('log')