In [1]:
from src.inference.generate import generative_prior, bayesian_inference, inference
from src.inference.conditional import half_callback_maker
import torch
from src.common.data_prep import dis_t
from src.datasets.dataset_helper import make_collate_fn
from src.datasets.shakespeare.shakespeare import ShakespeareDataset as Ds
from torch.nn.functional import one_hot
from torch.distributions import Categorical
from torch.nn import functional as F

In [2]:
from accelerate import Accelerator
from tqdm.auto import tqdm

In [3]:
from src.nn.discrete_model import DiscreteModel as Model
from src.tokenizers.byte.byte import ByT5Tokenizer as Tk
from src.schedule.vanilla import VanillaScheduler as Scheduler

In [4]:
from src.checkpointing.checkpointing import load_checkpoint

In [5]:
accelerator = Accelerator(log_with="tensorboard", project_dir="./runs")
checkpoint_name = "shakespeare_byt5_packed_toggleable_chunker"
checkpoint_dir = "./checkpoints"
batch_size = 256
seq_len = 128
min_t = 1e-8
num_workers = 3
hidden_size = 768
layers = 6
heads = 12
use_chunkers = False
dropout = 0.1
tk = Tk()
vocab_size = tk.vocab_size()
scheduler = Scheduler(20.4054 / vocab_size)

In [6]:
model = Model(
    max_seq_len=seq_len,
    K=vocab_size,
    hidden_dim=hidden_size,
    num_heads=heads,
    layers=layers,
    dropout=dropout,
    use_chunkers=use_chunkers,
)

In [7]:
model, _, _, _ = load_checkpoint(model, None, None, accelerator, checkpoint_dir + f"/{checkpoint_name}")

In [8]:
ds = Ds(tk, seq_len, min_t=min_t, train=True)

collate_fn = make_collate_fn(scheduler, vocab_size)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [9]:
dl = torch.utils.data.DataLoader(
    ds,
    batch_size=1,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

In [10]:
dl = accelerator.prepare(dl)

In [20]:
batch = next(iter(dl))

In [21]:
model_input = batch["model_input"]
t = batch["t"]
ground_truth = batch["ground_truth"]
scheduler_output = batch["scheduler_output"]
mask = batch["mask"]
doc_ids = batch["document_id"]
steps = 100
mask.shape, model_input.shape
inference_result = inference(
    model=model,
    scheduler=scheduler,
    num_steps=steps,
    batch_size=model_input.shape[0],
    seq_len=model_input.shape[1],
    K=vocab_size,
    mask=mask,
    masked_input=model_input,
    doc_ids=doc_ids,
    device=model_input.device,
    dtype=model_input.dtype,
)

In [25]:
tk.decode(model_input[0])

'atMLÆª\x14nxu, 2\x1fe Warck iownsa\x1a hi.Y\nERIgCE E"AAU:\nNay, mark hoR ~kQsstamp\x0e, as h\'were n`ttled:\nOoNbl\'s fv^the b'

In [26]:
tk.decode(ground_truth[0])

"at her news, while Warwick frowns at his.\n\nPRINCE EDWARD:\nNay, mark how Lewis stamps, as he were nettled:\nI hope all's for the b"

In [27]:
tk.decode(inference_result[0])

"ath princes, there Warwick crowns to him.\n\nPRINCE EDWARD:\nNay, mark him, here stamps, as he were nettled:\nAnd there 's for the b"