In [1]:
import os
import datetime

import torch
import torch.multiprocessing as mp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
from swadist.utils import spawn_fn

# mp.spawn may throw an error without this
os.environ['MKL_THREADING_LAYER'] = 'GNU'

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print('Using cuda')
else:
    print('Using cpu')

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

Using cuda
seed: 2246400


### Contents

- **Common arguments**
- **Asynchronous data-parallel SGD**
- **Asynchronous SGD + Codistillation**
- **Asynchronous SGD + Codistillation w/ data partitioning**
- **Codistillation w/ data partioning only**

### Common arguments

In [2]:
# number of model replicas
world_size = 2

dataloader_kwargs = {
    'dataset': 'cifar10',
    # size of training batches on each rank
    'batch_size': 256 // world_size,
    # cpu processes per rank for data loading
    'num_workers': 2,
}

model_kwargs = {
    'n_classes': 10,
    'in_kernel_size': 3,
    'stack_sizes': [1, 1, 1],
    'batch_norm': False,
}

optimizer_kwargs = {
    'lr': 2**-5.,
    'momentum': 0.975,
    'nesterov': True,
}

trainer_kwargs = {
    # whether to log training to Tensorboard
    'log': False,
}

train_kwargs = {}

scheduler_kwargs = {
    'alpha': 0.25,
    'decay_epochs': 15,
}

swa_scheduler_kwargs = {
    'swa_lr':  optimizer_kwargs['lr'] / 10, 
    'anneal_strategy': 'cos',
    'anneal_epochs': 3,
}

### Asynchronous data-parallel SGD

Run asynchronous SGD in parallel with distributed sampling for 15 epochs.

In [3]:
%%time

trainer_kwargs['name'] = 'async-sgd'

# when true, uses DistributedSampler
dataloader_kwargs['data_parallel'] = True

train_kwargs['epochs_sgd'] = 15
train_kwargs['epochs_codist'] = 0
train_kwargs['epochs_swa'] = 0

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs
        seed, # seed on rank i = seed + i
        False) # ddp=False, don't use DistributedDataParallel

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 15 | Codistillation epochs: 0 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.923787 <> acc=0.274013 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.690852 <> accuracy=0.379297 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.748831 <> accuracy=0.336133 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.535766 <> acc=0.440740 | Batch (size 1

### Asynchronous SGD + Codistillation

Switch to codistillation after 5 epochs of asynchronous SGD and train for 15 epochs total.

In [4]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

trainer_kwargs['name'] = 'codist'

dataloader_kwargs['data_parallel'] = True
    
train_kwargs['epochs_sgd'] = 5
train_kwargs['epochs_codist'] = 10
train_kwargs['codist_kwargs'] = {
    # how many steps before re-syncing stale replicas
    'sync_freq': 50,
    # transform to apply to mean replica output
    'transform': 'softmax',
    # when True, prints gather and update param steps
    'debug': False,
}

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 5 | Codistillation epochs: 10 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.924133 <> acc=0.271895 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.719567 <> accuracy=0.368359 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.836185 <> accuracy=0.301953 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.537889 <> acc=0.437019 | Batch (size 1

### Asynchronous SGD + Codistillation w/ data partitioning

As before, but this time we train each model replica on disjoint partitions of the training set.

In [5]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

trainer_kwargs['name'] = 'codist-split'

# when True, restrict each rank to half of the data
dataloader_kwargs['split_training'] = True
dataloader_kwargs['data_parallel'] = False

train_kwargs['codist_kwargs'] = {
    # how many steps before re-syncing stale replicas
    'sync_freq': 50,
    # transform to apply to mean replica output
    'transform': 'softmax',
    # when True, prints gather and update param steps
    'debug': False,
}

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2246400)
Using SubsetRandomSampler with samples 22500 to 44999
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2246400)
Using SubsetRandomSampler with samples 0 to 22499
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 5 | Codistillation epochs: 10 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.958013 <> acc=0.261422 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.710087 <> accuracy=0.352539 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.672574 <> accuracy=0.388281 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mean): 

## Data-parallel codistillation only

Now run data-parallel codistillation for all 15 epochs.

In [6]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

trainer_kwargs['name'] = 'codist-only'

dataloader_kwargs['split_training'] = False
dataloader_kwargs['data_parallel'] = True

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_codist'] = 15
train_kwargs['codist_kwargs'] = {
    # how many steps before re-syncing stale replicas
    'sync_freq': 50,
    # transform to apply to mean replica output
    'transform': 'softmax',
    # when True, prints gather and update param steps
    'debug': False,
}

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 0 | Codistillation epochs: 15 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Starting codistillation phase...

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.994283 <> acc=0.270767 <> codist_loss=2.231593 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.779046 <> accuracy=0.362109 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.739466 <> accuracy=0.365039 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mea

## Codistillation followed by SWA

In [3]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

trainer_kwargs['name'] = 'codistswa'

dataloader_kwargs['split_training'] = False
dataloader_kwargs['data_parallel'] = True

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_codist'] = 40 # codist should stop and pass to swa if codist stalls
train_kwargs['epochs_swa'] = 15
train_kwargs['codist_kwargs'] = {
    # how many steps before re-syncing stale replicas
    'sync_freq': 50,
    # transform to apply to mean replica output
    'transform': 'softmax',
    # when True, prints gather and update param steps
    'debug': False,
}

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 55-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2246400)
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 55-epoch training loop...
SGD epochs: 0 | Codistillation epochs: 40 | SWA epochs: 15
DistributedDataParallel: False
Stopping accuracy: None

Starting codistillation phase...

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.992147 <> acc=0.271135 <> codist_loss=2.230487 | Batch (size 100): 176/176 (100%) | Total steps: 176
Rank 0 | Validation mean |  cross_entropy=1.763798 <> accuracy=0.366016 | Batch: 40/40 (100%)
Rank 1 | Validation mean |  cross_entropy=1.739883 <> accuracy=0.364844 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch me