Notebook for estimating LLC in the palindrome task.

In [None]:
import torch
import torch
from torch.nn import functional as F

from devinterp.optim import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import plot_trace, default_nbeta

from rasp_models.palindrome import check_palindrome
from tracr.haiku_to_pytorch import haiku_to_pytorch, apply


from datasets.dataloaders import makePalindromeDataLoader
from torchinfo import summary

In [None]:
loader = makePalindromeDataLoader() # Get palindrome data

In [None]:
# Get haiku palindrome model and convert to pytorch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = check_palindrome()
torch_model = haiku_to_pytorch(model).to(DEVICE)

In [None]:
# Evaluation (loss calculation) for LLC
def evaluate(model, data):
    inputs, outputs = data[0]  # Unpack from batch
    
    # inputs is already a list of strings like ['BOS', 'd', 'x', 'c', 'e']
    # The model expects this format directly
 
    model_output = model.forward(inputs)

    loss = F.cross_entropy(model_output, outputs) # Use CE loss since this is a classification task
    return loss, {
        "logits": model_output
    }

In [None]:
summary(torch_model) # Get model parameter counts

In [None]:
print(torch_model) # Get model architecture

In [None]:
# Estimate LLC 10 times
for _ in range(9):
    learning_coeff_stats = estimate_learning_coeff_with_summary(
        torch_model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=SGLD,
        optimizer_kwargs=dict(lr=1e-5, localization=1.0, nbeta=default_nbeta(loader)),
        num_chains=10,  # How many independent chains to run
        num_draws=100,  # How many samples to draw per chain
        num_burnin_steps=0,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=1,  # How many steps to take between each sample
        device=DEVICE,
        online=True,
    )
    trace = learning_coeff_stats["loss/trace"]
    print(round(sum(learning_coeff_stats['llc/means'])/len(learning_coeff_stats['llc/means']), 2))

In [None]:
# Print loss trace to get an idea of what's going on as we move around the loss landscape.
plot_trace(
    trace,
    "Loss",
    x_axis="Step",
    title=f"Loss Trace, avg LLC = {sum(learning_coeff_stats['llc/means']) / len(learning_coeff_stats['llc/means']):.2f}",
    plot_mean=False,
    plot_std=False,
    fig_size=(12, 9),
    true_lc=None,
)