In [1]:
import os
import datetime

import torch
import torch.multiprocessing as mp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
from swadist.utils import spawn_fn

# mp.spawn may throw an error without this
os.environ['MKL_THREADING_LAYER'] = 'GNU'

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print('Using cuda')
else:
    print('Using cpu')

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

Using cuda
seed: 864000


### Contents

- **Common arguments**
- **Codistillation followed by SWA**
- **SWADist**

### Common arguments

In [2]:
# number of model replicas
world_size = 2

dataloader_kwargs = {
    'dataset': 'cifar10',
    'batch_size': 256 // world_size,
    'num_workers': 2,
    'split_training': True
}

model_kwargs = {
    'n_classes': 10,
    'in_kernel_size': 3,
    'stack_sizes': [1, 1, 1],
    'batch_norm': False,
}

optimizer_kwargs = {
    'lr': 2**-5.,
    'momentum': 0.975,
    'nesterov': True,
}

trainer_kwargs = {
    # whether to log training to Tensorboard
    'log': True,
    'log_dir': './_runs',
}

train_kwargs = {
    'epochs_sgd': 5,
    'epochs_codist': 5,
    'epochs_swa': 5,
    'codist_kwargs': {
        'sync_freq': 50,
        'transform': 'softmax',
        'debug': False,
    },
    'validations_per_epoch': 4,
    'save': True,
    'save_dir': './_state_dicts',
    'stopping_acc': 0.7,
}

scheduler_kwargs = {
    'alpha': 0.25,
    'decay_epochs': 15,
}

swa_scheduler_kwargs = {
    'swa_lr':  optimizer_kwargs['lr'] / 10, 
    'anneal_strategy': 'cos',
    'anneal_epochs': 3,
}

### Codistillation followed by SWA

Run codistillation for 10 epochs, then stop and switch to SWA for the final 5.

In [3]:
%%time

trainer_kwargs['name'] = 'codist-swa'

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(864001)
Using SubsetRandomSampler with samples 22500 to 44999
Number of training samples: 45000
Number of training batches: 176

Rank 1: torch.cuda.manual_seed(864003)

Worker 2/2 starting 15-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(864000)
Using SubsetRandomSampler with samples 0 to 22499
Number of training samples: 45000
Number of training batches: 176

Rank 0: torch.cuda.manual_seed(864002)

Worker 1/2 starting 15-epoch training loop...
SGD epochs: 5 | Codistillation epochs: 5 | SWA epochs: 5
DistributedDataParallel: False
Stopping accuracy: 0.7

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.991733 <> acc=0.257278 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.765704 <> accuracy=0.340234 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.704424 <> accura

### SWADist

Continue codistillation during SWA phase, but use averaged model replicas to generate mean teacher output. During intra-epoch synchronization, the averaged model replicas mean includes the current model as training progresses, which is cemented into the mean parameters on each rank at epoch end.

In [3]:
%%time

trainer_kwargs['name'] = 'swadist'

train_kwargs['swadist'] = True

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(864000)
Using SubsetRandomSampler with samples 0 to 22499
Number of training samples: 45000
Number of training batches: 176

Rank 0: torch.cuda.manual_seed(864002)

Worker 1/2 starting 15-epoch training loop...
SGD epochs: 5 | Codistillation epochs: 5 | SWADist epochs: 5
DistributedDataParallel: False
Stopping accuracy: 0.7

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(864001)
Using SubsetRandomSampler with samples 22500 to 44999
Number of training samples: 45000
Number of training batches: 176

Rank 1: torch.cuda.manual_seed(864003)

Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.991169 <> acc=0.257690 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.763006 <> accuracy=0.341992 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.660956 <> ac