In [2]:
from glob import glob
import json
import os
import pickle
import numpy as np
import torch
from tqdm import tqdm
from model import NeuralNet
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from pprint import pprint
import config

In [3]:
class MatchDataSet(Dataset):
    def __init__(self, data: dict):
        self.data = data
        self.keys = list(data.keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        board = torch.tensor(self.data[self.keys[index]][0], dtype=torch.int8)
        answer = torch.tensor(self.data[self.keys[index]][1], dtype=torch.int8)
        return (board, answer)

In [4]:
def collate_fn(batch):
    input = []
    policies = []
    values = []
    for i, (board, answer, value) in enumerate(batch):
        player = board[answer[1]][answer[0]]
        if player != 1 and player != 2:
            raise Exception("Invalid player: {}".format(player))

        # pprint(board)
        # pprint(answer)

        input.append(NeuralNet.process_input(board, player, 0))
        input.append(NeuralNet.process_input(board, player, 1, answer[0], answer[1]))

        n_board = board.clone()
        n_board[answer[1], answer[0]] = 0
        n_board[answer[3], answer[2]] = player

        input.append(NeuralNet.process_input(n_board, player, 2, answer[2], answer[3]))

        p1, v1 = NeuralNet.process_output(value, answer[0], answer[1], board, player, 0)

        p2, v2 = NeuralNet.process_output(value, answer[2], answer[3], board, player, 1, answer[0], answer[1])

        p3, v3 = NeuralNet.process_output(
            value,
            answer[4],
            answer[5],
            n_board,
            player,
            2,
            answer[2],
            answer[3],
        )

        policies.append(p1)
        policies.append(p2)
        policies.append(p3)

        values.append(v1)
        values.append(v2)
        values.append(v3)

    # convert to torch tensor
    input = torch.stack(input)
    policies = torch.stack(policies)
    values = torch.stack(values)
    return input, (policies, values)

In [5]:
def load_dataset(path):
    input_paths = glob(path, recursive=False)

    data = {}
    for path in tqdm(input_paths, desc="Loading data", unit="file"):
        data = {**data, **pickle.load(open(path, "rb"))}

    return MatchDataSet(data)

In [6]:
model_name = "data/resnet_01_40x128.pt"
data_path = "data/augment/train/*.pickle"

config.init()



https://app.neptune.ai/conql/Amazons/e/AM-76


In [8]:
# load data
train_dataset = load_dataset(data_path)
train_loader = DataLoader(
    train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8
)

Loading data: 100%|██████████| 181/181 [00:19<00:00,  9.18file/s]


In [9]:
# load model
net = NeuralNet()
if os.path.exists(model_name):
    net.load(path=model_name)

In [None]:
# train
optimizer = torch.optim.Adam(net.model.parameters(), lr=config.lr)
loss_pi = nn.CrossEntropyLoss()
loss_v = nn.MSELoss()
net.model.train()

for epoch in range(config.epochs):
    train_correct = 0
    train_actual_correct = 0
    train_ploss_sum = 0
    train_vloss_sum = 0
    train_count = 0

    for i, (X, Y) in tqdm(
        enumerate(train_loader), total=len(train_loader), desc="Training model"
    ):
        input_data = X
        target_pi, target_vs = Y
        sample_size = input_data.shape[0]

        if config.use_gpu:
            input_data, target_pi, target_vs = (
                input_data.contiguous().cuda(),
                target_pi.contiguous().cuda(),
                target_vs.contiguous().cuda(),
            )

        # predict
        out_pi, out_v = net.model(input_data)

        # Some samples have no value, so we need to replace it with the predicted value
        target_vs[target_vs == 0] = out_v[target_vs == 0]

        p_loss = loss_pi(out_pi, target_pi)
        v_loss = loss_v(out_v, target_vs)

        total_loss = p_loss + v_loss

        # update loss
        train_ploss_sum += p_loss.item() * sample_size
        train_vloss_sum += v_loss.item() * sample_size

        corrects = out_pi.view(sample_size, -1).argmax(1) == target_pi.view(
            sample_size, -1
        ).argmax(1)

        # update correct
        train_correct += corrects.sum().item()

        # calculate actual correct: 3 consecutive corrects
        ac = 0
        for j in range(0, sample_size, 3):
            if corrects[j] and corrects[j + 1] and corrects[j + 2]:
                ac += 1
        train_actual_correct += ac

        train_count += sample_size

        # backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # update log
        if i % 10 == 0:
            config.run["policy loss"].append(train_ploss_sum / train_count)
            config.run["value loss"].append(train_vloss_sum / train_count)
            config.run["accuracy"].append(train_correct / train_count)
            config.run["actual accuracy"].append(
                train_actual_correct / (train_count / 3)
            )

            train_ploss_sum = 0
            train_vloss_sum = 0
            train_correct = 0
            train_actual_correct = 0
            train_count = 0

        if i % 50 == 0:
            net.save(model_name)