In [None]:
import torch
import numpy as np
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt


cfg = dotdict()
# models: "EleutherAI/pythia-6.9b", "usvsnsp/pythia-6.9b-ppo", "lomahony/eleuther-pythia6.9b-hh-sft", "reciprocate/dahoas-gptj-rm-static"
# "EleutherAI/pythia-70m", "lomahony/eleuther-pythia70m-hh-sft"
cfg.model_name="EleutherAI/pythia-70m"
cfg.target_name="lomahony/eleuther-pythia70m-hh-sft"
cfg.layers=[0,1,2,3,4,5]
cfg.setting="residual"
# cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.tensor_name="gpt_neox.layers.{layer}" # "gpt_neox.layers.{layer}" (pythia), "transformer.h.{layer}" (rm)
cfg.target_tensor_name="gpt_neox.layers.{layer}"
original_l1_alpha = 2e-3
cfg.l1_alpha=original_l1_alpha
cfg.l1_alphas=[2e-3,]
# cfg.l1_alphas=[0, 1e-5, 1e-4, 2e-4, 4e-4, 8e-4, 1e-3, 2e-3, 4e-3, 8e-3]
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=8 * 8
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
cfg.seed = 0
cfg.max_length = 256
# cfg.device="cpu"

In [None]:
# base_name = "sft"  # base, rm
# capitalized_base_name = "SFT'd" # Base, RM, SFT'd
# target_name = "base"  # sft, ppo
# finetuning_type = "Base"  # SFT'd, RLHF'd

base_name = "base"  # base, rm
capitalized_base_name = "Base" # Base, RM, SFT'd
target_name = "sft"  # sft, ppo
finetuning_type = "SFT'd"  # SFT'd, RLHF'd


In [None]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]
target_tensor_names = [cfg.target_tensor_name.format(layer=layer) for layer in cfg.layers]

In [None]:
# Initialize New transfer autoencoder
# from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE, TransferSAE
from torch import nn

modes = ["scale", "rotation", "bias", "free"]
mode = "dfree"

checkpoint = 0

# mode_tsae = torch.load(f"trained_models/transfer_{base_name}_{target_name}_70m_{mode}.pt")
base_saes = []
mode_tsaes = []
autoTEDs = []
dead_features = []
frequencies = []

model_dir = "/root/sparse_coding/trained_models"
for layer in cfg.layers:
    l1_saes = []
    l1_tsaes = []
    l1_autoTEDS = []
    l1_dead_features = []
    l1_freq = []
    for l1_alpha in cfg.l1_alphas:
        sae = torch.load(f"{model_dir}/base_sae_70m/base_sae_70m_{layer}_{l1_alpha}.pt")
        sae.to_device(cfg.device)
        l1_saes.append(sae)
        
        # mode_tsae = torch.load(f"{model_dir}/transfer_base_chess_70m/transfer_base_chess_70m_{mode}_{layer}_{l1_alpha}.pt")
        # mode_tsae.to_device(cfg.device)
        # l1_tsaes.append(mode_tsae)
        
        autoTED = torch.load(f"{model_dir}/base_retrain_70m/base_autoTED_70m_{mode}_{layer}_{l1_alpha}.pt")
        autoTED.to_device(cfg.device)
        l1_autoTEDS.append(autoTED)
        
        c = torch.load(f"{model_dir}/base_dead_features_70m/base_dead_features_70m_{mode}_{layer}_{l1_alpha}.pt")
        c.to(cfg.device)
        l1_dead_features.append(c)
        
        f = torch.load(f"{model_dir}/base_frequency_70m/base_frequency_70m_{mode}_{layer}_{l1_alpha}.pt")
        f.to(cfg.device)
        l1_freq.append(f)
    
    base_saes.append(l1_saes)
    mode_tsaes.append(l1_tsaes)
    autoTEDs.append(l1_autoTEDS)
    dead_features.append(l1_dead_features)
    frequencies.append(l1_freq)


In [None]:
strong_cosims = []
weak_cosims = []
all_cosims = []
for layer in cfg.layers:
    l1_id = 0 # corresponds to cfg.l1_alphas[l1_id]
    base_sae = base_saes[layer][l1_id]
    autoTED = autoTEDs[layer][l1_id]
    # mode_tsae = mode_tsaes[layer][l1_id]
    dead_feature = dead_features[layer][l1_id]
    frequency = frequencies[layer][l1_id]
    sorted_acts = sorted((frequency.numpy().copy()))
    
    cutoff = sum(frequency< 1e-5)
    upper_boundary = sorted_acts[cutoff]
    vip_features = frequency > upper_boundary
    # transition_features = (frequency > lower_boundary) * (frequency <= upper_boundary)
    weak_features = (frequency > 0) * (frequency <= upper_boundary)
    
    
    auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()

    auto_rotation_f = auto_rotation[vip_features].numpy()
    strong_cosims.append(auto_rotation_f.mean())
    auto_rotation_f = auto_rotation[weak_features].numpy()
    weak_cosims.append(auto_rotation_f.mean())
    auto_rotation_f = auto_rotation[frequency > 0].numpy()
    all_cosims.append(auto_rotation_f.mean())
    
    print(f"Layer {layer} strong_cosim {strong_cosims[-1]:.3f} weak_cosim {weak_cosims[-1]:.3f} all_cosim {all_cosims[-1]:.3f}")
    

In [None]:
# Pick just one to focus on
layer = 4
l1_id = 0 # corresponds to cfg.l1_alphas[l1_id]
base_sae = base_saes[layer][l1_id]
autoTED = autoTEDs[layer][l1_id]
# mode_tsae = mode_tsaes[layer][l1_id]
dead_feature = dead_features[layer][l1_id]
frequency = frequencies[layer][l1_id]

sorted_acts = sorted((frequency.numpy().copy()))
print(sum(frequency> 1/100))

# lower_index = 1260 # 800 for 8e-4
upper_index = 300 # 950 for 8e-4

# lower_boundary = sorted_acts[lower_index]
upper_boundary = sorted_acts[upper_index]
vip_features = frequency > upper_boundary
# transition_features = (frequency > lower_boundary) * (frequency <= upper_boundary)
weak_features = frequency <= upper_boundary

In [None]:
# Set Style and Colors
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.style as style

# for st in style.available:
#     style.use(st)
#     colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 
#     # print('\n'.join(color for color in colors))  
#     print(f"{len(colors)}: {st}")
    
style.use("seaborn-v0_8")
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 

# colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 
# mode_colors = colors + []

# style.use("default")
# sns.set_style(style=None, rc = {})
palette = sns.color_palette(palette="Set3", n_colors=8, desat=None, as_cmap=False)
display(palette)


palette2 = sns.color_palette(palette=None, n_colors=6, desat=None, as_cmap=False) 

palette3 = sns.color_palette("RdPu", 10)

palette = [palette[5]] + palette2 + [palette3[4]]
# sns.set_palette(palette)
display(palette2)
display(palette3)
style.use("ggplot")

In [None]:
# torch.save(weak_features, "weak_features.p")
# torch.save(transition_features, "transition_features.p")
# torch.save(vip_features, "vip_features.p")


In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
reference_scales = autoTED.get_feature_scales().abs().detach().cpu()
auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()
filtering = vip_features

x = reference_scales[filtering]
y = auto_rotation[filtering]

model = LinearRegression()
model.fit(x[:,None],y)
r_sq = model.score(x[:,None], y)

xseq = np.linspace(0.8, 2, num=100)
plt.plot(xseq, model.predict(xseq[:,None]), color="black")

plt.scatter(x, y, s=3)
# plt.colorbar()
plt.ylabel("Cosine Similarity")
plt.xlabel("Scaling")
plt.title("Cosine Similarity vs Feature Scaling")
print(r_sq)

In [None]:
auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()
filtering = vip_features

plt.figure(dpi=1200)
style.use("ggplot")

orig_map=plt.cm.get_cmap('viridis') 
  
# reversing the original colormap using reversed() function 
reversed_map = orig_map.reversed() 
  
# making the scatter plot on x and y   
# giving color to the plot with respect 
# to y and passing cmap=reversed_map to reverse the colormap 

plt.scatter(frequency[filtering].log10(), (dead_feature[filtering]).log10(), c=auto_rotation[filtering], cmap=reversed_map, vmin=0.94)
plt.colorbar().ax.set_ylabel('Cosine Similarity', rotation=270, labelpad=15)
plt.xlabel("Log10 Frequency")
plt.ylabel("Log10 Sum of Activations")
plt.title("Cosine Similarity vs Frequency and Sum of Activations")
plt.tight_layout()


In [None]:
auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()
filtering = vip_features

plt.scatter(frequency[filtering].log10(), (dead_feature/(100_000_000*frequency))[filtering].log10(), c=auto_rotation[filtering], vmin=0.95, vmax=1)
plt.colorbar()
plt.xlabel("Log10 Frequency")
plt.ylabel("Log10 Activation Strength")
plt.title("Cosine Similarity vs Frequency and Strength of Activations")

In [None]:
auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()
# transfer_rotation = (autoTED.get_learned_dict() * mode_tsae.get_learned_dict()).sum(dim=-1).detach().cpu()
# total_rotation = (base_sae.get_learned_dict() * mode_tsae.get_learned_dict()).sum(dim=-1).detach().cpu()

filtering = vip_features
auto_rotation_f = auto_rotation[filtering].numpy()
# transfer_rotation_f = transfer_rotation[filtering].numpy()
# total_rotation_f = total_rotation[filtering].numpy()

print(f"Mean dot product (auto): {sum(auto_rotation_f)/len(auto_rotation_f)}")
# print(f"Mean dot product (transfer): {sum(transfer_rotation_f)/len(transfer_rotation_f)}")
plt.plot(sorted(auto_rotation_f))
plt.title("AutoTED Feature Similarity, Pythia 70m Layer 5")
plt.xlabel("Strong Features")
# plt.plot(scales.detach().cpu().numpy())

In [None]:
# plot autoTED rotations
# reference_scales = autoTED.get_feature_scales().abs().detach().cpu()
# list_scales = list((reference_scales).numpy())

auto_rotation = (base_sae.get_learned_dict() * autoTED.get_learned_dict()).sum(dim=-1).detach().cpu()

act_strength = (dead_feature/(100_000_000*frequency))

# transfer_rotation = (autoTED.get_learned_dict() * mode_tsae.get_learned_dict()).sum(dim=-1).detach().cpu()
# total_rotation = (base_sae.get_learned_dict() * mode_tsae.get_learned_dict()).sum(dim=-1).detach().cpu()

for filtering in [weak_features, vip_features]:
    auto_rotation_f = auto_rotation[filtering].numpy()
    print(f"Filtering leads to average rotation of: {auto_rotation_f.mean()}")

plt.scatter(act_strength.log10().numpy(), auto_rotation, s=5)
# plt.vlines([np.log10(lower_boundary),np.log10(upper_boundary)],0,6, colors=['black'], linestyles='dashed')
# plt.axvline(np.log10(lower_boundary),color='black', linestyle=':')
plt.axvline(np.log10(upper_boundary),color='black', linestyle=':')
plt.axhline(1, color='black', linestyle='--')
# plt.ylim(0.7,1.9)

plt.title("Retrained Decoder Feature Similarity, Pythia 70m Layer 5")
plt.xlabel("Log10 Frequency")
plt.ylabel("Dictionary Cosine Similarity")
plt.show()

In [None]:
plt.scatter(dead_feature.log10(), transfer_rotation, s=5)
plt.vlines([np.log10(lower_boundary),np.log10(upper_boundary)],0.6,1, colors=['black'],linestyles='dashed')
plt.ylim(0.6,1.0)
plt.title("Adjusted Finetuned Feature Similarity, Pythia 70m Layer 5")
plt.xlabel("Log10 Feature Activations")

In [None]:
dead_features[live_features]

In [None]:
for lay in range(len(cfg.layers)):
    for l1 in range(len(cfg.l1_alphas)):
        layer = cfg.layers[lay]
        l1_alpha = cfg.l1_alphas[l1]
        
        mode_tsae = mode_tsaes[lay][l1]
        autoTED = autoTEDs[lay][l1]
        acts = dead_features[lay][l1]
        
        scales = mode_tsae.get_feature_scales()
        reference_scales = autoTED.get_feature_scales()
        list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())

        plt.scatter(acts.log10(), list_scales, s=5)
        plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer}, L1 {l1_alpha})")
        plt.xlabel("Log10 of Total Feature Activations")
        plt.ylabel("Feature Scaling")
        plt.ylim(0,6)
        # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
lay = 3
# l1 = 12
for l1 in range(0,len(cfg.l1_alphas), 3):
    layer = cfg.layers[lay]
    l1_alpha = cfg.l1_alphas[l1]
    print(l1_alpha)

    mode_tsae = mode_tsaes[lay][l1]
    autoTED = autoTEDs[lay][l1]
    acts = dead_features[lay][l1]

    print(sum(acts < 10**6))
    print(sum(acts > 10**6))

    scales = mode_tsae.get_feature_scales()
    reference_scales = autoTED.get_feature_scales()
    list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())
    list_scales = list((reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((scales).abs().detach().cpu().numpy())

    plt.scatter(acts.log10(), list_scales, s=5)
    plt.title(f"AutoTED Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer})")
    plt.xlabel("Log10 of Total Feature Activations")
    plt.ylabel("Feature Scaling")
    plt.ylim(0,6)
    # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
lay = 4
# l1 = 12
for l1 in range(0,len(cfg.l1_alphas), 3):
    layer = cfg.layers[lay]
    l1_alpha = cfg.l1_alphas[l1]

    mode_tsae = mode_tsaes[lay][l1]
    autoTED = autoTEDs[lay][l1]
    acts = dead_features[lay][l1]
    print(sum(acts < 10**4))
    print(sum((acts > 10**4) * (acts < 10**6)))
    print(sum(acts > 10**6))

    scales = mode_tsae.get_feature_scales()
    reference_scales = autoTED.get_feature_scales()
    list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((scales).abs().detach().cpu().numpy())

    plt.scatter(acts.log10(), list_scales, s=2)
    plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer})")
    plt.xlabel("Log10 of Total Feature Activations")
    plt.ylabel("Feature Scaling")
    plt.ylim(0,6)
    # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
# most scaled features
lay = 4
l1 = 8
layer = cfg.layers[lay]
l1_alpha = cfg.l1_alphas[l1]

print(f"Layer {lay} (0 indexed), l1 {l1_alpha}")

mode_tsae = mode_tsaes[lay][l1]
autoTED = autoTEDs[lay][l1]
acts = dead_features[lay][l1]
print(sum(acts < 10**4))
print(sum((acts > 10**4) * (acts < 10**6)))
print(sum(acts > 10**6))

large_filter = acts > 10**6

scales = mode_tsae.get_feature_scales()
reference_scales = autoTED.get_feature_scales()
list_scales = ((scales/reference_scales).abs().detach().cpu().numpy())

argsorted_scales = [id for id in np.argsort(list_scales) if large_filter[id]==1]  # filtered to only include features with large activations

print(f"Top 10 scaled are: {argsorted_scales[-10:][::-1]}")
print(f"Bottom 10 scaled are: {argsorted_scales[:10]}")

In [None]:
argsorted_scales[0]

In [None]:
large_filter[662]

In [None]:
list_scales[662]