In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")
import torch as th

device = "cpu"

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")
import torch as th
from pathlib import Path
import json
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

BASE_DIR = Path("..")

CHECKPOINT_DIR = BASE_DIR / "checkpoints" / "feature_scaler"
PLOTS_DIR = BASE_DIR / "plots"
device = "cuda" if th.cuda.is_available() else "cpu"
device

In [25]:
one_init_names = ["S4-L13-mu0.0e+00-lr1e-02", "S42-L13-mu0.0e+00-lr1e-02", "S666-L13-mu0.0e+00-lr1e-02"]
zero_init_names = ["S4-L13-mu0.0e+00-lr1e-02-ZeroInit", "S42-L13-mu0.0e+00-lr1e-02-ZeroInit", "S666-L13-mu0.0e+00-lr1e-02-ZeroInit"]

random_indices_names = ["RandomIndicesL13-mu0.0e+00-lr1e-02", "RandomIndicesS4-L13-mu0.0e+00-lr1e-02", "RandomIndicesS666-L13-mu0.0e+00-lr1e-02"]
random_sources_names = ["RandomSourceL13-mu0.0e+00-lr1e-02", "RandomSourceS4-L13-mu0.0e+00-lr1e-02", "RandomSourceS666-L13-mu0.0e+00-lr1e-02"]
full_scaler_name = "L13-mu0.0e+00-lr1e-02-full"

base_model_fve = 0.83579

def load_fve(name):
    with open(CHECKPOINT_DIR / name / "last_eval_logs.json", "r") as f:
        return json.load(f)["val/frac_variance_explained"]
one_init_fve = [load_fve(name) for name in one_init_names]
zero_init_fve = [load_fve(name) for name in zero_init_names]
random_indices_fve = [load_fve(name) for name in random_indices_names]
random_sources_fve = [load_fve(name) for name in random_sources_names]
full_scaler_fve = load_fve(full_scaler_name)

In [None]:
# bar plot with std for one_init
one_init_fve_mean = np.mean(one_init_fve)
one_init_fve_std = np.std(one_init_fve)
zero_init_fve_mean = np.mean(zero_init_fve)
zero_init_fve_std = np.std(zero_init_fve)
random_indices_fve_mean = np.mean(random_indices_fve)
random_indices_fve_std = np.std(random_indices_fve)
random_sources_fve_mean = np.mean(random_sources_fve)
random_sources_fve_std = np.std(random_sources_fve)

# Calculate relative improvements
rel_one_init = (one_init_fve_mean - base_model_fve) / base_model_fve * 100
rel_zero_init = (zero_init_fve_mean - base_model_fve) / base_model_fve * 100
rel_random_indices = (random_indices_fve_mean - base_model_fve) / base_model_fve * 100
rel_random_sources = (random_sources_fve_mean - base_model_fve) / base_model_fve * 100
rel_full_scaler = (full_scaler_fve - base_model_fve) / base_model_fve * 100

# Convert std to relative
rel_one_init_std = one_init_fve_std / base_model_fve * 100
rel_zero_init_std = zero_init_fve_std / base_model_fve * 100
rel_random_indices_std = random_indices_fve_std / base_model_fve * 100
rel_random_sources_std = random_sources_fve_std / base_model_fve * 100

fig = px.bar(
    x=["S_I 1-init", "S_I Zero Init", "Random Set <br> 1-init", "Random Vectors <br> 1-init"],
    y=[rel_one_init, rel_zero_init, rel_random_indices, rel_random_sources],
    title="Relative FVE Improvement Over CrossCoder (%)",
    error_y=[rel_one_init_std, rel_zero_init_std, rel_random_indices_std, rel_random_sources_std],
)

# horizontal dashed blue line at full scaler improvement
fig.add_hline(y=rel_full_scaler, line_dash="dash", line_color="blue")
fig.update_traces(texttemplate='%{y:.2f}%', textposition='auto')

# add annotation for full scaler line
fig.add_annotation(
    x=2,
    y=rel_full_scaler,
    text="Full Scaler Improvement",
    font=dict(size=16, color="blue")
)
# Add scatter points for individual runs
fig.add_trace(go.Scatter(
    x=['S_I 1-init'] * len(one_init_fve),
    y=[(fve - base_model_fve) / base_model_fve * 100 for fve in one_init_fve],
    mode='markers',
    marker=dict(color='black', size=8),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=['S_I Zero Init'] * len(zero_init_fve), 
    y=[(fve - base_model_fve) / base_model_fve * 100 for fve in zero_init_fve],
    mode='markers',
    marker=dict(color='black', size=8),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=['Random Set <br> 1-init'] * len(random_indices_fve),
    y=[(fve - base_model_fve) / base_model_fve * 100 for fve in random_indices_fve],
    mode='markers', 
    marker=dict(color='black', size=8),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=['Random Vectors <br> 1-init'] * len(random_sources_fve),
    y=[(fve - base_model_fve) / base_model_fve * 100 for fve in random_sources_fve],
    mode='markers',
    marker=dict(color='black', size=8),
    showlegend=False
))

# text annotations
# font size of 20
fig.update_layout(font=dict(size=20), width=800, height=600)
# axis labels
fig.update_yaxes(title="Relative FVE Improvement (%)", title_font_size=20)
# remove x axis title
fig.update_xaxes(title="", title_font_size=20)
min_y = min(rel_one_init, rel_zero_init, rel_random_indices, rel_random_sources, rel_full_scaler)
max_y = max(rel_one_init, rel_zero_init, rel_random_indices, rel_random_sources, rel_full_scaler)
fig.update_layout(yaxis_range=[min_y-0.1, max_y+0.1])
fig.show()
fig.write_image(PLOTS_DIR / "scaler_fve_improvement.png")


In [30]:
def act_func(x):
    return th.nn.functional.elu(x) + 1


In [31]:
one_init_scalers = []
for name in one_init_names:
    scaler = th.load(CHECKPOINT_DIR / name / "scaler_0_1.pt", map_location="cpu")
    one_init_scalers.append(act_func(scaler["scaler"]).cpu().numpy())
one_init_scalers = np.stack(one_init_scalers)

# Load all zero init scalers 
zero_init_scalers = []
for name in zero_init_names:
    scaler = th.load(CHECKPOINT_DIR / name / "scaler_0_1.pt", map_location="cpu")
    zero_init_scalers.append(act_func(scaler["scaler"]).cpu().numpy())
zero_init_scalers = np.stack(zero_init_scalers)

# Load Random Indices and Random Sources scalers
random_indices_scalers = []
for name in random_indices_names:
    scaler = th.load(CHECKPOINT_DIR / name / "scaler_0_1.pt", map_location="cpu")
    random_indices_scalers.append(act_func(scaler["scaler"]).cpu().numpy())
random_indices_scalers = np.stack(random_indices_scalers)

random_sources_scalers = []
for name in random_sources_names:
    scaler = th.load(CHECKPOINT_DIR / name / "scaler_0_1.pt", map_location="cpu")
    random_sources_scalers.append(act_func(scaler["scaler"]).cpu().numpy())
random_sources_scalers = np.stack(random_sources_scalers)


In [None]:
# Load all one init scalers
threshold = 5e-2

thres_one_init = (one_init_scalers < threshold).mean(axis=1)
thres_zero_init = (zero_init_scalers < threshold).mean(axis=1)
thres_random_indices = (random_indices_scalers < threshold).mean(axis=1)
thres_random_sources = (random_sources_scalers < threshold).mean(axis=1)


one_init_mean = thres_one_init.mean() * 100
zero_init_mean = thres_zero_init.mean() * 100
random_indices_mean = thres_random_indices.mean() * 100
random_sources_mean = thres_random_sources.mean() * 100

one_init_std = thres_one_init.std() * 100
zero_init_std = thres_zero_init.std() * 100
random_indices_std = thres_random_indices.std() * 100
random_sources_std = thres_random_sources.std() * 100

# Create bar plot
fig = px.bar(
    x=["S_I 1-init", "S_I 0-init", "Random Set <br> 1-init", "Random Vectors <br> 1-init"],
    y=[one_init_mean, zero_init_mean, random_indices_mean, random_sources_mean],
    error_y=[one_init_std, zero_init_std, random_indices_std, random_sources_std],
    title=f"Percentage of Dead Feature Scalars (beta_i < {threshold:.1})",
)

# Add text annotations
fig.update_traces(texttemplate='%{y:.1f}%', textposition='auto')

# Update layout
fig.update_layout(
    font=dict(size=20),
    width=800,
    height=600,
    yaxis_title="Percentage of Dead Scalars",
    xaxis_title=""
)

# Add min/max annotations for each bar
y_positions = [one_init_mean, zero_init_mean, random_indices_mean, random_sources_mean]
stds = [one_init_std, zero_init_std, random_indices_std, random_sources_std]
x_positions = ["1-init", "0-init", "Random Set <br> 1-init", "Random Vectors <br> 1-init"]


# Set y-axis range from 0 to slightly above max
max_y = max(one_init_mean + one_init_std, 
            zero_init_mean + zero_init_std,
            random_indices_mean + random_indices_std,
            random_sources_mean + random_sources_std)
fig.update_layout(yaxis_range=[0, max_y * 1.1])

fig.show()
fig.write_image(PLOTS_DIR / "scaler_dead_scalars.png")

In [None]:
from plotly.subplots import make_subplots
# Create subplots in a 2x3 grid
fig = make_subplots(
    rows=2, cols=3,
    subplot_titles=[
        # First row: One init seed comparisons
        f'1-init Seeds {i+1} vs {j+1}' 
        for i in range(3) 
        for j in range(i+1,3)
    ] + [
        # First row: One init seed comparisons
        f'1-init Seed {i+1} vs 0-init Seed {j+1}' 
        for i in range(3) 
        for j in range(i+1,3)
    ],
    vertical_spacing=0.1
)

# First row: One init seed comparisons
pairs = [(i,j) for i in range(3) for j in range(i+1,3)]
for idx, (i,j) in enumerate(pairs):
    # Create 2D histogram data for this pair
    x = np.array(one_init_scalers[i]).flatten() # Ensure 1D array
    y = np.array(one_init_scalers[j]).flatten() # Ensure 1D array
    hist2d, x_edges, y_edges = np.histogram2d(
        x,
        y,
        bins=30
    )
    
    # Add heatmap to subplot
    fig.add_trace(
        go.Heatmap(
            x=x_edges[:-1],
            y=y_edges[:-1], 
            z=np.log1p(hist2d.T),
            colorscale='Viridis',
            colorbar=dict(title='Log Count'),
            showscale=(idx==2) # only show colorbar for rightmost plot
        ),
        row=1, col=idx+1
    )

# Second row: One init vs Two init comparisons
for idx, (i,j) in enumerate(pairs):
    x = np.array(one_init_scalers[i]).flatten()
    y = np.array(zero_init_scalers[j]).flatten()
    hist2d, x_edges, y_edges = np.histogram2d(
        x,
        y,
        bins=30
    )
    
    fig.add_trace(
        go.Heatmap(
            x=x_edges[:-1],
            y=y_edges[:-1], 
            z=np.log1p(hist2d.T),
            colorscale='Viridis',
            colorbar=dict(title='Log Count'),
            showscale=(idx==2) # only show colorbar for rightmost plot
        ),
        row=2, col=idx+1
    )

# Update layout
fig.update_layout(
    title='<b>Are the same features "dying"?</b> <br>2D Distributions of Feature Scaling Values',
    width=1200,
    height=800,
    showlegend=False,
)

fig.write_image(PLOTS_DIR / "scaler_2d_histo.png")
fig.show()

# Compute refined feature indices

In [34]:
from tools.feature_utils import mask_to_indices

In [None]:
# compute indices of dead features
all_scalers = np.concatenate([one_init_scalers, zero_init_scalers])
indices = (all_scalers < 1e-4).sum(axis=0) == all_scalers.shape[0]
# indices = mask_to_indices(th.tensor(indices))
indices.sum()

In [113]:
# dump indices to json
with open(PLOTS_DIR / "dead_feature_indices_1e-4.json", "w") as f:
    json.dump(mask_to_indices(th.tensor(indices)), f)

In [None]:
# compute indices of dead features
all_scalers = np.concatenate([one_init_scalers, zero_init_scalers])
indices = (all_scalers < 1e-3).sum(axis=0) == all_scalers.shape[0]
# indices = mask_to_indices(th.tensor(indices))
indices.sum()

In [115]:
# dump indices to json
with open(PLOTS_DIR / "dead_feature_indices_1e-3.json", "w") as f:
    json.dump(mask_to_indices(th.tensor(indices)), f)

In [116]:
# compute indices of dead features
all_scalers = np.concatenate([one_init_scalers, zero_init_scalers])
indices = (all_scalers < 1e-2).sum(axis=0) == all_scalers.shape[0]
# indices = mask_to_indices(th.tensor(indices))
indices.sum()
# dump indices to json
with open(PLOTS_DIR / "dead_feature_indices_1e-2.json", "w") as f:
    json.dump(mask_to_indices(th.tensor(indices)), f)


In [None]:
# Create scatter plot
fig = go.Figure(
    data=[
        go.Scatter(
            x=scaled_values_one,
            y=scaled_values_zero,
            mode="markers",
            marker=dict(color="blue", opacity=0.5),
        )
    ]
)

# Update layout
fig.update_layout(
    title="Scatter Plot of Feature Scaling Values",
    xaxis_title="One Init Scaling Factor",
    yaxis_title="Zero Init Scaling Factor",
)

fig.show()