# Applying Computation in Superposition to transcoders

# Setup

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath('../transcoder_circuits/'))

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transcoder_circuits.circuit_analysis import *
from transcoder_circuits.feature_dashboards import *
from transcoder_circuits.replacement_ctx import *

from sae_training.sparse_autoencoder import SparseAutoencoder
from utils import tokenize_and_concatenate

In [None]:
import torch
import numpy as np
from einops import *
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

## Load model & data

In [None]:
from transformer_lens import HookedTransformer, utils
model = HookedTransformer.from_pretrained('gpt2-small').to(device)

# Load transcoder

In [None]:
transcoder_template = "../dufensky_transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
transcoders = []
sparsities = []
for i in range(12):
    transcoders.append(SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(i)}.pt").eval())
    sparsities.append(torch.load(f"{transcoder_template.format(i)}_log_feature_sparsity.pt"))

In [None]:
transcoder_layer = 8
transcoder_template = "../dufensky_transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(transcoder_layer)}.pt").to(device).eval()
sparsity = torch.load(f"{transcoder_template.format(transcoder_layer)}_log_feature_sparsity.pt")
live_features = np.arange(len(sparsity))[utils.to_numpy(sparsity > -4)]

# Find transcoder's intermediate vectors
For an encoder directions $\mathbf{e}_i$ with activation $z_i(x)$, we can find directions $\mathbf{d}_i$ such that $Wx \approx \sum_i z_i(x) \mathbf{d}_i$

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

In [None]:
from midcoder import MidcoderConfig, Midcoder

config = MidcoderConfig()
config.device = device
config.batch_size = 64
config.steps_per_epoch = 10
config.train_tokens = 100_000
config.log = False
midcoder = Midcoder({'model':model, 'transcoder':transcoder}, transcoder_layer, config)

In [None]:
midcoder.fit()

# Compare decoder features to predicted values

In [None]:
from midcoder import MidcoderConfig, Midcoder

config = MidcoderConfig()
config.device = 'mps'
config.mid_dim = 'd_mlp'
midcoder = Midcoder({'model':model, 'transcoder':transcoder}, transcoder_layer, config)

path = '../data/midcoder_layer8_v4_200M_tokens.pt'
midcoder.load_weights(path, strict=False)

In [None]:
feat_ids = midcoder.feat_ids.cpu()

In [None]:
sparsity_mask = sparsity[feat_ids] > -4

x = np.arange(len(sparsity))
y = sparsity.sort().values
plt.plot(x,y)
plt.xlabel('rank')
plt.ylabel('Log freq')
plt.title('Feature sparsity')

In [None]:
act_gate_counts = midcoder.gate_act_counts[sparsity_mask]
act_gate_mean = midcoder.act_gate_sum[sparsity_mask]/midcoder.act_counts[sparsity_mask].unsqueeze(1)
act_gate_mean = act_gate_mean.detach()


plt.figure(figsize=(5,4), dpi=150)
plt.hist(act_gate_mean[:500].flatten().cpu(), 100)
plt.yscale('log')
plt.xlabel('Elementwise E[act * gate | act > 0]')
plt.ylabel('Counts')
plt.title('Distribution of activation*gate averages')

In [None]:
mid_vector = midcoder.W_mid[feat_ids][sparsity_mask].detach()
b_mid = midcoder.b_mid

gate_mean = midcoder.gate_act_counts[sparsity_mask].detach() / midcoder.act_counts[sparsity_mask].unsqueeze(1)

pre_relu_mean = midcoder.pre_relu_sum.detach() / (midcoder.counts)
# mlp_out_mean = midcoder.mlp_out_sum.detach() / (midcoder.counts)
# b_out = midcoder.b_dec_out.detach()
# b_mid_out = midcoder.b_mid_out.detach()

pred_decoder = (act_gate_mean * mid_vector) @ midcoder.W_out
decoder = midcoder.W_dec[feat_ids][sparsity_mask].detach()



In [None]:
cos_sims = torch.nn.functional.cosine_similarity(decoder, pred_decoder, dim=1)

freq = midcoder.act_counts[sparsity_mask] / midcoder.counts

plt.figure(figsize=(8,3.5), dpi=150, layout='tight')
plt.subplot(1,2,1)
plt.hist(cos_sims.cpu(), 500)
plt.yscale('log')
plt.xlabel('Cosine Similarity')
plt.ylabel('Counts')
plt.suptitle('Comparison of decoder and predicted decoder')

plt.subplot(1,2,2)
# plt.plot(sparsity[sparsity_mask], cos_sims.cpu(), '.', alpha=0.03)
plt.plot(freq.cpu(), cos_sims.cpu(), '.', alpha=0.1)
plt.xlabel('Activation Frequency')
plt.ylabel('Cosine Similarity')
plt.xscale('log')
# plt.title('Comparison of decoder and predicted decoder')

In [None]:
mask = sparsity[feat_ids][sparsity_mask] > -2

plt.figure(figsize=(9,3.5), dpi=150, layout='tight')
plt.subplot(1,2,1)
sims = []

idxs = range(10)
for k in idxs:
    dec = decoder[mask][k]
    pred = pred_decoder[mask][k]
    dec /= dec.norm()
    pred /= pred.norm()
    sims.append(f'{cos_sims[mask][k].item():.2f}')
    plt.plot(dec.cpu(), pred.cpu(),'.', label=f'{cos_sims[mask][k].item():.2f}')
vec = np.arange(-0.15, 0.15, 0.01)
plt.plot(vec, vec, 'k--')
plt.legend(title='Cos Sims', prop={'size':7})
plt.ylabel('Predicted decoder')
plt.xlabel('Decoder')
plt.title('Elementwise comparison')
plt.xlim((-0.15,0.15))
plt.ylim((-0.15, 0.15))


plt.subplot(1,2,2)
for k in idxs:
    log_counts = torch.log10(act_gate_counts[mask][k] + 1e-1)
    plt.hist(log_counts.cpu(), 50, histtype='step')
plt.yscale('log')
plt.ylabel('')
plt.xlabel('Log10( Counts for Act > 0 and Gate = 1)', fontsize=12)
plt.title('Distribution of counts')

plt.suptitle('High freq features')
print(sims)

In [None]:
mask = (sparsity[feat_ids][sparsity_mask] > -3) & (sparsity[feat_ids][sparsity_mask] < -2.5)

plt.figure(figsize=(9,3.5), dpi=150, layout='tight')
plt.subplot(1,2,1)
sims = []

idxs = range(10)
for k in idxs:
    dec = decoder[mask][k]
    pred = pred_decoder[mask][k]
    dec /= dec.norm()
    pred /= pred.norm()
    sims.append(f'{cos_sims[mask][k].item():.2f}')
    plt.plot(dec.cpu(), pred.cpu(),'.', label=f'{cos_sims[mask][k].item():.2f}')
vec = np.arange(-0.15, 0.15, 0.01)
plt.plot(vec, vec, 'k--')
plt.legend(title='Cos Sims', prop={'size':7})
plt.ylabel('Predicted decoder')
plt.xlabel('Decoder')
plt.title('Elementwise comparison')
plt.xlim((-0.15,0.15))
plt.ylim((-0.15, 0.15))


plt.subplot(1,2,2)
for k in idxs:
    log_counts = torch.log10(act_gate_counts[mask][k] + 1e-1)
    plt.hist(log_counts.cpu(), 50, histtype='step')
plt.yscale('log')
plt.ylabel('')
plt.xlabel('Log10( Counts for Act > 0 and Gate = 1)', fontsize=12)
plt.title('Distribution of counts')

plt.suptitle('Med freq features')

print(sims)

In [None]:
mask = cos_sims < 0.7

plt.figure(figsize=(8,3.5), dpi=150, layout='tight')
plt.subplot(1,2,1)
sims = []

idxs = range(8)
for k in idxs:
    dec = decoder[mask][k]
    pred = pred_decoder[mask][k]
    dec /= dec.norm()
    pred /= pred.norm()
    sims.append(f'{cos_sims[mask][k].item():.2f}')
    plt.plot(dec.cpu(), pred.cpu(),'.', label=f'{cos_sims[mask][k].item():.2f}')
vec = np.arange(-0.15, 0.15, 0.01)
plt.plot(vec, vec, 'k--')
plt.legend(title='Cos Sims', prop={'size':7})


plt.subplot(1,2,2)
for k in idxs:
    log_counts = torch.log10(act_gate_counts[mask][k] + 1e-1)
    plt.hist(log_counts.cpu(), 50, histtype='step')
plt.yscale('log')

plt.suptitle('Features with low cosine similarity')

print(sims)

In [None]:
sing_prob.shape

In [None]:
prob_corr_log[0][:30]

In [None]:
prob_corr.topk()

In [None]:
x = pair_prob / sing_prob
x = x - torch.eye(*x.shape).to('mps')
px.imshow(x.cpu(), color_continuous_scale = 'RdBu', color_continuous_midpoint = 0)

In [None]:
x.topk(dim=1, k=3).indices[250:270]

In [None]:
sing_prob = (midcoder.act_counts / midcoder.counts).unsqueeze(1)
pair_prob = midcoder.act_cov_counts / midcoder.counts

sort_idxs = sparsity[feat_ids][sparsity_mask].argsort(descending=True)
sing_prob = sing_prob[sparsity_mask][sort_idxs]
pair_prob = pair_prob[sparsity_mask,:][:,sparsity_mask][sort_idxs,:][:,sort_idxs]

sing_prob_var = sing_prob * (1 - sing_prob)
prob_corr = (pair_prob - sing_prob @ sing_prob.T) / torch.sqrt(sing_prob_var @ sing_prob.T)
px.imshow(prob_corr.cpu(), color_continuous_scale = 'RdBu', color_continuous_midpoint = 0, range_color=[-0.3,0.3])

In [None]:
cos_sims = torch.nn.functional.cosine_similarity(decoder, pred_decoder, dim=1)

pretrain_counts = midcoder.act_counts[sparsity_mask]

plt.figure(figsize=(8,3.5), dpi=150, layout='tight')
plt.subplot(1,2,1)
plt.hist(cos_sims.cpu(), 500)
plt.yscale('log')
plt.xlabel('Cosine Similarity')
plt.ylabel('Counts')
plt.suptitle('Comparison of decoder and predicted decoder')

plt.subplot(1,2,2)
# plt.plot(sparsity[sparsity_mask], cos_sims.cpu(), '.', alpha=0.03)
plt.plot(torch.arange(len(cos_sims)), cos_sims[sort_idxs].cpu(), '.', alpha=0.1)
plt.xlim((-1,400))
plt.xlabel('Pretrain Counts')
plt.ylabel('Cosine Similarity')
# plt.xscale('log')
# plt.title('Comparison of decoder and predicted decoder')

In [None]:
plt.plot(torch.arange(len(sing_prob)), sing_prob[:,0].cpu(), '.')
plt.yscale('log')

In [None]:
mask1 = sparsity[feat_ids][sparsity_mask] > -2
mask2 = (sparsity[feat_ids][sparsity_mask] > -3.5) & (sparsity[feat_ids][sparsity_mask] < -3)
num = 100

dec1 = decoder[mask1]
dec1 /= dec1.norm(dim=1, keepdim=True)
dec2 = decoder[mask2][:num]
dec2 /= dec2.norm(dim=1, keepdim=True)

sims11 = einsum(dec1, dec1, "f1 d, f2 d -> f1 f2")
sims12 = einsum(dec1, dec2, "f1 d, f2 d -> f1 f2")
sims22 = einsum(dec2, dec2, "f1 d, f2 d -> f1 f2")

sims11_nodiag = sims11 - torch.eye(*sims11.shape).to('mps')
sims22_nodiag = sims22 - torch.eye(*sims22.shape).to('mps')
print(sims11_nodiag.max(dim=0))
print(sims12.max(dim=0))
print(sims22_nodiag.max(dim=0))

In [None]:
mask1 = sparsity[feat_ids][sparsity_mask] > -2
mask2 = (sparsity[feat_ids][sparsity_mask] > -3.5) & (sparsity[feat_ids][sparsity_mask] < -3)
num = 100

dec1 = pred_decoder[mask1]
dec1 /= dec1.norm(dim=1, keepdim=True)
dec2 = pred_decoder[mask2][:num]
dec2 /= dec2.norm(dim=1, keepdim=True)

sims11 = einsum(dec1, dec1, "f1 d, f2 d -> f1 f2")
sims12 = einsum(dec1, dec2, "f1 d, f2 d -> f1 f2")
sims22 = einsum(dec2, dec2, "f1 d, f2 d -> f1 f2")

sims11_nodiag = sims11 - torch.eye(*sims11.shape).to('mps')
sims22_nodiag = sims22 - torch.eye(*sims22.shape).to('mps')
print(sims11_nodiag.max(dim=0))
print(sims12.max(dim=0))
print(sims22_nodiag.max(dim=0))

# Decoder directions from gaussian noise

In [None]:
feat_ids = midcoder.feat_ids.cpu()
sparsity_mask = sparsity[feat_ids] > -3.5
act_freqs = midcoder.act_counts.detach() / midcoder.counts.detach()
act_freqs, sort_idxs = act_freqs.sort(descending=True)

In [None]:
act_mean = midcoder.act_sum[sort_idxs].detach() / midcoder.act_counts[sort_idxs].detach()
act_sq_mean = midcoder.act_sq_sum[sort_idxs].detach() / midcoder.act_counts[sort_idxs].detach()
act_var = act_sq_mean - act_mean**2

mid_vec = midcoder.W_mid[feat_ids][sort_idxs].detach()
b_mid = midcoder.b_mid.detach()

mid_mean = midcoder.pre_relu_sum.detach() / midcoder.counts.detach()
mid_cov = midcoder.pre_relu_cov_sum.detach() / midcoder.counts.detach()
mid_cov = mid_cov - mid_mean.unsqueeze(1) @ mid_mean.unsqueeze(0)
mid_var = mid_cov.diag()
mid_corr = mid_cov / torch.sqrt(mid_var.unsqueeze(1) @ mid_var.unsqueeze(0))

## Assuming covariances are negligible

In [None]:
mid_vec.shape

In [None]:
noise_mean = b_mid.unsqueeze(0) + mid_mean.unsqueeze(0) - act_mean.unsqueeze(1) * mid_vec
noise_var = mid_var.unsqueeze(0) - act_var.unsqueeze(1) * mid_vec**2

In [None]:
plt.plot(act_mean.cpu(), act_var.cpu(),'.', alpha=0.3)
vec = torch.arange(0.1, 15, 0.01)
plt.plot(vec, vec**2,'--', label='Exponential Dist.') #expectation for exponential
plt.legend()
plt.yscale('log')
plt.xscale('log')

In [None]:
exp_dist = torch.distributions.exponential.Exponential(0.1)
gauss_dist = torch.distributions.normal.Normal(0, 1)



## Check covariance structure of pre-relu values

In [None]:
px.imshow(mid_cov[:100,:100].cpu(), color_continuous_scale = 'RdBu', color_continuous_midpoint = 0)

In [None]:
weight_cov = midcoder.W_in.T  @ midcoder.W_in
px.imshow(weight_cov[:100,:100].cpu(), color_continuous_scale = 'RdBu', color_continuous_midpoint = 0)

# Numerology for sparsities

In [None]:
s = sparsity.sort(descending=True).values
s = 10**(s)

In [None]:
s[None,:200,None].shape

In [None]:
pairs = s[:200,None] * s[None,:200]
pairs = pairs.flatten()
trips = s[:200,None,None] * s[None,:200,None] * s[None,None,:200]
trips = trips.flatten()

t = torch.cat((s[:200], pairs, trips)).sort(descending=True)

In [None]:
plt.plot(s)
plt.plot(t[:25000])
plt.yscale('log')


In [None]:
prod = s[:200].unsqueeze(0) * s[:200].unsqueeze(1)
prod.flatten().sort().values

# Decoder similarities

In [None]:
sort_ids = sparsity.argsort(descending=True)
decoders = midcoder.W_dec[sort_ids]
decoders /= decoders.norm(dim=1, keepdim=True)

In [None]:
id = 2000
cos_sims = decoders[id] @ decoders.T

plt.hist(cos_sims.cpu(), 100);
plt.title(f'Id = {id}, log10(freq) = {sparsity[sort_idxs[id]]:.2f}')
plt.yscale('log')

In [None]:
topk = cos_sims.abs().topk(k=30)
topk

In [None]:
cos_sims2 = decoders[topk.indices] @ decoders[topk.indices].T
px.imshow(cos_sims2.cpu(), color_continuous_scale = 'RdBu', range_color=[-1,1])


In [None]:
feat_ids = midcoder.feat_ids.cpu()
feat_ids

In [None]:
feat_sort_ids = []
for id in feat_ids:
    feat_sort_ids.append((sort_ids==id).nonzero().item())

In [None]:
feat_sort_ids = torch.tensor(feat_sort_ids)
feat_sort_ids

In [None]:
from bidict import bidict

feat2sort_ids = bidict({})
for feat_id in feat_ids:
    feat_id = feat_id.item()
    sort_id = (sort_ids==feat_id).nonzero().item()
    feat2sort_ids[feat_id] = sort_id

In [None]:
saved_sort_ids = sorted(list(feat2sort_ids.values()))

# Factorizing Decoders

In [None]:
sort_ids = sparsity.argsort(descending=True)
decoders = transcoder.W_dec[sort_ids].detach()
decoders /= decoders.norm(dim=1, keepdim=True)
decoders = decoders[:20000]

In [None]:
plt.plot(sparsity[sort_ids])

In [None]:
from factorizer import Factorizer, FactorizerConfig

cfg = FactorizerConfig()
cfg.theta = 0.3
cfg.factor_param = 0.5
cfg.factors = 10000

cfg.batch_size = 1000
cfg.epochs = 100
cfg.log = False
cfg.device = 'mps'

factorizer = Factorizer(cfg, decoders)

In [None]:
factorizer.fit()

In [None]:
theta = 0.5
top_feat = 10000

n_feat = decoders.shape[0]
C = (decoders[:top_feat] @ decoders[:top_feat].T).flatten()
C[C<theta] = 0
C = C.reshape(top_feat, top_feat)

In [None]:
C_cpu = C.cpu().numpy()

In [None]:
eig = np.linalg.eigh(C_cpu)

In [None]:
plt.plot(eig.eigenvalues, '.-')

In [None]:
plt.plot(svd.S.cpu(),'.-')
plt.yscale('log')

In [None]:
s_cutoff = 1e-2
S = svd.S
mask = S >= s_cutoff
Sinv = 1/S[mask].sqrt()
Ainv = torch.diag(Sinv) @ svd.U[mask].T 
features = Ainv @ decoders[:top_feat]

In [None]:
mask = svd.S > 1
svd.U[:,mask] / svd.V[:,mask]

In [None]:
svd.U.sum(dim=1, keepdim=True)