In [None]:
import os
import datetime

import torch
import torch.multiprocessing as mp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
from swadist.utils import spawn_fn

# mp.spawn may throw an error without this
os.environ['MKL_THREADING_LAYER'] = 'GNU'

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print('Using cuda')
else:
    print('Using cpu')

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

In [None]:
# number of model replicas
world_size = 2

dataloader_kwargs = {
    'dataset': 'cifar10',
    'batch_size': 256 // world_size,
    'num_workers': 2,
}

model_kwargs = {
    'n_classes': 10,
    'in_kernel_size': 3,
    'stack_sizes': [1, 1, 1],
    'batch_norm': False,
}

optimizer_kwargs = {
    'lr': 2**-5.,
    'momentum': 0.975,
    'nesterov': True,
}

trainer_kwargs = {
    # whether to log training to Tensorboard
    'log': True,
    'log_dir': './runs-notebooks',
}

train_kwargs = {
    'epochs_sgd': 0,
    'swadist': True,
    'codist_kwargs': {
        'sync_freq': 50,
        'transform': 'softmax',
        'debug': False,
    },
    'swadist_kwargs': {
        'max_averaged': 3,
        'sync_freq': 50,
        'transform': 'softmax',
        'debug': True,
    },
    'save': True,
    'save_dir': './_state_dicts',
    # 'stopping_acc': 0.7,
}

scheduler_kwargs = {
    'alpha': 0.25,
    'decay_epochs': 15,
}

# swa_scheduler_kwargs = {
#     'swa_lr':  optimizer_kwargs['lr'] / 10, 
#     'anneal_strategy': 'cos',
#     'anneal_epochs': 3,
# }

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-dp'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0 # 5 # TODO
train_kwargs['epochs_swa'] = 8

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-swa-replicas-dp'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 15
train_kwargs['swadist_kwargs']['swa_replicas'] = True
train_kwargs['stop_stall_n_epochs'] = 3

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-split'

dataloader_kwargs['data_parallel'] = False
dataloader_kwargs['split_training'] = True

train_kwargs['epochs_sgd'] = 5
train_kwargs['epochs_swa'] = 20

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-swa-replicas-split'

dataloader_kwargs['data_parallel'] = False
dataloader_kwargs['split_training'] = True

train_kwargs['epochs_sgd'] = 5
train_kwargs['epochs_swa'] = 20

train_kwargs['swadist_kwargs']['swa_replicas'] = True

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-dp-no-sgd'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 25

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-swa-replicas-dp-no-sgd'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 25

train_kwargs['swadist_kwargs']['swa_replicas'] = True

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-split-no-sgd'

dataloader_kwargs['data_parallel'] = False
dataloader_kwargs['split_training'] = True

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 25

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-swa-replicas-split-no-sgd'

dataloader_kwargs['data_parallel'] = False
dataloader_kwargs['split_training'] = True

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 25

train_kwargs['swadist_kwargs']['swa_replicas'] = True

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-ablate-codist-dp'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 15
train_kwargs['codist_kwargs']['sync_freq'] = 0
train_kwargs['swadist_kwargs']['swa_replicas'] = False

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)

In [None]:
%%time

trainer_kwargs['name'] = 'swadist-swa-replicas-ablate-codist-dp'

dataloader_kwargs['data_parallel'] = True
dataloader_kwargs['split_training'] = False

train_kwargs['epochs_sgd'] = 0
train_kwargs['epochs_swa'] = 15
train_kwargs['codist_kwargs']['sync_freq'] = 0
train_kwargs['swadist_kwargs']['swa_replicas'] = True

args = (world_size,
        dataloader_kwargs,
        model_kwargs,
        optimizer_kwargs,
        trainer_kwargs,
        train_kwargs,
        scheduler_kwargs,
        None, # swa_scheduler_kwargs,
        seed) # seed on rank i = seed + i

# begin training
mp.spawn(spawn_fn, args=args, nprocs=world_size, join=True)