In [1]:
import glob
import chess_SL_E6_lib as lib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import chess
import numpy as np

In [2]:
MODEL_NUM = 6
MODEL_VERSION = 3

path = "../Data/DataTrain"

csv_files1 = glob.glob(f'{path}/Chess_Jan_g*') + glob.glob(f'{path}/Chess_Jan_h*') + glob.glob(f'{path}/Chess_Jan_i*') + glob.glob(f'{path}/Chess_Jan_j*') + glob.glob(f'{path}/Chess_Jan_k*')
csv_files2 = glob.glob(f'{path}/Chess_Jan_e*') + glob.glob(f'{path}/Chess_Jan_f*')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Create a dataset
dataset1 = lib.ChessIterableDataset(csv_files1, chunksize = 50000)
dataset2 = lib.ChessIterableDataset(csv_files2, chunksize = 50000)

# Create a data loader
train_data_loader = DataLoader(dataset1, batch_size = 25000)
val_data_loader = DataLoader(dataset2, batch_size = 25000)

In [4]:
model = torch.load(f'models_EL/model_E6-2.pth', map_location=device)
model = model.to(device)

criterion = nn.L1Loss() # nn.MSELoss()
# optimizer = optim.SGD(model.parameters(), lr=0.035, momentum=0.9)

optimizer = optim.Adam(model.parameters(), lr=0.006)

# Train the model
training_loss_history, validation_loss_history = lib.train(model, train_data_loader, val_data_loader, criterion, optimizer, num_epochs=50)

Begin Training!


In [None]:
torch.save(model, f'model_E{MODEL_NUM}-{MODEL_VERSION}.pth')

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.plot(training_loss_history, label = 'Training Loss')
plt.plot(validation_loss_history, label = 'Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
# plt.ylim(0, None)
plt.title(f'Figure 1: Loss for E{MODEL_NUM}-{MODEL_VERSION} Model')
plt.legend()
plt.show()
plt.savefig(f'Loss_E{MODEL_NUM}-{MODEL_VERSION}.png')

In [None]:
import pickle

pickle.dump(training_loss_history, open(f'pickle/training_loss_history_E{MODEL_NUM}-{MODEL_VERSION}.pkl', 'wb'))
pickle.dump(validation_loss_history, open(f'pickle/validation_loss_history_E{MODEL_NUM}-{MODEL_VERSION}.pkl', 'wb'))