In [1]:
import os
import datetime

import torch
import torch.multiprocessing as mp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
from swadist.utils import spawn_fn

# mp.spawn may throw an error without this
os.environ['MKL_THREADING_LAYER'] = 'GNU'

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print('Using cuda')
else:
    print('Using cpu')

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

Using cuda
seed: 2764800


### Common arguments

In [2]:
# number of model replicas
world_size = 2

dataloader_kwargs = {
    'dataset': 'cifar10',
    # size of training batches on each rank
    'batch_size': 256 // world_size,
    # cpu processes per rank for data loading
    'num_workers': 1,
    'pin_memory': True,
}

model_kwargs = {
    'n_classes': 10,
    'in_kernel_size': 3,
    'stack_sizes': [1, 1, 1],
    'batch_norm': False,
}

optimizer_kwargs = {
    'lr': 2**-5.,
    'momentum': 0.975,
    'nesterov': True,
}

trainer_kwargs = {
    # whether to log training to Tensorboard
    'log': False,
    'log_dir': './results/tensorboard/03-codist/2022-05-12-commit-5c8bd29'
}

train_kwargs = {
    'codist_kwargs': {
        'transform': 'softmax',
        'debug': False,
    },
}

scheduler_kwargs = {
    'alpha': 0.25,
    'decay_epochs': 15,
}

swa_scheduler_kwargs = {
    'swa_lr':  optimizer_kwargs['lr'] / 10,
}

## Contents

- **Training variations**
    - Asynchronous SGD baseline
      - Data-parallel sampling
      - Training set split sampling
    - Codistillation
      - Data-parallel sampling
      - Training set split sampling
- **Hyperparam exploration** (along each dim independently) 
    - `epochs_sgd`: 1, 2, 4, 7, 9, 10, 11, 12 (on best training variations)
    - `epochs_swa`: 2, 3, 4, 5 (`epochs_sgd` of 0 and best non-zero)
    - `sync_freq`: 25, 50, 75, 100, 150, 200, 300, 400 (`epochs_sgd` and `epochs_swa` of 0, and best non-zero)

## Training variations

### Asynchronous SGD baseline

In [3]:
%%time

train_kwargs['epochs_sgd'] = 15
train_kwargs['epochs_codist'] = 0
train_kwargs['epochs_swa'] = 0

for split_training in [False, True]:

    dataloader_kwargs['split_training'] = split_training
    dataloader_kwargs['data_parallel'] = not split_training
    
    trainer_kwargs['name'] = f'async-sgd-split_training={split_training}'
    
    args = (world_size,
            dataloader_kwargs,
            model_kwargs,
            optimizer_kwargs,
            trainer_kwargs,
            train_kwargs,
            scheduler_kwargs,
            None, # swa_scheduler_kwargs,
            seed, # seed on rank i = seed + i
            False) # ddp=False, don't use DistributedDataParallel

    # begin training
    mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2764800)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2764800)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Starting training from spawn_fn...
Worker 1/2 starting 15-epoch training loop...
SGD epochs: 15 | Codistillation epochs: 0 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.925734 <> acc=0.278153 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.681352 <> accuracy=0.363867 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.515177 <> 

### Codistillation

In [4]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_codist'] = 15
train_kwargs['epochs_swa'] = 0

train_kwargs['codist_kwargs']['sync_freq'] = 50

for split_training in [False, True]:
    
    dataloader_kwargs['split_training'] = split_training
    dataloader_kwargs['data_parallel'] = not split_training

    trainer_kwargs['name'] = f'codist-epochs_sgd=0_sync_freq=50-split_training={split_training}'

    args = (world_size,
            dataloader_kwargs,
            model_kwargs,
            optimizer_kwargs,
            trainer_kwargs,
            train_kwargs,
            scheduler_kwargs,
            None, # swa_scheduler_kwargs,
            seed) # seed on rank i = seed + i

    # begin training
    mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Starting training from spawn_fn...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 0 | Codistillation epochs: 15 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Starting codistillation phase...
loss_fn: cross_entropy
sync_freq: 50
transform: softmax

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=2.028972 <> acc=0.257924 <> codist_loss=2.238378 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.78

## Hyperparam exploration

### `epochs_sgd`

In [5]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_swa'] = 0

train_kwargs['codist_kwargs']['sync_freq'] = 50

for epochs_sgd in [1, 2, 4, 7]:
    
    train_kwargs['epochs_sgd'] = epochs_sgd
    train_kwargs['epochs_codist'] = 15 - epochs_sgd

    trainer_kwargs['name'] = f'codist-epochs_sgd={epochs_sgd}_sync_freq=50-split_training=False'

    args = (world_size,
            dataloader_kwargs,
            model_kwargs,
            optimizer_kwargs,
            trainer_kwargs,
            train_kwargs,
            scheduler_kwargs,
            None, # swa_scheduler_kwargs,
            seed) # seed on rank i = seed + i

    # begin training
    mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Starting training from spawn_fn...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 1 | Codistillation epochs: 14 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.943864 <> acc=0.263690 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.674705 <> accuracy=0.379297 | Batch: 40/40 (100%)

Starting codistillation phase...
loss_fn: cross_entropy
sync_freq

In [6]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_swa'] = 0

train_kwargs['codist_kwargs']['sync_freq'] = 50

for epochs_sgd in [9, 10, 11, 12]:
    
    train_kwargs['epochs_sgd'] = epochs_sgd
    train_kwargs['epochs_codist'] = 15 - epochs_sgd

    trainer_kwargs['name'] = f'codist-epochs_sgd={epochs_sgd}_sync_freq=50-split_training=False'

    args = (world_size,
            dataloader_kwargs,
            model_kwargs,
            optimizer_kwargs,
            trainer_kwargs,
            train_kwargs,
            scheduler_kwargs,
            None, # swa_scheduler_kwargs,
            seed) # seed on rank i = seed + i

    # begin training
    mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Starting training from spawn_fn...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 9 | Codistillation epochs: 6 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.944476 <> acc=0.262960 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.681966 <> accuracy=0.372070 | Batch: 40/40 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.527985 <> a

### `epochs_swa`

In [7]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['codist_kwargs']['sync_freq'] = 50

for epochs_sgd, epochs_swa in [(0, 2), (0, 3), (0, 4), (0, 5),
                               (11, 2), (11, 3)]:

    train_kwargs['epochs_sgd'] = epochs_sgd
    train_kwargs['epochs_codist'] = 15 - epochs_swa - epochs_sgd
    train_kwargs['epochs_swa'] = epochs_swa

    trainer_kwargs['name'] = f'codist-epochs_sgd={epochs_sgd}_epochs_swa={epochs_swa}-sync_freq=50-split_training=False'

    args = (world_size,
            dataloader_kwargs,
            model_kwargs,
            optimizer_kwargs,
            trainer_kwargs,
            train_kwargs,
            scheduler_kwargs,
            swa_scheduler_kwargs,
            seed) # seed on rank i = seed + i

    # begin training
    mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Starting training from spawn_fn...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 0 | Codistillation epochs: 13 | SWA epochs: 2
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Starting codistillation phase...
loss_fn: cross_entropy
sync_freq: 50
transform: softmax

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=2.028878 <> acc=0.258400 <> codist_loss=2.238627 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.77

### `sync_freq`

In [8]:
%%time

if torch.cuda.is_available():
    torch.cuda.empty_cache()

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

for epochs_sgd, epochs_swa in [(0, 0), (11, 0), (11, 3)]:
    
    for sync_freq in [25, 75, 100, 150, 200]:
        
        epochs_codist = 15 - epochs_swa - epochs_sgd
        
        # skip redundant sync_freq where only initial sync occurs
        if epochs_codist == 1 and sync_freq in [300, 400]:
            continue

        train_kwargs['epochs_sgd'] = epochs_sgd
        train_kwargs['epochs_codist'] = epochs_codist
        train_kwargs['epochs_swa'] = epochs_swa

        train_kwargs['codist_kwargs']['sync_freq'] = sync_freq

        trainer_kwargs['name'] = f'codist-sync-epochs_sgd={epochs_sgd}_epochs_swa={epochs_swa}-sync_freq={sync_freq}-split_training=False'

        args = (world_size,
                dataloader_kwargs,
                model_kwargs,
                optimizer_kwargs,
                trainer_kwargs,
                train_kwargs,
                scheduler_kwargs,
                None if epochs_swa == 0 else swa_scheduler_kwargs,
                seed) # seed on rank i = seed + i

        # begin training
        mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

Starting training from spawn_fn...
Rank 0: joined process group on device cuda with backend nccl
Rank 0: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 1/2 starting 15-epoch training loop...
SGD epochs: 0 | Codistillation epochs: 15 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None
Global batch size: 256

Starting codistillation phase...
loss_fn: cross_entropy
sync_freq: 25
transform: softmax

Rank 1: joined process group on device cuda with backend nccl
Rank 1: torch.manual_seed(2678400)
Files already downloaded and verified
Using DistributedSampler
Number of training samples: 45000
Number of training batches: 176


Worker 2/2 starting 15-epoch training loop...
Train epoch: 1 | Metrics (epoch mean): cross_entropy=2.023923 <> acc=0.253398 <> codist_loss=2.165242 | Batch: 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.78