In [28]:
import re
import json
import pickle
import os
import sys
import torch
from transformer_lens import HookedTransformer
import plotly.io as pio
import pandas as pd
from itertools import chain
import plotly.graph_objects as go
import torch.nn.functional as F

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, generate_with_encoder
import utils.haystack_utils as haystack_utils
from process_tiny_stories_data import load_tinystories_validation_prompts


%reload_ext autoreload
%autoreload 2

In [3]:
# Load global state
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)

l0_encoder, l0_config = load_encoder('18_morning_sun', model_name, model)
l1_encoder, l1_config = load_encoder('8_deep_brook', model_name, model)
prompts = load_tinystories_validation_prompts()

In [50]:
# For single prompt: compare the derivative of the L1 feature wrt L0 mlp input to layer, to linear map (assuming GELU is identity)
torch.autograd.set_grad_enabled(True)
torch.set_grad_enabled(True)
for param in chain(model.parameters(), l0_encoder.parameters(), l1_encoder.parameters()):
    param.requires_grad = True

prompt = "He had to help Daisy. He was about"

def get_run_encoder_hook(encoder, config):
    def hook(value, hook):
        value = value.squeeze(0)
        _, reconstruct, acts, _, _ = encoder(value)
        hook.ctx['acts'] = acts
        return reconstruct.unsqueeze(0)
    return [(f'blocks.{config.layer}.{config.act_name}', hook)]

def access_grad(value, hook):
    hook.ctx['grad'] = value[0, -1]
access_grads = [('blocks.0.mlp.hook_post', access_grad)]

with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config), bwd_hooks=access_grads):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['acts']
    print('L1 trigram feature act:', acts[-1, 880])

    loss = -acts[-1, 880]
    loss.backward()
    grad = model.hook_dict['blocks.0.mlp.hook_post'].ctx['grad']

L1 trigram feature act: tensor(7.4160, device='cuda:0', grad_fn=<SelectBackward0>)


In [39]:
with torch.no_grad():
    decoder_grads = (l0_encoder.W_dec @ grad).cpu()

    # Histogram of feature gradients, feature believed to contribute to trigram circuit marked in red
    fig = go.Figure(data=[go.Histogram(x=decoder_grads)])
    fig.add_trace(go.Scatter(x=[decoder_grads[7105], decoder_grads[7105]], y=[0, 500], mode='lines', name='Line', line=dict(color='red')))
    fig.update_layout(
        title='Gradients of decoder features in Layer 0 wrt feature 880 in Layer 1',
        xaxis_title='Gradient',
        yaxis_title='Feature count'
    )
    fig.show()

    # Scatter of feature gradients against cosine sims with layer 1 feature believed to contribute to trigram circuit
    W_out = model.W_out[0]
    W_in = model.W_in[1]
    cosine_sims = F.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ F.normalize(W_in @ l1_encoder.W_enc[:, 880], dim=0)
    # print(cosine_sims[7105])

    fig = go.Figure(data=go.Scatter(x=decoder_grads, y=cosine_sims.cpu(), mode='markers'))
    fig.update_layout(
        title="Gradients of decoder features in Layer 0 wrt a Layer 1 feature against cosine similarities of same",
        xaxis_title='Gradient',
        yaxis_title='Cosine similarity')
    fig.show()

In [42]:
# Model wants loss to be lower and therefore the (+ve) L1 feature act to be higher. A negative gradient means you can do this by increasing the 
# L0 act, for arbitrary reasons that come down to the choice to perform gradient descent rather than ascent. 
# With gradient descent along one dimension on the x axis with loss on the y axis, a loss gradient that goes down and to the left 
# (increasing along the x axis) into a basin is negative and one that goes down into the basin from the right is positive. 
# So when we encounter a negative gradient we want to traverse along the x axis to the right, i.e. by increasing the activation value on the x axis. 
# This can correspond to subtracting something negative.

def change_l0_grad(value, hook):
    values, indices = torch.topk(value[0, -1], 10)
    value[0, -1, indices[0]] -= (grad[indices[0]] * 0.05)
    _, reconstruct, _, _, _ = l0_encoder(value.squeeze(0))
    return reconstruct.unsqueeze(0)
change_l0_grads = [(f'blocks.{l0_config.layer}.{l0_config.act_name}', change_l0_grad)]

with model.hooks(change_l0_grads + get_run_encoder_hook(l1_encoder, l1_config)):
    model(prompt)
    l1_feature_act = model.hook_dict['blocks.1.mlp.hook_post'].ctx['acts']
    print('L1 trigram feature act:', l1_feature_act[-1, 880])

L1 trigram feature act: tensor(7.4174, device='cuda:0', grad_fn=<SelectBackward0>)


In [66]:
# Do the same plots for way more data points by repeating the exercise for all the active features in turn
with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config)):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['acts']
    # start with just final position
    active_indices = torch.argwhere(acts > 5)

all_grads = []
with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config), bwd_hooks=access_grads):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['acts']

    for i, j in [row.tolist() for row in active_indices]:
        loss = -acts[i, j]
        loss.backward(retain_graph=True)
        grad = model.hook_dict['blocks.0.mlp.hook_post'].ctx['grad']
        all_grads.append(grad)

print(len(all_grads), "active directions")

with torch.no_grad():
    # n_feature d_mlp @ n_active_features d_mlp
    decoder_grads = l0_encoder.W_dec.cpu() @ torch.stack(all_grads).cpu().T
    decoder_grads = decoder_grads.flatten().cpu()

    # Scatter of feature gradients against cosine sims with layer 1 feature believed to contribute to trigram circuit
    W_out = model.W_out[0]
    W_in = model.W_in[1]
    cosine_sims = []
    for i, j in [row.tolist() for row in active_indices]:
        cosine_sims.append(F.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ F.normalize(W_in @ l1_encoder.W_enc[:, j], dim=0))
    cosine_sims = torch.concat(cosine_sims).squeeze(-1).cpu()

    fig = go.Figure(data=go.Scatter(x=decoder_grads, y=cosine_sims, mode='markers'))
    fig.update_layout(
        title="Gradients of decoder features in Layer 0 wrt each active Layer 1 feature against cosine similarities of same",
        xaxis_title='Gradient',
        yaxis_title='Cosine similarity')
    fig.show()

5
