In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, random_split

import torchvision
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.models.feature_extraction import create_feature_extractor

from train_utils import *
from train_xor import *

import numpy as np

#this file shows that ResNet-18 achieves 90% accuracy with our training setup

In [2]:
def prepare_resnet18(num_classes, scale=1.0):
    model = resnet18(num_classes=num_classes)
    model.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    ) #small convolution is better for CIFAR-10
    with torch.no_grad():
        for parameter in model.parameters():
            parameter.copy_(parameter * scale)
    return model

In [3]:
mean = [0.491, 0.482, 0.446]
std = [0.202, 0.199, 0.201]
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    test_transform
])

In [4]:
#prepare training parameters

epochs = 2**8

seed = 179

data_params = {
    'root': '/mnt/files/data',
    'data_fn': torchvision.datasets.CIFAR10,
    'train_transform': train_transform,
    'test_transform': test_transform,
}
loader_params = {'batch_size': 2**7, 'num_workers': 4}
model_params = {
    'num_classes': 10,
    'scale': 2**(-5)
}
loss_params = {}
optimizer_params = {
    'momentum': 0.9, 'weight_decay': 0.0005, 'nesterov': True
}
scheduler_params = {}
correction_params = {'lr_factor': 2**(-10), 'warmup_factor': 2**(-3)}

train_params = {
    'val_interval': epochs//4
}

train_kwargs = {
    'model_fn': prepare_resnet18
}

In [5]:
#train model

model = get_trained_model(
    epochs,
    data_params,
    loader_params,
    model_params,
    loss_params,
    optimizer_params,
    scheduler_params,
    correction_params,
    train_params,
    seed,
    seed,
    seed,
    seed,
    **train_kwargs
)

  0%|          | 0/256 [00:00<?, ?it/s]

Epoch:   0 Accuracy: 10.2%
Epoch:  64 Accuracy: 78.4%
Epoch: 128 Accuracy: 85.2%
Epoch: 192 Accuracy: 88.1%
Epoch: 256 Accuracy: 92.6%
