In [1]:
from sched import scheduler

import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
from torch.optim import AdamW as Opt
from torch.optim.lr_scheduler import ReduceLROnPlateau as ReduceLR
from torch.utils.data import DataLoader

from src.datasets.dataset_helper import make_collate_fn
from src.datasets.shakespeare.shakespeare import ShakespeareDataset as Ds
from src.nn.discrete_model import DiscreteModel as Model
from src.schedule.vanilla import VanillaScheduler as Scheduler
from src.tokenizers.byte.byte import ByT5Tokenizer as Tk
from src.training.train import TrainingContext as Context
from src.training.train import train
from src.checkpointing.checkpointing import load_checkpoint

In [2]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
    log_with="tensorboard", project_dir="./runs", kwargs_handlers=[ddp_kwargs]
)
checkpoint_name = "shakespeare_byt5_packed_ebt"
checkpoint_dir = "./checkpoints"
batch_size = 256
seq_len = 128
min_t = 1e-8
num_workers = 3
hidden_size = 768
layers = 6
heads = 12
tk = Tk()
vocab_size = tk.vocab_size()
scheduler = Scheduler(20.4054 / vocab_size)

In [3]:
model = Model(
    max_seq_len=seq_len,
    K=vocab_size,
    hidden_dim=hidden_size,
    num_heads=heads,
    layers=layers,
    dropout=0.1,
    use_chunkers=False,
)

In [4]:
model, _, _, _ = load_checkpoint(model, None, None, accelerator, checkpoint_dir + f"/{checkpoint_name}")

In [5]:
debug_data = torch.load("debug.pt")

In [6]:
debug_data

{'theta': tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039]]],
        device='cuda:1'),
 't': tensor([[1.0000e-05, 1.0000e-05, 1.0000e-05,  ..., 1.0000e-05, 1.0000e-05,
          1.0000e-05]], device='cuda:1'),
 'mask': tensor([[False, False, False,  ...,  True,  True,  True]], device='cuda:1'),
 'doc_ids': tensor([[  0,   0,   0,  ..., 255, 255, 255]], device='cuda:1')}

In [7]:
theta, t, mask, doc_ids = debug_data["theta"], debug_data["t"], debug_data["mask"], debug_data["doc_ids"]

In [9]:
theta = theta.to(accelerator.device)
t = t.to(accelerator.device)
mask = mask.to(accelerator.device)
doc_ids = doc_ids.to(accelerator.device)

In [10]:
logits, l = model(theta, t, mask, doc_ids)

  return _C._get_float32_matmul_precision()
