In [1]:
import shutil
import tempfile
import sys
sys.path.append('..')

import util
import datasets
from local_util import *

import torch
import torchvision
import torch.utils.tensorboard as tb

train = datasets.load_data(datasets.Cifar10("train"), batch_size=384)
test = datasets.load_data(datasets.Cifar10("test"), batch_size=384)

In [2]:
# Clear log_dir

shutil.rmtree("log_dir", ignore_errors=True)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

log_dir = "log_dir"
epochs = 10000
in_features = 3
out_features = 10

loss = torch.nn.CrossEntropyLoss()
models: list[tuple[str, torch.nn.Module]] = [
    (
        "alexnet", 
        torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=False).to(device),
    ),
    (
        "conv_net_pcn", 
        torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=False).to(device),
    ),
]
models[0][1].classifier[6] = torch.nn.Linear(4096, 10).to(device)
models[1][1].classifier = PCN([9216, 4096, 4096, 10], dimensions=16).to(device)

model_utils = [
    (
        tb.SummaryWriter(log_dir=f"{log_dir}/{models[0][0]}"),
        [
            torch.optim.SGD(models[0][1].parameters(), lr=0.0001)
        ],
    ),
    (
        tb.SummaryWriter(log_dir=f"{log_dir}/{models[1][0]}"),
        [
            torch.optim.SGD(models[1][1].features.parameters(), lr=0.0001),
            pcnSGD(models[1][1].classifier, lr=0.0001, opp="log"),
        ],
    ),
]


for epoch in range(epochs):
    for (_, model), (tb_logger, optimizers) in zip(models, model_utils):
        model.train()
        conf_mat = util.ConfusionMatrix(size=10)
        for i, (x, y) in enumerate(train):
            x = torchvision.transforms.Resize((227, 227))(x)
            step = epoch * len(train) + i
            x, y = x.to(device).float(), y.to(device)
            
            pred = model(x)

            l = loss(pred, y)
            l.backward()

            for optimizer in optimizers:
                optimizer.step()
                optimizer.zero_grad()

            conf_mat.add(pred.cpu().argmax(1), y.cpu())
            tb_logger.add_scalar("loss", l, step)
        tb_logger.add_scalar("accuracy", conf_mat.accuracy, epoch)

        model.eval()
        conf_mat = util.ConfusionMatrix(size=10)
        with torch.no_grad():
            for i, (x, y) in enumerate(test):
                x = torchvision.transforms.Resize((227, 227))(x)
                x, y = x.to(device).float(), y.to(device)
                pred = model(x)
                conf_mat.add(pred.cpu().argmax(1), y.cpu())
            tb_logger.add_scalar("test_accuracy", conf_mat.accuracy, epoch)

Using cache found in C:\Users\hette/.cache\torch\hub\pytorch_vision_v0.10.0
Using cache found in C:\Users\hette/.cache\torch\hub\pytorch_vision_v0.10.0


In [3]:

# Start tensorboard

shutil.rmtree(tempfile.gettempdir() + "/.tensorboard-info", ignore_errors=True) # sort of 'force reload' for tensorboard
%load_ext tensorboard 
%tensorboard --logdir log_dir --reload_interval 1 --port 6005