## 0. Dependencies

In [None]:
!pip install python-chess
!pip install tqdm

## 0.2 Global Vars

In [None]:
DATA_DIR = '/Drive/chess_data'
PLOTS_DIR = '/Drive/plots'
CHECKPOINTS_DIR = '/Drive/checkpoints'

## 0.5 Import dependencies

In [None]:
import os
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import numpy as np

from tqdm import tqdm
import matplotlib.pyplot as plt

import chess

In [None]:
if 'x' not in os.listdir(DATA_DIR):
    os.makedirs(os.path.join(DATA_DIR,'x'))

if 'y' not in os.listdir(DATA_DIR):
    os.makedirs(os.path.join(DATA_DIR,'y'))

if 'models' not in os.listdir(CHECKPOINTS_DIR):
    os.makedirs(os.path.join(CHECKPOINTS_DIR,'models'))

if 'state' not in os.listdir(CHECKPOINTS_DIR):
    os.makedirs(os.path.join(CHECKPOINTS_DIR,'state'))

## 1. Utility Functions

In [None]:
def convert_to_bb(board):
    board = board.replace(' ','')
    bb = np.zeros((8,8,12))
    rows = board.split()
    for r_idx,r in enumerate(rows):

        for c_idx,c in enumerate(r):

            for piece_idx,piece in enumerate('PRNBQKprnbqk'):
                if c == piece:
                    bb[r_idx,c_idx,piece_idx] = 1

    return bb


## 2. Data

In [None]:
class chessData(Dataset):

    def __init__(self,train=True):
        x = [os.path.join(f'{DATA_DIR}/x',f) for f in os.listdir(f'{DATA_DIR}/x')]
        y = [os.path.join(f'{DATA_DIR}/y',f) for f in os.listdir(f'{DATA_DIR}/y')]

        self.files = list(zip(x,y))
        self.train = train

        random.shuffle(self.files)

        num_train = int(.8*len(self.files))

        if self.train:
            self.files = self.files[:num_train]
        else:
            self.files = self.files[num_train:]


    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        x = self.files[index][0]
        y = self.files[index][1]

        x = np.load(x,allow_pickle=True)
        y = np.load(y,allow_pickle=True)

        x = torch.from_numpy(x).permute(2,0,1).float()
        y = torch.from_numpy(y).float()

        return x,y


## 3. Neural Net

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=12,out_channels=24,kernel_size=4,stride=2,padding=1)
        self.conv2 = nn.Conv2d(in_channels=24,out_channels=48,kernel_size=4,stride=2,padding=1)
        self.fc1 = nn.Linear(192,50)
        self.fc2 = nn.Linear(50,1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    
    def forward(self,x):
        #x -> (n, 8,8,12)
        x = self.conv1(x) #(8-4 +2)/2 + 1 -> (4,4)
        x = self.relu(x)

        x = self.conv2(x) #(4-4 + 2)/2 +1 -> (2,2)
        x = self.relu(x)

        x = x.view(x.shape[0],-1)

        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.sigmoid(x)

        return x

    def save_model(self):
        torch.save(self,f'{CHECKPOINTS_DIR}/models/Net.pt')

    def save_checkpoint(self):
        torch.save(self.state_dict(),f'{CHECKPOINTS_DIR}/state/Net_state_dict.pt')

    def load_checkpoint(self):
        self.load_state_dict(torch.load(f'{CHECKPOINTS_DIR}/state/Net_state_dict.pt'))

## 4. Training

In [None]:
net = Net()
net.save_model()

epochs = 100
lr = 1e-5

lossfn = nn.L1Loss()
optimizer = torch.optim.Adam(net.parameters(),lr=lr,)

trainloader = DataLoader(chessData(train=True),batch_size=64,shuffle=True)
testloader = DataLoader(chessData(train=False),batch_size=64,shuffle=True)

train_loss_over_time = []
test_loss_over_time = []

lowest_loss = float('inf')

print('training started')

for epoch in tqdm(range(epochs)):

    train_loss_epoch = []
    test_loss_epoch = []

    net.train()

    for x,y in tqdm(trainloader):

        p = net(x)
        loss = lossfn(p,y)

        train_loss_epoch.append(loss.item())

        loss.backward()

        optimizer.step()

        optimizer.zero_grad()

    
    net.eval()

    with torch.no_grad():

        for x,y in testloader:
            
            p = net(x)
            loss = lossfn(p,y)

            test_loss_epoch.append(loss.item())

    train_loss_over_time.append(sum(train_loss_epoch)/len(train_loss_epoch))
    test_loss_over_time.append(sum(test_loss_epoch)/len(test_loss_epoch))

    if test_loss_over_time[-1] < lowest_loss:
        net.save_checkpoint()
        lowest_loss = test_loss_over_time[-1]

    
    plt.plot(train_loss_over_time,label='Train loss over time')
    plt.plot(test_loss_over_time,label='Test loss over time')
    plt.legend()
    plt.savefig('plots/Net_Loss')
    plt.close('all')


## 5. Testing 

In [None]:
net.eval()
net.requires_grad_(False)

board = chess.Board()
board.push(list(board.legal_moves)[0])


for move in list(board.legal_moves):
    board.push(move)

    bb = torch.from_numpy(convert_to_bb(str(board)))\
        .permute(2,0,1).unsqueeze(0).float()
    
    pred = net(bb).squeeze()

    print(pred)

    board.pop()