In [None]:
from Ajedrez import Ajedrez,train
from SplitImageDataset import SplitImageDataset
from WholeImageDataset import WholeImageDataset

import chess

import numpy as np
import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt
import math
import time
import random

import torch.optim
from tqdm.notebook import tqdm

import DataUtils

import sys
sys.path.append("../engines")

from ChessEngine import ChessEngine

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader, random_split

%load_ext autoreload
%autoreload 2

## Visualizing Data

## Training AJ

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

AJ = Ajedrez(3, continue_training=True).to(device)

In [None]:
color_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
depth_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0, std=1)
])

FULL_SIZE = 242_000
TRAIN_SIZE = 169_400
TEST_SIZE = 72_600

dset = SplitImageDataset('./image_dataset/metadata.csv',
    dataset_size=1100,
    use_depth=False,
    color_transform=color_transforms, 
    depth_transform=depth_transforms,
)

train_data, test_data = random_split(dset, [1_000, 100])

train_loader = DataLoader(train_data, batch_size=20, 
                          shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=20, 
                         shuffle=False, num_workers=2)

sgd = optim.SGD(AJ.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(sgd, step_size=2, gamma=0.1)

train(AJ, train_loader, test_loader, sgd, scheduler, device, 4)

torch.save(AJ.state_dict(), './aj_model.pt')

torch.cuda.empty_cache()

## Visual Test of AJ

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

AJ = Ajedrez(3, continue_training=False).to(device)

In [None]:
color_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print(f'Retrieving One Item')

dset_full = WholeImageDataset('./image_dataset/metadata.csv',
    color_transform=color_transforms,
    dataset_size=10
)

color_img, concat_img = dset_full[0] 

plt.figure(figsize=(15,15))
plt.axis('off')
plt.imshow(color_img)

AJ.load_state_dict(torch.load('./aj_model.pt'))
AJ.eval()

concat_img = concat_img.to(device)    
out_c = AJ.forward(concat_img)

classes = out_c.argmax(1)

print(classes)
print(classes.shape)

board = chess.Board(None)

nrow = 8
ncol = 8

for i in range(nrow):
    for j in range(ncol):
        rank = chess.RANK_NAMES[7 - i]
        file = chess.FILE_NAMES[j]
        
        square = chess.parse_square(file + rank)
        
        piece = ChessEngine.numberToPiece(classes[i*ncol+j].item())
        
        if piece is not None:
            board.set_piece_at(square, piece)
            
board