In [None]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt

cfg = dotdict()
# models: "EleutherAI/pythia-70m-deduped", "EleutherAI/pythia-6.9b", "usvsnsp/pythia-6.9b-ppo", "lomahony/eleuther-pythia6.9b-hh-sft", "reciprocate/dahoas-gptj-rm-static"
cfg.target_name="EleutherAI/pythia-70m"
cfg.model_name="lomahony/pythia-70m-helpful-sft"
cfg.layers=[4]
cfg.setting="residual"
# cfg.tensor_name="gpt_neox.layers.{layer}" or "transformer.h.{layer}"
cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.target_tensor_name="gpt_neox.layers.{layer}"
original_l1_alpha = 8e-4
cfg.l1_alpha=original_l1_alpha
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=8
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
cfg.seed = 0
# cfg.device="cpu"

In [None]:
base_name = "sft"  # base, rm
capitalized_base_name = "SFT'd" # Base, RM, SFT'd
target_name = "base"  # sft, ppo
finetuning_type = "Base"  # SFT'd, RLHF'd

# base_name = "base"  # base, rm
# capitalized_base_name = "Base" # Base, RM, SFT'd
# target_name = "sft"  # sft, ppo
# finetuning_type = "SFT'd"  # SFT'd, RLHF'd


In [None]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]
target_tensor_names = [cfg.target_tensor_name.format(layer=layer) for layer in cfg.layers]

In [None]:
# Initialize New transfer autoencoder
# from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE, TransferSAE
from torch import nn

modes = ["scale", "rotation", "bias", "free"]
mode = "free"

mode_tsae = torch.load(f"trained_models/transfer_{base_name}_{target_name}_70m_{mode}.pt")
mode_tsae.to_device(cfg.device)



In [None]:
dead_features = torch.load(f"trained_models/{base_name}_dead_features_70m.pt")
live_features = dead_features > 10

In [None]:
live_features.sum()

In [None]:
import matplotlib.pyplot as plt
scales = mode_tsae.get_feature_scales()[live_features]
# feature_vectors = transfer_free.get_learned_dict()[live_features]
list_scales = list(scales.abs().detach().cpu().numpy())
print(sorted(list_scales))
# plt.plot(scales.detach().cpu().numpy())

In [None]:
plt.plot(sorted(list_scales))
plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM")
plt.xlabel("Feature")
plt.savefig(f"feature_scaling_{base_name}_{target_name}_{mode}_70m")

In [None]:
dead_features[live_features]

In [None]:
plt.scatter(dead_features[live_features].log10(), list_scales, s=5)
plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations")
plt.xlabel("Log10 of Total Feature Activations")
plt.ylabel("Feature Scaling")
plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_70m")

In [None]:
(sorted(list(dead_features[live_features])))