In [None]:
import torch
from torch import nn
from torch.utils.data import Subset, random_split

import torchvision
from torchvision import transforms
from torchvision.models import resnet18

from train_utils import *
from train_xor import *

import numpy as np

from sklearn.linear_model import LogisticRegression

In [None]:
#generate training seeds
#seeds = torch.randint(0, 2**12, (6,)).tolist()
#seeds
seeds = [3313, 1998, 1900, 1608, 3585, 96]

In [None]:
#functions for preparing domino dataset

def prepare_domino_data(
    root,
    train_transform,
    test_transform,
    val_frac=0.25,
):
    train_data = torchvision.datasets.ImageFolder(
        root=root, transform=train_transform
    )
    val_data = torchvision.datasets.ImageFolder(
        root=root, transform=test_transform
    )

    train_ind, val_ind = random_split(
        range(len(train_data)), [1 - val_frac, val_frac]
    )
    train_data = Subset(train_data, train_ind)
    val_data = Subset(val_data, val_ind)
    return train_data, val_data

def prepare_resnet18(num_classes, scale=1.0):
    model = resnet18(num_classes=num_classes)
    model.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    ) #small convolution is better for CIFAR-10
    with torch.no_grad():
        for parameter in model.parameters():
            parameter.copy_(parameter * scale)
    return model

In [None]:
#normalize layer for concatenated images

class Normalize(torch.nn.Module):
    def __init__(self, means, stds):
        super().__init__()
        self.trans1 = transforms.Normalize(means[0], stds[0])
        self.trans2 = transforms.Normalize(means[1], stds[1])

    def forward(self, tensor):
        dtype = tensor.dtype
        tensor1 = self.trans1(tensor[:, :32, :])
        tensor2 = self.trans2(tensor[:, 32:, :])
        return torch.cat((tensor1, tensor2), 1)

In [None]:
#function for logisic regression during training

def ood_correct_training_params(
    epochs,
    train_loader,
    loader_params,
    optimizer_params,
    scheduler_params,
    train_params,
    stats,
    ood_data_params,
    ood_reg_params={},
    ood_data_seed=None,
    ood_seed=None,
    ood_data_fn=prepare_data,
    ood_reg_fn=lambda **x: LogisticRegression(**x),
    id_correct_training_params_fn=correct_training_params,
    **id_correct_training_params
):
    id_correct_training_params_fn(
        epochs,
        train_loader,
        loader_params,
        optimizer_params,
        scheduler_params,
        train_params,
        **id_correct_training_params
    )

    set_deterministic_seed(ood_data_seed)
    ood_val_data, ood_test_data = ood_data_fn(**ood_data_params)
    train_params['val_epoch_params']['ood_val_data'] = ood_val_data
    train_params['val_epoch_params']['ood_test_data'] = ood_test_data
    train_params['val_epoch_params']['stats'] = stats
    train_params['val_epoch_params']['ood_reg_fn'] \
        = ood_reg_fn(**ood_reg_params)

In [None]:
means = [[0.1307, 0.1307, 0.1307], [0.491, 0.482, 0.446]]
stds = [[0.3081, 0.3081, 0.3081], [0.202, 0.199, 0.201]]

test_transform = transforms.Compose([
    transforms.ToTensor(),
    Normalize(means, stds),
])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    test_transform
])

In [None]:
#prepare training parameters

root = '/mnt/files/data'
epochs = 1

seed = 179
stats = [[], [], [], []]

data_params = {
    'root': root + '/mnist-cifar10-train-100',
    'train_transform': train_transform,
    'test_transform': test_transform,
}
loader_params = {'batch_size': 2**7, 'num_workers': 4}
model_params = {
    'num_classes': 10,
    'scale': 2**(-5)
}
loss_params = {}
optimizer_params = {
    'momentum': 0.9, 'weight_decay': 0.0005, 'nesterov': True
}
scheduler_params = {}
correction_params = {
    'lr_factor': 2**(-10),
    'warmup_factor': 2**(-3),
    'stats': stats,
    'ood_data_params': {
        'root': root + '/mnist-cifar10-test',
        'train_transform': test_transform,
        'test_transform': test_transform},
    'ood_reg_params': {
        'C': 1e3,
        'max_iter': 20000,
        'n_jobs': 1,
        'warm_start': True},
    'ood_data_seed': seed,
    'ood_data_fn': prepare_domino_data,
    'ood_reg_fn': lambda **x: LogisticRegression(**x)
}

train_params = {
    'val_epoch_fn': ood_epoch,
    'val_epoch_params': {
        'feature_index': 'avgpool',
        'loader_params': loader_params,
        'warm_start_restarts': 50},
    'val_interval': max(epochs//4, 1)
}

train_kwargs = {
    'data_fn': prepare_domino_data,
    'model_fn': prepare_resnet18,
    'correct_training_params_fn': ood_correct_training_params
}

In [None]:
def make_experiment(data_params, model_params):
    experiment = []
    for i in tqdm(range(len(seeds))):
        seed = seeds[i]
        stats = [[], [], [], []]

        correction_params['stats'] = stats
        correction_params['ood_data_seed'] = seed
        _ = get_trained_model(
            epochs,
            data_params, loader_params,
            model_params, loss_params,
            optimizer_params, scheduler_params, correction_params,
            train_params,
            seed, seed, seed, seed,
            **train_kwargs
        )

        experiment.append(stats)
    return experiment

In [None]:
#train models in different settings

epochs = 2**8
#epochs = 1
#seeds = [179]

train_params['range_fn'] = lambda x: range(x)
train_params['val_epoch_params']['print_fn'] = lambda *x: None
train_params['val_epoch_params']['print_ood'] = lambda *x: None
train_params['val_interval'] = epochs // 16

experiments = []
for scale, train_dir in zip(
    (2**(-5), 1, 2**(-5)),
    ('/mnist-cifar10-train-100',
     '/mnist-cifar10-train-100',
     '/mnist-cifar10-train-95')
):
    data_params['root'] = root + train_dir
    model_params['scale'] = scale
    experiments.append(make_experiment(data_params, model_params))

experiments = np.array(experiments)
with open('domino.npy', 'wb') as f:
    np.save(f, experiments)

del correction_params['ood_reg_params']['C']
correction_params['ood_reg_params']['penalty'] = None
data_params['root'] = root + '/mnist-cifar10-train-100'
model_params['scale'] = 2**(-5)
train_params['val_interval'] = epochs // 4

experiment_no_reg = make_experiment(data_params, model_params)

experiment_no_reg = np.array(experiment_no_reg)
with open('domino_no_reg.npy', 'wb') as f:
    np.save(f, experiment_no_reg)

In [None]:
from scipy.stats import t, norm

x = experiments[0, :, 1]
diff = x[:, 4] - x[:, 16]
mean, std = diff.mean(), diff.std(ddof=1)
t_val = mean * len(diff)**0.5 / std
p_val = t.sf(t_val, len(diff)-1)
print(f"{mean:2.2%}, {std:2.2%}, {p_val}")

In [None]:
x = experiments[1, :, 1]
diff = x[:, 4] - x[:, 16]
mean, std = diff.mean(), diff.std(ddof=1)
t_val = mean * len(diff)**0.5 / std
p_val = t.sf(t_val, len(diff)-1)
print(f"{mean:2.2%}, {std:2.2%}, {p_val}")

In [None]:
norm.sf(mean / std)