In [5]:
import os
import datetime

import numpy as np
import pandas as pd

import torch
import torch.multiprocessing as mp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
from swadist.utils import spawn_fn

# mp.spawn may throw an error without this
os.environ['MKL_THREADING_LAYER'] = 'GNU'

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
seed = (datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds()
print(f'seed: {int(seed)}')

seed: 172800


### Two-phase training: SGD + Codistillation

In phase 1, we train ResNet-8 asynchronously over the whole training set using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) (SGD w/ Nesterov momentum).

In phase 2, we run codistillation ([Hinton et al. 2015](http://arxiv.org/abs/1503.02531)).

In [2]:
# whether to log training to Tensorboard
log = False

# number of model replicas
world_size = 2

# overall size of training minibatches, aka effective batch size
eff_batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 5

# training epochs
epochs_sgd, epochs_codist = 5, 10

dataloader_kwargs = {
    'dataset': 'cifar10',
    'batch_size': eff_batch_size // world_size,
    'num_workers': 4,
    'data_parallel': True,
}
model_kwargs = {
    'in_kernel_size': 3,
    'stack_sizes': [1, 1, 1],
    'n_classes': 10,
    'batch_norm': False,
}
optimizer_kwargs = {
    'lr': lr0,
    'momentum': momentum,
    'nesterov': True,
}
trainer_kwargs = {
    'log': log,
    'name': 'codist',
}
train_kwargs = {
    'epochs_sgd': epochs_sgd,
    'epochs_codist': epochs_codist,
}
scheduler_kwargs = {
    'alpha': alpha,
    'decay_epochs': decay_epochs,
}

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs
        seed) # seed on rank i = seed + i

In [3]:
%%time

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 0 joined process group on device cuda
Rank 0 using seed 172800.0
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176

Param preview:
tensor([[[-0.0474,  0.0416,  0.0171],
         [ 0.0755,  0.0522, -0.0476],
         [ 0.0778,  0.2243,  0.0656]],

        [[ 0.0514, -0.0160,  0.1759],
         [ 0.1962, -0.1677,  0.1313],
         [ 0.1330,  0.0804, -0.0392]],

        [[ 0.1299,  0.0665,  0.1465],
         [ 0.0598,  0.0056, -0.0290],
         [ 0.0085, -0.2033,  0.0930]]], device='cuda:0',
       grad_fn=<SelectBackward0>) 

Worker 1/2 starting 15-epoch training loop...
Random seed on rank 0: 1046294152460790038

SGD epochs: 5 | Codistillation epochs: 10 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Rank 1 joined process group on device cuda
Rank 1 using seed 172801.0
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of tra

### Two-phase training: SGD + Codistillation w/ data partitioning

As before, but this time we train each model replica on disjoint partitions of the training set in phase 1 and 2.

In [4]:
dataloader_kwargs['data_parallel'] = False
dataloader_kwargs['split_training'] = True
trainer_kwargs['name'] = 'codist-partitioned'

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 0 joined process group on device cuda
Rank 0 using seed 172800.0
Files already downloaded and verified
Using SubsetRandomSampler with samples 0 to 22499
Number of training samples: 45000
Number of training batches: 176

Param preview:
tensor([[[-0.0474,  0.0416,  0.0171],
         [ 0.0755,  0.0522, -0.0476],
         [ 0.0778,  0.2243,  0.0656]],

        [[ 0.0514, -0.0160,  0.1759],
         [ 0.1962, -0.1677,  0.1313],
         [ 0.1330,  0.0804, -0.0392]],

        [[ 0.1299,  0.0665,  0.1465],
         [ 0.0598,  0.0056, -0.0290],
         [ 0.0085, -0.2033,  0.0930]]], device='cuda:0',
       grad_fn=<SelectBackward0>) 

Worker 1/2 starting 15-epoch training loop...
Random seed on rank 0: 16542267418990546344

SGD epochs: 5 | Codistillation epochs: 10 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Rank 1 joined process group on device cuda
Rank 1 using seed 172801.0
Files already downloaded and verified
Using SubsetRandomSampler with samples 22500 t