In [None]:
import random
import uuid
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from topography.training import train, test
from topography.models import resnet
from topography.utils import random_split

In [None]:
seed = 0
epochs = 20
batch_size = 128
lr = 1e-4
weight_decay = 5e-4
root = './cifar10'

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_transform = transforms.Compose([
    transforms.RandAugment(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = torchvision.datasets.CIFAR10(root=f'{root}/data', train=True, download=True)
train_length = int(0.8*len(dataset))
train_set, val_set = random_split(
    dataset, [train_length, len(dataset)-train_length],
    [train_transform, test_transform])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2, pin_memory=True)
test_set = torchvision.datasets.CIFAR10(root=f'{root}/data', train=False,
                                       download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2, pin_memory=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss()
model = resnet(out_features=10, pretrained=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
save_dir = f'{root}/{uuid.uuid4()}'
print(f'Save directory: {save_dir}')
results = {'train': {}, 'val': {}}
for epoch in range(1, epochs+1):
    results['train'][epoch] = train(model, train_loader, optimizer,
                                    criterion, device, save_dir, epoch)
    results['val'][epoch] = test(model, val_loader, criterion, device,
                                 save_dir, 'val', epoch)
results['test'] = test(model, test_loader, criterion, device, save_dir, 'test')