In [None]:
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from lib.service import SamplesService

In [None]:
import numpy as np

def decode_int64_bitset(x):
    """
    Convert a 64-bit integer into a 64-element float tensor
    """
    masks = torch.tensor(2 ** np.arange(64, dtype=np.int64), device='cuda')
    expanded = torch.bitwise_and(x.unsqueeze(-1), masks).ne(0).to(torch.float32)
    return expanded

class ChessModel(torch.nn.Module):

    def __init__(self, num_features):
        super(ChessModel, self).__init__()
        
        self.linear1 = torch.nn.Linear(num_features, 256)
        self.linear2 = torch.nn.Linear(256, 64)
        self.linear3 = torch.nn.Linear(64, 64)
        self.linear4 = torch.nn.Linear(64, 1)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        #x = self.activation(x)
        x = self.linear3(x)
        #x = self.activation(x)
        x = self.linear4(x)
        return x
        
chessmodel = ChessModel(num_features=768)
chessmodel.cuda()

print(chessmodel)

In [None]:
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, inputs):
        inputs = inputs.reshape(-1, 3)
        
        p = inputs[:,0]
        q = inputs[:,1]
        r = inputs[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

EPOCHS = 10
BATCHES_PER_EPOCH = 1000
BATCH_SIZE = 4096
LOSS_EVERY = 20

samples_service = SamplesService(batch_size=BATCH_SIZE)
optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.01)
loss_fn = PQRLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/{}'.format(timestamp))

def train_one_epoch(epoch_index):
    running_loss = 0.0
    last_loss = 0.0

    for i in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch_index}'):
        batch = samples_service.next_batch()
        batch = decode_int64_bitset(batch)
        batch = batch.reshape(-1, 768)

        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = chessmodel(batch)

        # Compute the loss
        loss = loss_fn(outputs)
        loss.backward()

        # Update the parameters
        optimizer.step()

        running_loss += loss.item()
        if i % LOSS_EVERY == (LOSS_EVERY - 1):
            last_loss = running_loss / LOSS_EVERY
            running_loss = 0.0

    return last_loss

for epoch in range(EPOCHS):
    # Make sure gradient tracking is on
    chessmodel.train()
    
    avg_loss = train_one_epoch(epoch)

    if math.isnan(avg_loss):
        print("Loss is NaN, exiting")
        break

    writer.add_scalar('Loss/train', avg_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]["lr"], epoch)
    writer.flush()
