In [4]:
import random
from tracemalloc import start

import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import TorchDynamoPlugin

from src.datasets.discrete_helper import collate_fn
from src.datasets.shakespeare.shakespeare import ShakespeareDataset
from src.inference.discrete_inference import bayesian_inference, dis_t
from src.nn.models.discrete_model import DiscreteModel
from src.tokenizers.byt5.byt5_tokenizer import ByT5Tokenizer as Tokenizer
from src.training.checkpoint import CheckpointManager, CheckpointMetadata
from src.training.training import TrainingContext, train_discrete_model

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

accelerator = Accelerator(cpu=True)
print(f"Using device: {accelerator.device}")
print(f"Num processes: {accelerator.num_processes}")
print(
    f"Using fsdp: {hasattr(accelerator.state, 'fsdp_plugin') and accelerator.state.fsdp_plugin is not None}"
)
tokenizer = Tokenizer()
max_seq_len = 56
batch_size = 256
train_ds = ShakespeareDataset(tokenizer=tokenizer, max_length=max_seq_len)
train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=3
)
test_ds = ShakespeareDataset(tokenizer=tokenizer, max_length=max_seq_len)
test_dl = torch.utils.data.DataLoader(
    test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=3
)


model_kwargs = {
    "max_seq_len": max_seq_len,
    "K": tokenizer.vocab_size(),
    "hidden_dim": 256,
    "num_heads": 8,
    "layers": 3,
}
model = DiscreteModel(**model_kwargs)

print(
    f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters"
)

optimizer_kwargs = {
    "lr": 1e-5,
}
opt = torch.optim.Adam(
    model.parameters(), **optimizer_kwargs  # pyright: ignore[reportArgumentType]
)

metadata = CheckpointMetadata(
    model_kwargs=model_kwargs,
    optimizer_kwargs=optimizer_kwargs,
    is_fsdp=hasattr(accelerator.state, "fsdp_plugin")
    and accelerator.state.fsdp_plugin is not None,
    num_accelerators=accelerator.num_processes,
)

checkpoint_name = "shakespeare_dynamo_f32"

checkpoint_dir = f"./checkpoint/{checkpoint_name}"

checkpoint_manager = CheckpointManager()
checkpoint_manager.prepare(model, opt, accelerator, metadata)
checkpoint_manager.load(checkpoint_dir, error_if_not_exists=False)
start_epoch = (
    checkpoint_manager.metadata.current_epoch if checkpoint_manager.metadata else 0
)

model, opt = checkpoint_manager.model, checkpoint_manager.optimizer
train_dl, test_dl = accelerator.prepare(train_dl, test_dl)

Using device: cpu
Num processes: 1
Using fsdp: False


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Model has 2637824 parameters


In [5]:
test_iter = iter(test_dl)

In [15]:
from src.datasets.discrete_helper import beta_t, theta, y_distribution

while True:
    # Get a batch of data from the test set
    try:
        test_data = next(test_iter)
    except StopIteration:
        test_iter = iter(test_dl)
        test_data = next(test_iter)

    # Ask user for number of iterations
    try:
        num_iterations_str = input("Enter number of inference iterations (e.g., 100): ")
        if not num_iterations_str:
            print("Defaulting to 100 iterations.")
            num_iterations = 100
        else:
            num_iterations = int(num_iterations_str)
        if num_iterations <= 0:
            raise ValueError("Number of iterations must be positive.")
    except ValueError as e:
        print(f"Invalid input: {e}. Please enter a positive integer.")
        continue

    # Perform inference on the first item of the batch
    with torch.no_grad():
        model.eval()
        x = test_data["x"][:1]  # Take the first sample
        t = test_data["t"][:1]
        beta_1 = test_data["beta_1"][:1]
        
        _, seq_len, K = x.shape
        total_iterations = torch.ones_like(t) * num_iterations

        # Mask to keep the first half of the sequence
        indices = torch.arange(seq_len, device=x.device)
        mask = (indices < (seq_len // 2)).unsqueeze(0).unsqueeze(-1)

        # Recreate input from x at t=0
        beta_0 = beta_t(beta_1, t * 0)
        model_input_acc = theta(y_distribution(beta_0, K, x))

        # Prepare x for conditional generation (set 0s to -inf for logits)
        x_zero = x.clone().float()
        x_zero[x_zero == 0] = float("-inf")

        # Iterative generation loop
        for i in range(1, num_iterations + 1):
            current_iteration = torch.ones_like(t) * i
            t_curr = dis_t(current_iteration, total_iterations)
            output = model(model_input_acc, t_curr)
            
            # Apply mask: use original for the first half, model output for the second
            output = torch.where(mask, x_zero, output)
            
            model_input_acc = bayesian_inference(
                model_input_acc, output, current_iteration, total_iterations, beta_1
            )

        # Decode and display results
        original_ids = x.squeeze()
        generated_ids = model_input_acc.squeeze()

        expected_sequence = tokenizer.decode(original_ids)
        generated_sequence = tokenizer.decode(generated_ids)

        print("\n--- Inference Result ---")
        print(f"Expected Sequence:\n{expected_sequence}")
        print("-" * 20)
        print(f"Generated Sequence:\n{generated_sequence}")
        print("--- End of Result ---\n")

    # Ask to continue
    another_run = input("Perform another inference? (y/n): ").lower()
    if another_run != 'y':
        break



--- Inference Result ---
Expected Sequence:
First Citizen:
Before we proceed any further, hear me sp
--------------------
Generated Sequence:
First Citizen:
Before we proceeanHfu0heV,heacme&}p
--- End of Result ---



In [16]:
output

tensor([[[   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
         ...,
         [-3.6965, 15.1197, -2.9202,  ..., 15.1056, 15.1407, 15.1035],
         [-3.7855, 15.1616, -2.9738,  ..., 15.1814, 15.1272, 15.1405],
         [-3.4867, 13.4419, -2.7438,  ..., 15.0606, 13.6649, 14.2573]]])