In [33]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW, Muon

from pathlib import Path

from MaskedDiffusion.model import MaskedDiffusion
from MaskingSchedule.MaskingSchedule import LinearSchedule, PolynomialSchedule, GeometricSchedule, CosineSchedule
from Config import Config
from tokenization.tokenization import tokens_to_fen, FENTokens


In [4]:
base_path = Path("./dataset")
fen_tokens = torch.load(base_path / "fen_tokens.pt")
theme_tokens = torch.load(base_path / "theme_tokens.pt")
ratings = torch.load(base_path / "ratings.pt")

print(fen_tokens.shape, theme_tokens.shape, ratings.shape)
print(fen_tokens, theme_tokens, ratings)

torch.Size([5600086, 76]) torch.Size([5600086, 66]) torch.Size([5600086])
tensor([[10,  0,  0,  ..., 37, 40, 43],
        [ 0,  0,  0,  ..., 37, 40, 45],
        [ 0,  0,  0,  ..., 37, 44, 42],
        ...,
        [ 0,  0,  0,  ..., 37, 40, 47],
        [ 0,  0,  0,  ..., 37, 41, 44],
        [ 0,  0,  0,  ..., 37, 42, 41]], dtype=torch.int8) tensor([[1, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]], dtype=torch.int8) tensor([0.4934, 0.3679, 0.3191,  ..., 0.2593, 0.1713, 0.3110],
       dtype=torch.float16)


In [32]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config = Config(n_layers=1, batch_size=3)
model = MaskedDiffusion(config)
model.to(device=device, dtype=torch.float32)  # bfloat16

total = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        # print(f"{name:40} | {param.numel():,}")
        total += param.numel()
print(f"Total parameters: {total:,}")

Total parameters: 12,834,816


In [30]:
dataset = TensorDataset(fen_tokens, theme_tokens, ratings)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

In [34]:
optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
# optimizer = Muon(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [35]:
for fen, theme, rating in dataloader:
    fen, theme, rating = fen.to(dtype=torch.int32, device=device), theme.to(dtype=torch.float32, device=device), rating.to(dtype=torch.float32, device=device)

    batch_size = len(fen)
    t = (torch.rand(1) + torch.arange(1, batch_size + 1, 1) / batch_size) % 1
    alpha_t = config.masking_schedule(t).unsqueeze(1)
    random_mask = torch.rand(fen.size()) < alpha_t  # true with probability alpha_t and false with 1 - alpha_t
    masked_fen = torch.where(random_mask, fen, FENTokens.mask)

    optimizer.zero_grad()

    logits = model(masked_fen, theme, rating)
    loss = model.elbo_loss(t, logits, fen, masked_fen)
    
    loss.backward()
    optimizer.step()
    
    break


RuntimeError: The size of tensor a (76) must match the size of tensor b (3) at non-singleton dimension 1

In [18]:
n = 2
model.sample(theme_tokens[0:n].to(torch.float32), ratings[0:n].to(torch.float32), steps=50)

tensor([[19,  3, 13,  6, 21, 11, 43, 27, 17, 29, 40,  7, 45, 33, 18, 22, 41, 46,
         17, 41,  6, 14, 15, 23, 25, 37, 18, 27,  0, 25, 34, 37, 41, 41,  1,  7,
          6, 43, 29, 41,  8, 42,  3,  6, 38, 25,  3, 27, 45, 11,  9, 14, 41, 33,
         15,  8, 17, 13, 18, 33,  4, 42, 22, 27, 25, 39, 30, 27, 41, 17,  4, 29,
         31, 18, 46,  5],
        [35, 45,  0,  7, 31, 31, 30, 15, 12, 41, 11, 24, 38,  0, 36, 25, 45, 27,
         22, 23,  8, 42, 26, 14, 23,  0, 15, 10, 37,  8, 18,  6, 17, 42,  9,  7,
         18,  1, 43, 39,  3, 13, 10,  8,  3, 14, 36, 19,  4, 28, 33, 34, 17,  0,
         26, 40,  6, 29, 41, 34, 26,  6,  4,  7,  9, 41, 26, 29, 18, 27, 12,  2,
         44, 29, 39, 43]])