In [1]:
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from lib.service import SamplesService

In [2]:
import numpy as np

def decode_int64_bitset(x):
    """
    Convert a 64-bit integer into a 64-element float tensor
    """
    masks = torch.tensor(2 ** np.arange(64, dtype=np.int64), device='cuda')
    expanded = torch.bitwise_and(x.unsqueeze(-1), masks).ne(0).to(torch.float32)
    return expanded

class ChessModel(torch.nn.Module):

    def __init__(self, num_features):
        super(ChessModel, self).__init__()
        
        self.activation = torch.nn.ReLU()
        self.linear1 = torch.nn.Linear(num_features, 1024)
        self.linear2 = torch.nn.Linear(1024, 64)
        self.linear3 = torch.nn.Linear(64, 64)
        self.linear4 = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        x = self.activation(x)
        x = self.linear4(x)
        return x

chessmodel = ChessModel(num_features=768)
chessmodel.cuda()

print(chessmodel)

ChessModel(
  (activation): ReLU()
  (linear1): Linear(in_features=768, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=64, bias=True)
  (linear4): Linear(in_features=64, out_features=1, bias=True)
)


In [4]:
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, inputs):
        inputs = inputs.reshape(-1, 3)
        
        p = inputs[:,0]
        q = inputs[:,1]
        r = inputs[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

EPOCHS = 1000
BATCHES_PER_EPOCH = 3_000
BATCH_SIZE = 4096

samples_service = SamplesService(batch_size=BATCH_SIZE)
optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.00001)
loss_fn = PQRLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/{}'.format(timestamp))

def train_one_epoch(epoch_index):
    running_loss = 0.0
    last_loss = 0.0

    for i in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch_index}'):
        batch = samples_service.next_batch()
        batch = decode_int64_bitset(batch)
        batch = batch.reshape(-1, 768)

        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = chessmodel(batch)

        # Compute the loss
        loss = loss_fn(outputs)
        loss.backward()

        # Update the parameters
        optimizer.step()

        running_loss += loss.item()

    return running_loss / BATCHES_PER_EPOCH

for epoch in range(EPOCHS):
    # Make sure gradient tracking is on
    chessmodel.train()
    
    avg_loss = train_one_epoch(epoch)

    if math.isnan(avg_loss):
        print("Loss is NaN, exiting")
        break

    writer.add_scalar('Train/loss', avg_loss, epoch)
    writer.add_scalar('Train/lr', optimizer.param_groups[0]["lr"], epoch)
    writer.add_scalar('Params/mean-l1', torch.mean(chessmodel.linear1.weight), epoch)
    writer.add_scalar('Params/mean-l2', torch.mean(chessmodel.linear2.weight), epoch)
    writer.add_scalar('Params/mean-l3', torch.mean(chessmodel.linear3.weight), epoch)
    writer.add_scalar('Params/mean-l4', torch.mean(chessmodel.linear4.weight), epoch)
    for name, param in chessmodel.named_parameters():
        writer.add_histogram(name, param, epoch)
    writer.flush()


Samples service cleanup


Epoch 0:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 3000/3000 [00:57<00:00, 52.09it/s]
Epoch 1: 100%|██████████| 3000/3000 [01:00<00:00, 49.61it/s]
Epoch 2: 100%|██████████| 3000/3000 [01:49<00:00, 27.42it/s]
Epoch 3: 100%|██████████| 3000/3000 [01:58<00:00, 25.22it/s]
Epoch 4:  63%|██████▎   | 1888/3000 [01:08<00:40, 27.67it/s]


KeyboardInterrupt: 