## Setup

In [1]:
# check GPU
!nvidia-smi

Thu Jan 12 01:40:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0    41W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# clone apex
!git clone https://github.com/NVIDIA/apex

Cloning into 'apex'...
remote: Enumerating objects: 10686, done.[K
remote: Counting objects: 100% (208/208), done.[K
remote: Compressing objects: 100% (146/146), done.[K
remote: Total 10686 (delta 120), reused 119 (delta 62), pack-reused 10478[K
Receiving objects: 100% (10686/10686), 15.22 MiB | 32.67 MiB/s, done.
Resolving deltas: 100% (7348/7348), done.


In [3]:
# install apex
!cd apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./

[0mUsing pip 22.0.4 from /usr/local/lib/python3.8/dist-packages/pip (python 3.8)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing /content/apex
  Running command python setup.py egg_info


  torch.__version__  = 1.13.1+cu116


  running egg_info
  creating /tmp/pip-pip-egg-info-09sxrs62/apex.egg-info
  writing /tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/requires.txt
  writing top-level names to /tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/top_level.txt
  writing manifest file '/tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/SOURCES.txt'
  adding license file 'LICENSE'
  writing manifest file '/tmp/pip-pip-egg-info-09sxrs62/apex.egg-info/SOURCES.txt'
  Preparing metadata (setup.py) ... [?25l[?25hdone
Skipping wheel build for apex, due to binaries being 

In [4]:
# reload modules in .py files
%load_ext autoreload
%autoreload 2

In [5]:
# pull repo
!git clone https://github.com/char-tan/sparsity

Cloning into 'sparsity'...
remote: Enumerating objects: 158, done.[K
remote: Counting objects: 100% (158/158), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 158 (delta 74), reused 138 (delta 57), pack-reused 0[K
Receiving objects: 100% (158/158), 59.31 KiB | 5.93 MiB/s, done.
Resolving deltas: 100% (74/74), done.


In [6]:
# change working directory, make dir for models
import os

os.chdir("sparsity")
os.makedirs("models", exist_ok=True)

In [7]:
# checkout branch
!git checkout ct_dev

Branch 'ct_dev' set up to track remote branch 'ct_dev' from 'origin'.
Switched to a new branch 'ct_dev'


## Training config

In [8]:
import torch

from training.training import *
from training.utils import *

from apex.contrib.sparsity import ASP

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


In [9]:
config = Config(num_epochs=20)

torch.manual_seed(config.seed)

model = resnet18_small_input().to(config.device)

torch.save(model.state_dict(), "models/init.pt")

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=config.lr,
    momentum=config.momentum,
    weight_decay=config.weight_decay,
)

train_loader, test_loader = cifar10_dataloaders(config)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data


## Phase 1 training

In [10]:
train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), "models/phase1.pt")

epoch: 0 | time: 19.26 | train loss: 1.949 | train acc: 26.81 | test loss: 1.836 | test acc: 32.52 | 
epoch: 1 | time: 10.08 | train loss: 1.502 | train acc: 44.2 | test loss: 1.398 | test acc: 47.91 | 
epoch: 2 | time: 10.09 | train loss: 1.176 | train acc: 57.27 | test loss: 1.319 | test acc: 54.02 | 
epoch: 3 | time: 10.17 | train loss: 0.9565 | train acc: 65.59 | test loss: 1.335 | test acc: 58.1 | 
epoch: 4 | time: 10.55 | train loss: 0.8337 | train acc: 70.08 | test loss: 0.8892 | test acc: 69.19 | 
epoch: 5 | time: 10.28 | train loss: 0.7232 | train acc: 74.39 | test loss: 0.7767 | test acc: 72.95 | 
epoch: 6 | time: 10.14 | train loss: 0.6366 | train acc: 77.7 | test loss: 0.7102 | test acc: 74.39 | 
epoch: 7 | time: 10.18 | train loss: 0.5655 | train acc: 80.35 | test loss: 0.6108 | test acc: 79.52 | 
epoch: 8 | time: 10.36 | train loss: 0.5133 | train acc: 82.26 | test loss: 0.7268 | test acc: 76.55 | 
epoch: 9 | time: 10.28 | train loss: 0.4602 | train acc: 83.94 | test loss

## Prune model, evaluate after pruning

In [11]:
# prune model + applying mask s.t params stay zeroed
ASP.prune_trained_model(model, optimizer)

torch.save(model.state_dict(), "models/phase1_pruned.pt")

[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.

[set_permutation_params_from_asp] Set permutation needed parameters

[set_identical_seed] Set the identical seed: 1 for all GPUs to make sure the same results generated in permutation search
[ASP] Auto skipping pruning conv1::weight of size=torch.Size([64, 3, 3, 3]) and type=torch.float32 for sparsity
[ASP] Auto skipping pruning fc::weight of size=torch.Size([10, 512]) and type=torch.float32 for sparsity

[build_offline_permutation_graph] Further refine the model graph built by Torch.FX for offline permutation
[build_fx_graph] The torch version is: 1.13.1+cu116, version major is: 1, version minor is: 13, version minimum is: 1+cu116
[build_fx_graph] The Torch.FX is supported.

[build_fx_graph] Print the model structure with pure PyTorch function
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [12]:
# evaluate on train + test data
train_loss, train_acc = test_epoch(model, train_loader, config.device)
test_loss, test_acc = test_epoch(model, test_loader, config.device)

epoch_summary(
    {
        "train loss": train_loss,
        "train acc": train_acc,
        "test loss": test_loss,
        "test acc": test_acc,
    }
)

train loss: 0.5173 | train acc: 81.64 | test loss: 0.6731 | test acc: 78.19 | 


## Phase 2 training

In [13]:
train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), "models/phase2.pt")

epoch: 0 | time: 10.3 | train loss: 0.2686 | train acc: 90.54 | test loss: 0.6183 | test acc: 80.48 | 
epoch: 1 | time: 10.37 | train loss: 0.3285 | train acc: 88.59 | test loss: 0.4762 | test acc: 84.07 | 
epoch: 2 | time: 10.37 | train loss: 0.3081 | train acc: 89.18 | test loss: 0.5399 | test acc: 83.15 | 
epoch: 3 | time: 10.59 | train loss: 0.2935 | train acc: 89.86 | test loss: 0.5265 | test acc: 83.6 | 
epoch: 4 | time: 10.34 | train loss: 0.2807 | train acc: 90.17 | test loss: 0.5158 | test acc: 83.87 | 
epoch: 5 | time: 10.44 | train loss: 0.2684 | train acc: 90.64 | test loss: 0.4854 | test acc: 84.56 | 
epoch: 6 | time: 10.27 | train loss: 0.2559 | train acc: 91.16 | test loss: 0.4198 | test acc: 86.19 | 
epoch: 7 | time: 10.13 | train loss: 0.2502 | train acc: 91.3 | test loss: 0.4047 | test acc: 87.23 | 
epoch: 8 | time: 10.18 | train loss: 0.2317 | train acc: 91.93 | test loss: 0.4608 | test acc: 85.73 | 
epoch: 9 | time: 10.47 | train loss: 0.2242 | train acc: 92.14 | te

## Train from original init with mask (LTH)

In [15]:
# apply mask to init params then load into model
model.load_state_dict(
    mask_checkpoint(torch.load("models/init.pt"), model), strict=False
)

torch.save(model.state_dict(), "models/init_pruned.pt")

train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), "models/lottery_ticket.pt")

epoch: 0 | time: 10.25 | train loss: 1.938 | train acc: 27.15 | test loss: 1.865 | test acc: 33.6 | 
epoch: 1 | time: 10.18 | train loss: 1.497 | train acc: 44.13 | test loss: 1.411 | test acc: 49.3 | 
epoch: 2 | time: 10.34 | train loss: 1.192 | train acc: 56.6 | test loss: 1.15 | test acc: 59.26 | 
epoch: 3 | time: 10.3 | train loss: 0.9747 | train acc: 64.91 | test loss: 0.965 | test acc: 66.34 | 
epoch: 4 | time: 10.41 | train loss: 0.8397 | train acc: 70.29 | test loss: 0.8474 | test acc: 70.34 | 
epoch: 5 | time: 10.16 | train loss: 0.7176 | train acc: 74.74 | test loss: 0.8993 | test acc: 69.88 | 
epoch: 6 | time: 10.26 | train loss: 0.6375 | train acc: 77.76 | test loss: 0.8328 | test acc: 73.17 | 
epoch: 7 | time: 10.44 | train loss: 0.5684 | train acc: 80.35 | test loss: 0.7879 | test acc: 73.92 | 
epoch: 8 | time: 10.29 | train loss: 0.5142 | train acc: 82.09 | test loss: 0.5979 | test acc: 79.67 | 
epoch: 9 | time: 10.38 | train loss: 0.4717 | train acc: 83.78 | test loss: 

## Train from random init with mask

In [16]:
torch.manual_seed(config.seed + 1)

# produce new initalisation
new_init_params = resnet18_small_input().cuda().state_dict()

torch.save(new_init_params, "models/new_init.pt")

# apply mask to params then load into model
model.load_state_dict(mask_checkpoint(new_init_params, model), strict=False)

torch.save(model.state_dict(), "models/new_init_pruned.pt")

train_phase(model, optimizer, train_loader, test_loader, config)

torch.save(model.state_dict(), "models/random_lottery_ticket.pt")

epoch: 0 | time: 10.45 | train loss: 1.941 | train acc: 26.95 | test loss: 1.808 | test acc: 35 | 
epoch: 1 | time: 10.49 | train loss: 1.484 | train acc: 44.94 | test loss: 1.528 | test acc: 48.21 | 
epoch: 2 | time: 10.36 | train loss: 1.15 | train acc: 58.39 | test loss: 1.243 | test acc: 56.54 | 
epoch: 3 | time: 10.43 | train loss: 0.9589 | train acc: 65.73 | test loss: 1.002 | test acc: 65.65 | 
epoch: 4 | time: 10.36 | train loss: 0.8356 | train acc: 70.68 | test loss: 0.9978 | test acc: 65.29 | 
epoch: 5 | time: 10.31 | train loss: 0.7177 | train acc: 74.69 | test loss: 0.8093 | test acc: 73.48 | 
epoch: 6 | time: 10.3 | train loss: 0.6273 | train acc: 78.32 | test loss: 0.7156 | test acc: 76.09 | 
epoch: 7 | time: 10.34 | train loss: 0.5643 | train acc: 80.48 | test loss: 0.845 | test acc: 73.15 | 
epoch: 8 | time: 10.22 | train loss: 0.5104 | train acc: 82.42 | test loss: 0.6858 | test acc: 77.65 | 
epoch: 9 | time: 10.23 | train loss: 0.4635 | train acc: 84.13 | test loss: 0