In [None]:
from chessf.parser import FilePGN
from chessf.engine import Stockfish
from chessf.convert import eval_to_white_win_p, win_p_to_one_hot_bin, elo_to_one_hot, elo_diff_to_one_hot, move_number_to_one_hot

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(0)

import plotly.express as px
from time import perf_counter

## Prepare data

In [None]:
pgn_2017_02 = "pgn/lichess_db_standard_rated_2017-02.pgn"
file = FilePGN(pgn_2017_02)

In [None]:
stockfish_path = "stockfish/stockfish.exe"
stockfish = Stockfish(stockfish_path)

In [None]:
def get_data(n_games=10):

    XX_1 = []
    XX_2 = []
    YY = []
    
    n_found = 0
    while n_found < n_games:

        game_moves, info = file.get_and_parse_next_good_game()
        if len(game_moves) < 21:
            continue
        
        stockfish.start_new_game()

        target = int(info["Result"] == "1-0") 
        
        elo_white = elo_to_one_hot(info['WhiteElo'])
        elo_black = elo_to_one_hot(info['BlackElo'])
        elo_diff = elo_diff_to_one_hot(info['WhiteElo'] - info['BlackElo'])

        move_to_analyze = np.random.randint(10*2, len(game_moves))
        
        for move_number, game_move in enumerate(game_moves, start=1):
            if move_number != move_to_analyze:
                stockfish.make_pgn_move(game_move)
                continue
            
            bbm = stockfish.get_bitboard_matrix()
            bbm_flat = bbm.flatten()

            where_can_move = stockfish.get_where_can_move()
            where_can_move_flat = where_can_move.flatten()

            pseudolegal_moves = stockfish.get_pseudolegal_moves()
            pseudolegal_moves_flat = pseudolegal_moves.flatten()
            
            win_p = eval_to_white_win_p( *stockfish.get_eval(depth=5) )
            win_p_ohe = win_p_to_one_hot_bin(win_p)
            
            move_number_ohe = move_number_to_one_hot(move_number)
            
            if stockfish.side_to_move == 'w':
                side_to_move_arr = np.array([1, 0])
            else:
                side_to_move_arr = np.array([0, 1])

            stockfish.make_pgn_move(game_move)

            where_can_move_opponent = stockfish.get_where_can_move()
            where_can_move_opponent_flat = where_can_move.flatten()
            where_can_move_all_flat = where_can_move_flat + where_can_move_opponent_flat
            
            features = np.concat([
                bbm_flat, 
                where_can_move_all_flat, 
                pseudolegal_moves_flat
            ])

            features_added = np.concat([
                side_to_move_arr, elo_white, elo_black, elo_diff, move_number_ohe, win_p_ohe
            ])

            XX_1.append(features.copy())
            XX_2.append(features_added.copy())
            YY.append(target)
            n_found += 1

    XX_1 = np.array(XX_1).reshape(n_games, -1, 8, 8)
    XX_2 = np.array(XX_2)
    YY = np.array(YY)
    
    return (
        torch.Tensor(XX_1),
        torch.Tensor(XX_2),
        torch.Tensor(YY).view(-1, 1),
    )

In [None]:
%%time
test_data = get_data(n_games=2000)

In [None]:
test_data[0].shape

## NN 1

In [None]:
N_CLASSES = 1
N_FEATURES = (64*5 + 10)
INIT_STD = 0.01

N_HIDDEN_1 = 128
N_HIDDEN_2 = 128

class SimpleFC(nn.Module):
    def __init__(self):
        super().__init__()        

        self.conv_1 = nn.Conv2d(in_channels=42, out_channels=5, kernel_size=(1, 1))
        self.conv_2 = nn.Conv2d(in_channels=5, out_channels=5, kernel_size=(1, 1))
        
        self.fc_layer = nn.Linear(38, 10)
        self.layer_1 = nn.Linear(N_FEATURES, N_HIDDEN_1)
        self.layer_2 = nn.Linear(N_HIDDEN_1, N_HIDDEN_2)
        self.layer_f = nn.Linear(N_HIDDEN_2, N_CLASSES)

        # nn.init.normal_(self.fc_layer.weight, 0, INIT_STD)
        # nn.init.normal_(self.layer_1.weight, 0, INIT_STD)
        # nn.init.normal_(self.layer_2.weight, 0, INIT_STD)
        # nn.init.normal_(self.layer_f.weight, 0, INIT_STD)

        self.norm_0 = nn.BatchNorm1d(N_FEATURES)
        self.norm_1 = nn.BatchNorm1d(N_HIDDEN_1)
        self.norm_2 = nn.BatchNorm1d(N_HIDDEN_2)
        
    def forward(self, x1, x2):

        x1 = F.relu(self.conv_1(x1))
        x1 = F.relu(self.conv_2(x1))
        x2 = F.relu(self.fc_layer(x2))

        x = torch.cat((x1.flatten(1), x2), dim=1)
        x = self.norm_0(x)
        
        x = F.relu(self.layer_1(x))
        x = self.norm_1(x)
        x = F.relu(self.layer_2(x))
        x = self.norm_2(x)
        x = self.layer_f(x)
        
        return x

In [None]:
model = SimpleFC()

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train_losses = []
test_losses = []
test_accuracy = []
roc_scores = []

## Training

In [None]:
for epoch in range(5):
    # <train>
    
    model.train()
    running_loss = 0.0

    train_data = get_data(n_games=2000)
    
    inputs_1, inputs_2, labels = train_data
    
    optimizer.zero_grad()

    outputs = model(inputs_1, inputs_2)
    loss = criterion(outputs, labels)
    loss.backward()

    optimizer.step()
    
    running_loss += loss.item()

    # print(f"Running loss: {current_mean_loss:.6f}")

    mean_loss = running_loss / 1
    train_losses.append(mean_loss)

    # <test>
    
    model.eval()
    correct, total = 0, 0
    running_loss = 0.0

    with torch.no_grad():
            
        images, images_2, labels = test_data
        
        outputs = model(images, images_2)
        loss = criterion(outputs, labels)
        running_loss += loss.item()
        
        
        predicted_probability = F.sigmoid(outputs)
        roc_score = roc_auc_score(labels[:, 0], predicted_probability[:, 0])
        roc_scores.append(roc_score)
        
        total += labels.size(0)
        correct += (F.sigmoid(outputs).round() == labels).sum()

    mean_loss_test = running_loss / 1
    test_losses.append(mean_loss_test)

    accuracy = correct / total
    test_accuracy.append(accuracy)

    print(f"Epoch: {str(epoch+1).zfill(4)}, Loss: {mean_loss:.6f}, Test: {mean_loss_test:.6f}, Acc: {accuracy:.6f} Roc: {roc_score:.6f}")

## Plot results

In [None]:
fig = px.line(y=train_losses, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Train Loss")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/cnn_train_loss.png')

In [None]:
fig = px.line(y=test_losses, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test Loss")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/cnn_test_loss.png')

In [None]:
fig = px.line(y=test_accuracy, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test Accuracy")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/cnn_test_accuracy.png')

In [None]:
fig = px.line(y=roc_scores, template='plotly_white')
fig.update_traces(mode="lines", line_color='navy', line_width=4)
fig.update_layout(xaxis_title="Epoch", yaxis_title="Test ROC AUC")
fig.update_layout(height=1080, width=1920, font_size=24)
fig.show()
fig.write_image('plots/cnn_test_rocauc.png')

In [None]:
model.eval()
with torch.no_grad():
    images_1, images_2, labels = get_data(10_000)
    outputs = model(images_1, images_2)
    pred = F.sigmoid(outputs)[:, 0].numpy()

y = labels.numpy()[:, 0]
df = pd.DataFrame({"pred": pred, "y": y})

In [None]:
fig = px.scatter(
    df.groupby( df["pred"] // 0.10 * 0.10 + 0.05 ).agg({"y": "mean"}).squeeze(),
    template='plotly_white',
    # trendline='ols'
).update_layout(height=1080, width=1080, yaxis_range=(0, 1), xaxis_range=(0, 1), showlegend=False,
                xaxis_title="Predicted group", yaxis_title="Real percent").update_traces(
    mode="lines+markers", line_color='navy', marker_color='navy'
)
fig.update_traces(mode="lines+markers", line_color='navy', line_width=4, marker_size=12)
fig.update_layout(font_size=24)
fig.show()
fig.write_image('plots/cnn_groups_10.png')

In [None]:
fig = px.scatter(
    df.groupby( df["pred"] // 0.05 * 0.05 + 0.025 ).agg({"y": "mean"}).squeeze(),
    template='plotly_white',
    # trendline='ols'
).update_layout(height=1080, width=1080, yaxis_range=(0, 1), xaxis_range=(0, 1), showlegend=False,
                xaxis_title="Predicted group", yaxis_title="Real percent").update_traces(
    mode="lines+markers", line_color='navy', marker_color='navy'
)
fig.update_traces(mode="lines+markers", line_color='navy', line_width=4, marker_size=12)
fig.update_layout(font_size=24)
fig.show()
fig.write_image('plots/cnn_groups_20.png')