In [1]:
from mygo.datasets import MCTSDataset
from mygo.model import TinyModel
from torch.utils.data import DataLoader
import torch
from torch import nn

In [2]:
torch.manual_seed(25565)  # for reproducibility

<torch._C.Generator at 0x7f0cf68eebd0>

In [3]:
train_data = MCTSDataset(
    root="../data",
    train=True,
    transform=torch.from_numpy,
)

test_data = MCTSDataset(
    root="../data",
    train=False,
    transform=torch.from_numpy,
)

In [4]:
def train_loop(dataloader, model, loss_fn, optimizer):
    batches = len(dataloader)
    model.train()
    for i, (xs, ys) in enumerate(dataloader):
        optimizer.zero_grad()

        pred = model(xs)
        loss = loss_fn(pred, ys)

        loss.backward()
        optimizer.step()

        if i % 10 == 9:
            print(f"loss: {loss.item():>7f} [{i + 1:>2d}/{batches:>2d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    batches = len(dataloader)
    model.eval()
    test_loss, correct = 0.0, 0.0

    with torch.no_grad():
        for xs, ys in dataloader:
            pred = model(xs)
            test_loss += loss_fn(pred, ys).item()
            correct += (pred.argmax(1) == ys.argmax(1)).type(torch.float).sum().item()

    test_loss /= batches
    accuracy = correct / size * 100

    print(f"\nTest Accuracy: {accuracy:>.1f}%\nAvg Loss: {test_loss:>7f}\n")

In [5]:
model = TinyModel(9)
print(model)

total_params = sum(p.numel() for p in model.parameters())
print(f"parameters: {total_params:,}")

TinyModel(
  (conv_stack): Sequential(
    (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=768, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=81, bias=True)
    (4): Softmax(dim=1)
  )
)
parameters: 456,545


In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
batch_size = 128
epoches = 100

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=2 * batch_size)

for i in range(epoches):
    print(f"Epoch {i + 1}\n{'-' * 25}")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(test_loader, model, loss_fn)

Epoch 1
-------------------------
loss: 4.394397 [10/292]
loss: 4.394379 [20/292]
loss: 4.394395 [30/292]
loss: 4.394384 [40/292]
loss: 4.394180 [50/292]
loss: 4.394292 [60/292]
loss: 4.394122 [70/292]
loss: 4.393923 [80/292]
loss: 4.394064 [90/292]
loss: 4.393973 [100/292]
loss: 4.394046 [110/292]
loss: 4.393663 [120/292]
loss: 4.393711 [130/292]
loss: 4.393630 [140/292]
loss: 4.393457 [150/292]
loss: 4.393137 [160/292]
loss: 4.393414 [170/292]
loss: 4.392933 [180/292]
loss: 4.392430 [190/292]
loss: 4.392473 [200/292]
loss: 4.392848 [210/292]
loss: 4.391551 [220/292]
loss: 4.391968 [230/292]
loss: 4.389781 [240/292]
loss: 4.389716 [250/292]
loss: 4.388471 [260/292]
loss: 4.383140 [270/292]
loss: 4.386102 [280/292]
loss: 4.391413 [290/292]

Test Accuracy: 2.4%
Avg Loss: 4.386696

Epoch 2
-------------------------
loss: 4.392581 [10/292]
loss: 4.389024 [20/292]
loss: 4.381825 [30/292]
loss: 4.387025 [40/292]
loss: 4.393531 [50/292]
loss: 4.379637 [60/292]
loss: 4.389768 [70/292]
loss: 4

In [7]:
torch.save(model.state_dict(), "../models/tiny_weights.pt")