In [None]:
from chessf.parser import FilePGN
from chessf.engine import Stockfish
from chessf.convert import eval_to_white_win_p, win_p_to_one_hot_bin, elo_to_one_hot, elo_diff_to_one_hot, move_number_to_one_hot

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(0)

import plotly.express as px
from time import perf_counter

## Prepare data

In [None]:
pgn_2017_02 = "pgn/lichess_db_standard_rated_2017-02.pgn"
file = FilePGN(pgn_2017_02)

In [None]:
stockfish_path = "stockfish/stockfish.exe"
stockfish = Stockfish(stockfish_path)

In [None]:
def get_data(n_games=10):

    XX = []
    YY = []
    
    n_found = 0
    while n_found < n_games:

        game_moves, info = file.get_and_parse_next_good_game()
        if len(game_moves) < 21:
            continue
        
        stockfish.start_new_game()

        target = int(info["Result"] == "1-0") 
        
        elo_white = elo_to_one_hot(info['WhiteElo'])
        elo_black = elo_to_one_hot(info['BlackElo'])
        elo_diff = elo_diff_to_one_hot(info['WhiteElo'] - info['BlackElo'])

        move_to_analyze = np.random.randint(2*10, len(game_moves))
        
        for move_number, game_move in enumerate(game_moves, start=1):
            if move_number != move_to_analyze:
                stockfish.make_pgn_move(game_move)
                continue
            
            bbm = stockfish.get_bitboard_matrix()
            bbm_flat = bbm.flatten()

            where_can_move = stockfish.get_where_can_move()
            where_can_move_flat = where_can_move.flatten()

            pseudolegal_moves = stockfish.get_pseudolegal_moves()
            pseudolegal_moves_flat = pseudolegal_moves.flatten()
            
            win_p = eval_to_white_win_p( *stockfish.get_eval(depth=5) )
            win_p_ohe = win_p_to_one_hot_bin(win_p)
            
            move_number_ohe = move_number_to_one_hot(move_number)
            
            if stockfish.side_to_move == 'w':
                side_to_move_arr = np.array([1, 0])
            else:
                side_to_move_arr = np.array([0, 1])

            stockfish.make_pgn_move(game_move)

            where_can_move_opponent = stockfish.get_where_can_move()
            where_can_move_opponent_flat = where_can_move.flatten()
            where_can_move_all_flat = where_can_move_flat + where_can_move_opponent_flat

            features = np.concat([
                bbm_flat, where_can_move_all_flat, pseudolegal_moves_flat, side_to_move_arr, 
                elo_white, elo_black, elo_diff, move_number_ohe, win_p_ohe
            ])
              
            XX.append(features.copy())
            YY.append(target)
            n_found += 1

    XX = np.array(XX)
    YY = np.array(YY)
    
    return (
        torch.Tensor(XX),
        torch.Tensor(YY).view(-1, 1),
    )

In [None]:
%%time
test_data = get_data(n_games=2000)

In [None]:
N_FEATURES

## NN 1

In [None]:
N_CLASSES = 1
N_FEATURES = test_data[0].shape[1]
INIT_STD = 0.01
N_HIDDEN_1 = 128
N_HIDDEN_2 = 128
# N_HIDDEN_3 = 16

class SimpleFC(nn.Module):
    def __init__(self):
        super().__init__()        
                
        self.layer_1 = nn.Linear(N_FEATURES, N_HIDDEN_1)
        self.norm_1 = nn.BatchNorm1d(N_HIDDEN_1)
        nn.init.normal_(self.layer_1.weight, 0, INIT_STD)

        self.layer_2 = nn.Linear(N_HIDDEN_1, N_HIDDEN_2)
        self.norm_2 = nn.BatchNorm1d(N_HIDDEN_2)
        nn.init.normal_(self.layer_2.weight, 0, INIT_STD)

        # self.layer_3 = nn.Linear(N_HIDDEN_2, N_HIDDEN_3)
        # self.norm_3 = nn.BatchNorm1d(N_HIDDEN_3)
        # nn.init.normal_(self.layer_3.weight, 0, INIT_STD)
        
        self.layer_f = nn.Linear(N_HIDDEN_2, N_CLASSES)
        nn.init.normal_(self.layer_f.weight, 0, INIT_STD)
        
    def forward(self, x):
        
        x = F.relu(self.layer_1(x))
        x = self.norm_1(x)
        
        x = F.relu(self.layer_2(x))
        x = self.norm_2(x)

        # x = F.relu(self.layer_3(x))
        # x = self.norm_3(x)
        
        x = self.layer_f(x)
        
        return x

In [None]:
model = SimpleFC()

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train_losses = []
test_losses = []
test_accuracy = []
roc_scores = []

## Training

In [None]:
for epoch in range(1000):
    # <train>
    
    model.train()
    running_loss = 0.0

    train_data = get_data(n_games=2000)
    
    inputs, labels = train_data
    
    optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()

    # print(f"Running loss: {current_mean_loss:.6f}")

    mean_loss = running_loss / 1
    train_losses.append(mean_loss)

    # <test>
    
    model.eval()
    correct, total = 0, 0
    running_loss = 0.0

    with torch.no_grad():
            
        images, labels = test_data
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item()
        
        
        predicted_probability = F.sigmoid(outputs)
        roc_score = roc_auc_score(labels[:, 0], predicted_probability[:, 0])
        roc_scores.append(roc_score)
        
        total += labels.size(0)
        correct += (F.sigmoid(outputs).round() == labels).sum()

    mean_loss_test = running_loss / 1
    test_losses.append(mean_loss_test)

    accuracy = correct / total
    test_accuracy.append(accuracy)

    print(f"Epoch: {str(epoch+1).zfill(4)}, Loss: {mean_loss:.6f}, Test: {mean_loss_test:.6f}, Acc: {accuracy:.6f} Roc: {roc_score:.6f}")

In [None]:
# # train_losses

# [0.691733181476593,
#  0.6694008708000183,
#  0.6458799839019775,
#  0.6277754306793213,
#  0.6094498038291931,
#  0.594309389591217,
#  0.5838368535041809,
#  0.5799633860588074,
#  0.566249668598175,
#  0.5617096424102783,
#  0.5487838387489319,
#  0.5591247081756592,
#  0.5485094785690308,
#  0.5402898192405701,
#  0.541603147983551,
#  0.548852801322937,
#  0.5364051461219788,
#  0.5117275714874268,
#  0.5235812664031982,
#  0.5227475762367249,
#  0.5251681804656982,
#  0.5328614115715027,
#  0.5318834781646729,
#  0.5135568380355835,
#  0.5328724980354309,
#  0.5133646726608276,
#  0.5374367833137512,
#  0.5200607776641846,
#  0.5306217074394226,
#  0.513527512550354,
#  0.5085349678993225,
#  0.5305171608924866,
#  0.5086538195610046,
#  0.5157380700111389,
#  0.5052306056022644,
#  0.49326279759407043,
#  0.52902752161026,
#  0.4963837265968323,
#  0.5028964281082153,
#  0.5004900693893433,
#  0.49111098051071167,
#  0.47132009267807007,
#  0.5001329779624939,
#  0.48526984453201294,
#  0.4964050054550171,
#  0.4808172285556793,
#  0.4823331832885742,
#  0.5014909505844116,
#  0.4874553382396698,
#  0.46981778740882874,
#  0.5110610127449036,
#  0.5105932950973511,
#  0.4978426396846771,
#  0.47142738103866577,
#  0.4974954426288605,
#  0.49606066942214966,
#  0.48007693886756897,
#  0.485073983669281,
#  0.4954683780670166,
#  0.5035555958747864,
#  0.503817617893219,
#  0.5108762383460999,
#  0.4931943714618683,
#  0.5026611089706421,
#  0.4729540944099426,
#  0.47707122564315796,
#  0.4784761369228363,
#  0.5099768042564392,
#  0.4695850908756256,
#  0.4732595682144165,
#  0.4851847290992737,
#  0.46555960178375244,
#  0.5025174021720886,
#  0.48600733280181885,
#  0.49645334482192993,
#  0.48143982887268066,
#  0.47670796513557434,
#  0.48489803075790405,
#  0.4976862668991089,
#  0.4905087947845459,
#  0.4746129512786865,
#  0.4899434745311737,
#  0.5033617615699768,
#  0.48479366302490234,
#  0.49080052971839905,
#  0.4928801953792572,
#  0.5123609900474548,
#  0.47152674198150635,
#  0.4908825755119324,
#  0.49158453941345215,
#  0.4948044419288635,
#  0.48049044609069824,
#  0.4923814535140991,
#  0.4626178443431854,
#  0.47783464193344116,
#  0.4835050106048584,
#  0.44817712903022766,
#  0.4918004274368286,
#  0.4877481758594513,
#  0.4846593141555786,
#  0.4911070466041565,
#  0.4912460744380951,
#  0.48282578587532043,
#  0.48970648646354675,
#  0.4707571268081665,
#  0.4851202964782715,
#  0.49378326535224915,
#  0.473067045211792,
#  0.4741015136241913,
#  0.48202401399612427,
#  0.4981482923030853,
#  0.4751274287700653,
#  0.46983882784843445,
#  0.4764077663421631,
#  0.49840784072875977,
#  0.46911418437957764,
#  0.4542337954044342,
#  0.4591987729072571,
#  0.47105520963668823,
#  0.4665891230106354,
#  0.48890095949172974,
#  0.4552322030067444,
#  0.4672274589538574,
#  0.48117613792419434,
#  0.49081549048423767]

In [None]:
# # test_losses

# [0.692962646484375,
#  0.692556619644165,
#  0.6918733716011047,
#  0.6910200119018555,
#  0.6890814304351807,
#  0.6858342289924622,
#  0.6817464232444763,
#  0.6769980788230896,
#  0.6721311807632446,
#  0.6672349572181702,
#  0.6608647704124451,
#  0.6532121896743774,
#  0.643156111240387,
#  0.6339752078056335,
#  0.6261454820632935,
#  0.6182607412338257,
#  0.6115203499794006,
#  0.6022678017616272,
#  0.591109573841095,
#  0.5800097584724426,
#  0.5714279413223267,
#  0.5649736523628235,
#  0.5600365996360779,
#  0.5533901453018188,
#  0.5437690019607544,
#  0.533027172088623,
#  0.5256180763244629,
#  0.521976113319397,
#  0.5184096693992615,
#  0.515976071357727,
#  0.5152158141136169,
#  0.51487797498703,
#  0.5108316540718079,
#  0.5040802955627441,
#  0.498773992061615,
#  0.4966137707233429,
#  0.49769532680511475,
#  0.4954218864440918,
#  0.49312666058540344,
#  0.4906705319881439,
#  0.48722192645072937,
#  0.48189231753349304,
#  0.4770474135875702,
#  0.47595593333244324,
#  0.4745844006538391,
#  0.47580546140670776,
#  0.48288729786872864,
#  0.48800063133239746,
#  0.4884986877441406,
#  0.48083609342575073,
#  0.47277745604515076,
#  0.4709300398826599,
#  0.47313910722732544,
#  0.47504952549934387,
#  0.47759824991226196,
#  0.48128506541252136,
#  0.483161985874176,
#  0.4787048101425171,
#  0.4728732705116272,
#  0.4685536026954651,
#  0.46931469440460205,
#  0.4715568721294403,
#  0.4684484302997589,
#  0.46315428614616394,
#  0.4632052183151245,
#  0.4658080041408539,
#  0.46737828850746155,
#  0.4645152986049652,
#  0.4601116478443146,
#  0.46221643686294556,
#  0.4670182466506958,
#  0.4616984724998474,
#  0.4539903700351715,
#  0.470284104347229,
#  0.47853830456733704,
#  0.46352460980415344,
#  0.4531189799308777,
#  0.4667767882347107,
#  0.46849575638771057,
#  0.45553529262542725,
#  0.45461153984069824,
#  0.46357208490371704,
#  0.4683086574077606,
#  0.46762514114379883,
#  0.4596749246120453,
#  0.4540982246398926,
#  0.45919543504714966,
#  0.46410322189331055,
#  0.45779234170913696,
#  0.4522702097892761,
#  0.4512341320514679,
#  0.45230740308761597,
#  0.45103347301483154,
#  0.4492197632789612,
#  0.4479631781578064,
#  0.4484831690788269,
#  0.45041683316230774,
#  0.45094189047813416,
#  0.44776666164398193,
#  0.44591447710990906,
#  0.4503496587276459,
#  0.4661242663860321,
#  0.47063469886779785,
#  0.45724597573280334,
#  0.4482657313346863,
#  0.44570526480674744,
#  0.4520362615585327,
#  0.4580342471599579,
#  0.46195128560066223,
#  0.4589138925075531,
#  0.45725998282432556,
#  0.46178093552589417,
#  0.4650282561779022,
#  0.4616328179836273,
#  0.4514215290546417,
#  0.4486832916736603,
#  0.4482627213001251,
#  0.44792407751083374,
#  0.44766032695770264,
#  0.44700464606285095,
#  0.4475146532058716,
#  0.448143869638443,
#  0.45045432448387146,
#  0.45117247104644775,
#  0.4515133500099182]

In [None]:
# # test_accuracy

# [tensor(0.4945),
#  tensor(0.4945),
#  tensor(0.5065),
#  tensor(0.5305),
#  tensor(0.5720),
#  tensor(0.6675),
#  tensor(0.7010),
#  tensor(0.7065),
#  tensor(0.7115),
#  tensor(0.7165),
#  tensor(0.7235),
#  tensor(0.7315),
#  tensor(0.7095),
#  tensor(0.7030),
#  tensor(0.7140),
#  tensor(0.7305),
#  tensor(0.7430),
#  tensor(0.7450),
#  tensor(0.7440),
#  tensor(0.7370),
#  tensor(0.7250),
#  tensor(0.7065),
#  tensor(0.7050),
#  tensor(0.7130),
#  tensor(0.7280),
#  tensor(0.7505),
#  tensor(0.7545),
#  tensor(0.7575),
#  tensor(0.7630),
#  tensor(0.7595),
#  tensor(0.7570),
#  tensor(0.7530),
#  tensor(0.7545),
#  tensor(0.7620),
#  tensor(0.7645),
#  tensor(0.7545),
#  tensor(0.7510),
#  tensor(0.7500),
#  tensor(0.7540),
#  tensor(0.7585),
#  tensor(0.7650),
#  tensor(0.7610),
#  tensor(0.7585),
#  tensor(0.7635),
#  tensor(0.7620),
#  tensor(0.7645),
#  tensor(0.7605),
#  tensor(0.7600),
#  tensor(0.7540),
#  tensor(0.7645),
#  tensor(0.7710),
#  tensor(0.7715),
#  tensor(0.7685),
#  tensor(0.7640),
#  tensor(0.7685),
#  tensor(0.7680),
#  tensor(0.7680),
#  tensor(0.7710),
#  tensor(0.7690),
#  tensor(0.7775),
#  tensor(0.7745),
#  tensor(0.7690),
#  tensor(0.7705),
#  tensor(0.7810),
#  tensor(0.7760),
#  tensor(0.7690),
#  tensor(0.7690),
#  tensor(0.7705),
#  tensor(0.7740),
#  tensor(0.7715),
#  tensor(0.7665),
#  tensor(0.7710),
#  tensor(0.7820),
#  tensor(0.7680),
#  tensor(0.7590),
#  tensor(0.7755),
#  tensor(0.7865),
#  tensor(0.7725),
#  tensor(0.7710),
#  tensor(0.7880),
#  tensor(0.7800),
#  tensor(0.7770),
#  tensor(0.7685),
#  tensor(0.7675),
#  tensor(0.7825),
#  tensor(0.7835),
#  tensor(0.7765),
#  tensor(0.7710),
#  tensor(0.7805),
#  tensor(0.7820),
#  tensor(0.7840),
#  tensor(0.7865),
#  tensor(0.7840),
#  tensor(0.7875),
#  tensor(0.7895),
#  tensor(0.7875),
#  tensor(0.7880),
#  tensor(0.7895),
#  tensor(0.7875),
#  tensor(0.7880),
#  tensor(0.7855),
#  tensor(0.7780),
#  tensor(0.7790),
#  tensor(0.7815),
#  tensor(0.7865),
#  tensor(0.7905),
#  tensor(0.7850),
#  tensor(0.7760),
#  tensor(0.7775),
#  tensor(0.7760),
#  tensor(0.7760),
#  tensor(0.7765),
#  tensor(0.7750),
#  tensor(0.7775),
#  tensor(0.7775),
#  tensor(0.7860),
#  tensor(0.7885),
#  tensor(0.7870),
#  tensor(0.7835),
#  tensor(0.7815),
#  tensor(0.7815),
#  tensor(0.7805),
#  tensor(0.7800),
#  tensor(0.7770),
#  tensor(0.7760)]

In [None]:
# # roc_scores

# [np.float64(0.6194809571958206),
#  np.float64(0.6956976794192098),
#  np.float64(0.7292482390369235),
#  np.float64(0.7091478068846331),
#  np.float64(0.7553974030857734),
#  np.float64(0.7786297141954177),
#  np.float64(0.7867751997991757),
#  np.float64(0.7940710826009948),
#  np.float64(0.8010129225636302),
#  np.float64(0.804967901116035),
#  np.float64(0.8071616665616539),
#  np.float64(0.8108651146788762),
#  np.float64(0.8162547668267861),
#  np.float64(0.8149941142878288),
#  np.float64(0.8171208716254665),
#  np.float64(0.8230875935988254),
#  np.float64(0.823482141339102),
#  np.float64(0.8268720515182338),
#  np.float64(0.8315336155674837),
#  np.float64(0.8307435199659159),
#  np.float64(0.8311435683717732),
#  np.float64(0.8360471617065665),
#  np.float64(0.837140293975571),
#  np.float64(0.836124171024694),
#  np.float64(0.8403606836427208),
#  np.float64(0.8435920746410316),
#  np.float64(0.8413187995747486),
#  np.float64(0.8379183881249632),
#  np.float64(0.8422239090930002),
#  np.float64(0.8462003902472199),
#  np.float64(0.8445151863375469),
#  np.float64(0.8410217636333996),
#  np.float64(0.8428099800075809),
#  np.float64(0.8471125006125741),
#  np.float64(0.8474215380060988),
#  np.float64(0.8456153194536539),
#  np.float64(0.8458213443826703),
#  np.float64(0.8484346605939319),
#  np.float64(0.8493527716853739),
#  np.float64(0.8484516626511808),
#  np.float64(0.8487376972613687),
#  np.float64(0.8509019591370555),
#  np.float64(0.8515240344081636),
#  np.float64(0.8511569899957896),
#  np.float64(0.8526491705496364),
#  np.float64(0.8536542921693524),
#  np.float64(0.854193357396245),
#  np.float64(0.8506079235587505),
#  np.float64(0.8474005354647912),
#  np.float64(0.850136866560854),
#  np.float64(0.8552684874869858),
#  np.float64(0.8555465211290567),
#  np.float64(0.8555055161674564),
#  np.float64(0.8572297247967002),
#  np.float64(0.8581528364932156),
#  np.float64(0.8558985637262109),
#  np.float64(0.8528721975359017),
#  np.float64(0.853418763670404),
#  np.float64(0.8562686085016287),
#  np.float64(0.8587904136400505),
#  np.float64(0.8578818036982475),
#  np.float64(0.8563496183038148),
#  np.float64(0.8587999147896895),
#  np.float64(0.8615332455227083),
#  np.float64(0.8610511871936504),
#  np.float64(0.8591409560556826),
#  np.float64(0.8574497514199217),
#  np.float64(0.8587839128534553),
#  np.float64(0.8619853002213268),
#  np.float64(0.8622113275706359),
#  np.float64(0.8616172556879382),
#  np.float64(0.8627418917689041),
#  np.float64(0.8666048591879618),
#  np.float64(0.8653277046522629),
#  np.float64(0.8648166428137806),
#  np.float64(0.8674889661649059),
#  np.float64(0.8677179938772592),
#  np.float64(0.86575575644653),
#  np.float64(0.8658117632233501),
#  np.float64(0.867937020379466),
#  np.float64(0.8692466788481407),
#  np.float64(0.868165548031312),
#  np.float64(0.8660307897255568),
#  np.float64(0.8664258375263407),
#  np.float64(0.8686421056947891),
#  np.float64(0.8698772551478728),
#  np.float64(0.8680920391367355),
#  np.float64(0.8673329472866216),
#  np.float64(0.8695792190855094),
#  np.float64(0.8712064159763332),
#  np.float64(0.871122405811103),
#  np.float64(0.8697052343333543),
#  np.float64(0.8700682782616698),
#  np.float64(0.8718024881010602),
#  np.float64(0.8723525546591138),
#  np.float64(0.8707803644240953),
#  np.float64(0.8698422509123603),
#  np.float64(0.8703283097254767),
#  np.float64(0.8712064159763331),
#  np.float64(0.8721055247684969),
#  np.float64(0.8726795942309019),
#  np.float64(0.8726455901164041),
#  np.float64(0.8724315642192705),
#  np.float64(0.8727946081475859),
#  np.float64(0.8734256845078254),
#  np.float64(0.8741987780521443),
#  np.float64(0.8743157922108575),
#  np.float64(0.8733916803933276),
#  np.float64(0.8722855465511328),
#  np.float64(0.8703933175914285),
#  np.float64(0.871324430256061),
#  np.float64(0.8720435172655893),
#  np.float64(0.872203536627932),
#  np.float64(0.8716764728532153),
#  np.float64(0.87097738826398),
#  np.float64(0.870758361761773),
#  np.float64(0.8711084041168982),
#  np.float64(0.8714824493763745),
#  np.float64(0.8715934628089999),
#  np.float64(0.8718694962090413),
#  np.float64(0.8717934870119284),
#  np.float64(0.8715164534908724),
#  np.float64(0.8708273701117835),
#  np.float64(0.8699062586572976),
#  np.float64(0.8690531554318071)]

## Plot results

In [None]:
fig = px.line(y=train_losses, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Train Loss")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/fnn_train_loss.png')

In [None]:
fig = px.line(y=test_losses, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test Loss")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/fnn_test_loss.png')

In [None]:
fig = px.line(y=test_accuracy, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test Accuracy")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/fnn_test_accuracy.png')

In [None]:
fig = px.line(y=roc_scores, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test ROC AUC")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/fnn_test_rocauc.png')

In [None]:
model.eval()
with torch.no_grad():
    images, labels = test_data
    outputs = model(images)
    pred = F.sigmoid(outputs)[:, 0].numpy()

y = labels.numpy()[:, 0]
df = pd.DataFrame({"pred": pred, "y": y})

In [None]:
px.scatter(
    df.groupby( df["pred"] // 0.10 * 0.10 + 0.05 ).agg({"y": "mean"}).squeeze(),
    template='plotly_white',
    # trendline='ols'
).update_layout(height=800, width=800, yaxis_range=(0, 1), xaxis_range=(0, 1), showlegend=False,
                xaxis_title="Predicted group", yaxis_title="Real average").update_traces(
    mode="lines+markers", line_color='navy', marker_color='navy'
)