In [1]:
import shutil
import tempfile

import torch
from torch import nn

import sys
from matplotlib import pyplot as plt

sys.path.append("..")
from datasets import load_data, MNIST

from local_util import PCN, FCN, pcnSGD

In [2]:

# Load Training Data

subset_size = 100

full_train = MNIST("train")
subset_train = torch.utils.data.Subset(full_train, list(range(subset_size)))

train = load_data(subset_train)
test = load_data(MNIST("test"))

In [8]:

# Clear log_dir

shutil.rmtree("log_dir", ignore_errors=True)

In [9]:
import torch.utils.tensorboard as tb

log_dir = "log_dir"
learn_rate = 0.01
epochs = 10000

cpu = torch.device("cpu")
gpu = torch.device("cuda:0")
device = gpu if torch.cuda.is_available() else cpu

models: list[tuple[str, nn.Module]] = [
    # ("pcn2048_1D", PCN([28 * 28, 3200, 3200, 3200, 3200, 10], dimensions=1).to(device)),
    ("pcn2048_4D", PCN([28 * 28, 512, 1024, 4096, 4096 * 4, 10], dimensions=4).to(device)),
    # ("pcn2048_4D", PCN([28 * 28, 400, 400, 400, 400, 10], dimensions=4).to(device)),
    ("pcn256_16D", PCN([28 * 28, 32, 10], dimensions=128).to(device)),
    # ("fcn", FCN(28 * 28, 10, [128, 256, 512, 1024, 2048]).to(device)),
]

loggers = [tb.SummaryWriter(log_dir + "/" + name) for name, _ in models]
optimizers = [
    # pcnSGD(models[0][1], lr=0.00000001, opp="log"),
    pcnSGD(models[0][1], lr=0.0002, opp="log"),
    pcnSGD(models[1][1], lr=0.0002, opp="log"),
    # torch.optim.SGD(models[2][1].parameters(), lr=0.01),
]

loss = nn.CrossEntropyLoss()

for epoch in range(epochs):

    for (_, model), optimzer, logger in zip(models, optimizers, loggers):
        correct = 0
        total = 0
        for i, (x, y) in enumerate(train):
            step = epoch * len(train) + i
            x = x.float().to(device)
            y = y.to(device)
            pred = model(x).float()

            correct += (pred.argmax(dim=1) == y).sum().item()
            total += y.shape[0]

            l = loss(pred, y)
            l.backward()

            optimzer.step()
            optimzer.zero_grad()

            logger.add_scalar("loss", l, step)
        logger.add_scalar("train_accuracy", correct / total, epoch)

    with torch.no_grad():
        for (_, model), logger in zip(models, loggers):
            correct = 0
            total = 0
            for i, (x, y) in enumerate(test):
                x = x.float().to(device)
                y = y.to(device)
                pred = model(x).float()

                correct += (pred.argmax(dim=1) == y).sum().item()
                total += y.shape[0]

            logger.add_scalar("valid_accuracy", correct / total, epoch)

KeyboardInterrupt: 

In [4]:

# Start tensorboard

shutil.rmtree(tempfile.gettempdir() + "/.tensorboard-info", ignore_errors=True) # sort of 'force reload' for tensorboard
%load_ext tensorboard 
%tensorboard --logdir log_dir --reload_interval 1 --port 6005

Launching TensorBoard...