In [31]:
import einops
from torch.distributions import Categorical

In [32]:
import os

import numpy as np
from torch import Tensor
from src.datasets.shakespeare.shakespeare import ShakespeareDataset
from torch.nn import functional as F
from src.training.discrete_loss import (
    #alpha_variance_loss,
    divergence_loss,
    format_loss,
    loss,
    variance_loss,
)

import torch
from accelerate import Accelerator
from matplotlib import pyplot as plt
from safetensors.torch import load_file
from tqdm.auto import tqdm

from src.datasets.discrete_helper import collate_fn
from src.inference.discrete_inference import bayesian_inference, dis_t
from src.nn.layers.learnable_schedule import LearnableBetaScheduleNI
from src.nn.models.discrete_model import DiscreteModel
from src.tokenizers.ascii.ascii_tokenizer import ASCIITokenizer as Tokenizer
from src.training.checkpoint import CheckpointManager, CheckpointMetadata
from src.training.training import train_discrete_model

accelerator = Accelerator(project_dir="./runs/shakespeare")
tokenizer = Tokenizer()
batch_size = 64 * 2
max_seq_len = 32
folds = 8
effective_batch_size = batch_size // folds
dataset = ShakespeareDataset(
    tokenizer=tokenizer, max_length=max_seq_len, folds=folds,
)

model_kwargs = {
    "max_seq_len": max_seq_len,
    "K": tokenizer.vocab_size(),
    "hidden_dim": 512,
    "num_heads": 8,
    "layers": 5,
    # beta_1 from https://arxiv.org/html/2407.20294v2 equation 5
    "reference_beta_1": 20.4054 / tokenizer.vocab_size(),
    "learner_weight": 1.0,
    "freeze_body": False,
}
model = DiscreteModel(**model_kwargs)

optimizer_kwargs = {
    "lr": 3e-5,
}
body_opt = torch.optim.Adam(
    model.body.parameters(), **optimizer_kwargs  # pyright: ignore[reportArgumentType]
)
schedule_opt = torch.optim.Adam(
    model.learnable_beta.parameters(),
    **optimizer_kwargs,  # pyright: ignore[reportArgumentType]
)

metadata = CheckpointMetadata(
    model_kwargs=model_kwargs,
    optimizer_kwargs=optimizer_kwargs,
    is_fsdp=hasattr(accelerator.state, "fsdp_plugin")
    and accelerator.state.fsdp_plugin is not None,
    num_accelerators=accelerator.num_processes,
)

checkpoint_dir = "./checkpoint/shakespeare_shannon_ASCII"
checkpoint_manager = CheckpointManager()
print("Preparing model...")
checkpoint_manager.prepare(model, body_opt, schedule_opt, accelerator, metadata)

model, opt = checkpoint_manager.model, checkpoint_manager.body_optimizer

assert model is not None

Using the latest cached version of the module from /media/john/Tertiary/Data/huggingface/modules/datasets_modules/datasets/karpathy--tiny_shakespeare/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e (last modified on Sun Jul 20 14:43:33 2025) since it couldn't be found locally at karpathy/tiny_shakespeare, or remotely on the Hugging Face Hub.


Preparing model...


In [33]:
debug_data_current_epoch = torch.load("debug_data_current_epoch.pt")

In [34]:
model.load_state_dict(debug_data_current_epoch['model_state_dict'])

<All keys matched successfully>

In [35]:
schedule: LearnableBetaScheduleNI = model.learnable_beta

assert isinstance(schedule, LearnableBetaScheduleNI)

In [36]:
debug_data_current_epoch.keys()

dict_keys(['x', 't', 'output', 'alpha', 'formatted_loss', 'l_infty_loss', 'var_loss', 'div_loss', 'l', 'model_state_dict'])

In [37]:
x, t = debug_data_current_epoch['x'], debug_data_current_epoch['t']

In [38]:
x.shape

torch.Size([128, 32, 38])

In [39]:
samples = 64

In [40]:
batch_size

128

In [41]:
effective_batch_size

16

In [42]:
output, alpha = model(x, t)

In [43]:
formatted_loss = format_loss(
                alpha, x, model_output_logits=output, folds=folds
            )

In [44]:
l_infty_loss = loss(formatted_loss)
var_loss = variance_loss(formatted_loss)
l = l_infty_loss + var_loss

In [45]:
schedule.monotonic_nn.integrand.net[0].weight, schedule.monotonic_nn.integrand.net[0].weight.grad

(Parameter containing:
 tensor([[ 0.0608,  0.0501,  0.0179,  ...,  0.0516,  0.0122, -0.0392],
         [-0.0136, -0.0085, -0.0086,  ..., -0.0421, -0.0559, -0.0354],
         [-0.0175, -0.0024,  0.0445,  ..., -0.0073, -0.0203,  0.0003],
         ...,
         [-0.0553,  0.0384, -0.0604,  ..., -0.0348, -0.0207, -0.0231],
         [-0.0095, -0.0529, -0.0431,  ..., -0.0589, -0.0607, -0.0184],
         [-0.0608,  0.0548,  0.0080,  ...,  0.0199, -0.0049,  0.0091]],
        device='cuda:0', requires_grad=True),
 None)

In [46]:
# l.backward() not caused by l

In [47]:
# torch.isnan(schedule.monotonic_nn.integrand.net[0].weight.grad).any()

In [48]:
# torch.isinf(schedule.monotonic_nn.integrand.net[0].weight.grad).any()

In [49]:
div_loss_debug = torch.load("debug_div_loss.pt")

In [50]:
div_loss_debug.keys()

dict_keys(['expected_entropy', 'target_entropy', 'beta_t', 't', 'logits', 'eps'])

In [51]:
schedule.monotonic_nn.integrand.net[0].weight, schedule.monotonic_nn.integrand.net[0].weight.grad

(Parameter containing:
 tensor([[ 0.0608,  0.0501,  0.0179,  ...,  0.0516,  0.0122, -0.0392],
         [-0.0136, -0.0085, -0.0086,  ..., -0.0421, -0.0559, -0.0354],
         [-0.0175, -0.0024,  0.0445,  ..., -0.0073, -0.0203,  0.0003],
         ...,
         [-0.0553,  0.0384, -0.0604,  ..., -0.0348, -0.0207, -0.0231],
         [-0.0095, -0.0529, -0.0431,  ..., -0.0589, -0.0607, -0.0184],
         [-0.0608,  0.0548,  0.0080,  ...,  0.0199, -0.0049,  0.0091]],
        device='cuda:0', requires_grad=True),
 None)

In [52]:
beta_t = schedule.forward(div_loss_debug['t'], x.shape[-1])

In [53]:
beta_t = einops.repeat(
    beta_t,
    "(batch_size folds) -> (batch_size folds samples)",
    batch_size=effective_batch_size,
    folds=folds,
    samples=samples,
)

x = einops.repeat(
    x,
    "(batch_size folds) seq_len K -> (batch_size folds samples) seq_len K",
    batch_size=effective_batch_size,
    folds=folds,
    samples=samples,
)

In [54]:
x.requires_grad, beta_t.requires_grad

(False, True)

In [55]:
beta_t_y_dist = beta_t.view(-1, 1, 1)
mean = beta_t_y_dist * (tokenizer.vocab_size() * x - 1)
var = beta_t_y_dist * tokenizer.vocab_size()

In [56]:
print(f"beta_t has nan: {torch.isnan(beta_t).any()}, beta_t has inf: {torch.isinf(beta_t).any()}")
print(f"eps has nan: {torch.isnan(div_loss_debug['eps']).any()}, eps has inf: {torch.isinf(div_loss_debug['eps']).any()}")

beta_t has nan: False, beta_t has inf: False
eps has nan: False, eps has inf: False


In [57]:
dist = mean + (var.clamp(min=1e-8)**0.5) * div_loss_debug['eps']

In [58]:
print(f"mean has nan: {torch.isnan(mean).any()}, mean has inf: {torch.isinf(mean).any()}")
print(f"var has nan: {torch.isnan(var).any()}, var has inf: {torch.isinf(var).any()}")

mean has nan: False, mean has inf: False
var has nan: False, var has inf: False


In [59]:
cat = Categorical(logits=dist)

In [60]:
dist_clamped = torch.clamp(dist, -10, 10)
cat = Categorical(logits=dist_clamped)

In [61]:
print(f"dist min: {dist.min()}, dist max: {dist.max()}, dist mean: {dist.mean()}")

dist min: -15.377331733703613, dist max: 21.009212493896484, dist mean: -0.00028902949998155236


In [62]:
print(torch.isnan(dist).any(), torch.isinf(dist).any())

tensor(False, device='cuda:0') tensor(False, device='cuda:0')


In [63]:
entropy = cat.entropy()

In [64]:
print(f"entropy has nan: {torch.isnan(entropy).any()}, entropy has inf: {torch.isinf(entropy).any()}")

entropy has nan: False, entropy has inf: False


In [65]:
entropy.shape

torch.Size([8192, 32])

In [66]:
expected_entropy = einops.reduce(
    entropy,
    "(batch_size folds samples) seq_len -> (batch_size folds)",
    "mean",
    batch_size=effective_batch_size,
    folds=folds,
    samples=samples,
)

In [67]:
target_entropy = div_loss_debug['target_entropy']

In [68]:
print(f"target_entropy has nan: {torch.isnan(target_entropy).any()}, target_entropy has inf: {torch.isinf(target_entropy).any()}")

target_entropy has nan: False, target_entropy has inf: False


In [69]:
l = F.mse_loss(expected_entropy, target_entropy)

In [70]:
l.backward()

In [71]:
for param in schedule.parameters():
    if param.grad is not None:
        print(torch.isnan(param.grad).any(), torch.isinf(param.grad).any())

tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
tensor(False, device='cuda:0') tensor(False, device='cuda:0')
