In [39]:
import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib.pyplot as plt
import time
import torch
from tqdm import tqdm
from sklearn import metrics

In [40]:
class ChessData(torch.utils.data.Dataset):
    def __init__(self, path):
        self.df = pd.read_csv(path, index_col=0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        inputs = row[:321].values.astype(np.float32)
        outputs = row[321:].values.astype(np.float32)

        return {'board': inputs, 'themes': outputs}


In [41]:
from torch.utils.data import DataLoader

batch_size = 256

chess_data_set = ChessData('cleaned_data/cleaned_train_puzzles.csv')

train_dataloader = DataLoader(chess_data_set, batch_size=batch_size, num_workers=4, persistent_workers=True, shuffle=True)

In [42]:
chess_validation_set = ChessData('cleaned_data/cleaned_validation_puzzles.csv')

validation_dataloader = DataLoader(chess_validation_set, batch_size=batch_size, num_workers=4, persistent_workers=True)

In [43]:
import torch.nn as nn

net = nn.Sequential(
    nn.Linear(321, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Linear(64, 28),
)

In [44]:
net.load_state_dict(torch.load("models/model-epoch-100.pth"))

<All keys matched successfully>

In [45]:
net = net.cuda()

In [50]:
final_acc = 0
correct = 0

for batch in tqdm(validation_dataloader):
    X = batch['board'].cuda()

    pred = net(X)
    boolean_pred = (pred>0).float()
    correct += (boolean_pred == batch['themes'].cuda()).float().sum()


100%|██████████| 607/607 [00:11<00:00, 54.16it/s]


In [51]:
final_acc = 0
train_corrects = 0

for batch in tqdm(train_dataloader):
    X = batch['board'].cuda()

    pred = net(X)
    boolean_pred = (pred>0).float()
    train_corrects += (boolean_pred == batch['themes'].cuda()).float().sum()


100%|██████████| 2428/2428 [00:46<00:00, 51.79it/s]


In [52]:
print(correct.item())
print(train_corrects.item())

3949700.0
15801546.0


In [53]:
print(len(validation_dataloader.dataset))
print(len(train_dataloader.dataset))

155383
621532


In [54]:
acc = correct.item()/(len(validation_dataloader.dataset) * 28)
print(acc)

0.907825915870554


In [55]:
acc = train_corrects.item()/(len(train_dataloader.dataset) * 28)
print(acc)

0.9079837057004765
