In [6]:
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from lib.service import SamplesService


In [7]:
def decode_int64_bitset(x):
    """
    Convert a 64-bit integer into a 64-element float tensor
    """
    masks = 2 ** torch.arange(64, dtype=torch.int64, device='cuda')
    expanded = torch.bitwise_and(x.unsqueeze(-1), masks).ne(0).to(torch.float32)
    return expanded

class ChessModel(torch.nn.Module):

    def __init__(self, num_features):
        super(ChessModel, self).__init__()
        
        self.activation = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(num_features, 1024)
        self.linear2 = torch.nn.Linear(1024, 64)
        self.linear3 = torch.nn.Linear(64, 64)
        self.linear4 = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        x = self.activation(x)
        x = self.linear4(x)
        return x

In [8]:
class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, inputs):
        inputs = inputs.reshape(-1, 3)
        
        p = inputs[:,0]
        q = inputs[:,1]
        r = inputs[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

In [9]:
import os
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

EPOCHS = 100000
BATCHES_PER_EPOCH = 100
BATCH_SIZE = 4096

FEATURE_SET = "basic"
NUM_FEATURES = 768

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
folder = 'runs/{}'.format(timestamp)
os.makedirs(f'{folder}/models', exist_ok=True)

samples_service = SamplesService(batch_shape=(BATCH_SIZE, 3, NUM_FEATURES // 64), feature_set=FEATURE_SET)
chessmodel = ChessModel(num_features=NUM_FEATURES)
chessmodel.cuda()

optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold=0.001, factor=0.7, patience=10)
loss_fn = PQRLoss()
writer = SummaryWriter(folder)

# @torch.compile
def train_step(batch):
    # Clear the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = chessmodel(batch)

    # Compute the loss
    loss = loss_fn(outputs)
    loss.backward()

    # Update the parameters
    optimizer.step()

    return loss

# Make sure gradient tracking is on
chessmodel.train()

for epoch in range(EPOCHS):
    avg_loss = 0.0

    for _ in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch}'):
        batch = samples_service.next_batch()
        batch = decode_int64_bitset(batch)
        batch = batch.reshape(-1, NUM_FEATURES)

        loss = train_step(batch)
        avg_loss += loss.item()

        if math.isnan(avg_loss):
            raise Exception("Loss is NaN, exiting")

    avg_loss /= BATCHES_PER_EPOCH

    # Step the scheduler
    scheduler.step(avg_loss)

    writer.add_scalar('Train/loss', avg_loss, epoch)
    writer.add_scalar('Train/lr', scheduler._last_lr[0], epoch) # get_last_lr()
    writer.add_scalar('Params/mean-l1', torch.mean(chessmodel.linear1.weight), epoch)
    writer.add_scalar('Params/mean-l2', torch.mean(chessmodel.linear2.weight), epoch)
    writer.add_scalar('Params/mean-l3', torch.mean(chessmodel.linear3.weight), epoch)
    writer.add_scalar('Params/mean-l4', torch.mean(chessmodel.linear4.weight), epoch)
    for name, param in chessmodel.named_parameters():
        writer.add_histogram(name, param, epoch)
    writer.flush()

    # save model
    torch.save(chessmodel.state_dict(), f'{folder}/models/{epoch}.pth')

Epoch 0:   5%|▌         | 5/100 [00:00<00:02, 42.04it/s]

Epoch 0: 100%|██████████| 100/100 [00:01<00:00, 57.66it/s]
Epoch 1: 100%|██████████| 100/100 [00:01<00:00, 60.04it/s]
Epoch 2:  55%|█████▌    | 55/100 [00:00<00:00, 57.82it/s]


KeyboardInterrupt: 

In [None]:
for i in tqdm(range(1000000)):
    a = samples_service.next_batch()

  3%|▎         | 25979/1000000 [00:08<05:25, 2992.88it/s]


KeyboardInterrupt: 