In [1]:
%cd ../src

/Users/allen/Documents/GitHub/chesshacks-training/src


In [2]:
import os
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from train import train

# Hyperparameters
batch_size = 128
epochs = 10
learning_rate = 0.01
momentum = 0.9
weight_decay = 5e-4
num_workers = 2  # adjust for your environment

# Data transforms and download
base_transforms = torchvision.transforms.Compose(
    [torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor()]
)
train_dataset = torchvision.datasets.MNIST(
    "data", train=True, transform=base_transforms, download=True
)
test_dataset = torchvision.datasets.MNIST(
    "data", train=False, transform=base_transforms, download=True
)


# Model architecture
conv = [
    (1, 64),
    (64, 128),
    (128, 256),
    (256, 256),
    (256, 512),
    (512, 512),
    (512, 512),
    (512, 512),
]
pool = [
    (2, 2),
    (2, 2),
    (-1, -1),
    (2, 2),
    (-1, -1),
    (2, 2),
    (-1, -1),
    (2, 2),
]


class VGG11(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.ModuleList(
            [
                torch.nn.Conv2d(in_channels, out_channels, 3, 1, 1)
                for in_channels, out_channels in conv
            ]
        )
        self.conv_batchnorm = torch.nn.ModuleList(
            [torch.nn.BatchNorm2d(out_channels) for _, out_channels in conv]
        )
        self.conv_relu = torch.nn.ModuleList([torch.nn.ReLU() for _, _ in conv])

        pools = []
        for size, stride in pool:
            if size != -1:
                pools.append(torch.nn.MaxPool2d(size, stride))
        self.conv_pool = torch.nn.ModuleList(pools)
        self.flatten = torch.nn.Flatten(1, -1)
        self.fc1 = torch.nn.Linear(512, 4096)
        self.fc1_relu = torch.nn.ReLU()
        self.fc1_dropout = torch.nn.Dropout(0.5)
        self.fc2 = torch.nn.Linear(4096, 4096)
        self.fc2_relu = torch.nn.ReLU()
        self.fc2_dropout = torch.nn.Dropout(0.5)
        self.fc3 = torch.nn.Linear(4096, 10)

    def forward(self, x):
        pool_index = 0
        for conv, batchnorm, relu, (pool_size, _) in zip(
            self.conv, self.conv_batchnorm, self.conv_relu, pool
        ):
            x = conv(x)
            x = batchnorm(x)
            x = relu(x)
            if pool_size != -1:
                x = self.conv_pool[pool_index](x)
                pool_index += 1

        x = self.flatten(x)

        x = self.fc1(x)
        x = self.fc1_relu(x)
        x = self.fc1_dropout(x)

        x = self.fc2(x)
        x = self.fc2_relu(x)
        x = self.fc2_dropout(x)

        x = self.fc3(x)
        return x


# Instantiate model, loss, optimizer, scheduler
model = VGG11()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)

In [3]:
train(
    "MNIST",
    model,
    criterion,
    optimizer,
    scheduler,
    2,
    batch_size,
    train_dataset,
    test_dataset,
)

  return FileStore(store_uri, store_uri)


Epoch 1/2: train_loss=0.1450 val_loss=0.0402 val_acc=0.9873
Epoch 2/2: train_loss=0.0256 val_loss=0.0245 val_acc=0.9917
Training complete. Artifacts and metrics logged to MLflow.
