In [None]:
import random
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import random_split
from topography.training import train, test
from topography.models import resnet

In [None]:
epochs = 20
batch_size = 32
seed = 0
lr = 3e-4
root = './cifar10'

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = torchvision.datasets.CIFAR10(root=f'{root}/data', train=True,
                                        download=True, transform=transform)
train_length = int(0.8*len(dataset))
train_set, val_set = random_split(
    dataset, [train_length, len(dataset)-train_length])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root=f'{root}/data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = resnet(out_features=10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lr=lr)

In [None]:
save_dir = f'{root}/logs'
results = {'train': {}, 'val': {}}
for epoch in range(1, epochs+1):
    results['train'][epoch] = train(model, train_loader, optimizer,
                                    criterion, device, save_dir, epoch)
    results['val'][epoch] = test(model, val_loader, criterion, device,
                                 save_dir, 'val', epoch)
results['test'] = test(model, test_loader, criterion, save_dir, 'test')