In [None]:
import sys

# make mygo available
sys.path.insert(0, "../src")

In [None]:
from mygo.datasets import *
from mygo.model import *
from mygo.encoder import *
from torch.utils.data import DataLoader
import torch
from torch import nn
import numpy as np
from pathlib import Path

In [None]:
# initialization
torch.manual_seed(25565)  # for reproducibility

In [None]:
# Hyperparameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

game_count = 1000
test_game_count = 100
board_size = 19
encoder = SevenPlaneEncoder(board_size)
model = SmallModel(board_size, encoder.plane_count).to(device)
model_name = "small"

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters())
batch_size = 128
epochs = 100

weight_template = f"../models/{model_name}-{board_size}x-{encoder.plane_count}p-{game_count}g-{{i}}e_weights.pt"
fig_title = f"{model_name.capitalize()}Model {board_size}x{board_size} {encoder.plane_count} plane(s) {game_count} games"

In [None]:
total_params = sum(p.numel() for p in model.parameters())

print(repr(device))
print(model)
print(f"parameters: {total_params:,}")

In [None]:
def transform(data):
    if isinstance(data, np.ndarray):
        return torch.from_numpy(data).to(device)
    elif isinstance(data, (int, tuple, list)):
        return torch.tensor(data, device=device)

In [None]:
train_data = KGSIterableDataset(
    root="../data/kgs_sgfs",
    train=True,
    game_count=game_count,
    encoder=encoder,
    transform=transform,
    target_transform=transform,
)

test_data = KGSIterableDataset(
    root="../data/kgs_sgfs",
    train=False,
    game_count=test_game_count,
    encoder=encoder,
    transform=transform,
    target_transform=transform,
)

train_loader = DataLoader(train_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=2 * batch_size)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = 0
    batches = 0
    batch_pred = game_count * 187 // batch_size
    train_loss, correct = 0.0, 0.0
    model.train()

    for i, (xs, ys) in enumerate(dataloader):
        size += len(xs)
        batches += 1
        optimizer.zero_grad()

        pred = model(xs)
        loss = loss_fn(pred, ys)

        train_loss += loss.item()
        correct += (pred.argmax(1) == ys).type(torch.float).sum().item()

        loss.backward()
        optimizer.step()

        if i % 500 == 0:
            print(f"loss: {loss.item():>7f} [{i + 1:>4d}/~{batch_pred:>4d}]")

    train_loss /= batches
    accuracy = correct / size * 100
    print(f"\nTrain: accuracy: {accuracy:>.3f}%, loss: {train_loss:>7f}")

    return train_loss, accuracy


def test_loop(dataloader, model, loss_fn):
    size = 0
    batches = 0
    test_loss, correct = 0.0, 0.0
    model.eval()

    with torch.no_grad():
        for xs, ys in dataloader:
            size += len(xs)
            batches += 1

            pred = model(xs)
            test_loss += loss_fn(pred, ys).item()
            correct += (pred.argmax(1) == ys).type(torch.float).sum().item()

    test_loss /= batches
    accuracy = correct / size * 100
    print(f"Test: accuracy: {accuracy:>.3f}%, loss: {test_loss:>7f}")

    return test_loss, accuracy

In [None]:
# record for plotting
xs = []
train_losses, test_losses = [], []
train_accs, test_accs = [], []

In [None]:
%%time

i_start = 1
for i in range(epochs, 1, -1):
    weight_file = Path(weight_template.format(i=i))
    if weight_file.is_file():
        print(f"Loading {weight_file}")
        model.load_state_dict(torch.load(weight_file))
        i_start = i + 1
        break

for i in range(i_start, epochs + 1):
    print(f"Epoch {i}\n{'-' * 25}")

    train_loss, train_accuracy = train_loop(train_loader, model, loss_fn, optimizer)
    test_loss, test_accuracy = test_loop(test_loader, model, loss_fn)
    torch.save(model.state_dict(), Path(weight_template.format(i=i)))

    # only save the last model file
    prev_weight_file = Path(weight_template.format(i=i - 1))
    if prev_weight_file.is_file():
        print(f"Removing previous model file: {prev_weight_file}")
        prev_weight_file.unlink()

    xs.append(i)
    train_losses.append(train_loss)
    train_accs.append(train_accuracy)
    test_losses.append(test_loss)
    test_accs.append(test_accuracy)
    print()

In [None]:
print(
    f"{fig_title=}\n{xs=}\n{train_losses=}\n{train_accs=}\n{test_losses=}\n{test_accs=}"
)

In [None]:
import matplotlib.pyplot as plt

plt.style.use("dark_background")

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

fig.suptitle(fig_title)

axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy (%)")

axs[0].plot(xs, train_losses, label="train")
axs[0].plot(xs, test_losses, label="test")
axs[0].legend()
axs[1].plot(xs, train_accs, label="train")
axs[1].plot(xs, test_accs, label="test")
axs[1].legend()

plt.show()