In [None]:
import os
os.environ["ROCM_PATH"] = "/opt/rocm"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"

import numpy as np
import matplotlib.pyplot as plt
import chess
import torch
import math
import warnings
import train
from utils import *
from chess_net import ChessModel
import parse_data
import time

from torch.utils.data.datapipes.iter import IterableWrapper, Shuffler # type: ignore

device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    piece_layer = {"P":0,"R":1,"N":2,"B":3,"Q":4, "K":5,"p":6,"r":7,"n":8,"b":9,"q":10, "k":11}

    board=chess.Board()
    board.push_san("e4")
    # board.push_san("Nf6")
    board.push_san("a5")
    board.push_san("e5")
    board.push_san("d5")
    board_rep = get_board_rep(board)
    fig,axs = plt.subplots(4,6,subplot_kw={
        "xticklabels":["","","","","","",""], 
        "yticklabels":["","","","","","",""],
        "xticks":[0.5,1.5,2.5,3.5,4.5,5.5,6.5], 
        "yticks":[0.5,1.5,2.5,3.5,4.5,5.5,6.5]
        },figsize=(20,10))
    for i in range(4):
        for j in range(6):
            axs[i][j].grid(which='both', axis='both',lw=1,color='white')
    for i in range(2):
        for j in range(6):
            axs[i][j].imshow(board_rep[6*i + j,:,:],vmin=0, vmax=1)
            axs[i][j].set_title(list(piece_layer.keys())[6*i+j])
    axs[2][0].imshow(board_rep[12,:,:],vmin=0, vmax=1)
    axs[2][0].set_title('Player')
    axs[2][1].imshow(board_rep[13,:,:],vmin=0, vmax=1)
    axs[2][1].set_title('En passant')
    axs[2][2].imshow(board_rep[14,:,:],vmin=0, vmax = 50)
    axs[2][2].set_title('Half moves')
    axs[2][3].imshow(board_rep[15,:,:],vmin=0, vmax = 150)
    axs[2][3].set_title('Move count')
    axs[3][0].imshow(board_rep[16,:,:],vmin=0, vmax=1)
    axs[3][1].imshow(board_rep[17,:,:],vmin=0, vmax=1)
    axs[3][2].imshow(board_rep[18,:,:],vmin=0, vmax=1)
    axs[3][3].imshow(board_rep[19,:,:],vmin=0, vmax=1)
    axs[3][0].set_title("White Kingside Castle")
    axs[3][1].set_title("White Queenside Castle")
    axs[3][2].set_title("Black Kingside Castle")
    axs[3][3].set_title("Black Queenside Castle")
    fig.savefig("img.png")
    #0-5 - white pawn/rook/knight/bishop/queen/king
    #6-11 same for black
    #12 current player's color -> 1 for white
    #13 en passant square
    #14 half moves since last capture
    #15 moves this game
    #16-#19 white kingside, white queenside, black kingside, black queenside castling rules

In [None]:
model = ChessModel().to('cuda')

In [None]:
board_rep_tensor = torch.tensor(board_rep).to('cuda')
print(board_rep_tensor.shape)
p, v = model(board_rep_tensor.unsqueeze(0))
print(p.shape, v.shape)
print(p,v)

In [None]:
show_move_rep(p)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
policy_loss_fn = torch.nn.CrossEntropyLoss()
value_loss_fn = torch.nn.MSELoss()
epochs, model_path, loss, batch_size, print_counter, save_counter = 5, "model.pt", 0, 1024, 10, 200

In [None]:
training_history = {'loss' : []}

In [None]:
# checkpoint = torch.load("checkpoints/1/checkpoint_1_start.pt")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# start_epoch = checkpoint['epoch']
# print(start_epoch)

In [None]:
for epoch in range(epochs):
    torch.save({
        'epoch' : epoch,
        'model_state_dict':model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
    }, f"checkpoints/{epoch}/checkpoint_{epoch}_start.pt")
    print(f"Epoch {epoch} checkpoint saved")
    loader = parse_data.dataloader_from_filepaths(["/home/gerard/Documents/Personal/Programming/rl/chessai/data/lichess_2015/lichess_db_standard_rated_2015-09.pgn"], batch_size=batch_size, num_workers=4)
    losses = train.train_loop(loader, model, policy_loss_fn, value_loss_fn, optimizer, epoch)
    training_history['loss'].append(losses)