## Tutorial 1. ResNet18 on CIFAR10. 


In this tutorial, we will show 

- How to end-to-end train and compress a ResNet18 from scratch on CIFAR10 to get a compressed ResNet18.
- The compressed ResNet18 achives both **high performance** and **significant FLOPs and parameters reductions** than the full model. 
- The compressed ResNet18 **reduces about 92% parameters** to achieve **92.91% accuracy** only lower than the baseline by **0.11%**.
- More detailed new HESSO optimizer setup. (Technical report regarding HESSO will be released on the early of 2024).

### Step 0. Clone repo and set up enviroment 

In [1]:
# Clone GETA repositories
!rm -rf /kaggle/working/geta
!git clone --branch pytorch-2.6-compatibility https://github.com/eli-bigman/geta.git


Cloning into 'geta'...
remote: Enumerating objects: 297, done.[K
remote: Counting objects: 100% (297/297), done.[K
remote: Compressing objects: 100% (211/211), done.[K
remote: Total 297 (delta 100), reused 265 (delta 76), pack-reused 0 (from 0)[K
Receiving objects: 100% (297/297), 542.17 KiB | 12.91 MiB/s, done.
Resolving deltas: 100% (100/100), done.


In [2]:
!ls /kaggle/working/geta/sanity_check/backends


carn				   diffusion_transformer_sr
convnext.py			   hf_llama
demo_group_conv_case1.py	   hf_phi2
demonet_batchnorm_pruning.py	   hf_sam
demonet_concat_case1.py		   hf_vit
demonet_concat_case2.py		   __init__.py
demonet_convtranspose_in_case1.py  mamba
demonet_convtranspose_in_case2.py  mlp.py
demonet_groupnorm_case1.py	   resnet20_cifar10.py
demonet_groupnorm_case2.py	   resnet_cifar10.py
demonet_groupnorm_case3.py	   resnet_DuBIN.py
demonet_groupnorm_case4.py	   resnet_DuBN.py
demonet_in_case3.py		   simple_vit.py
demonet_weightshare_case1.py	   tnlg
demonet_weightshare_case2.py	   vgg7.py
densenet.py			   vision_transformer
diffusion


In [3]:
!pip install "torch==2.0.1+cu117" \
             "torchvision==0.15.2+cu117" \
             "torchaudio==2.0.2" \
             --extra-index-url https://download.pytorch.org/whl/cu117


Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117
Collecting torch==2.0.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp311-cp311-linux_x86_64.whl (1843.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m480.8 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision==0.15.2+cu117
  Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp311-cp311-linux_x86_64.whl (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m92.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hCollecting torchaudio==2.0.2
  Downloading https://download.pytorch.org/whl/cu117/torchaudio-2.0.2%2Bcu117-cp311-cp311-linux_x86_64.whl (4.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m85.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m
Collecting triton==2.0.0 (from torch==2.0.1+cu117)
  Downloading

In [4]:
!python --version


Python 3.11.13


### Step 1. Create OTO instance

In [5]:
import sys
sys.path.append('/kaggle/working/geta')
# sys.path.append('/kaggle/working/OpenGait')
# sys.path.append('..')
from sanity_check.backends.resnet_cifar10 import resnet18_cifar10
from only_train_once import OTO
import torch

model = resnet18_cifar10()
dummy_input = torch.rand(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

OTO graph constructor
graph build
NodePattern mul None
NodePattern transpose None
NodePattern matmul None
Post-processing of graph completed.
Graph has 70 nodes and 77 edges.


#### (Optional) Visualize the pruning dependancy graph of DNN

In [6]:
# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.
oto.visualize(view=False, out_dir='../cache')



### Step 2. Dataset Preparation

In [7]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))

trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 35115684.99it/s]


Extracting cifar10/cifar-10-python.tar.gz to cifar10
Files already downloaded and verified


### Step 3. Setup HESSO optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. Currently support `sgd`, `adam` and `adamw`.
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `start_pruning_steps`: The number of steps that **starts** to prune.
- `pruning_steps`: The number of steps that **finishes** pruning (reach `target_group_sparsity`) after `start_pruning_steps`.
- `pruning_periods`:  Incrementally produce the group sparsity equally among pruning periods.

We empirically suggest `start_pruning_steps` as 1/10 of total number of training steps. `pruning_steps` until 1/4 or 1/5 of total number of training steps.
The advatnages of HESSO compared to DHSPG is its explicit control over group sparsity exploration, which is typically more convenient.

In [8]:
optimizer = oto.hesso(
    variant='sgd', 
    lr=0.1, 
    weight_decay=1e-4,
    target_group_sparsity=0.7,
    start_pruning_step=10 * len(trainloader), 
    pruning_periods=10,
    pruning_steps=10 * len(trainloader)
)

Setup HESSO
Target redundant groups per period:  [201, 201, 201, 201, 201, 201, 201, 201, 201, 206]


### Step 4. Train ResNet18 as normal.

In [9]:
from tutorials.utils.utils import check_accuracy

max_epoch = 100
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
# Every 50 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 

for epoch in range(max_epoch):
    f_avg_val = 0.0
    model.train()
    lr_scheduler.step()
    for X, y in trainloader:
        X = X.cuda()
        y = y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    opt_metrics = optimizer.compute_metrics()
    # group_sparsity, param_norm, _ = optimizer.compute_group_sparsity_param_norm()
    # norm_important, norm_redundant, num_grps_important, num_grps_redundant = optimizer.compute_norm_groups()
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    
    print("Ep: {ep}, loss: {f:.2f}, norm_all:{param_norm:.2f}, grp_sparsity: {gs:.2f}, acc1: {acc1:.4f}, norm_import: {norm_import:.2f}, norm_redund: {norm_redund:.2f}, num_grp_import: {num_grps_import}, num_grp_redund: {num_grps_redund}"\
         .format(ep=epoch, f=f_avg_val, param_norm=opt_metrics.norm_params, gs=opt_metrics.group_sparsity, acc1=accuracy1,\
         norm_import=opt_metrics.norm_important_groups, norm_redund=opt_metrics.norm_redundant_groups, \
         num_grps_import=opt_metrics.num_important_groups, num_grps_redund=opt_metrics.num_redundant_groups
        ))



Ep: 0, loss: 1.59, norm_all:4122.36, grp_sparsity: 0.00, acc1: 39.6700, norm_import: 4122.36, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 1, loss: 1.05, norm_all:4114.05, grp_sparsity: 0.00, acc1: 52.3600, norm_import: 4114.05, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 2, loss: 0.80, norm_all:4105.77, grp_sparsity: 0.00, acc1: 67.2000, norm_import: 4105.77, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 3, loss: 0.65, norm_all:4096.03, grp_sparsity: 0.00, acc1: 72.0100, norm_import: 4096.03, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 4, loss: 0.56, norm_all:4085.63, grp_sparsity: 0.00, acc1: 65.4200, norm_import: 4085.63, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 5, loss: 0.50, norm_all:4074.52, grp_sparsity: 0.00, acc1: 70.7400, norm_import: 4074.52, norm_redund: 0.00, num_grp_import: 2880, num_grp_redund: 0
Ep: 6, loss: 0.45, norm_all:4062.94, grp_sparsity: 0.00, acc1: 78.8300, norm_impor

### Step 5. Get compressed model in torch format

In [1]:
# By default OTO will construct subnet by the last checkpoint. If intermedia ckpt reaches the best performance,
# need to reinitialize OTO instance
# oto = OTO(torch.load(ckpt_path), dummy_input)
# then construct subnetwork
oto.construct_subnet(out_dir='./cache')

NameError: name 'oto' is not defined

### (Optional) Check the compressed model size

In [None]:
import os

full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model     : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")

### (Optional) Check the compressed model accuracy
#### # Both full and compressed model should return the exact same accuracy.

In [12]:
full_model = torch.load(oto.full_group_sparse_model_path)
compressed_model = torch.load(oto.compressed_model_path)

acc1_full, acc5_full = check_accuracy(full_model, testloader)
print("Full model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_full, acc5=acc5_full))

acc1_compressed, acc5_compressed = check_accuracy(compressed_model, testloader)
print("Compressed model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_compressed, acc5=acc5_compressed))

Full model: Acc 1: 92.6, Acc 5: 99.74
Compressed model: Acc 1: 92.6, Acc 5: 99.74
