In [42]:
import torch
import torch
from torch.nn import functional as F

from devinterp.optim import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import plot_trace, default_nbeta

from fractok import check_fractok
from tracr.haiku_to_pytorch import haiku_to_pytorch


from dataloaders import makeFractokDataLoader
from torchinfo import summary

In [43]:
loader = makeFractokDataLoader()

In [44]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = check_fractok()
torch_model = haiku_to_pytorch(model).to(DEVICE)

In [45]:
def evaluate(model, batch):
    inputs, outputs = batch[0]  # Unpack from batch
    
    # inputs is already a list of strings like ['BOS', 'd', 'x', 'c', 'e']
    # The model expects this format directly
    
    outputs = outputs.to(DEVICE)
    
    model_output = model.forward(inputs)  # Pass strings directly

    if not model_output.is_cuda and DEVICE == "cuda":
        model_output = model_output.to(DEVICE)

    # if torch.isnan(sum(sum(sum(model_output)))):
    #     print(inputs)
    #     print(outputs)
    #     print(model_output)

    # print("Debug info:")
    # print(inputs[0], outputs[0], model_output[0])
    loss = F.mse_loss(model_output, outputs) * 1000
    # loss = F.cross_entropy(model_output, outputs)

    return loss, {
        "logits": model_output
    }

In [46]:
# summary(torch_model)

In [47]:
# print(torch_model)

In [48]:
for _ in range(9):
    learning_coeff_stats = estimate_learning_coeff_with_summary(
        torch_model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=SGLD,
        optimizer_kwargs=dict(lr=1e-5, localization=1.0, nbeta=default_nbeta(loader)),
        num_chains=10,  # How many independent chains to run
        num_draws=100,  # How many samples to draw per chain
        num_burnin_steps=1,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=1,  # How many steps to take between each sample
        device=DEVICE,
        online=True,
    )
    trace = learning_coeff_stats["loss/trace"]
    print(round(sum(learning_coeff_stats['llc/means'])/len(learning_coeff_stats['llc/means']), 2))

Chain 0: 100%|██████████| 101/101 [00:01<00:00, 74.95it/s]
Chain 1: 100%|██████████| 101/101 [00:01<00:00, 94.17it/s]
Chain 2: 100%|██████████| 101/101 [00:01<00:00, 96.25it/s]
Chain 3: 100%|██████████| 101/101 [00:01<00:00, 98.96it/s]
Chain 4: 100%|██████████| 101/101 [00:01<00:00, 99.34it/s]
Chain 5: 100%|██████████| 101/101 [00:01<00:00, 99.13it/s]
Chain 6: 100%|██████████| 101/101 [00:01<00:00, 95.90it/s]
Chain 7: 100%|██████████| 101/101 [00:01<00:00, 98.58it/s]
Chain 8: 100%|██████████| 101/101 [00:01<00:00, 97.04it/s]
Chain 9: 100%|██████████| 101/101 [00:01<00:00, 98.61it/s]


1.87


Chain 0: 100%|██████████| 101/101 [00:01<00:00, 98.12it/s]
Chain 1: 100%|██████████| 101/101 [00:01<00:00, 87.40it/s]
Chain 2: 100%|██████████| 101/101 [00:01<00:00, 96.62it/s]
Chain 3: 100%|██████████| 101/101 [00:01<00:00, 99.55it/s]
Chain 4: 100%|██████████| 101/101 [00:01<00:00, 99.46it/s]
Chain 5: 100%|██████████| 101/101 [00:01<00:00, 93.22it/s]
Chain 6: 100%|██████████| 101/101 [00:01<00:00, 97.77it/s]
Chain 7: 100%|██████████| 101/101 [00:01<00:00, 99.50it/s]
Chain 8: 100%|██████████| 101/101 [00:01<00:00, 99.28it/s]
Chain 9: 100%|██████████| 101/101 [00:01<00:00, 99.56it/s]


1.81


Chain 0: 100%|██████████| 101/101 [00:01<00:00, 98.96it/s]
Chain 1: 100%|██████████| 101/101 [00:01<00:00, 75.68it/s]
Chain 2: 100%|██████████| 101/101 [00:01<00:00, 98.63it/s]
Chain 3: 100%|██████████| 101/101 [00:01<00:00, 95.85it/s]
Chain 4:  90%|█████████ | 91/101 [00:00<00:00, 93.74it/s]


KeyboardInterrupt: 

In [None]:
plot_trace(
    trace,
    "Loss",
    x_axis="Step",
    title=f"Loss Trace, avg LLC = {sum(learning_coeff_stats['llc/means']) / len(learning_coeff_stats['llc/means']):.2f}",
    plot_mean=False,
    plot_std=False,
    fig_size=(12, 9),
    true_lc=None,
)

In [None]:
# import pandas as pd
# from torchinfo import summary
# import itertools

# # Define experiment parameters
# max_seq_lens = [10, 50]
# vocab_ranges = ['small', 'medium', 'large']
# num_trials = 5

# # Storage for results
# results = []

# total_configs = len(max_seq_lens) * len(vocab_ranges)
# config_num = 0

# # Run all combinations
# for max_seq_len, vocab_range in itertools.product(max_seq_lens, vocab_ranges):
#     config_num += 1
#     print(f"\n{'='*60}")
#     print(f"Configuration {config_num}/{total_configs}: vocab={vocab_range}, max_seq_len={max_seq_len}")
#     print(f"{'='*60}")
    
#     # Create model and loader
#     loader = makeFractokDataLoader(max_seq_len=max_seq_len, vocab_size=vocab_range)
#     model = check_fractok(max_seq_len=max_seq_len, vocab_size=vocab_range)
#     torch_model = haiku_to_pytorch(model).to(DEVICE)
    
#     # Get model info
#     model_summary = summary(torch_model, verbose=0)
#     num_params = model_summary.total_params
#     num_blocks = len([m for m in torch_model.modules() if 'block' in str(type(m)).lower()])
    
#     print(f"Model: {num_params} params, {num_blocks} blocks")
    
#     # Convert vocab_range to readable format
#     if vocab_range == 'small':
#         vocab_str = 'a-e'
#     elif vocab_range == 'medium':
#         vocab_str = 'a-m'
#     elif vocab_range == 'large':
#         vocab_str = 'a-z'
    
#     # Run trials
#     for trial in range(num_trials):
#         print(f"  Trial {trial + 1}/{num_trials}...", end=" ", flush=True)
        
#         learning_coeff_stats = estimate_learning_coeff_with_summary(
#             torch_model,
#             loader=loader,
#             evaluate=evaluate,
#             sampling_method=SGLD,
#             optimizer_kwargs=dict(lr=1e-5, localization=1.0, nbeta=default_nbeta(loader)),
#             num_chains=10,
#             num_draws=100,
#             num_burnin_steps=0,
#             num_steps_bw_draws=1,
#             device=DEVICE,
#             online=True,
#         )
#         avg_llc = sum(learning_coeff_stats['llc/means']) / len(learning_coeff_stats['llc/means'])
#         print(f"LLC = {avg_llc:.4f}")
        
#         # Append one row per trial
#         results.append({
#             'Average LLC': round(avg_llc, 4),
#             'num_blocks': num_blocks,
#             'num_params': num_params,
#             'vocab': vocab_str,
#             'max_input_size': max_seq_len
#         })

# # Create and display table
# print("\n" + "="*60)
# print("FINAL RESULTS")
# print("="*60)
# df = pd.DataFrame(results)
# df = df[['Average LLC', 'num_blocks', 'num_params', 'vocab', 'max_input_size']]

# # Format to show 4 decimal places
# pd.options.display.float_format = '{:.4f}'.format
# print(df.to_string(index=False))

# df.to_csv('fractok_results2.csv', index=False, float_format='%.4f')
# print("\nSaved to fractok_results2.csv")