In [23]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from data_loading import load_games, load_data, board_to_tensor
from moves_encoder import MovesEncoder
from typing import List
from chess.pgn import Game
from chess import Board, Move
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data processing

In [13]:
games = load_games('lichess_elite_2023-05.pgn')

encoder = MovesEncoder()
data_generator = load_data(games, encoder, limit_moves=8_000_000)

X, y = [], []
for X_data, y_data in data_generator:
    X.append(X_data)
    y.append(y_data)

X = np.array(X)
y = np.array(y)

In [14]:
# train test split
from sklearn.model_selection import train_test_split

X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.125, random_state=42)  # 0.125 * 0.8 = 0.1

In [17]:
# prepare for pytorch
X_train = torch.tensor(X_train, dtype=torch.bool)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.bool)
y_test = torch.tensor(y_test, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.bool)
y_val = torch.tensor(y_val, dtype=torch.float32)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
val_dataset = TensorDataset(X_val, y_val)


In [20]:
torch.save(train_dataset, 'train_dataset.pt')
torch.save(test_dataset, 'test_dataset.pt')
torch.save(val_dataset, 'val_dataset.pt')

## Loading the processed datasets

In [26]:
train_dataset: TensorDataset = torch.load('train_dataset.pt')
test_dataset: TensorDataset = torch.load('test_dataset.pt')
val_dataset: TensorDataset = torch.load('val_dataset.pt')

In [27]:
len(train_dataset), len(test_dataset), len(val_dataset)

(5600000, 1600000, 800000)

In [28]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)