In [1]:
import torch
from transformers import AutoTokenizer
from boardGPT.models import GameGPT
from boardGPT.datasets import BoardDataset, collate_fn_board
from boardGPT.games.othello import game_to_board

In [2]:
print(torch.cuda.is_available())

True


## Load models and tokenizer

In [3]:
model, model_config = GameGPT.from_pretrained(repo_id="theartificialis/OthelloGPT-Synthetic-20m")
tokenizer = AutoTokenizer.from_pretrained("theartificialis/OthelloGPT-Synthetic-20m", subfolder="tokenizer")
model = model.to('cuda')
model.eval()

GameGPT(
  (token_emb_hook): HookPoint()
  (pos_emb_hook): HookPoint()
  (pre_logits_hook): HookPoint()
  (transformer): ModuleDict(
    (wte): Embedding(61, 512)
    (wpe): Embedding(60, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (qk_hook): HookPoint()
          (v_hook): HookPoint()
          (c_attn): Linear(in_features=512, out_features=1536, bias=False)
          (c_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=2048, out_features=512, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    

In [4]:
val_dataset = BoardDataset(
    data_dir="../../data/othello/othello-synthetic",
    board_func=game_to_board,
    split="train"
)

In [5]:
# Create a dataloader
train_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=512,
    num_workers=8,
    pin_memory=True,
    shuffle=True,
    drop_last=False,
    collate_fn=lambda b: collate_fn_board(b, tokenizer)
)

In [6]:
batch = next(iter(train_dataloader))

In [7]:
# This is the longtensor with idx for moves
batch[0].shape

torch.Size([512, 60])

In [8]:
# This is the board states
batch[1].shape

torch.Size([512, 64])

In [9]:
x = batch[0].to('cuda')
x, logits, loss, residuals = model(x, to_return=['residuals7'])

In [10]:
len(residuals)

1

In [11]:
residuals[0].shape

torch.Size([512, 60, 512])

In [12]:
import torch
import torch.nn as nn

class StackedBoardProbes(nn.Module):
    """
    Linear probes for stacked residuals.

    Input:
        x: [B, n_layers, T, d_model]
    Output:
        logits: [B, n_layers, T, 64, 3]
    """
    def __init__(
            self,
            d_model: int,
            n_layers: int = 8
    ):
        super().__init__()
        self.n_layers = n_layers
        self.linears = nn.ModuleList([
            nn.Linear(d_model, 64 * 3) for _ in range(n_layers)
        ])
    # end def __init__

    def forward(
            self,
            x: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x (Tensor): [B, n_layers, T, d_model]
        Returns:
            logits (Tensor): [B, n_layers, T, 64, 3]
        """
        B, L, T, D = x.shape
        outs = []
        for i in range(L):
            layer_out = self.linears[i](x[:, i])          # [B, T, 192]
            layer_out = layer_out.view(B, T, 64, 3)       # [B, T, 64, 3]
            outs.append(layer_out)
        # end for
        return torch.stack(outs, dim=1)                   # [B, L, T, 64, 3]
    # end def forward

# end class StackedBoardProbes

In [13]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Create model
probes = StackedBoardProbes(d_model=512, n_layers=8).to('cuda')

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probes.parameters(), lr=1e-3)

In [15]:
for epoch in range(100):
    probes.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch in train_dataloader:
        # Moves and game state
        move_idx, game_x = batch

        # To cuda
        move_idx, game_x = move_idx.to('cuda'), game_x.to('cuda')

        # Get residuals
        with torch.no_grad():
            _, _, _, residuals = model(
                idx=move_idx,
                to_return=[
                    'residuals0',
                    'residuals1',
                    'residuals2',
                    'residuals3',
                    'residuals4',
                    'residuals5',
                    'residuals6',
                    'residuals7'
                ]
            )
        # end with

        # Stack residuals
        residuals = torch.stack(residuals, dim=1)

        # -------------
        # 2. Forward probes
        # -------------
        preds = probes(residuals)          # [B, 8, 60, 64, 3]
        B, L, T, S, C = preds.shape        # 512, 8, 60, 64, 3

        # -------------
        # 3. Flatten for loss
        # -------------
        preds = preds.view(B, L * T * S, C)           # [B, 30720, 3]
        targets = game_x.unsqueeze(1).expand(-1, L * T, -1)  # [B, 480, 64]
        targets = targets.reshape(B, -1)              # [B, 30720]

        # -------------
        # 4. Compute loss
        # -------------
        loss = criterion(
            preds.reshape(B * 30720, 3),      # [B*30720, 3]
            targets.reshape(B * 30720).long()  # [B*30720]
        )

        # -------------
        # 5. Backprop
        # -------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Total loss
        total_loss += loss.item()

        # -------------
        # 6. Accuracy
        # -------------
        with torch.no_grad():
            pred_labels = preds.argmax(dim=-1)      # [B, 30720]
            correct = (pred_labels == targets).float().sum()
            total_correct += correct.item()
            total_samples += targets.numel()
        # end with
    # end for

    # Epoch stats
    epoch_loss = total_loss / len(train_dataloader)
    epoch_acc = total_correct / total_samples

    print(f"Epoch {epoch+1:02d} | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")
# end for

In [20]:
print(game_x[0])

tensor([0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 0, 0, 1, 1, 2, 2, 0,
        2, 0, 0, 1, 1, 2, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 0, 2, 1, 0,
        0, 1, 1, 2, 0, 2, 1, 1, 0, 1, 0, 2, 1, 0, 0, 0], device='cuda:0')


In [58]:
torch.cuda.is_available()

True