In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset

from pathlib import Path
from dataclasses import dataclass

from MaskedDiffusion.model import MaskedDiffusion
from tokenization.tokenization import tokens_to_fen, FENTokens

In [2]:
base_path = Path("./dataset")
fen_tokens = torch.load(base_path / "fen_tokens.pt")
theme_tokens = torch.load(base_path / "theme_tokens.pt")
ratings = torch.load(base_path / "ratings.pt")

print(fen_tokens.shape, theme_tokens.shape, ratings.shape)
print(fen_tokens, theme_tokens, ratings)

torch.Size([5600086, 76]) torch.Size([5600086, 66]) torch.Size([5600086])
tensor([[10,  0,  0,  ..., 37, 40, 43],
        [ 0,  0,  0,  ..., 37, 40, 45],
        [ 0,  0,  0,  ..., 37, 44, 42],
        ...,
        [ 0,  0,  0,  ..., 37, 40, 47],
        [ 0,  0,  0,  ..., 37, 41, 44],
        [ 0,  0,  0,  ..., 37, 42, 41]], dtype=torch.int8) tensor([[1, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]], dtype=torch.int8) tensor([0.4934, 0.3679, 0.3191,  ..., 0.2593, 0.1713, 0.3110],
       dtype=torch.float16)


In [3]:
@dataclass
class Config:
    # tokenization
    n_fen_tokens = 48 + 1  # 48 real tokens and one mask token
    n_themes = 66
    rating_dim = 1
    fen_length = 76
    mask_token = FENTokens.mask

    # model
    n_heads = 8
    n_layers = 16
    embed_dim = 1024

    # optimizer and training
    lr = 1e-4
    weight_decay = 1e-4
    batch_size = 3  # 1024
    n_steps = 100_000

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config = Config()
model = MaskedDiffusion(config)
model.to(device=device, dtype=torch.float32)  # bfloat16

dataset = TensorDataset(fen_tokens, theme_tokens, ratings)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)  #, batch_size=config.batch_size, pin_memory=True

In [5]:
total = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        # print(f"{name:40} | {param.numel():,}")
        total += param.numel()
print(f"Total parameters: {total:,}")

Total parameters: 201,671,680


In [6]:
for fen, theme, rating in dataloader:
    print(fen.shape, theme.shape, rating.shape)
    fen, theme, rating = fen.to(dtype=torch.int32, device=device), theme.to(dtype=torch.int32, device=device), rating.to(dtype=torch.float32, device=device)
    print(fen[0])
    result = model(fen, theme, rating)
    print(result[:, :, :49].argmax(dim=2)[0])
    print(result.shape)
    break


torch.Size([3, 76]) torch.Size([3, 66]) torch.Size([3])
tensor([ 0,  0,  0,  4,  0,  0,  0,  0,  4,  0,  0,  0,  0,  7,  7,  7,  0,  7,
         0,  0,  0,  8,  0,  0,  0,  3, 12,  0,  0,  0,  0,  0,  0,  0,  1, 10,
         7,  0,  0,  0,  0,  0,  2,  0,  0,  0,  0,  1,  0, 11,  0,  0,  0,  1,
         1,  0,  0,  0,  0,  0,  0,  0,  6,  0, 13, 15, 15, 15, 15, 20, 20, 37,
        44, 37, 40, 47], dtype=torch.int32)
tensor([30, 30, 30, 47, 30, 30, 30, 30, 47, 30, 30, 30, 30, 47, 47, 47, 30, 47,
        30, 30, 30, 47, 30, 30, 30, 30, 47, 30, 30, 30, 30, 30, 30, 30, 47, 47,
        47, 30, 30, 30, 30, 30, 47, 30, 30, 30, 30, 47, 30, 47, 30, 30, 30, 47,
        47, 30, 30, 30, 30, 30, 30, 30,  0, 30, 47, 47, 47, 47, 47, 47, 47, 47,
        30, 47, 47, 47])
torch.Size([3, 76, 49])


In [7]:
tokens = result[:, :, :49].argmax(dim=2)
print(tokens_to_fen(tokens[0]))
print(tokens_to_fen(tokens[1]))
print(tokens_to_fen(tokens[2]))

NotLegalPositionError: The tokens do not yield a legal position

In [14]:
tokens = fen_tokens[0:1].clone().to(int)
tokens2 = tokens.clone().to(int)
tokens[0, 0] = FENTokens.mask
tokens[0, 1] = FENTokens.mask

logits = torch.zeros(1, 76, 49)
logits[0, 0, 0] = 0.5
logits[0, 0, 1] = 0.5
logits[0, 1, 0] = 0.7
logits[0, 1, 1] = 0.3

print(tokens)
print(tokens2)
print(logits.shape, tokens.shape, tokens2.shape)
print(model.elbo_loss(0.5, logits, tokens2, tokens))


tokens = fen_tokens[0:1].clone().to(int)
tokens2 = tokens.clone().to(int)
tokens[0, 0] = FENTokens.mask
tokens[0, 1] = FENTokens.mask

logits = torch.zeros(1, 76, 49)
logits[0, 0, 0] = 0.5
logits[0, 0, 10] = -10000
logits[0, 1, 0] = 1000
logits[0, 1, 1] = -1

import torch.nn.functional as F
print(F.softmax(logits, dim=-1))

print(model.elbo_loss(0.5, logits, tokens2, tokens))


tokens = fen_tokens[0:1].clone().to(int)
tokens2 = tokens.clone().to(int)
tokens[0, 0] = FENTokens.mask
tokens[0, 1] = FENTokens.mask

logits = torch.zeros(1, 76, 49)
logits[0, 0, 0] = 100
logits[0, 0, 10] = 0
logits[0, 1, 0] = 1000
logits[0, 1, 1] = -1

print(model.elbo_loss(0.5, logits, tokens2, tokens))


tokens = fen_tokens[0:1].clone().to(int)
tokens2 = tokens.clone().to(int)
tokens[0, 0] = FENTokens.mask
tokens[0, 1] = FENTokens.mask

logits = torch.zeros(1, 76, 49)
logits[0, 0, 0] = 100
logits[0, 0, 10] = 0
logits[0, 10, 0] = 1000
logits[0, 10, 1] = -1

print(model.elbo_loss(0.5, logits, tokens2, tokens))

tokens = fen_tokens[0:1].clone().to(int)
tokens2 = tokens.clone().to(int)
tokens[0, 0] = FENTokens.mask
tokens[0, 1] = FENTokens.mask

logits = torch.zeros(1, 76, 49)
logits[0, 0, 0] = 100
logits[0, 0, 10] = 0

print(model.elbo_loss(0.5, logits, tokens2, tokens))

tensor([[48, 48,  0,  0,  0,  0,  0, 12,  7,  7,  0,  0, 10,  0,  0,  7,  0,  0,
          0,  0,  4,  7,  0,  5,  0,  0,  0,  7,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  2,  0,  1,  0,  0,  9,  0,  1, 11,  1,  0,  0,  0,
          1,  1,  0,  0,  0,  0,  0,  0,  0,  6, 13, 15, 15, 15, 15, 20, 20, 37,
         38, 37, 40, 43]])
tensor([[10,  0,  0,  0,  0,  0,  0, 12,  7,  7,  0,  0, 10,  0,  0,  7,  0,  0,
          0,  0,  4,  7,  0,  5,  0,  0,  0,  7,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  2,  0,  1,  0,  0,  9,  0,  1, 11,  1,  0,  0,  0,
          1,  1,  0,  0,  0,  0,  0,  0,  0,  6, 13, 15, 15, 15, 15, 20, 20, 37,
         38, 37, 40, 43]])
torch.Size([1, 76, 49]) torch.Size([1, 76]) torch.Size([1, 76])
tensor(-14.2744)
tensor([[[0.0339, 0.0206, 0.0206,  ..., 0.0206, 0.0206, 0.0206],
         [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0204, 0.0204, 0.0204,  ..., 0.0204, 0.0204, 0.0204],
         ...,
        