In [1]:
from boardGPT.datasets import BoardDataset, collate_fn_board
from boardGPT.games.othello import game_to_board

## Loading auto-encoder

In [17]:
from boardGPT.models import GameAutoEncoder
from transformers import AutoTokenizer
autoencoder, model_config = GameAutoEncoder.from_pretrained(repo_id="theartificialis/Othello-Synthetic-AutoEncoder-20m-S")
autoencoder = autoencoder.to('cuda')
autoencoder.eval()

GameAutoEncoder(
  (encoder): ModuleDict(
    (wte): Embedding(61, 512)
    (wpe): Embedding(60, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (qk_hook): HookPoint()
          (v_hook): HookPoint()
          (c_attn): Linear(in_features=512, out_features=1536, bias=False)
          (c_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=2048, out_features=512, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (to_latent_token): Linear(in_features=512, out_features=16, bias=

In [19]:
tokenizer = AutoTokenizer.from_pretrained("theartificialis/Othello-Synthetic-AutoEncoder-20m-S", subfolder="tokenizer")

## Loading dataset

In [4]:
train_dataset = BoardDataset(
    data_dir="../../data/othello/othello-synthetic",
    board_func=game_to_board,
    split="train"
)

In [5]:
import torch

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=512,
    num_workers=8,
    pin_memory=True,
    shuffle=True,
    drop_last=False,
    collate_fn=lambda b: collate_fn_board(b, tokenizer)
)

In [6]:
batch = next(iter(dataloader))

In [7]:
len(batch)

2

In [8]:
batch[0].shape

torch.Size([512, 60])

In [9]:
batch[1].shape

torch.Size([512, 64])

In [11]:
import torch.nn as nn

class BoardProbe(nn.Module):
    def __init__(self, d_latent):
        super().__init__()
        self.linear = nn.Linear(d_latent, 64 * 3)
    # end def __init__

    def forward(self, x):
        logits = self.linear(x)           # [N, 192]
        logits = logits.view(-1, 64, 3)   # [N, 64, 3]
        return logits
    # end def forward
# end class BoardProbe

In [21]:
probe = BoardProbe(64)
probe = probe.to('cuda')

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-4)

In [23]:
len(dataloader)

39063

In [24]:
for epoch in range(10):
    total_loss = 0
    s_count = 0
    for idx, board in dataloader:
        idx = idx.to('cuda')
        board = board.to('cuda')  # [B, 64]

        with torch.no_grad():
            latent = autoencoder.encode(idx)   # [B, d_latent]
        # end with

        # Predict board state
        logits = probe(latent)    # [B, 64, 3]

        # CE
        loss = criterion(
            logits.view(-1, 3),
            board.view(-1)
        )
        # print(f"Loss: {loss}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        s_count += 1
        print(f"{s_count} / {len(dataloader)}")
    # end for
    print(f"Epoch {epoch+1}: loss={total_loss/len(dataloader):.4f}")
# end for

512 / 39063
1024 / 39063
1536 / 39063
2048 / 39063
2560 / 39063
3072 / 39063
3584 / 39063
4096 / 39063
4608 / 39063
5120 / 39063
5632 / 39063
6144 / 39063
6656 / 39063
7168 / 39063
7680 / 39063
8192 / 39063
8704 / 39063
9216 / 39063
9728 / 39063
10240 / 39063
10752 / 39063
11264 / 39063
11776 / 39063
12288 / 39063
12800 / 39063
13312 / 39063
13824 / 39063
14336 / 39063
14848 / 39063
15360 / 39063
15872 / 39063
16384 / 39063
16896 / 39063
17408 / 39063
