# Timing Plots

### Get SAE Training Timing
You will need to retrieve the SAE training time and loss data from WandB. Enter the entity, project, run_id, model_type, and sae_type below.



In [None]:
import wandb
import json
import matplotlib.pyplot as plt

api = wandb.Api()

entity = "" # WandB Entity
project = "" # WandB Project Name
run_id = "" # WandB Run ID
model_type = "" # either "gemma" or "llama"
sae_type = "" # either "normal" or "e2e"


run = api.run(f"{entity}/{project}/{run_id}")
metric_names = ["training_minutes_between_evals", "val_loss"]
metric_history = run.history(keys=metric_names)

time_between = metric_history["training_minutes_between_evals"].tolist()
val_loss = metric_history["val_loss"].tolist()

cumulative_time = []
total = 0
for t in time_between:
    total += t
    cumulative_time.append(total)

MODEL_NAME = "gemma-2-2b" if model_type == "gemma" else "Llama-3.2-1b"

with open(f'data/TopK/time/{MODEL_NAME}/{model_type}_{sae_type}_SAE.json', 'w') as f:
    json.dump(cumulative_time, f, indent=4)

with open(f'data/TopK/val_loss/{MODEL_NAME}/{model_type}_{sae_type}_SAE.json', 'w') as f:
    json.dump(val_loss, f, indent=4)


### Get LoRA Training Timing
You will need to supply the exact path to the peft times and val_losses depending on what experiment parameters you used during LoRA training.

In [None]:
local_SAE = {} # dict[time: loss]
local_val_times = []
local_val_losses = []

model_type = "" # either "gemma" or "llama"
MODEL_NAME = "gemma-2-2b" if model_type == "gemma" else "Llama-3.2-1b"

with open(f"data/TopK/time/{MODEL_NAME}/{model_type}_normal_SAE.json", 'r') as f:
    local_val_times = json.load(f)

with open(f"data/TopK/val_loss/{MODEL_NAME}/{model_type}_normal_SAE.json", 'r') as f:
    local_val_losses = json.load(f)

for time, loss in zip(local_val_times, local_val_losses):
    local_SAE[time/60] = loss # Convert to hours

kl_e2e_SAE = {} # dict[time: loss]
kl_e2e_times = []
kl_e2e_losses = []

with open(f"data/TopK/time/{MODEL_NAME}/{model_type}_KL_e2e_SAE.json", 'r') as f:
    kl_e2e_times = json.load(f)

with open(f"data/TopK/val_loss/{MODEL_NAME}/{model_type}_KL_e2e_SAE.json", 'r') as f:
    kl_e2e_losses = json.load(f)

for time, loss in zip(kl_e2e_times, kl_e2e_losses):
    kl_e2e_SAE[time/60] = loss

print("Local SAE:")
for time, loss in local_SAE.items()[:10]:
    print(f"Time: {time:.4f} hours, Loss: {loss:.4f}")

print("\nKL E2E SAE:")
for time, loss in kl_e2e_SAE.items()[:10]:
    print(f"Time: {time:.4f} hours, Loss: {loss:.4f}")

peft_data = {}
pct_steps = range(10, 101, 10)
pcts = range(10, 101, 10)

for pct, step in zip(pcts, pct_steps):
    peft_data[pct] = {}
    sae_time = list(local_SAE.keys())[step]

    time_path = "" # path to peft times - e.g. "data/TopK/time/gemma-2-2b/expansion_8_L0_64/peft_0-25_rank_64_time_15k.json"
    val_path = "" # path to peft val_losses

    with open(time_path, 'r') as f:
        peft_times = json.load(f)
    with open(val_path, 'r') as f:
        peft_losses = json.load(f)
    for time, loss in zip(peft_times, peft_losses):
        peft_data[pct][sae_time + time/60] = loss


### Plotting Time Graph

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# Create figure and axis objects explicitly
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.size': 5,
    'axes.labelsize': 7.25,
    'axes.titlesize': 7.25,
    'xtick.labelsize': 6,
    'ytick.labelsize': 6,
    'legend.fontsize': 5.5,
    'figure.dpi': 300
})

fig, ax = plt.subplots(figsize=(3.25, 2))

# Plot local SAE
ax.plot(
    list(local_SAE.keys()), list(local_SAE.values()),
    label='TopK SAE', color='blue', linewidth=0.6
)

# Plot KL E2E SAE
ax.plot(
    list(kl_e2e_SAE.keys()), list(kl_e2e_SAE.values()),
    label='e2e SAE', color='crimson', linewidth=0.6
)

# Collect final values for PEFT curves
final_times = []
final_losses = []

# Loop through PEFT data to plot curves and connect them to SAE points
for i, (pct, step) in enumerate(zip(list(peft_data.keys()), pct_steps)):
    peft_times = list(peft_data[pct].keys())
    peft_losses = list(peft_data[pct].values())
    
    sae_time = list(local_SAE.keys())[step]
    sae_loss = local_SAE[sae_time]
    
    if sae_loss is not None:
        ax.plot(
            [sae_time, peft_times[1]], [sae_loss, peft_losses[1]],
            color='green', linestyle='-', alpha=1, linewidth=0.5
        )
    
    ax.plot(
        peft_times[1::2], peft_losses[1::2],
        color='green', linestyle='-', linewidth=0.5, alpha=1
    )
    
    final_times.append(peft_times[-1])
    final_losses.append(peft_losses[-1])

# Plot the curve connecting final values
ax.plot(
    final_times, final_losses, '--', color='darkgreen',
    linewidth=0.5, alpha=1, label='TopK + LoRA'
)

handles, labels = ax.get_legend_handles_labels()

# Add horizontal line for model loss with text annotation
model_loss = 2.5481 if model_type == "llama" else 2.476
model_loss_color = 'dimgray'
ax.axhline(y=model_loss, color=model_loss_color, linestyle='--', linewidth=0.5, alpha=1)
ax.text(6, model_loss + 0.003, f'Original Model Loss (best achievable) = {model_loss:.4f}', fontsize=5, va='bottom', color=model_loss_color)

# Axis labels and title
ax.set_xlabel('Training Time (hours)', fontsize=7.5)
ax.set_ylabel('CE Loss', fontsize=7.5)

# Adjust the y-axis range
lower_lim = 2.465 if model_type == "gemma" else 2.52
upper_lim = 2.70 if model_type == "gemma" else 3.0
ax.set_ylim(lower_lim, upper_lim)

# Legend placement
ax.legend(
    handles,
    labels,
    loc='upper right',
    fontsize=5.5,
    frameon=True,
    edgecolor='gray',  # Edge color of the legend box
    fancybox=False,     # Ensures rectangular legend box
    framealpha=1,       # Makes the edge fully opaque
    borderpad=0.3,      # Reduces padding within the legend box
    labelspacing=0.2,   # Reduces spacing between legend entries
    handlelength=1.25,  # Increases the length of the legend handle
).get_frame().set_linewidth(0.3)  # Sets the legend border's linewidth

# Grid and axis limits
ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.7)
ax.set_xlim(left=0)

# Customize tick labels with reduced padding
ax.tick_params(labelsize=5.5, width=0.5, pad=1)  # Reduced padding here

for spine in ax.spines.values():
    spine.set_linewidth(0.5)  # Adjust the value to make it thinner
# Finalize layout with minimal padding
plt.tight_layout(pad=0.1)

# Save with minimal borders and high dpi for quality
plt.savefig(f'plots/{model_type}_timing_plot.pdf', 
            bbox_inches='tight',
            pad_inches=0.01)
plt.show()


# Scaling Law Plots

### Sparsity, Width, and Layers Frontier
Fill in NUM_TRAINING depending on how many training examples you low rank adapted for.

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Load base data
BASE_2B = 1.8913393317646476
BASE_9B = 1.7262719157499813
BASE_27B = 1.5732088888573061

NUM_TRAINING = 15 # 15, 30 or 100

def abs_diff(L_base, L0_sae, L_sae):
    return L0_sae - L_sae

def percent_diff(L_base, L0_sae, L_sae):
    return (L0_sae - L_sae) / (L0_sae - L_base)

# Different Sparsity
sparsity_SAEs = {
    22: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_22",
    41: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_41",
    82: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_82",
    176: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_176",
    445: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_445"
}
sparsities = [22, 41, 82, 176, 445]

diff_sparsity_data = {}
for sparsity, sae_path in sparsity_SAEs.items():
    if sae_path is None:
        continue
    diff_sparsity_data[sparsity] = {}
    for rank in [1, 4, 16, 64, 256]:
        with open(f"{sae_path}/peft_13-25_rank_{rank}_CE_increase_{NUM_TRAINING}k.json", 'r') as f:
            sae_data = json.load(f)
        
        L_base = BASE_2B
        L0_sae = sae_data["initial"]
        L_sae = sae_data["converged"]
        diff_sparsity_data[sparsity][rank] = {
            "abs_diff": abs_diff(L_base, L0_sae, L_sae),
            "percent_diff": percent_diff(L_base, L0_sae, L_sae)
        }

# Different Widths
width_SAEs = {
    16: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_82",
    32: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_32k_average_l0_76", 
    65: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_65k_average_l0_72",
    131: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_131k_average_l0_67",
    262: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_262k_average_l0_67",
    524: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_524k_average_l0_65"
}

diff_width_data = {}
for width, sae_path in width_SAEs.items():
    if sae_path is None:
        continue
    diff_width_data[width] = {}
    for rank in [1, 4, 16, 64, 256]:
        with open(f"{sae_path}/peft_13-25_rank_{rank}_CE_increase_{NUM_TRAINING}k.json", 'r') as f:
            width_data = json.load(f)
        
        L_base = BASE_2B
        L0_sae = width_data["initial"]
        L_sae = width_data["converged"]
        diff_width_data[width][rank] = {
            "abs_diff": abs_diff(L_base, L0_sae, L_sae),
            "percent_diff": percent_diff(L_base, L0_sae, L_sae)
        }


# Different Layers
layer_SAEs = {
    6: f"data/scaling/CE_increase/gemma-2-2b/layer_6_width_16k_average_l0_70",
    9: f"data/scaling/CE_increase/gemma-2-2b/layer_9_width_16k_average_l0_73",
    12: f"data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_82",
    15: f"data/scaling/CE_increase/gemma-2-2b/layer_15_width_16k_average_l0_78",
    18: f"data/scaling/CE_increase/gemma-2-2b/layer_18_width_16k_average_l0_74"
}

diff_layer_data = {}
for layer, sae_path in layer_SAEs.items():
    if sae_path is None:
        continue
    diff_layer_data[layer] = {}
    for rank in [1, 4, 16, 64, 256]:
        with open(f"{sae_path}/peft_{layer+1}-25_rank_{rank}_CE_increase_{NUM_TRAINING}k.json", 'r') as f:
            layer_data = json.load(f)
        
        L_base = BASE_2B
        L0_sae = layer_data["initial"]
        L_sae = layer_data["converged"]
        diff_layer_data[layer][rank] = {
            "abs_diff": abs_diff(L_base, L0_sae, L_sae),
            "percent_diff": percent_diff(L_base, L0_sae, L_sae)
        }


# Set style parameters
plt.style.use('seaborn-v0_8-paper')
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.size': 5,
    'axes.labelsize': 9.5,
    'axes.titlesize': 9.5,
    'xtick.labelsize': 7.5,
    'ytick.labelsize': 7.5,
    'legend.fontsize': 6.25,
    'figure.figsize': (6.5, 2.75), # Fixed width of 5.5
    'figure.dpi': 300
})

# Create 2x3 subplot figure with minimal spacing
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(
    2, 3,
    gridspec_kw={'wspace': 0.04, 'hspace': 0.06},  # Remove horizontal and vertical spacing
)

# Plot rank-sparsity data (first column)
colors_sparsity = plt.cm.plasma(np.linspace(0.1, 0.8, 5))
for i, rank in enumerate([1, 4, 16, 64, 256]):
    abs_diffs = [diff_sparsity_data[sparsity][rank]["abs_diff"] for sparsity in sparsities]
    pct_diffs = [100*diff_sparsity_data[sparsity][rank]["percent_diff"] for sparsity in sparsities]
    
    ax1.plot(sparsities, abs_diffs, marker='o', label=f'Rank {rank}',
             color=colors_sparsity[i], linewidth=0.8, markersize=3)
    ax4.plot(sparsities, pct_diffs, marker='o', label=f'Rank {rank}',
             color=colors_sparsity[i], linewidth=0.8, markersize=3)

# Plot width data (middle column)
colors_width = plt.cm.plasma(np.linspace(0.1, 0.8, 5))
widths = list(diff_width_data.keys())
for i, rank in enumerate([1, 4, 16, 64, 256]):
    abs_diffs = [diff_width_data[width][rank]["abs_diff"] for width in widths]
    pct_diffs = [100*diff_width_data[width][rank]["percent_diff"] for width in widths]
    
    ax2.plot(widths, abs_diffs, marker='o', color=colors_width[i], linewidth=0.8,
             markersize=3, label=f'Rank {rank}')
    ax5.plot(widths, pct_diffs, marker='o', color=colors_width[i], linewidth=0.8,
             markersize=3, label=f'Rank {rank}')

# Plot layer data (third column)
colors_layer = plt.cm.plasma(np.linspace(0.1, 0.8, 5))
layers = list(diff_layer_data.keys())
for i, rank in enumerate([1, 4, 16, 64, 256]):
    abs_diffs = [diff_layer_data[layer][rank]["abs_diff"] for layer in layers]
    pct_diffs = [100*diff_layer_data[layer][rank]["percent_diff"] for layer in layers]
    
    ax3.plot(layers, abs_diffs, marker='o', color=colors_layer[i], linewidth=0.8,
             markersize=3, label=f'Rank {rank}')
    ax6.plot(layers, pct_diffs, marker='o', color=colors_layer[i], linewidth=0.8,
             markersize=3, label=f'Rank {rank}')

LABELPAD = 4
ax1.set_ylabel('CE Improvement', labelpad=LABELPAD)  # Increased labelpad for alignment
ax4.set_ylabel('CE Improvement (%)', labelpad=11.6)  # Adjusted labelpad for alignment
ax4.set_xlabel('Sparsity (L0)', labelpad=LABELPAD)
ax5.set_xlabel('Width (k)', labelpad=LABELPAD)
ax6.set_xlabel('Layer', labelpad=LABELPAD)

# Set x-axis scales and limits
width_margin = (max(widths) - min(widths)) * 0.08
ax2.set_xlim([0, max(widths)+width_margin])
ax5.set_xlim([0, max(widths)+width_margin])
layer_margin = 1.0
ax3.set_xlim([min(layers)-layer_margin, max(layers)+layer_margin])
ax6.set_xlim([min(layers)-layer_margin, max(layers)+layer_margin])

# Set x-axis limits for sparsity plots to start at 0
ax1.set_xlim([0, max(sparsities)+30])
ax4.set_xlim([0, max(sparsities)+30])

# Get y-axis limits for each row and set them to match
y_top = [ax.get_ylim() for ax in [ax1, ax2, ax3]]
y_lim_top = (min(y[0] for y in y_top), max(y[1] for y in y_top))
for ax in [ax1, ax2, ax3]:
    ax.set_ylim(y_lim_top)

y_bottom = [ax.get_ylim() for ax in [ax4, ax5, ax6]]
y_lim_bottom = (min(y[0] for y in y_bottom), max(y[1] for y in y_bottom))
for ax in [ax4, ax5, ax6]:
    ax.set_ylim(y_lim_bottom)

# Add grid and legends
for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    # Add more gridlines without labels
    ax.grid(True, linestyle=':', alpha=0.65, linewidth=0.35)
    ax.grid(True, which='minor', linestyle=':', alpha=0.65, linewidth=0.35)
    ax.grid(True, which='major', linestyle=':', alpha=0.65, linewidth=0.35)
    
    if ax in [ax1, ax2]:
        ax.legend(frameon=True, fancybox=True, shadow=False,
                 loc='upper right', bbox_to_anchor=(0.98, 0.98))
    if ax in [ax3]:
        ax.legend(frameon=True, fancybox=True, shadow=False,
                 loc='upper left', bbox_to_anchor=(0.02, 0.98))

# Remove unnecessary tick labels and ticks
ax1.set_xticklabels([])
ax2.set_xticklabels([])
ax3.set_xticklabels([])
for ax in [ax2, ax3, ax5, ax6]:
    ax.set_yticklabels([])
    ax.tick_params(axis='y', length=0)  # Remove y-axis ticks

# Set custom x-axis ticks for width plots
ax2.set_xticks(np.arange(0, 280, 100))  # Major ticks every 100
ax2.set_xticks(np.arange(0, 280, 50), minor=True)  # Minor ticks every 50
ax5.set_xticks(np.arange(0, 550, 100))  # Major ticks every 100
# ax5.set_xticks(np.arange(0, 550, 50), minor=True)  # Minor ticks every 50

# Set custom x-axis ticks for layer plots to show every 3rd layer
ax3.set_xticks(np.arange(int(min(layers)), int(max(layers))+1, 3))
ax6.set_xticks(np.arange(int(min(layers)), int(max(layers))+1, 3))

# Set spine linewidth for all axes
for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    for spine in ax.spines.values():
        spine.set_linewidth(0.3)

# Set tick parameters for ax1 (no x ticks)
ax1.tick_params(axis='x', which='major', length=0, width=0.35, pad=1)
ax1.tick_params(axis='x', which='minor', length=0, width=0.3, pad=1)
ax1.tick_params(axis='y', which='major', length=4, width=0.35, pad=1)

# Set tick parameters for ax4 (normal ticks)
ax4.tick_params(axis='x', which='major', length=4, width=0.35, pad=1)
ax4.tick_params(axis='x', which='minor', length=2.5, width=0.3, pad=1)
ax4.tick_params(axis='y', which='major', length=4, width=0.35, pad=1)

# Set tick parameters for ax2 and ax3 (no x and y ticks)
for ax in [ax2, ax3]:
    ax.tick_params(axis='x', which='major', length=0, width=0.35, pad=1)
    ax.tick_params(axis='x', which='minor', length=0, width=0.3, pad=1)
    ax.tick_params(axis='y', which='major', length=0, width=0.35, pad=1)

# Set tick parameters for ax5 and ax6 (no y ticks)
for ax in [ax5, ax6]:
    ax.tick_params(axis='x', which='major', length=4, width=0.35, pad=1)
    ax.tick_params(axis='x', which='minor', length=2.5, width=0.3, pad=1)
    ax.tick_params(axis='y', which='major', length=0, width=0.35, pad=1)

# Adjust layout and display with reduced spacing between subplots
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01, wspace=-0.15, hspace=-0.15)
plt.savefig('plots/combined_scaling_laws.pdf', bbox_inches='tight', pad_inches=0.025, dpi=300)
plt.show()


### Model Size Scaling

In [None]:
# Create figure with 2x1 subplots
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.size': 5,
    'axes.labelsize': 6,
    'axes.titlesize': 6,
    'xtick.labelsize': 5.5,
    'ytick.labelsize': 5.5,
    'legend.fontsize': 4.9,
    'figure.dpi': 300
})

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.25, 2.25))

# Load and process data for different model sizes
models = ["2b", "9b", "27b"]
model_sizes = [2, 9, 27]  # Actual sizes in billions
layers = [12, 20, 22]
sparsities = [67, 62, 82]
widths = [131, 131, 131]
total_layers = [26, 42, 46]

colors = plt.cm.magma(np.linspace(0.15, 0.82, 5))
ranks = [1, 4, 16, 64, 256]

for rank, color in zip(ranks, colors):
    abs_increases = []
    pct_increases = []
    
    for model, layer, L0, width, total_layer in zip(models, layers, sparsities, widths, total_layers):
        folder_path = f"data/scaling/CE_increase/gemma-2-{model}/layer_{layer}_width_{width}k_average_l0_{L0}"
        filename = f"{folder_path}/peft_{layer+1}-{total_layer-1}_rank_{rank}_CE_increase_{NUM_TRAINING}k.json"
        with open(filename, 'r') as f:
            data = json.load(f)
            if model == "2b":
                base_val = BASE_2B
            elif model == "9b":
                base_val = BASE_9B
            elif model == "27b":
                base_val = BASE_27B

            abs_increase = abs_diff(base_val, data["initial"], data["converged"])
            pct_increase = percent_diff(base_val, data["initial"], data["converged"])
            abs_increases.append(abs_increase)
            pct_increases.append(100*pct_increase)

    # Plot absolute CE increase vs model size
    ax1.plot(model_sizes, abs_increases, 'o-', 
             label=f'Rank {rank}', color=color, linewidth=1, markersize=3)

    # Plot percent CE increase vs model size  
    ax2.plot(model_sizes, pct_increases, 'o-',
             label=f'Rank {rank}', color=color, linewidth=1, markersize=3)

# Configure axes
ax1.set_ylabel('CE Improvement')
ax2.set_ylabel('CE Improvement (%)', labelpad=9.75)
ax2.set_xlabel('Gemma Size')

for ax in [ax1, ax2]:
    ax.grid(False)  # Turn off default grid
    # Add vertical grid lines at model sizes
    for size in model_sizes:
        ax.axvline(x=size, color='gray', linestyle=':', alpha=0.65, linewidth=0.35)
    
    # Add horizontal grid lines at y-axis ticks
    yticks = ax.get_yticks()
    for y in yticks:
        ax.axhline(y=y, color='gray', linestyle=':', alpha=0.65, linewidth=0.35)
    
    # ax.set_xscale('log')  # Set x-axis to log scale for both plots
    
    if ax == ax1:
        ax.set_xticks([])  # Set ticks at model sizes
        ax.set_xticklabels([])  # Remove x-tick labels for first subplot
    else:
        ax.legend(frameon=True, fancybox=True, shadow=False,
             loc='lower right', bbox_to_anchor=(1.01, -0.03))
        ax.set_xticks(model_sizes)  # Set ticks at model sizes
        ax.set_xticklabels([f'{size}B' for size in model_sizes])
    for spine in ax.spines.values():
        spine.set_linewidth(0.3)
    ax.tick_params(axis='x', which='major', length=4, width=0.35, pad=1)
    ax.tick_params(axis='y', which='major', length=4, width=0.35, pad=1)

# Adjust layout and display
plt.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.1, hspace=0.07)  # Reduced hspace
plt.savefig('plots/model_size_scaling.pdf', bbox_inches='tight', pad_inches=0.025, dpi=300)
plt.show()


### One layer at a time

In [None]:
# Create figure with 1x1 subplot with twin y-axis
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.size': 5,
    'axes.labelsize': 7.25,
    'axes.titlesize': 7.25,
    'xtick.labelsize': 6,
    'ytick.labelsize': 6,
    'legend.fontsize': 5.5,
    'figure.dpi': 300
})

fig, ax1 = plt.subplots(figsize=(3.25, 1.15))
ax2 = ax1.twinx()  # Create twin axis

# Load and process data for different layers and ranks
layers = range(1, 13)  # Adjust based on model size
color1 = 'black'
color2 = 'green'

abs_increases = []
pct_increases = []

just_layer_after_abs = None
all_layer_after_abs = None

for layer in layers:
    folder_path = "data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_82"
    filename = f"{folder_path}/peft_{12+layer}_rank_64_CE_increase_{NUM_TRAINING}k.json"
    with open(filename, 'r') as f:
        data = json.load(f)
        abs_increase = abs_diff(BASE_2B, data["initial"], data["converged"])
        if layer == 1:
            just_layer_after_abs = abs_increase
        pct_increase = percent_diff(BASE_2B, data["initial"], data["converged"])
        abs_increases.append(abs_increase)
        pct_increases.append(100*pct_increase)

with open("data/scaling/CE_increase/gemma-2-2b/layer_12_width_16k_average_l0_82/peft_13-25_rank_64_CE_increase_{NUM_TRAINING}k.json", 'r') as f:
    data = json.load(f)
    all_layer_after_abs = abs_diff(BASE_2B, data["initial"], data["converged"])

print(f"Absolute CE improvement after 1 layer: {just_layer_after_abs:.4f}")
print(f"Absolute CE improvement after all layers: {all_layer_after_abs:.4f}")
print("PERCENT ACHIEVED", 100*just_layer_after_abs/all_layer_after_abs)

line_color = 'blue'

# Plot absolute CE increase vs layer position as bars
ax1.bar(layers, abs_increases, color=line_color, alpha=0.67, label='Rank 64')

# Add horizontal line for all_layer_after_abs
ax1.axhline(y=all_layer_after_abs, color='crimson', linestyle='--', alpha=0.8, linewidth=0.6, label='All layers after')

# Plot percent CE increase vs layer position as bars
ax2.bar(layers, pct_increases, color=line_color, alpha=0.01)

# Configure axes
ax1.set_xlabel('Number of Layers after SAE')
ax1.set_ylabel('CE Improvement', color=color1)
ax2.set_ylabel('CE Improvement (%)', color=color2)

# Set y-axis limits with some padding
y1_min, y1_max = 0, max(all_layer_after_abs, max(abs_increases))
y2_min, y2_max = 0, max(pct_increases)
# Add padding but ensure all_layer_after_abs line is visible
padding_bottom = 0.0
padding_top = 0.05
ax1.set_ylim(y1_min - padding_bottom*(y1_max-y1_min), y1_max + padding_top*(y1_max-y1_min))
ax2.set_ylim(y2_min - padding_bottom*(y2_max-y2_min), y2_max + padding_top*(y2_max-y2_min))

# Set tick colors to match respective axes
ax1.tick_params(axis='y', colors=color1)
ax2.tick_params(axis='y', colors=color2)

# Add grid
ax1.grid(True, linestyle=':', alpha=0.5, linewidth=0.35)
ax1.grid(True, which='minor', linestyle=':', alpha=0.5, linewidth=0.35)
ax1.grid(True, which='major', linestyle=':', alpha=0.5, linewidth=0.35)

# Set spine and tick parameters
for ax in [ax1, ax2]:
    for spine in ax.spines.values():
        spine.set_linewidth(0.3)
    ax.tick_params(axis='x', which='major', length=3, width=0.35, pad=1)
    ax.tick_params(axis='x', which='minor', length=1.5, width=0.3, pad=1)
    ax.tick_params(axis='y', which='major', length=3, width=0.35, pad=1)

# Color the spines
ax1.spines['left'].set_color(color1)
ax2.spines['right'].set_color(color2)

# Add legend
ax1.legend(frameon=True, fancybox=True, shadow=False,
          loc='upper right', bbox_to_anchor=(0.99, 0.99))

ax = fig.gca()
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(0.2)

# Adjust layout and save
plt.subplots_adjust(left=0.15, right=0.85, top=0.95, bottom=0.15)
plt.savefig('plots/peft_layer_interp.pdf', bbox_inches='tight', pad_inches=0.025, dpi=300)
plt.show()
