## Setup

In [None]:
# check GPU
!nvidia-smi

In [None]:
# clone apex
!git clone https://github.com/NVIDIA/apex

In [None]:
# install apex (you need to comment out the version check in setup.py)
!cd apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./

In [None]:
# reload modules in .py files
%load_ext autoreload
%autoreload 2

In [None]:
# pull repo
!git clone https://github.com/char-tan/sparsity

In [None]:
# change working directory, make dir for models
import os
os.chdir('sparsity')
os.makedirs('models', exist_ok=True)

In [None]:
# checkout branch
!git checkout ct_dev

## Training config

In [None]:
import torch

from training.training import *
from training.utils import *

from apex.contrib.sparsity import ASP

In [None]:
config = Config(num_epochs=2)

torch.manual_seed(config.seed)

model = resnet18_small_input().to(config.device)

torch.save(model.state_dict(), 'models/init.pt')

optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.lr,
        momentum=config.momentum,
        weight_decay=config.weight_decay)

train_loader, test_loader = cifar10_dataloaders(config)

## Phase 1 training

In [None]:
train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), 'models/phase1.pt')

## Prune model, evaluate after pruning

In [None]:
# prune model + applying mask s.t params stay zeroed
ASP.prune_trained_model(model, optimizer)

torch.save(model.state_dict(), 'models/phase1_pruned.pt')

In [None]:
# evaluate on train + test data
train_loss, train_acc = test_epoch(model, train_loader, config.device)
test_loss, test_acc = test_epoch(model, test_loader, config.device)

epoch_summary({
"train loss": train_loss,
"train acc": train_acc,
"test loss": test_loss,
"test acc": test_acc,
})

## Phase 2 training

In [None]:
train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), 'models/phase2.pt')

## Train from original init with mask (LTH)

In [None]:
# apply mask to init params then load into model
model.load_state_dict(mask_checkpoint(torch.load('init.pt'), model), strict=False)

torch.save(model.state_dict(), 'models/init_pruned.pt')

train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), 'models/lottery_ticket.pt')

## Train from random init with mask

In [None]:
torch.manual_seed(config.seed + 1)

# produce new initalisation
new_init_params = resnet18_small_input().cuda().state_dict()

torch.save(new_init_params, 'models/new_init.pt')

# apply mask to params then load into model
model.load_state_dict(mask_checkpoint(new_init_params, model), strict=False)

torch.save(model.state_dict(), 'models/new_init_pruned.pt')

train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), 'models/random_lottery_ticket.pt')