In [1]:
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from lib.service import SamplesService


In [2]:
def decode_int64_bitset(x):
    """
    Convert a 64-bit integer into a 64-element float tensor
    """
    masks = 2 ** torch.arange(64, dtype=torch.int64, device='cuda')
    expanded = torch.bitwise_and(x.unsqueeze(-1), masks).ne(0).to(torch.float32)
    return expanded

class ChessModel(torch.nn.Module):

    def __init__(self, num_features):
        super(ChessModel, self).__init__()
        
        self.activation = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(num_features, 1024)
        self.linear2 = torch.nn.Linear(1024, 64)
        self.linear3 = torch.nn.Linear(64, 64)
        self.linear4 = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        x = self.activation(x)
        x = self.linear4(x)
        return x

chessmodel = ChessModel(num_features=768)
chessmodel.cuda()

print(chessmodel)

ChessModel(
  (activation): ReLU()
  (linear1): Linear(in_features=768, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=64, bias=True)
  (linear4): Linear(in_features=64, out_features=1, bias=True)
)


In [3]:
class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, inputs):
        inputs = inputs.reshape(-1, 3)
        
        p = inputs[:,0]
        q = inputs[:,1]
        r = inputs[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

In [4]:
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

EPOCHS = 1000
BATCHES_PER_EPOCH = 3_000
BATCH_SIZE = 4096

samples_service = SamplesService(batch_size=BATCH_SIZE)
optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.001)
loss_fn = PQRLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/{}'.format(timestamp))

# @torch.compile
def train_step(batch):
    # Clear the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = chessmodel(batch)

    # Compute the loss
    loss = loss_fn(outputs)
    loss.backward()

    # Update the parameters
    optimizer.step()

    return loss

# Make sure gradient tracking is on
chessmodel.train()

for epoch in range(EPOCHS):
    avg_loss = 0.0

    for _ in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch}'):
        batch = samples_service.next_batch()
        batch = decode_int64_bitset(batch)
        batch = batch.reshape(-1, 768)

        loss = train_step(batch)
        avg_loss += loss.item()

        if math.isnan(avg_loss):
            raise Exception("Loss is NaN, exiting")

    avg_loss /= BATCHES_PER_EPOCH

    writer.add_scalar('Train/loss', avg_loss, epoch)
    writer.add_scalar('Train/lr', optimizer.param_groups[0]["lr"], epoch)
    writer.add_scalar('Params/mean-l1', torch.mean(chessmodel.linear1.weight), epoch)
    writer.add_scalar('Params/mean-l2', torch.mean(chessmodel.linear2.weight), epoch)
    writer.add_scalar('Params/mean-l3', torch.mean(chessmodel.linear3.weight), epoch)
    writer.add_scalar('Params/mean-l4', torch.mean(chessmodel.linear4.weight), epoch)
    for name, param in chessmodel.named_parameters():
        writer.add_histogram(name, param, epoch)
    writer.flush()


2024-02-28 15:34:16.254562: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-28 15:34:16.254686: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-28 15:34:16.265653: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-28 15:34:16.308352: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 0:   0%|          | 6/3000 [00:00<04:12, 11.87i

tensor(0.6971, device='cuda:0', grad_fn=<AddBackward0>)
0.6970961093902588
tensor(0.6942, device='cuda:0', grad_fn=<AddBackward0>)
0.6942439079284668
tensor(0.6935, device='cuda:0', grad_fn=<AddBackward0>)
0.6935494542121887
tensor(0.6931, device='cuda:0', grad_fn=<AddBackward0>)
0.693077802658081
tensor(0.6932, device='cuda:0', grad_fn=<AddBackward0>)
0.6932233572006226
tensor(0.6932, device='cuda:0', grad_fn=<AddBackward0>)
0.6931798458099365
tensor(0.6929, device='cuda:0', grad_fn=<AddBackward0>)
0.6929364800453186
tensor(0.6928, device='cuda:0', grad_fn=<AddBackward0>)
0.6928325295448303
tensor(0.6929, device='cuda:0', grad_fn=<AddBackward0>)
0.6928582787513733
tensor(0.6928, device='cuda:0', grad_fn=<AddBackward0>)
0.6928172707557678
tensor(0.6926, device='cuda:0', grad_fn=<AddBackward0>)
0.6926230192184448


Epoch 0:   1%|          | 18/3000 [00:00<01:36, 30.82it/s]

tensor(0.6925, device='cuda:0', grad_fn=<AddBackward0>)
0.6924653649330139
tensor(0.6924, device='cuda:0', grad_fn=<AddBackward0>)
0.6924282312393188
tensor(0.6923, device='cuda:0', grad_fn=<AddBackward0>)
0.6923004984855652
tensor(0.6920, device='cuda:0', grad_fn=<AddBackward0>)
0.6920188069343567
tensor(0.6919, device='cuda:0', grad_fn=<AddBackward0>)
0.6919196248054504
tensor(0.6918, device='cuda:0', grad_fn=<AddBackward0>)
0.691828191280365
tensor(0.6916, device='cuda:0', grad_fn=<AddBackward0>)
0.6916108131408691
tensor(0.6913, device='cuda:0', grad_fn=<AddBackward0>)
0.6913023591041565
tensor(0.6910, device='cuda:0', grad_fn=<AddBackward0>)
0.6909537315368652
tensor(0.6906, device='cuda:0', grad_fn=<AddBackward0>)
0.6905672550201416
tensor(0.6902, device='cuda:0', grad_fn=<AddBackward0>)
0.6902071833610535
tensor(0.6894, device='cuda:0', grad_fn=<AddBackward0>)
0.6894310712814331


Epoch 0:   1%|          | 31/3000 [00:01<01:08, 43.39it/s]

tensor(0.6889, device='cuda:0', grad_fn=<AddBackward0>)
0.6888781785964966
tensor(0.6884, device='cuda:0', grad_fn=<AddBackward0>)
0.6884173154830933
tensor(0.6879, device='cuda:0', grad_fn=<AddBackward0>)
0.6878847479820251
tensor(0.6869, device='cuda:0', grad_fn=<AddBackward0>)
0.6869296431541443
tensor(0.6862, device='cuda:0', grad_fn=<AddBackward0>)
0.68623948097229
tensor(0.6858, device='cuda:0', grad_fn=<AddBackward0>)
0.6858070492744446
tensor(0.6853, device='cuda:0', grad_fn=<AddBackward0>)
0.6853053569793701
tensor(0.6846, device='cuda:0', grad_fn=<AddBackward0>)
0.6846150755882263
tensor(0.6833, device='cuda:0', grad_fn=<AddBackward0>)
0.6833410263061523
tensor(0.6815, device='cuda:0', grad_fn=<AddBackward0>)
0.6814746260643005
tensor(0.6814, device='cuda:0', grad_fn=<AddBackward0>)
0.6813890337944031
tensor(0.6800, device='cuda:0', grad_fn=<AddBackward0>)
0.6799975037574768


Epoch 0:   1%|▏         | 43/3000 [00:01<00:58, 50.42it/s]

tensor(0.6815, device='cuda:0', grad_fn=<AddBackward0>)
0.6815463900566101
tensor(0.6818, device='cuda:0', grad_fn=<AddBackward0>)
0.681839108467102
tensor(0.6788, device='cuda:0', grad_fn=<AddBackward0>)
0.6788350343704224
tensor(0.6789, device='cuda:0', grad_fn=<AddBackward0>)
0.6789047122001648
tensor(0.6780, device='cuda:0', grad_fn=<AddBackward0>)
0.677987277507782
tensor(0.6758, device='cuda:0', grad_fn=<AddBackward0>)
0.6757647395133972
tensor(0.6777, device='cuda:0', grad_fn=<AddBackward0>)
0.677666425704956
tensor(0.6759, device='cuda:0', grad_fn=<AddBackward0>)
0.6759375929832458
tensor(0.6751, device='cuda:0', grad_fn=<AddBackward0>)
0.6750903129577637
tensor(0.6732, device='cuda:0', grad_fn=<AddBackward0>)
0.6731860637664795
tensor(0.6738, device='cuda:0', grad_fn=<AddBackward0>)
0.6738492250442505
tensor(0.6733, device='cuda:0', grad_fn=<AddBackward0>)
0.6733296513557434


Epoch 0:   2%|▏         | 52/3000 [00:01<01:23, 35.22it/s]


tensor(0.6723, device='cuda:0', grad_fn=<AddBackward0>)
0.6722923517227173
tensor(0.6722, device='cuda:0', grad_fn=<AddBackward0>)
0.6722132563591003
tensor(0.6725, device='cuda:0', grad_fn=<AddBackward0>)
0.6725433468818665
tensor(0.6702, device='cuda:0', grad_fn=<AddBackward0>)
0.670182466506958
tensor(0.6693, device='cuda:0', grad_fn=<AddBackward0>)
0.6693065762519836


KeyboardInterrupt: 