In [20]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
#torch.autograd.set_grad_enabled(False)
#torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder

%reload_ext autoreload
%autoreload 2

In [22]:
def pickle_pt(name: str, path: Path):
    autoencoder = torch.load(os.path.join(path, name + '.pt'))
    with open(os.path.join(path, name + '_2' + '.pkl'), 'wb') as f:
        pickle.dump(autoencoder, f)

pickle_pt(name='mlp.hook_post_l5', path=Path('pythia-70m'))


In [23]:
german_data = haystack_utils.load_json_data("data/german_europarl.json")

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.


Using pad_token, but it is not set yet.


In [14]:
# Load 70m dict
d_in = model.cfg.d_mlp
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.01

autoencoder_70m = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
autoencoder_70m_filename = "pythia-70m/mlp.hook_post_l5_2.pkl"
with open(autoencoder_70m_filename, 'rb') as f:
    autoencoder_70m_state_dict = pickle.load(f)
autoencoder_70m.load_state_dict(autoencoder_70m_state_dict)
autoencoder_70m.to(device)

def evaluate_dict(autoencoder: AutoEncoder, encoded_hook_name: str, german_data: list):
    with torch.no_grad():
        def encode_activations_hook(value, hook):
            value = value.squeeze(0)
            _, x_reconstruct, _, _, _ = autoencoder(value)
            return x_reconstruct.unsqueeze(0)

        hooks = [(encoded_hook_name, encode_activations_hook)]

        original_losses = []
        reconstruct_losses = []
        for prompt in tqdm(german_data[:200]):
            original_loss = model(prompt, return_type="loss")
            with model.hooks(hooks):
                reconstruct_loss = model(prompt, return_type="loss")
            original_losses.append(original_loss.item())
            reconstruct_losses.append(reconstruct_loss.item())

    print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

evaluate_dict(autoencoder_70m, "blocks.5.mlp.hook_post", german_data=german_data)

  0%|          | 0/200 [00:00<?, ?it/s]

Average loss increase after encoding: 0.8281


In [None]:
d_in = model.cfg.d_model
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.01

our_autoencoder = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
our_autoencoder_filename = "pythia-160m/hook_mlp_out_l8.pt"
our_autoencoder.load_state_dict(torch.load(our_autoencoder_filename))
our_autoencoder.to(device)

with open("pythia-160m/hook_mlp_out_l8.pkl", "wb") as f:
    pickle.dump(our_autoencoder, f)

In [None]:
# Evaluate our dict
with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = our_autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

In [None]:
# Evaluate Logan's one
autoencoder2.to_device(device)
model.to(device)

with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        acts = autoencoder2.encode(value)
        out = autoencoder2.decode(acts)
        return out.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")