In [3]:
import numpy as np
from torch import Tensor
from src.datasets.shakespeare.shakespeare import ShakespeareDataset
from torch.nn import functional as F
from src.training.discrete_loss import (
    #alpha_variance_loss,
    divergence_loss,
    format_loss,
    loss,
    variance_loss,
)

In [4]:
import os

import torch
from accelerate import Accelerator
from matplotlib import pyplot as plt
from safetensors.torch import load_file
from tqdm.auto import tqdm

from src.datasets.discrete_helper import collate_fn
from src.inference.discrete_inference import bayesian_inference, dis_t
from src.nn.layers.learnable_schedule import LearnableBetaScheduleNI
from src.nn.models.discrete_model import DiscreteModel
from src.tokenizers.ascii.ascii_tokenizer import ASCIITokenizer as Tokenizer
from src.training.checkpoint import CheckpointManager, CheckpointMetadata
from src.training.training import train_discrete_model

accelerator = Accelerator(project_dir="./runs/shakespeare")
tokenizer = Tokenizer()
max_seq_len = 32
folds = 8
dataset = ShakespeareDataset(
    tokenizer=tokenizer, max_length=max_seq_len, folds=folds,
)

model_kwargs = {
    "max_seq_len": max_seq_len,
    "K": tokenizer.vocab_size(),
    "hidden_dim": 512,
    "num_heads": 8,
    "layers": 5,
    # beta_1 from https://arxiv.org/html/2407.20294v2 equation 5
    "reference_beta_1": 20.4054 / tokenizer.vocab_size(),
    "learner_weight": 1.0,
    "freeze_body": False,
}
model = DiscreteModel(**model_kwargs)

optimizer_kwargs = {
    "lr": 3e-5,
}
body_opt = torch.optim.Adam(
    model.body.parameters(), **optimizer_kwargs  # pyright: ignore[reportArgumentType]
)
schedule_opt = torch.optim.Adam(
    model.learnable_beta.parameters(),
    **optimizer_kwargs,  # pyright: ignore[reportArgumentType]
)

metadata = CheckpointMetadata(
    model_kwargs=model_kwargs,
    optimizer_kwargs=optimizer_kwargs,
    is_fsdp=hasattr(accelerator.state, "fsdp_plugin")
    and accelerator.state.fsdp_plugin is not None,
    num_accelerators=accelerator.num_processes,
)

checkpoint_dir = "./checkpoint/shakespeare_ASCII_shannon"
checkpoint_manager = CheckpointManager()
print("Preparing model...")
checkpoint_manager.prepare(model, body_opt, schedule_opt, accelerator, metadata)
print("Starting checkpoint loading process...")
checkpoint_manager.load(checkpoint_dir, error_if_not_exists=True)
print("Finished loading checkpoint")

model, opt = checkpoint_manager.model, checkpoint_manager.body_optimizer

assert model is not None
assert isinstance(model, DiscreteModel)

schedule: LearnableBetaScheduleNI = model.learnable_beta

assert isinstance(schedule, LearnableBetaScheduleNI)

Using the latest cached version of the module from /media/john/Tertiary/Data/huggingface/modules/datasets_modules/datasets/karpathy--tiny_shakespeare/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e (last modified on Sun Jul 20 14:43:33 2025) since it couldn't be found locally at karpathy/tiny_shakespeare, or remotely on the Hugging Face Hub.


Preparing model...
Starting checkpoint loading process...
Attempting to load checkpoint from epoch 11
Finished loading checkpoint


In [7]:
schedule.forward(torch.tensor([0.0], device="cuda"), tokenizer.vocab_size())

tensor([0.], device='cuda:0', grad_fn=<SqueezeBackward1>)