In [5]:
import json
import plotly_express as px
import numpy as np
import plotly.graph_objects as go

our_path = "results/eval_results_layer20_width16k_ninputs200.json"
gemma_path = "results/eval_results_gemmascope_layer20_width16k_l0s_139_22_294_38_71_ninputs200.json"

# Load json
with open(our_path, 'r') as f:
    our_data = json.load(f)

with open(gemma_path, 'r') as f:
    gemma_data = json.load(f)

In [17]:
# For our_data
l0s = []
frac_variance_explained = []
for trainer in our_data:
    l0s.append(our_data[trainer]['l0'])
    frac_variance_explained.append(our_data[trainer]['frac_variance_explained'])

# Sort our_data by l0 values
sorted_our = sorted(zip(l0s, frac_variance_explained))
l0s_sorted, frac_variance_explained_sorted = zip(*sorted_our)
l0s_sorted = list(l0s_sorted)
frac_variance_explained_sorted = list(frac_variance_explained_sorted)

# For gemma_data
gemma_l0s = []
gemma_frac_variance_explained = []
for key in gemma_data:
    gemma_l0s.append(gemma_data[key]['l0'])
    gemma_frac_variance_explained.append(gemma_data[key]['frac_variance_explained'])

# Sort gemma_data by l0 values
sorted_gemma = sorted(zip(gemma_l0s, gemma_frac_variance_explained))
gemma_l0s_sorted, gemma_frac_variance_explained_sorted = zip(*sorted_gemma)
gemma_l0s_sorted = list(gemma_l0s_sorted)
gemma_frac_variance_explained_sorted = list(gemma_frac_variance_explained_sorted)

# Create figure with both traces
fig = go.Figure()

# Add our data trace
fig.add_trace(go.Scatter(
    x=l0s_sorted,
    y=frac_variance_explained_sorted,
    mode='lines+markers',
    name='Our Model',
    marker=dict(symbol='circle'),
    line=dict(color='blue')
))

# Add gemma data trace
fig.add_trace(go.Scatter(
    x=gemma_l0s_sorted,
    y=gemma_frac_variance_explained_sorted,
    mode='lines+markers',
    name='Gemma Model',
    marker=dict(symbol='circle'),
    line=dict(color='red')
))

# Update layout
fig.update_layout(
    title="L0 vs Fraction of Variance Explained",
    xaxis_title="L0",
    yaxis_title="Fraction of Variance Explained",
    width=600,
    showlegend=True
)

fig.show()

In [18]:
# Same thing for l2_loss
# For our_data
l0s = []
l2_loss = []
for trainer in our_data:
    l0s.append(our_data[trainer]['l0'])
    l2_loss.append(our_data[trainer]['l2_loss'])

# Sort our_data by l0 values
sorted_our = sorted(zip(l0s, l2_loss))
l0s_sorted, l2_loss_sorted = zip(*sorted_our)
l0s_sorted = list(l0s_sorted)
l2_loss_sorted = list(l2_loss_sorted)

# For gemma_data
gemma_l0s = []
gemma_l2_loss = []
for key in gemma_data:
    gemma_l0s.append(gemma_data[key]['l0'])
    gemma_l2_loss.append(gemma_data[key]['l2_loss'])

# Sort gemma_data by l0 values
sorted_gemma = sorted(zip(gemma_l0s, gemma_l2_loss))
gemma_l0s_sorted, gemma_l2_loss_sorted = zip(*sorted_gemma)
gemma_l0s_sorted = list(gemma_l0s_sorted)
gemma_l2_loss_sorted = list(gemma_l2_loss_sorted)

# Create figure with both traces
fig = go.Figure()

# Add our data trace
fig.add_trace(go.Scatter(
    x=l0s_sorted,
    y=l2_loss_sorted,
    mode='lines+markers',
    name='Our Model',
    marker=dict(symbol='circle'),
    line=dict(color='blue')
))

# Add gemma data trace
fig.add_trace(go.Scatter(
    x=gemma_l0s_sorted,
    y=gemma_l2_loss_sorted,
    mode='lines+markers',
    name='Gemma Model',
    marker=dict(symbol='circle'),
    line=dict(color='red')
))

# Update layout
fig.update_layout(
    title="L0 vs L2 Loss",
    xaxis_title="L0",
    yaxis_title="L2 Loss",
    width=600,
    showlegend=True
)

fig.show()


In [23]:
# For our_data
l0s = []
loss_increase = []
for trainer in our_data:
    l0s.append(our_data[trainer]['l0'])
    orig = our_data[trainer]['loss_original']    
    recon = our_data[trainer]['loss_reconstructed']
    increase = 1 - ((recon - orig) / orig)
    loss_increase.append(increase)

# Sort our_data by l0 values
sorted_our = sorted(zip(l0s, loss_increase))
l0s_sorted, loss_increase_sorted = zip(*sorted_our)
l0s_sorted = list(l0s_sorted)
loss_increase_sorted = list(loss_increase_sorted)

# For gemma_data
gemma_l0s = []
gemma_loss_increase = []
for key in gemma_data:
    gemma_l0s.append(gemma_data[key]['l0'])
    orig = gemma_data[key]['loss_original']
    recon = gemma_data[key]['loss_reconstructed']
    increase = 1 - ((recon - orig) / orig)
    gemma_loss_increase.append(increase)

# Sort gemma_data by l0 values
sorted_gemma = sorted(zip(gemma_l0s, gemma_loss_increase))
gemma_l0s_sorted, gemma_loss_increase_sorted = zip(*sorted_gemma)
gemma_l0s_sorted = list(gemma_l0s_sorted)
gemma_loss_increase_sorted = list(gemma_loss_increase_sorted)

# Create figure with both traces
fig = go.Figure()

# Add our data trace
fig.add_trace(go.Scatter(
    x=l0s_sorted,
    y=loss_increase_sorted,
    mode='lines+markers',
    name='Our Model',
    marker=dict(symbol='circle'),
    line=dict(color='blue')
))

# Add gemma data trace
fig.add_trace(go.Scatter(
    x=gemma_l0s_sorted,
    y=gemma_loss_increase_sorted,
    mode='lines+markers',
    name='Gemma Model',
    marker=dict(symbol='circle'),
    line=dict(color='red')
))

# Update layout
fig.update_layout(
    title="Loss Recovered",
    xaxis_title="L0",
    yaxis_title="1 - ((Recon - Original) / Original)",
    width=600,
    showlegend=True
)

fig.show()


In [24]:
# For our_data
l0s = []
rel_recon_bias = []
for trainer in our_data:
    l0s.append(our_data[trainer]['l0'])
    rel_recon_bias.append(our_data[trainer]['relative_reconstruction_bias'])

# Sort our_data by l0 values
sorted_our = sorted(zip(l0s, rel_recon_bias))
l0s_sorted, rel_recon_bias_sorted = zip(*sorted_our)
l0s_sorted = list(l0s_sorted)
rel_recon_bias_sorted = list(rel_recon_bias_sorted)

# For gemma_data
gemma_l0s = []
gemma_rel_recon_bias = []
for key in gemma_data:
    gemma_l0s.append(gemma_data[key]['l0'])
    gemma_rel_recon_bias.append(gemma_data[key]['relative_reconstruction_bias'])

# Sort gemma_data by l0 values
sorted_gemma = sorted(zip(gemma_l0s, gemma_rel_recon_bias))
gemma_l0s_sorted, gemma_rel_recon_bias_sorted = zip(*sorted_gemma)
gemma_l0s_sorted = list(gemma_l0s_sorted)
gemma_rel_recon_bias_sorted = list(gemma_rel_recon_bias_sorted)

# Create figure with both traces
fig = go.Figure()

# Add our data trace
fig.add_trace(go.Scatter(
    x=l0s_sorted,
    y=rel_recon_bias_sorted,
    mode='lines+markers',
    name='Our Model',
    marker=dict(symbol='circle'),
    line=dict(color='blue')
))

# Add gemma data trace
fig.add_trace(go.Scatter(
    x=gemma_l0s_sorted,
    y=gemma_rel_recon_bias_sorted,
    mode='lines+markers',
    name='Gemma Model',
    marker=dict(symbol='circle'),
    line=dict(color='red')
))

# Update layout
fig.update_layout(
    title="L0 vs Relative Reconstruction Bias",
    xaxis_title="L0",
    yaxis_title="Relative Reconstruction Bias",
    width=600,
    showlegend=True
)

fig.show()