In [166]:
%reload_ext autoreload
%autoreload 2
import chess.pgn
import numpy as np
from chess_dataset import ChessDataset, ChessPairDataset
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

In [167]:
chess_pair_dataset = ChessPairDataset("dataset/npy/")
train_chess_dataset = ChessDataset("dataset/npy/")
test_chess_dataset = ChessDataset("dataset/npy/", train=False)

In [168]:
chess_trainloader = DataLoader(train_chess_dataset, batch_size=64, shuffle=True)
chess_testloader = DataLoader(test_chess_dataset, batch_size=64, shuffle=False)

In [307]:
class Pos2Vec(nn.Module):
  def __init__(self):
    super(Pos2Vec, self).__init__()
    self.en1 = nn.Linear(773, 600)
    self.en2 = nn.Linear(600, 400)
    self.en3 = nn.Linear(400, 200)
    self.en4 = nn.Linear(200, 100)

    self.de1 = nn.Linear(600, 773)
    self.de2 = nn.Linear(400, 600)
    self.de3 = nn.Linear(200, 400)
    self.de4 = nn.Linear(100, 200)

    self.g1 = nn.Sequential(self.en1, self.de1, nn.Sigmoid())
    self.g2 = nn.Sequential(self.en1, nn.ReLU(), self.en2, 
                            self.de2, nn.ReLU(), self.de1, nn.Sigmoid())
    self.g3 = nn.Sequential(self.en1, nn.ReLU(), self.en2, nn.ReLU(), self.en3, 
                            self.de3, nn.ReLU(), self.de2, nn.ReLU(), self.de1, nn.Sigmoid())
    self.g4 = nn.Sequential(self.en1, nn.ReLU(), self.en2, nn.ReLU(), self.en3, nn.ReLU(), self.en4, 
                            self.de4, nn.ReLU(), self.de3, nn.ReLU(), self.de2, nn.ReLU(), self.de1, nn.Sigmoid())
    
  def forward(self, x):
    return self.curr(x)
  
  def train_all(self, train_data, epochs=10):
    # 773 - 600 -773
    self.curr = self.g1
    self.train(train_data, epochs)
    
    # * means freeze the layer
    # 773 -* 600 - 400 - 600 -* 773
    AutoEncoder.freeze(self.en1)
    AutoEncoder.freeze(self.de1)
    self.curr = self.g2
    self.train(train_data, epochs)
    
    # 773 -* 600 -* 400 - 200 - 400 -* 600 -* 773
    AutoEncoder.freeze(self.en2)
    AutoEncoder.freeze(self.de2)
    self.curr = self.g3
    self.train(train_data, epochs)
    
    # 773 -* 600 -* 400 -* 200 - 100 - 200 -* 400 -* 600 -* 773
    AutoEncoder.freeze(self.en3)
    AutoEncoder.freeze(self.de3)
    self.curr = self.g4
    self.train(train_data, epochs)
    
  def train(self, train_data, epochs=10):
    loss_f = nn.MSELoss()
    optim = torch.optim.Adam(self.parameters(), lr=0.005)
    for epoch in range(epochs):
      for states in (t := tqdm(chess_trainloader)):
        states = states.type(torch.float)
        reconstructed = self(states)
        loss = loss_f(reconstructed, states)

        optim.zero_grad()
        loss.backward()
        optim.step()

        t.set_description("Epoch: {} | Loss: {}".format(epoch, loss))
        
  def build(self):
    self.curr = self.g4
    
  @staticmethod
  def freeze(layer):
    layer.weight.requires_grad = False
    layer.bias.requires_grad = False

In [296]:
p2v = Pos2Vec()

In [297]:
p2v.train_all(train_data=train_chess_dataset)

Epoch: 0 | Loss: 0.024175286293029785: 100%|████████████████████████████████████████████████████| 12/12 [00:00<00:00, 36.06it/s]
Epoch: 1 | Loss: 0.02206273190677166: 100%|█████████████████████████████████████████████████████| 12/12 [00:00<00:00, 35.38it/s]
Epoch: 2 | Loss: 0.02060483768582344: 100%|█████████████████████████████████████████████████████| 12/12 [00:00<00:00, 34.07it/s]
Epoch: 3 | Loss: 0.01942518725991249: 100%|█████████████████████████████████████████████████████| 12/12 [00:00<00:00, 24.82it/s]
Epoch: 4 | Loss: 0.020248031243681908: 100%|████████████████████████████████████████████████████| 12/12 [00:00<00:00, 23.66it/s]
Epoch: 5 | Loss: 0.018199291080236435: 100%|████████████████████████████████████████████████████| 12/12 [00:00<00:00, 24.98it/s]
Epoch: 6 | Loss: 0.01832696422934532: 100%|█████████████████████████████████████████████████████| 12/12 [00:00<00:00, 25.50it/s]
Epoch: 7 | Loss: 0.017923761159181595: 100%|████████████████████████████████████████████████████|

In [301]:
torch.save(p2v.state_dict(), "model/save/pos2vec.pth")

In [308]:
saved_model = Pos2Vec()
ss = torch.load("model/save/pos2vec.pth")

In [309]:
saved_model.build()
saved_model.load_state_dict(ss)

<All keys matched successfully>

In [314]:
state = next(iter(chess_trainloader))
state = state.type(torch.float)
saved_model(state)

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], grad_fn=<SigmoidBackward0>)