In [1]:
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from lib.service import SamplesService


In [2]:
def decode_int64_bitset(x):
    """
    Convert a 64-bit integer into a 64-element float tensor
    """
    masks = 2 ** torch.arange(64, dtype=torch.int64, device='cuda')
    expanded = torch.bitwise_and(x.unsqueeze(-1), masks).ne(0).to(torch.float32)
    return expanded

class ChessModel(torch.nn.Module):

    def __init__(self, num_features):
        super(ChessModel, self).__init__()
        
        self.activation = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(num_features, 1024)
        self.linear2 = torch.nn.Linear(1024, 64)
        self.linear3 = torch.nn.Linear(64, 64)
        self.linear4 = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        x = self.activation(x)
        x = self.linear4(x)
        return x

chessmodel = ChessModel(num_features=768)
chessmodel.cuda()

print(chessmodel)

ChessModel(
  (activation): ReLU()
  (linear1): Linear(in_features=768, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=64, bias=True)
  (linear4): Linear(in_features=64, out_features=1, bias=True)
)


In [3]:
class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, inputs):
        inputs = inputs.reshape(-1, 3)
        
        p = inputs[:,0]
        q = inputs[:,1]
        r = inputs[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

In [4]:
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

EPOCHS = 1000
BATCHES_PER_EPOCH = 3_000
BATCH_SIZE = 4096

samples_service = SamplesService(batch_size=BATCH_SIZE)
optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.001)
loss_fn = PQRLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/{}'.format(timestamp))

# @torch.compile
def train_step(batch):
    # Clear the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = chessmodel(batch)

    # Compute the loss
    loss = loss_fn(outputs)
    loss.backward()

    # Update the parameters
    optimizer.step()

    return loss

# Make sure gradient tracking is on
chessmodel.train()

for epoch in range(EPOCHS):
    avg_loss = 0.0

    for _ in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch}'):
        batch = samples_service.next_batch()
        batch = decode_int64_bitset(batch)
        batch = batch.reshape(-1, 768)

        loss = train_step(batch)
        avg_loss += loss.item()

        if math.isnan(avg_loss):
            raise Exception("Loss is NaN, exiting")

    avg_loss /= BATCHES_PER_EPOCH

    writer.add_scalar('Train/loss', avg_loss, epoch)
    writer.add_scalar('Train/lr', optimizer.param_groups[0]["lr"], epoch)
    writer.add_scalar('Params/mean-l1', torch.mean(chessmodel.linear1.weight), epoch)
    writer.add_scalar('Params/mean-l2', torch.mean(chessmodel.linear2.weight), epoch)
    writer.add_scalar('Params/mean-l3', torch.mean(chessmodel.linear3.weight), epoch)
    writer.add_scalar('Params/mean-l4', torch.mean(chessmodel.linear4.weight), epoch)
    for name, param in chessmodel.named_parameters():
        writer.add_histogram(name, param, epoch)
    writer.flush()


2024-02-28 16:43:37.818314: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-28 16:43:37.818363: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-28 16:43:37.819745: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-28 16:43:37.827994: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 0:  31%|███       | 921/3000 [00:15<00:36, 57.6

KeyboardInterrupt: 