### Load the data

In [4]:
import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
from neel_plotly.plot import line
from helpers import loss_fn
from devinterp.slt.sampler import estimate_learning_coeff_with_summary, estimate_learning_coeff, SGLD
from devinterp.utils import plot_trace

device = "cuda" if torch.cuda.is_available() else "cpu"
RAND_SEED = 42

LOAD_LOCATION = "../saves/check_point_50/grokking_add_multi_0.7.pth"

cached_data = torch.load(LOAD_LOCATION, weights_only=False)

state_dict = cached_data['model']
model_checkpoints = cached_data["checkpoints"]
checkpoint_epochs = cached_data["checkpoint_epochs"]
test_losses = cached_data['test_losses']
train_losses = cached_data['train_losses']
add_test_losses = cached_data['add_test_losses']
multi_test_losses = cached_data['multi_test_losses']
max_nums = cached_data['max_nums']
mod_value = cached_data['mod_value']
train_frac = cached_data['train_frac']
addition_frac = cached_data['addition_frac']
train_data = cached_data['train_data']

print(f"train_frac = {train_frac} addition_frac = {addition_frac}")
print(f"len(train_losses) = {len(train_losses)} len(test_losses) = {len(test_losses)} len(model_checkpoints) = {len(model_checkpoints)}")

train_frac = 0.5 addition_frac = 0.7
len(train_losses) = 6000 len(test_losses) = 6000 len(model_checkpoints) = 120


In [7]:
from helpers import get_dataloader

train_loader = get_dataloader(train_data, batch_size=64, shuffle=True)

ImportError: cannot import name 'get_dataloader' from 'helpers' (/Users/james/dev/ArithmeticTransformer/src/helpers.py)

### Setup the Model

In [5]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type="LN",
    d_vocab=max_nums+1,
    d_vocab_out=mod_value,
    n_ctx= 3,
    init_weights=True,
    device=device,
    seed = 999,
)

model = HookedTransformer(cfg)

### Local Learning Coefficient (RLCT) Estimation

#### Config

In [6]:
lr = 5e-4
localization = 100.0
num_chains = 3
num_draws = 1000
num_burnin_steps = 0
num_steps_bw_draws = 1

def evaluate(model, data):
    inputs, outputs = data

    return loss_fn(model(inputs), outputs), {
        "logits": model(inputs)
    }

In [None]:
print(f"len(model_checkpoints) = {len(model_checkpoints)}")

model.load_state_dict(model_checkpoints[10])

results = estimate_learning_coeff_with_summary(
                model,
                loader=train_loader,
                evaluate=evaluate,
                sampling_method=SGLD,
                optimizer_kwargs=dict(lr=lr, localization=localization),
                num_chains=num_chains,                  # How many independent chains to run
                num_draws=num_draws,                    # How many samples to draw per chain
                num_burnin_steps=num_burnin_steps,      # How many samples to discard at the beginning of each chain
                num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
                device=device,
                online=True,
            )

plot_trace(
    results["llc/trace"],
    "Loss",
    x_axis="Step",
    title=f"Loss Trace, avg LLC = {sum(results['llc/means']) / len(results['llc/means']):.2f}",
    plot_mean=False,
    plot_std=False,
    fig_size=(12, 9),
    true_lc=None,
)

In [None]:
from devinterp.slt.sampler import estimate_learning_coeff_with_summary, estimate_learning_coeff, SGLD
from devinterp.utils import plot_trace

def evaluate(model, data):
    inputs, outputs = data

    return loss_fn(model(inputs), outputs), {
        "logits": model(inputs)
    }

llc_estimates = []

for saved_model in tqdm.tqdm(model_checkpoints):
    model.load_state_dict(saved_model)
    result = estimate_learning_coeff(
                model,
                loader=train_loader,
                evaluate=evaluate,
                sampling_method=SGLD,
                optimizer_kwargs=dict(lr=5e-4, localization=100.0),
                num_chains=15,           # How many independent chains to run
                num_draws=5,            # How many samples to draw per chain
                num_burnin_steps=20,    # How many samples to discard at the beginning of each chain
                num_steps_bw_draws=1,   # How many steps to take between each sample
                device=device,
                #online=True,
            )
    llc_estimates.append(result)
    print(f"result = {result}")

# Graph the LLC estimates
line(llc_estimates, xaxis="Epoch", yaxis="LLC", title="Learning Coefficient Estimates", log_y=False, toggle_x=True, toggle_y=True)

### Save the data

In [None]:
SAVE_LOCATION = "../saves/check_point_50/llc_estimates.pth"
SAVE = False
if SAVE:
    torch.save({
        "llc_estimates": llc_estimates,
    }, SAVE_LOCATION)

## Make some cool graphs

In [None]:
line(llc_estimates, xaxis="Epoch", yaxis="LLC", title="Learning Coefficient Estimates", log_y=False, toggle_x=True, toggle_y=True)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have your data prepared as before
epochs = np.arange(0, 4000, 10)
lc_epochs = np.arange(0, 4000, 100)

# Set up the plot
plt.figure(figsize=(12, 6))
sns.set_style("ticks")

# Create the main axis
ax1 = plt.gca()

# Plot loss curves
ax1.plot(epochs, test_losses_avg[::10], color=colors[0], label='Test Loss', linewidth=2)
ax1.plot(epochs, train_losses_avg[::10], color=colors[1], label='Train Loss', linewidth=2)
#ax1.plot(epochs, add_test_losses_avg, color=colors[2], label='Addition Only Test Loss', linewidth=2)
#ax1.plot(epochs, multi_test_losses_avg, color=colors[3], label='Multiplication Only Test Loss', linewidth=2)

# Set up the second y-axis for learning coefficient
ax2 = ax1.twinx()
ax2.plot(lc_epochs, llc_estimates[:4000], color=colors[4], linestyle='--', label='Learning Coefficient', linewidth=2)

# Customize the plot
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Rolling Avg Loss', fontsize=12)
ax2.set_ylabel('Learning Coefficient', fontsize=12)
plt.title(f"Loss and LLC During Training - {addition_frac*100:.0f}% Addition", fontsize=14)

# Adjust tick parameters
ax1.tick_params(axis='both', which='major', labelsize=10)
ax2.tick_params(axis='y', which='major', labelsize=10)

# Create a combined legend
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='center right') # bbox_to_anchor=(1.1, 1), fontsize=10)

plt.tight_layout()
plt.show()