In [None]:
import torch
import numpy as np
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt


cfg = dotdict()
# models: "EleutherAI/pythia-6.9b", "usvsnsp/pythia-6.9b-ppo", "lomahony/eleuther-pythia6.9b-hh-sft", "reciprocate/dahoas-gptj-rm-static"
# "EleutherAI/pythia-70m", "lomahony/eleuther-pythia70m-hh-sft"
cfg.model_name="EleutherAI/pythia-70m"
cfg.target_name="BlueSunflower/Pythia-70M-chess"
cfg.layers=[0,1,2,3,4,5]
cfg.setting="residual"
# cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.tensor_name="gpt_neox.layers.{layer}" # "gpt_neox.layers.{layer}" (pythia), "transformer.h.{layer}" (rm)
cfg.target_tensor_name="gpt_neox.layers.{layer}"
original_l1_alpha = 2e-3
cfg.l1_alpha=original_l1_alpha
cfg.l1_alphas=[2e-3,]
# cfg.l1_alphas=[0, 1e-5, 1e-4, 2e-4, 4e-4, 8e-4, 1e-3, 2e-3, 4e-3, 8e-3]
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=8 * 8
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cpu"
cfg.ratio = 4
cfg.seed = 0
cfg.max_length = 256
# cfg.device="cpu"

In [None]:
# base_name = "sft"  # base, rm
# capitalized_base_name = "SFT'd" # Base, RM, SFT'd
# target_name = "base"  # sft, ppo
# finetuning_type = "Base"  # SFT'd, RLHF'd

base_name = "base"  # base, rm
capitalized_base_name = "Base" # Base, RM, SFT'd
target_name = "sft"  # sft, ppo
finetuning_type = "SFT'd"  # SFT'd, RLHF'd


In [None]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]
target_tensor_names = [cfg.target_tensor_name.format(layer=layer) for layer in cfg.layers]

In [None]:
# Initialize New transfer autoencoder
# from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE, TransferSAE
from torch import nn

# modes = ["scale", "rotation", "bias", "free"]
mode = "dfree"

checkpoint = 0

# mode_tsae = torch.load(f"trained_models/transfer_{base_name}_{target_name}_70m_{mode}.pt")
mode_tsaes = []
autoTEDs = []
dead_features = []
frequencies = []

model_dir = "/root/sparse_coding/trained_models"
for layer in cfg.layers:
    l1_tsaes = []
    l1_autoTEDS = []
    l1_dead_features = []
    l1_freq = []
    for l1_alpha in cfg.l1_alphas:
        # mode_tsae = torch.load(f"{model_dir}/transfer_base_chess_70m/transfer_base_chess_70m_{mode}_{layer}_{l1_alpha}.pt")
        # mode_tsae.to_device(cfg.device)
        # l1_tsaes.append(mode_tsae)
        
        autoTED = torch.load(f"{model_dir}/base_retrain_70m/base_autoTED_70m_{mode}_{layer}_{l1_alpha}.pt")
        autoTED.to_device(cfg.device)
        l1_autoTEDS.append(autoTED)
        
        c = torch.load(f"{model_dir}/base_dead_features_70m/base_dead_features_70m_{mode}_{layer}_{l1_alpha}.pt")
        c.to(cfg.device)
        l1_dead_features.append(c)
        
        f = torch.load(f"{model_dir}/base_frequency_70m/base_frequency_70m_{mode}_{layer}_{l1_alpha}.pt")
        f.to(cfg.device)
        l1_freq.append(f)
        
    mode_tsaes.append(l1_tsaes)
    autoTEDs.append(l1_autoTEDS)
    dead_features.append(l1_dead_features)
    frequencies.append(l1_freq)


In [None]:
# Set Style and Colors
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.style as style

# for st in style.available:
#     style.use(st)
#     colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 
#     # print('\n'.join(color for color in colors))  
#     print(f"{len(colors)}: {st}")
    
style.use("seaborn-v0_8")
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 

# colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 
# mode_colors = colors + []

# style.use("default")
# sns.set_style(style=None, rc = {})
palette = sns.color_palette(palette="Set3", n_colors=8, desat=None, as_cmap=False)
display(palette)


palette2 = sns.color_palette(palette=None, n_colors=6, desat=None, as_cmap=False) 

palette3 = sns.color_palette("RdPu", 10)

palette = [palette[5]] + palette2 + [palette3[4]]
# sns.set_palette(palette)
display(palette2)
display(palette3)
style.use("ggplot")

In [None]:
# lower_index = 240 # 800 for 8e-4
plt.figure(dpi=1200) 
upper_index = 300 # 950 for 8e-4
for layer in cfg.layers:
    l1_id = 0 # corresponds to cfg.l1_alphas[l1_id]
    autoTED = autoTEDs[layer][l1_id]
    # mode_tsae = mode_tsaes[layer][l1_id]
    dead_feature = dead_features[layer][l1_id]
    frequency = frequencies[layer][l1_id]
    print(2048 - (frequency > 0).sum())

    plt.plot(sorted(list(frequency.log10().numpy())))
    # plt.vlines([lower_index,upper_index],-6,0, colors=['black'], linestyles='dashed')

# plt.axvline(lower_index,color='black', linestyle=':')
# plt.axvline(upper_index,color='black', linestyle=':')
plt.legend([f"Layer {l+1}" for l in cfg.layers])
# plt.title("Log10 of Feature Activations, Pythia 70m Layer 5")
plt.title("Empirical Feature Frequency, Pythia 70m")
plt.xlabel("Feature Index, Sorted")
plt.ylabel("Log10 Frequency")
plt.show()

In [None]:
# Pick just one to focus on
layer = 4
l1_id = 0 # corresponds to cfg.l1_alphas[l1_id]
autoTED = autoTEDs[layer][l1_id]
# mode_tsae = mode_tsaes[layer][l1_id]
dead_feature = dead_features[layer][l1_id]
frequency = frequencies[layer][l1_id]

sorted_acts = sorted((frequency.numpy().copy()))
print(sum(frequency> 1/100))


upper_index = 300

# lower_boundary = sorted_acts[lower_index]
upper_boundary = sorted_acts[upper_index]
vip_features = frequency > upper_boundary
# transition_features = (frequency > lower_boundary) * (frequency <= upper_boundary)
weak_features = frequency <= upper_boundary

In [None]:
strong_scales = []
weak_scales = []
all_scales = []
for layer in cfg.layers:
    l1_id = 0 # corresponds to cfg.l1_alphas[l1_id]
    autoTED = autoTEDs[layer][l1_id]
    # mode_tsae = mode_tsaes[layer][l1_id]
    dead_feature = dead_features[layer][l1_id]
    frequency = frequencies[layer][l1_id]
    sorted_acts = sorted((frequency.numpy().copy()))
    
    cutoff = sum(frequency< 1e-5)
    upper_boundary = sorted_acts[cutoff]
    vip_features = frequency > upper_boundary
    # transition_features = (frequency > lower_boundary) * (frequency <= upper_boundary)
    weak_features = (frequency > 0) * (frequency <= upper_boundary)
    
    reference_scales = autoTED.get_feature_scales().abs().detach().cpu()
    strong_scale = reference_scales[vip_features].mean()
    weak_scale = reference_scales[weak_features].mean()
    all_scale = reference_scales[frequency > 0].mean()
    
    strong_scales.append(strong_scale)
    weak_scales.append(weak_scale)
    all_scales.append(all_scale)
    print(f"Layer {layer} strong_scale {strong_scale:.3f} weak_scale {weak_scale:.3f} all_scale {all_scale:.3f}")
    

In [None]:
# plot autoTED scalings
reference_scales = autoTED.get_feature_scales().abs().detach().cpu()
list_scales = list((reference_scales).numpy())
act_strength = (dead_feature/(100_000_000*frequency))

additive_suppression = (autoTED.get_feature_scales().abs().detach().cpu()-1) * act_strength

x_axis = frequency
sorted_acts = sorted((x_axis.numpy().copy()))
upper_index = 300
upper_bound = sorted_acts[upper_index]

# plt.scatter(frequency.log10().numpy(), list_scales, s=5)
plt.scatter(x_axis[vip_features].log10().numpy(), reference_scales[vip_features], s=5)

# plt.vlines([np.log10(lower_boundary),np.log10(upper_boundary)],0,6, colors=['black'], linestyles='dashed')
# plt.axvline(np.log10(lower_boundary),color='black', linestyle=':')
plt.axvline(np.log10(upper_bound),color='black', linestyle=':')
plt.axhline(1, color='black', linestyle='--')
# plt.ylim(0.7,1.9)

plt.title("Feature Scaling, Pythia 70m Layer 5")
plt.xlabel("Log10 Frequency")
plt.ylabel("Feature Scaling")
plt.show()

In [None]:
# plot autoTED scalings
reference_scales = autoTED.get_feature_scales().abs().detach().cpu()
list_scales = list((reference_scales).numpy())
act_strength = (dead_feature/(100_000_000*frequency))

additive_suppression = (autoTED.get_feature_scales().abs().detach().cpu()-1) * act_strength

x_axis = frequency
sorted_acts = sorted((x_axis.numpy().copy()))
upper_index = 300
upper_bound = sorted_acts[upper_index]

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 12))

filter = vip_features

axes[0].scatter(act_strength[filter].log10().numpy(), reference_scales[filter], s=5, alpha=0.8)
axes[1].scatter(frequency[filter].log10().numpy(), reference_scales[filter], s=5, alpha=0.8)
axes[2].scatter(act_strength[filter].log10().numpy(), additive_suppression[filter], s=5, alpha=0.8)




# plt.scatter(frequency.log10().numpy(), list_scales, s=5)

# plt.vlines([np.log10(lower_boundary),np.log10(upper_boundary)],0,6, colors=['black'], linestyles='dashed')
# plt.axvline(np.log10(lower_boundary),color='black', linestyle=':')
# plt.axvline(np.log10(upper_bound),color='black', linestyle=':')
# plt.axhline(1, color='black', linestyle='--')
# plt.ylim(0.7,1.9)
# axes[1].set_title("Change in Average Activation Strength, Pythia 70m Layer 5")

axes[0].set_title("Scale vs Activation Strength")
axes[1].set_title("Scale vs Frequency")
axes[2].set_title("Activation Change vs Activation Strength")

axes[0].set_xlabel("Log10 Activation Strength")
axes[1].set_xlabel("Log10 Frequency")
axes[2].set_xlabel("Log10 Activation Strength")

axes[0].set_ylabel("Scale")
axes[1].set_ylabel("Scale")
axes[2].set_ylabel("Activation Change")

# plt.ylabel("Activation Change")
plt.show()

In [None]:
((reference_scales[vip_features] - 1)/(cfg.l1_alphas[0]/2/act_strength[vip_features])).mean()

In [None]:
def tmp(a, strength=1, c=1):
    return (strength-a*strength)**2/512 + c*a/2048

xseq = np.linspace(0, 1, num=1000)
plt.plot(xseq, [tmp(xval, strength=1, c=1) for xval in xseq])

Math for scaling:

(1 - cfg.l1_alphas[0]/2/s * 512) * s = act_strength
s - cfg.l1_alphas[0]/2 * 512 = act_strength
s = act_strength + cfg.l1_alphas[0]/2 * 512 = act_strength + 0.512
scale = 1 + 0.512/act_strength

In [None]:
# compute average collision
dictionary = autoTED.get_learned_dict().cpu().detach()[vip_features]
p = frequency[vip_features]


In [None]:
# (F Eye(-pi/(1-pi))  F^T - I) K = Sum[ -pi/(1-pi) * (cd/2) *(fi)]
# F = dictionary.T
p_frac = (-p/(1-p))
M = (dictionary.mT * p_frac) @ (dictionary)


In [None]:
(dictionary.mT * p_frac).shape

In [None]:
Before_K = M - torch.eye(M.shape[0])
K = torch.linalg.pinv(Before_K) @ ((2e-3*512/2) * dictionary.mT @ p_frac)

# ai = si + 1/(1-pi) * (-cd/2 + fi * K)
# si = ai - 1/(1-pi) * (-cd/2 + fi * K)
rhs = -1/(1-p) * (dictionary @ K - 2e-3 * 512/2)
expected_scale = 1 + rhs/act_strength[vip_features]/3

# dotprods = dictionary @ dictionary.mT
# n = dictionary.shape[0]
# epsilon = (dotprods.sum() - n)/(n*(n-1)) # remove diagonal and take mean
# epsilon

In [None]:
(1+epsilon*av_sparsity)

In [None]:
act_strength = (dead_feature/(100_000_000*frequency)).nan_to_num(0)
# suppression = 1 - cfg.l1_alphas[0]/2/act_strength * 512
# av_sparsity = 49.77802658
# av_e = 1 # average collision f_i*f_j
# expected_scale = 1 + 512 * cfg.l1_alphas[0]/2/act_strength/(1+epsilon*av_sparsity)
plt.scatter(expected_scale, reference_scales[vip_features], s=2)
plt.xlabel("Expected Scale")
plt.ylabel("Measured Scale")
plt.title("Measured vs Expected Scale")

In [None]:
style.available

In [None]:
style.use("ggplot")

plt.figure(dpi=200) 
plt.scatter(frequency[vip_features].log10(), 
            (dead_feature/(100_000_000*frequency))[vip_features].log10(), 
            c=reference_scales[vip_features],
            cmap='viridis')
cb = plt.colorbar().ax.set_ylabel('Scale Factor', rotation=270, labelpad=15)
# cb.set_label("Scaling", rotation=270)
plt.xlabel("Log10 Frequency")
plt.ylabel("Log10 Activation Strength")
plt.title("Feature Scaling")
plt.tight_layout()


# Compute Partial Correlations

In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
model = LinearRegression()

x1 = frequency[vip_features].log10() # frequency
x2 = (dead_feature/(100_000_000*frequency))[vip_features].log10() # activation strength
# y = additive_suppression[vip_features] # reference_scales[vip_features], additive_suppression[vip_features]
y = reference_scales[vip_features]

# Step 1: Control for X2 when examining the relationship between X1 and Y
# Regress Y on X2 to get the residuals of Y (controlling for X2)
model_y_on_x2 = LinearRegression().fit(x2.reshape(-1, 1), y)
residuals_y_x2 = y - model_y_on_x2.predict(x2.reshape(-1, 1))

# Regress X1 on X2 to get the residuals of X1 (controlling for X2)
model_x1_on_x2 = LinearRegression().fit(x2.reshape(-1, 1), x1)
residuals_x1_x2 = x1 - model_x1_on_x2.predict(x2.reshape(-1, 1))

# Step 2: Control for X1 when examining the relationship between X2 and Y (similar steps as above)
# Regress Y on X1 to get the residuals of Y (controlling for X1)
model_y_on_x1 = LinearRegression().fit(x1.reshape(-1, 1), y)
residuals_y_x1 = y - model_y_on_x1.predict(x1.reshape(-1, 1))

# Regress X2 on X1 to get the residuals of X2 (controlling for X1)
model_x2_on_x1 = LinearRegression().fit(x1.reshape(-1, 1), x2)
residuals_x2_x1 = x2 - model_x2_on_x1.predict(x1.reshape(-1, 1))

# Compute the correlation between residuals to get partial correlations
partial_corr_x1_y = np.corrcoef(residuals_x1_x2, residuals_y_x2)[0, 1]
partial_corr_x2_y = np.corrcoef(residuals_x2_x1, residuals_y_x1)[0, 1]

print(partial_corr_x1_y, partial_corr_x2_y)



from scipy.stats import pearsonr

# Assuming x1, x2, and y are defined as per your question

# Convert x1 and x2 to numpy arrays if they are not already
# It seems like they might be pandas Series given the context; ensure they are correctly defined before this step
# For example:
# x1 = np.array(x1)
# x2 = np.array(x2)
# y = np.array(y)

# Initialize the linear regression model
model = LinearRegression()

# Remove the effect of x2 from y
model.fit(x2.reshape(-1, 1), y)  # Fit model on x2 to predict y
y_res_x2 = y - model.predict(x2.reshape(-1, 1))  # Calculate residuals: part of y not explained by x2

# Compute the correlation between x1 and the residuals (effect of x2 removed from y)
corr_x1_y_given_x2, _ = pearsonr(x1, y_res_x2)
print(f"Partial correlation between x1 and y, given x2: {corr_x1_y_given_x2}")

# Now, do the reverse: Remove the effect of x1 from y
model = LinearRegression()
model.fit(x1.reshape(-1, 1), y)  # Fit model on x1 to predict y
y_res_x1 = y - model.predict(x1.reshape(-1, 1))  # Calculate residuals: part of y not explained by x1

# Compute the correlation between x2 and the residuals (effect of x1 removed from y)
corr_x2_y_given_x1, _ = pearsonr(x2, y_res_x1)
print(f"Partial correlation between x2 and y, given x1: {corr_x2_y_given_x1}")

# normal corrs:
print("normal corrs", pearsonr(x1, y)[0], pearsonr(x2, y)[0])

model = LinearRegression()
X = np.column_stack((x1, x2))
model.fit(X, y)

# Print the coefficients and intercept
print(f"Coefficient for x1: {model.coef_[0]}")
print(f"Coefficient for x2: {model.coef_[1]}")
print(f"Intercept: {model.intercept_}")


from sklearn.isotonic import IsotonicRegression

# Assuming x1, x2, and y are defined as per your question

# Isotonic regression for x1
iso_reg_x1 = IsotonicRegression().fit(x1, y)
y_pred_x1 = iso_reg_x1.predict(x1)

# Isotonic regression for x2
iso_reg_x2 = IsotonicRegression().fit(x2, y)
y_pred_x2 = iso_reg_x2.predict(x2)

# Since isotonic regression is for univariate cases, the above fits two separate models:
# f(x1) = y and g(x2) = y

mse_x1_isotonic = (y - y_pred_x1).pow(2).mean()
mse_x2_isotonic = (y - y_pred_x2).pow(2).mean()

print(f"Isotonic MSEs: {mse_x1_isotonic}, {mse_x2_isotonic}")

# If you want to assess the fit, you could compare the predicted y values (y_pred_x1, y_pred_x2) with the actual y values
# For example, using mean squared error (MSE) or another suitable metric





In [None]:

# BINNING
def binned_cor(x, y, binning_val, num_bins = 10, exclude_bottom=True):
    argsorted_ids = torch.argsort(binning_val)
    if exclude_bottom:
        argsorted_ids = argsorted_ids[len(argsorted_ids)%num_bins:]
    else:
        argsorted_ids = argsorted_ids[:-(len(argsorted_ids)%num_bins)]
    binned_ids = argsorted_ids.reshape((num_bins,-1))
    cors = []
    for bin in range(num_bins):
        bin_filter = [binned_ids[bin]]
        cor = pearsonr(x[bin_filter], y[bin_filter])
        cors.append(cor[0])
    
    return sum(cors)/len(cors)

In [None]:
binned_cor(x1, y, x2, num_bins = 50, exclude_bottom=True)

In [None]:
binned_cor(x2, y, x1, num_bins = 50, exclude_bottom=False)

In [None]:
style.use("ggplot")

plt.figure(dpi=1200) 
plt.scatter(frequency[vip_features].log10(), 
            (dead_feature/(100_000_000*frequency))[vip_features].log10(), 
            c=additive_suppression[vip_features], 
            vmin=0, 
            vmax=0.3,
            cmap='viridis')
cb = plt.colorbar().ax.set_ylabel('Activation Strength Difference', rotation=270, labelpad=15)
# cb.set_label("Scaling", rotation=270)
plt.xlabel("Log10 Frequency")
plt.ylabel("Log10 Activation Strength")
plt.title("Additive Increase in Activation Strength")
plt.tight_layout()


In [None]:
plt.scatter(frequency[vip_features].log10(), (dead_feature/(100_000_000*frequency))[vip_features].log10(), c=additive_suppression[vip_features], vmin=0, vmax=0.3)
plt.colorbar()
plt.xlabel("Log10 Frequency")
plt.ylabel("Log10 Activation Strength")
plt.title("Change in Activation Strength")


In [None]:
reference_scales[vip_features].mean()

In [None]:
reference_scales[weak_features].mean()

In [None]:
autoTED.get_learned_dict().shape

In [None]:
dead_feature

In [None]:
# plt.plot(reference_scales)
# mode_tsae = mode_tsaes[lay][l1]
# autoTED = autoTEDs[lay][l1]
# acts = dead_features[lay][l1]

scales = mode_tsae.get_feature_scales()
reference_scales = autoTED.get_feature_scales()
list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())

plt.scatter(dead_feature.log10(), list_scales, s=5)
plt.vlines([np.log10(lower_boundary),np.log10(upper_boundary)],0.7,400, colors=['black'], linestyles='dashed')
# plt.ylim(0.7,1.9)
plt.title("Adjusted Chess Feature Scaling, Pythia 70m Layer 5") # AutoTED Scaling, Pythia 70m Layer 5
plt.xlabel("Log10 Feature Activations")

In [None]:
filtering = vip_features # weak_features # transition_features # vip_features
scales = mode_tsae.get_feature_scales()[filtering]
reference_scales = autoTED.get_feature_scales()[filtering]
# feature_vectors = transfer_free.get_learned_dict()[live_features]
list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())
print(f"Mean scaling: {sum(list_scales)/len(list_scales)}")
plt.plot(sorted(list_scales))
plt.title("Adjusted Finetuned Feature Scaling, Pythia 70m Layer 5")
plt.xlabel("Strong Features")
# plt.plot(scales.detach().cpu().numpy())

In [None]:
plt.plot(sorted(list_scales))
plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM (Checkpoint {checkpoint})")
plt.xlabel("Feature")
plt.savefig(f"feature_scaling_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
dead_features[live_features]

In [None]:
for lay in range(len(cfg.layers)):
    for l1 in range(len(cfg.l1_alphas)):
        layer = cfg.layers[lay]
        l1_alpha = cfg.l1_alphas[l1]
        
        mode_tsae = mode_tsaes[lay][l1]
        autoTED = autoTEDs[lay][l1]
        acts = dead_features[lay][l1]
        
        scales = mode_tsae.get_feature_scales()
        reference_scales = autoTED.get_feature_scales()
        list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())

        plt.scatter(acts.log10(), list_scales, s=5)
        plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer}, L1 {l1_alpha})")
        plt.xlabel("Log10 of Total Feature Activations")
        plt.ylabel("Feature Scaling")
        plt.ylim(0,10)
        # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
lay = 3
# l1 = 12
for l1 in range(0,len(cfg.l1_alphas), 3):
    layer = cfg.layers[lay]
    l1_alpha = cfg.l1_alphas[l1]
    print(l1_alpha)

    mode_tsae = mode_tsaes[lay][l1]
    autoTED = autoTEDs[lay][l1]
    acts = dead_features[lay][l1]

    print(sum(acts < 10**6))
    print(sum(acts > 10**6))

    scales = mode_tsae.get_feature_scales()
    reference_scales = autoTED.get_feature_scales()
    list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())
    list_scales = list((reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((scales).abs().detach().cpu().numpy())

    plt.scatter(acts.log10(), list_scales, s=5)
    plt.title(f"AutoTED Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer})")
    plt.xlabel("Log10 of Total Feature Activations")
    plt.ylabel("Feature Scaling")
    plt.ylim(0,6)
    # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

In [None]:
lay = 4
# l1 = 12
for l1 in range(0,len(cfg.l1_alphas), 3):
    layer = cfg.layers[lay]
    l1_alpha = cfg.l1_alphas[l1]

    mode_tsae = mode_tsaes[lay][l1]
    autoTED = autoTEDs[lay][l1]
    acts = dead_features[lay][l1]
    print(sum(acts < 10**4))
    print(sum((acts > 10**4) * (acts < 10**6)))
    print(sum(acts > 10**6))

    scales = mode_tsae.get_feature_scales()
    reference_scales = autoTED.get_feature_scales()
    list_scales = list((scales/reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((reference_scales).abs().detach().cpu().numpy())
    # list_scales = list((scales).abs().detach().cpu().numpy())

    plt.scatter(acts.log10(), list_scales, s=2)
    plt.title(f"Feature Scaling from {capitalized_base_name} to {finetuning_type} 70m LLM vs Activations (Layer {layer})")
    plt.xlabel("Log10 of Total Feature Activations")
    plt.ylabel("Feature Scaling")
    plt.ylim(0,6)
    # plt.savefig(f"feature_scaling_vs_activations_{base_name}_{target_name}_{mode}_ckpt{checkpoint}_70m")

Outputting Most Scaled Features

In [None]:
# most scaled features
lay = 4
l1 = 0
layer = cfg.layers[lay]
l1_alpha = cfg.l1_alphas[l1]
filter = vip_features

print(f"Layer {lay} (0 indexed), l1 {l1_alpha}")

mode_tsae = mode_tsaes[lay][l1]
autoTED = autoTEDs[lay][l1]
acts = dead_features[lay][l1]

large_filter = vip_features

scales = mode_tsae.get_feature_scales()
reference_scales = autoTED.get_feature_scales()
list_scales = ((scales/reference_scales).abs().detach().cpu().numpy())

argsorted_scales = [id for id in np.argsort(list_scales) if filter[id]==True]  # filtered to only include features with large activations

print(f"Top 10 scaled are: {argsorted_scales[-5:][::-1]}")
print(f"Bottom 10 scaled are: {argsorted_scales[:5]}")

In [None]:
argsorted_scales[0]

In [None]:
large_filter[662]

In [None]:
list_scales[662]