In [199]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from datetime import datetime
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
DATA_DIR = "data/"
STOCKFISH_DIR = 'stockfish/'
ARCHIVE_DIR = DATA_DIR + 'archives/'
BITBOARD_DIR = DATA_DIR + 'bitboards/'
ELITE_DATA_BASE_URL  = "https://database.nikonoel.fr/"
STOCKFISH_DOWNSTREAM = "https://github.com/official-stockfish/Stockfish/releases/latest/download/"

SAMPLE_ZIP = "lichess_elite_2021-11.zip"
SAMPLE_PGN = "lichess_elite_2021-11.pgn"
SAMPLE_BITBOARD = "elite_bitboard.csv"
ELITE_DATA_SAMPLE_URL = ELITE_DATA_BASE_URL + SAMPLE_ZIP
SAMPLE_ZIP_FILE = ARCHIVE_DIR + SAMPLE_ZIP
SAMPLE_PGN_FILE  = DATA_DIR + SAMPLE_PGN

STOCKFISH_AVX512_TAR = "stockfish-ubuntu-x86-64-avx512.tar"
STOCKFISH_AVX512 = "stockfish-ubuntu-x86-64-avx512"
STOCKFISH_AVX512_URL = STOCKFISH_DOWNSTREAM + STOCKFISH_AVX512_TAR
STOCKFISH_AVX512_EXE = STOCKFISH_DIR + STOCKFISH_AVX512

SAMPLE_BITBOARD_FILE = BITBOARD_DIR + SAMPLE_BITBOARD

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [10]:
metadata_df = pd.read_csv(SAMPLE_BITBOARD_FILE, sep=";", dtype="int64", usecols=range(12, 14))['draw']

In [12]:
metadata_df.to_numpy(dtype=np.single).dtype

dtype('float32')

In [23]:
class BitboardDrawDataset(Dataset):
    def __init__(self, bitboard_file):
        bitboards_df = pd.read_csv(bitboard_file, sep=";", dtype="uint64", usecols=range(12))
        metadata_df = pd.read_csv(bitboard_file, sep=";", dtype="int64", usecols=range(12, 14))

        self.bitboards = np.unpackbits(np.ascontiguousarray(bitboards_df.to_numpy()).view(np.uint8), axis=1).astype(np.single)
        self.is_draw = metadata_df['draw'].to_numpy(dtype=np.single)
        self.length = self.is_draw.size

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.bitboards[idx], self.is_draw[idx]

In [24]:
dataset = BitboardDrawDataset(SAMPLE_BITBOARD_FILE)  

In [26]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True, pin_memory_device=device.type)

In [203]:
class SimpleModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(768, 2000),
            nn.BatchNorm1d(2000),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(2000, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 1)
        )
        
    def forward(self, x):
        return self.model.forward(x)        
        

In [204]:
class Train:
    
    def __init__(self):
        self.train_dataset, self.validate_dataset = random_split(BitboardDrawDataset(SAMPLE_BITBOARD_FILE), [90000, 10000])
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=64, shuffle=True, pin_memory=True, pin_memory_device=device.type)
        self.validate_dataloader = DataLoader(self.validate_dataset, batch_size=64, shuffle=True, pin_memory=True, pin_memory_device=device.type)
        self.model = SimpleModel()
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.model.to(device)
        self.total_batches = len(self.train_dataloader)
        self.print_every = 100
        
    def train_one_epoch(self) -> int:
        running_loss = 0.
        last_loss = 0.
    
        for i, data in enumerate(self.train_dataloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.unsqueeze(1).to(device)
    
            self.optimizer.zero_grad()
            
            outputs = self.model(inputs)
            
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()
    
            running_loss += loss.item()
            if i % self.print_every == self.print_every - 1:
                last_loss = running_loss / self.print_every
                # print(f"  batch {i+1} loss: {last_loss}")
                running_loss = 0.
            elif i == self.total_batches - 1:
                last_loss = running_loss / (i % self.print_every + 1)
                # print(f"  batch {i+1} loss: {last_loss}")
            
        return last_loss
    
    def train(self, epochs):
        best_vloss = np.inf
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
        for epoch in range(1, epochs + 1):
            print(f'EPOCH {epoch}')
        
            # Make sure gradient tracking is on, and do a pass over the data
            self.model.train(True)
            avg_loss = self.train_one_epoch()
        
            running_vloss = 0.0
            self.model.eval()
        
            with torch.no_grad():
                acc = 0
                for i, vdata in enumerate(self.validate_dataloader):
                    vinputs, vlabels = vdata
                    vinputs, vlabels = vinputs.to(device), vlabels.unsqueeze(1).to(device)
                    voutputs = self.model(vinputs)
                    pred = nn.functional.sigmoid(voutputs)
                    acc += (pred.round() == vlabels).sum()
                    vloss = self.loss_fn(voutputs, vlabels)
                    running_vloss += vloss
            acc = acc / len(self.validate_dataset)
            avg_vloss = running_vloss / (i + 1)
            print(f"LOSS train {avg_loss} valid {avg_vloss}")
            print(f"ACC {(acc * 100):.2f}%")
        
            if avg_vloss < best_vloss:
                best_vloss = avg_vloss
                model_path = f"model_{timestamp}_{epoch}"
                torch.save(self.model.state_dict(), model_path)
        

In [205]:
pipe = Train()

In [206]:
pipe.train(50)

EPOCH 1
LOSS train 0.5031549675124032 valid 0.5284397006034851
ACC 73.54%
EPOCH 2
LOSS train 0.5292552879878453 valid 0.48217472434043884
ACC 76.75%
EPOCH 3
LOSS train 0.4683122677462442 valid 0.44431421160697937
ACC 79.50%
EPOCH 4
LOSS train 0.4317083401339395 valid 0.41948866844177246
ACC 80.18%
EPOCH 5
LOSS train 0.4360984095505306 valid 0.4067915678024292
ACC 81.49%
EPOCH 6
LOSS train 0.48176683698381695 valid 0.38727742433547974
ACC 82.11%
EPOCH 7
LOSS train 0.3506034995828356 valid 0.37940219044685364
ACC 82.64%
EPOCH 8
LOSS train 0.42542765395981924 valid 0.37466031312942505
ACC 83.04%
EPOCH 9
LOSS train 0.3581426015922001 valid 0.36087125539779663
ACC 83.74%
EPOCH 10
LOSS train 0.34757310152053833 valid 0.3533051311969757
ACC 84.23%
EPOCH 11
LOSS train 0.35279917291232515 valid 0.3504669666290283
ACC 84.56%
EPOCH 12
LOSS train 0.40338647791317533 valid 0.35275694727897644
ACC 84.28%
EPOCH 13
LOSS train 0.37180652788707186 valid 0.34645166993141174
ACC 85.06%
EPOCH 14
LOSS train

In [207]:
m = torch.load("model_20240503_002000_33")