In [1]:
import re
import json
import pickle
import os
import sys
import torch
from transformer_lens import HookedTransformer
import plotly.io as pio
import pandas as pd
from itertools import chain
import plotly.graph_objects as go
import torch.nn.functional as F
import einops

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, generate_with_encoder
import utils.haystack_utils as haystack_utils
from process_tiny_stories_data import load_tinystories_validation_prompts


%reload_ext autoreload
%autoreload 2

In [2]:
# Load global state
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)

l0_encoder, l0_config = load_encoder('18_morning_sun', model_name, model)
l1_encoder, l1_config = load_encoder('8_deep_brook', model_name, model)
prompts = load_tinystories_validation_prompts()

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [3]:
# For single prompt: compare the derivative of the L1 feature wrt L0 mlp input to layer, to linear map (assuming GELU is identity)
torch.autograd.set_grad_enabled(True)
torch.set_grad_enabled(True)
for param in chain(model.parameters(), l0_encoder.parameters(), l1_encoder.parameters()):
    param.requires_grad = True

prompt = "He had to help Daisy. He was about"

def get_run_encoder_hook(encoder, config):
    def hook(value, hook):
        value = value.squeeze(0)
        _, reconstruct, acts, _, _ = encoder(value)
        hook.ctx['mlp_acts'] = value#.detach()
        hook.ctx["encoder_acts"] = acts#.detach()
        return reconstruct.unsqueeze(0)
    return [(f'blocks.{config.layer}.{config.act_name}', hook)]

def access_grad(value, hook):
    hook.ctx['grad'] = value[0, -1]
access_grads = [('blocks.0.mlp.hook_post', access_grad)]

with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config), bwd_hooks=access_grads):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['encoder_acts']
    print('L1 trigram feature act:', acts[-1, 880])

    loss = -acts[-1, 880]
    loss.backward(retain_graph=True)
    grad = model.hook_dict['blocks.0.mlp.hook_post'].ctx['grad']

L1 trigram feature act: tensor(7.4160, device='cuda:0', grad_fn=<SelectBackward0>)


In [4]:
with torch.no_grad():
    acts = model.hook_dict['blocks.0.mlp.hook_post'].ctx['encoder_acts']
    active_features = torch.nonzero(acts[-1, :] > 0.1).squeeze(1)
    decoder_grads = (l0_encoder.W_dec @ grad).cpu()

    # Histogram of feature gradients, feature believed to contribute to trigram circuit marked in red
    fig = go.Figure(data=[go.Histogram(x=decoder_grads)])
    fig.add_trace(go.Scatter(x=[decoder_grads[7105], decoder_grads[7105]], y=[0, 500], mode='lines', name='Line', line=dict(color='red')))
    fig.update_layout(
        title='Gradients of decoder features in Layer 0 wrt feature 880 in Layer 1',
        xaxis_title='Gradient',
        yaxis_title='Feature count'
    )
    fig.show()

    # Scatter of feature gradients against cosine sims with layer 1 feature believed to contribute to trigram circuit
    W_out = model.W_out[0]
    W_in = model.W_in[1]
    cosine_sims = F.normalize(l0_encoder.W_dec[active_features] @ W_out, dim=-1) @ F.normalize(W_in @ l1_encoder.W_enc[:, 880], dim=0)
    # print(cosine_sims[7105])

    fig = go.Figure(data=go.Scatter(x=decoder_grads[active_features.cpu()], y=cosine_sims.cpu(), mode='markers'))
    fig.update_layout(
        title="Gradients of decoder features in Layer 0 wrt a Layer 1 feature against cosine similarities of same",
        xaxis_title='Gradient',
        yaxis_title='Cosine similarity')
    fig.show()

In [5]:
# Model wants loss to be lower and therefore the (+ve) L1 feature act to be higher. A negative gradient means you can do this by increasing the 
# L0 act, for arbitrary reasons that come down to the choice to perform gradient descent rather than ascent. 
# With gradient descent along one dimension on the x axis with loss on the y axis, a loss gradient that goes down and to the left 
# (increasing along the x axis) into a basin is negative and one that goes down into the basin from the right is positive. 
# So when we encounter a negative gradient we want to traverse along the x axis to the right, i.e. by increasing the activation value on the x axis. 
# This can correspond to subtracting something negative.

def change_l0_grad(value, hook):
    values, indices = torch.topk(value[0, -1], 10)
    value[0, -1, indices[0]] -= (grad[indices[0]] * 0.05)
    _, reconstruct, _, _, _ = l0_encoder(value.squeeze(0))
    return reconstruct.unsqueeze(0)
change_l0_grads = [(f'blocks.{l0_config.layer}.{l0_config.act_name}', change_l0_grad)]

with model.hooks(change_l0_grads + get_run_encoder_hook(l1_encoder, l1_config)):
    model(prompt)
    l1_feature_act = model.hook_dict['blocks.1.mlp.hook_post'].ctx['encoder_acts']
    print('L1 trigram feature act:', l1_feature_act[-1, 880])

L1 trigram feature act: tensor(7.4174, device='cuda:0', grad_fn=<SelectBackward0>)


In [6]:
def access_grad_all_pos(value, hook):
    hook.ctx['grad'] = value[0]
access_grads_all_pos = [('blocks.0.mlp.hook_post', access_grad_all_pos)]

# Do the same plots for way more data points by repeating the exercise for all the active features in turn
with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config)):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['encoder_acts']
    #print(acts.shape, torch.argwhere(acts > 2).shape, torch.argwhere(acts > 2))
    # start with just final position
    active_indices = torch.argwhere(acts > 2)

all_grads = []
all_masks = []
with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config) + get_run_encoder_hook(l1_encoder, l1_config), bwd_hooks=access_grads_all_pos):
    model(prompt)
    acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['encoder_acts']

    for i, j in [row.tolist() for row in active_indices]:
        loss = -acts[i, j]
        loss.backward(retain_graph=True)
        grad = model.hook_dict['blocks.0.mlp.hook_post'].ctx['grad']
        l0_acts = model.hook_dict["blocks.0.mlp.hook_post"].ctx["encoder_acts"]
        l0_mask = l0_acts > 1
        print(l0_acts.shape, grad.shape)
        #print(grad.shape, acts.shape)
        all_grads.append(grad)
        all_masks.append(l0_mask)

print(len(all_grads), "active directions")

with torch.no_grad():
    # n_feature d_mlp @ n_active_features d_mlp
    print(l0_encoder.W_dec.shape, torch.stack(all_grads).shape)
    decoder_weights_masked = einops.repeat(l0_encoder.W_dec.cpu(), "d_enc d_mlp -> n_pos d_enc d_mlp", n_pos=acts.shape[0])[torch.stack(all_masks).cpu()]
    print(decoder_weights_masked.shape)
    decoder_grads = einops.einsum(l0_encoder.W_dec.cpu(), torch.stack(all_grads).cpu(), "d_enc d_mlp, n_feature n_pos d_mlp -> d_enc n_feature n_pos")
    #decoder_grads = l0_encoder.W_dec.cpu() @ torch.stack(all_grads).cpu().T.reshape(-1, model.cfg.d_mlp)
    decoder_grads = decoder_grads.flatten().cpu()

    # Scatter of feature gradients against cosine sims with layer 1 feature believed to contribute to trigram circuit
    W_out = model.W_out[0]
    W_in = model.W_in[1]
    cosine_sims = []
    for i, j in [row.tolist() for row in active_indices]:
        cosine_sims.append(F.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ F.normalize(W_in @ l1_encoder.W_enc[:, j], dim=0))
    cosine_sims = torch.concat(cosine_sims).squeeze(-1).cpu()

    # fig = go.Figure(data=go.Scatter(x=decoder_grads, y=cosine_sims, mode='markers'))
    # fig.update_layout(
    #     title="Gradients of decoder features in Layer 0 wrt each active Layer 1 feature against cosine similarities of same",
    #     xaxis_title='Gradient',
    #     yaxis_title='Cosine similarity')
    # fig.show()

torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
torch.Size([10, 16384]) torch.Size([10, 4096])
13 active directions
torch.Size([16384, 4096]) torch.Size([13, 10, 4096])


IndexError: The shape of the mask [13, 10, 16384] at index 0 does not match the shape of the indexed tensor [10, 16384, 4096] at index 0

In [65]:
# Get L0 active features
# Get L1 active features
# For each position
import torch.nn.functional as F

prompt = prompts[0]

def access_grad_all_pos(value, hook):
    hook.ctx['grad'] = value[0]
access_grads_all_pos = [('blocks.0.mlp.hook_post', access_grad_all_pos)]

def get_run_encoder_hook(encoder, config):
    def hook(value, hook):
        value = value.squeeze(0)
        _, reconstruct, acts, _, _ = encoder(value)
        hook.ctx["encoder_acts"] = acts
        return reconstruct.unsqueeze(0)
    return [(f'blocks.{config.layer}.{config.act_name}', hook)]

# Run first to avoid reconstructing through l0 encoder
with model.hooks(fwd_hooks=get_run_encoder_hook(l0_encoder, l0_config)):
    model(prompt)
    l0_acts = model.hook_dict["blocks.0.mlp.hook_post"].ctx["encoder_acts"]

all_grads = []
all_sims = []
with model.hooks(fwd_hooks=get_run_encoder_hook(l1_encoder, l1_config), bwd_hooks=access_grads_all_pos):
    model(prompt)
    l1_acts = model.hook_dict['blocks.1.mlp.hook_post'].ctx['encoder_acts']
    num_pos = l1_acts.shape[0]

    for pos in range(1, num_pos):
        active_l1_features = torch.argwhere(l1_acts[pos] > 1).flatten()
        active_l0_features = torch.argwhere(l0_acts[pos] > 1).flatten()
        for active_l1_feature in active_l1_features:
            # Calculate gradient w.r.t. to single active L1 feature
            loss = -l1_acts[pos, active_l1_feature]
            loss.backward(retain_graph=True)

            # Retrieve L0 gradients only for active L0 features
            grad = model.hook_dict['blocks.0.mlp.hook_post'].ctx['grad'][pos]
            active_decoder_weights = l0_encoder.W_dec[active_l0_features]
            decoder_grads = einops.einsum(active_decoder_weights, grad, "d_enc_active d_mlp, d_mlp -> d_enc_active")
            
            # Cosine sims between active L0 features and current active L1 feature
            decoder_w_out = active_decoder_weights @ model.W_out[0]
            encoder_w_in = model.W_in[1] @ l1_encoder.W_enc[:, active_l1_feature]
            cosine_sims = F.cosine_similarity(decoder_w_out, encoder_w_in.unsqueeze(0), dim=1).cpu()
            all_grads.append(decoder_grads.cpu().detach())
            all_sims.append(cosine_sims.detach())

all_grads = torch.cat(all_grads)
all_sims = torch.cat(all_sims)
print(all_grads.shape, all_sims.shape)

        

torch.Size([2435]) torch.Size([2435])


In [66]:
fig = go.Figure(data=go.Scatter(x=all_grads.abs(), y=all_sims.abs(), mode='markers'))
fig.update_layout(
    title="Comparison of absolute cosine sims and gradients for active feature pairs in L0 and L1",
    xaxis_title='Gradient',
    yaxis_title='Cosine similarity',
    width = 1000)
fig.show()