In [1]:
import os
import chess
import torch
import pytorch_lightning as pl
import numpy as np
import pandas as pd
from psqt import fill_features, PSQT, FEATURES
import utils
DATA_FOLDER = os.path.join("..", "data", "batches")

In [None]:
batch = pd.read_csv(os.path.join(DATA_FOLDER, "batch_29.csv"))
feature_matrix = np.zeros((len(batch), FEATURES), dtype=np.int8)
labels = np.array(batch.iloc[:, -1], dtype=np.float32)
for i in range(len(batch)):
    fill_features(feature_matrix, i, batch.iloc[i, 0])
np.save(os.path.join(DATA_FOLDER,"feature_matrix_batch_29"), feature_matrix)

In [2]:
batch = pd.read_csv(os.path.join(DATA_FOLDER, "batch_29.csv"))
labels = np.array(batch.iloc[:, -1], dtype=np.float32)
feature_matrix = np.load(os.path.join(DATA_FOLDER, "feature_matrix_batch_29.npy"))

In [3]:
#Shuffling the data, making a train validation split
from torch.utils.data import TensorDataset, DataLoader
perm = np.random.permutation(len(feature_matrix))
feature_matrix = feature_matrix[perm]
labels = labels[perm]
N = int(0.9 * len(feature_matrix))
x_train, x_val, y_train, y_val = feature_matrix[:N], feature_matrix[N:], labels[:N], labels[N:]
y_train, y_val = np.expand_dims(y_train, axis = 1), np.expand_dims(y_val, axis = 1)
x_train, y_train, x_val, y_val = map(torch.tensor, (x_train, y_train, x_val, y_val))
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size = 256, shuffle=True)
val_ds = TensorDataset(x_val, y_val)
val_dl = DataLoader(val_ds, batch_size = 64)

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = PSQT()
checkpoint = ModelCheckpoint(monitor='val_loss', dirpath=DATA_FOLDER, filename="{epoch}-{val_loss:.6f}-batch29")
stopping = EarlyStopping('val_loss')
trainer = pl.Trainer(callbacks=[checkpoint, stopping, utils.MetricsCallback()], max_epochs=50, progress_bar_refresh_rate=100)

trainer.fit(model, train_dl, val_dl)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type       | Params
------------------------------------
0 | psqt | Sequential | 770   
1 | loss | MSELoss    | 0     
------------------------------------
770       Trainable params
0         Non-trainable params
770       Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 - Loss 0.09672620892524719 - Val Loss: 0.09511090815067291


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 - Loss 0.09481775760650635 - Val Loss: 0.09465120732784271


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 - Loss 0.09448208659887314 - Val Loss: 0.094448983669281


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 - Loss 0.09426166862249374 - Val Loss: 0.09424746036529541


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 - Loss 0.0940709188580513 - Val Loss: 0.09402737021446228


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 - Loss 0.09390495717525482 - Val Loss: 0.09384410083293915


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 - Loss 0.09375409036874771 - Val Loss: 0.09373324364423752


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 - Loss 0.09361317753791809 - Val Loss: 0.09355642646551132


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 - Loss 0.0934872254729271 - Val Loss: 0.09352049976587296


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 - Loss 0.09335985779762268 - Val Loss: 0.09333701431751251


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 - Loss 0.0932445153594017 - Val Loss: 0.09317933768033981


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 - Loss 0.09313325583934784 - Val Loss: 0.09314656257629395


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 - Loss 0.09303347021341324 - Val Loss: 0.09300246834754944


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 - Loss 0.09293533861637115 - Val Loss: 0.0929693728685379


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 - Loss 0.09283485263586044 - Val Loss: 0.09282835572957993


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 - Loss 0.09275024384260178 - Val Loss: 0.0927499532699585


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 - Loss 0.09265825152397156 - Val Loss: 0.092678964138031


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 - Loss 0.09257146716117859 - Val Loss: 0.0925554558634758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 - Loss 0.09248575568199158 - Val Loss: 0.09259121865034103


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 - Loss 0.0924048125743866 - Val Loss: 0.09243284910917282


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 - Loss 0.09233590215444565 - Val Loss: 0.0923844650387764


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 - Loss 0.09225483238697052 - Val Loss: 0.09223061800003052


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 - Loss 0.09217993915081024 - Val Loss: 0.09220313280820847


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 - Loss 0.09211848676204681 - Val Loss: 0.09209341555833817


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 - Loss 0.09204521030187607 - Val Loss: 0.09208234399557114


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 - Loss 0.09197206795215607 - Val Loss: 0.09197035431861877


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 - Loss 0.09191593527793884 - Val Loss: 0.09184875339269638


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 - Loss 0.09185239672660828 - Val Loss: 0.09191165864467621


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 - Loss 0.09178755432367325 - Val Loss: 0.09177280962467194


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 - Loss 0.09173312783241272 - Val Loss: 0.09170465916395187


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 - Loss 0.09167326241731644 - Val Loss: 0.09170613437891006


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 - Loss 0.09162241220474243 - Val Loss: 0.0915960967540741


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 - Loss 0.09156763553619385 - Val Loss: 0.09154821932315826


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 - Loss 0.09151550382375717 - Val Loss: 0.09151628613471985


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 - Loss 0.09146501868963242 - Val Loss: 0.09145080298185349


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 - Loss 0.09141635149717331 - Val Loss: 0.0914396122097969


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 - Loss 0.0913657397031784 - Val Loss: 0.09146349877119064


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 - Loss 0.0913139209151268 - Val Loss: 0.09129992127418518


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 - Loss 0.09127651154994965 - Val Loss: 0.09124661237001419


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 - Loss 0.09122747927904129 - Val Loss: 0.09119740128517151


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 - Loss 0.09118269383907318 - Val Loss: 0.091151662170887


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 - Loss 0.09113966673612595 - Val Loss: 0.09114959836006165


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 - Loss 0.09110599011182785 - Val Loss: 0.09130457043647766


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 - Loss 0.09105652570724487 - Val Loss: 0.09110767394304276


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 - Loss 0.09102559834718704 - Val Loss: 0.09103292971849442


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 - Loss 0.09098879992961884 - Val Loss: 0.09097617119550705


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 - Loss 0.0909503698348999 - Val Loss: 0.09098181873559952


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 - Loss 0.09091253578662872 - Val Loss: 0.09094251692295074


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 - Loss 0.09087944775819778 - Val Loss: 0.09088476747274399


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 - Loss 0.09084104746580124 - Val Loss: 0.09093712270259857



1

In [None]:
## To load a model and its weights from a checkpoint:
model = PSQT.load_from_checkpoint("checkpoint")
## To resume the training
model = PSQT()
trainer = Trainer(resume_from_checkpoint="checkpoint")
# Restored model
trainer.fit(model)