# Scripts to setup the development environment

Choose a GPU runtime (L4/T4/A100/H100). Do not choose CPU/TPU runtime.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# mount Google Drive to the "/mnt" folder in the Colab virtual machine
from google.colab import drive
drive.mount("/mnt")

Mounted at /mnt


In [None]:
# create folder and "cd" to the path
!mkdir -p /mnt/MyDrive/
%cd "/mnt/MyDrive/"

/mnt/MyDrive


In [None]:
# clone Gihub repo
!git clone https://github.com/changmg/transaxx.git

Cloning into 'transaxx'...
remote: Enumerating objects: 199, done.[K
remote: Counting objects: 100% (121/121), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 199 (delta 36), reused 114 (delta 36), pack-reused 78 (from 1)[K
Receiving objects: 100% (199/199), 407.41 KiB | 901.00 KiB/s, done.
Resolving deltas: 100% (51/51), done.


In [None]:
# "cd" to the repo folder
%cd "/mnt/MyDrive/transaxx"

/mnt/MyDrive/transaxx


In [None]:
# install python package(s)
!pip install ninja

Collecting ninja
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/180.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.13.0


# Model evaluation and re-training with TransAxx on CIFAR10 dataset

In this notebook you can evaluate different approximate multipliers on various models.
You can also retrain the model for further accuracy improvement

**Note**:
* Currently, the quantization bitwidth supported is 8bit and supported layers are Conv2d and Linear

* Please make sure you have run the installation steps first

* This example notebook approximates Conv2d layers

In [None]:
from classification.utils import *
device = 'cuda'

CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89


## Load dataset

Set your path for the CIFAR10 dataset

'calib dataset' is created from a 10% sample of train data for calibration purposes


In [None]:
val_data, calib_data = cifar10_data_loader(data_path="./datasets/cifar10_data", batch_size=128)

100%|██████████| 170M/170M [00:13<00:00, 12.8MB/s]


## Select a pretrained model

In [None]:
# an example repo with cifar10 models. you can use your own (ref: https://github.com/chenyaofo/pytorch-cifar-models)
model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_repvgg_a0', pretrained=True).to(device)



Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/zipball/master" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a0-ef08a50e.pt" to /root/.cache/torch/hub/checkpoints/cifar10_repvgg_a0-ef08a50e.pt


100%|██████████| 30.1M/30.1M [00:00<00:00, 37.2MB/s]


## Optional: Evaluate default model


In [None]:
top1 = evaluate_cifar10(model, val_data, device = device)

100%|██████████| 78/78 [00:02<00:00, 32.79it/s]

2.4874628289999805
Accuracy of the network on the 10000 test images: 94.3209 %





## Initialize model with axx layers


In [None]:
# get conv2d layers to approximate
conv2d_layers = [(name, module) for name, module in model.named_modules() if (isinstance(module, torch.nn.Conv2d) or isinstance(module, AdaptConv2D)) and ("head" not in name and "reduction" not in name)]

In [None]:
len(conv2d_layers)

44

In [None]:
# Initialize model with all required approximate multipliers for axx layers.
# No explicit assignment needed; this step JIT compiles all upcoming multipliers

axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(conv2d_layers)
axx_list[3:4] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 1

start = time.time()
replace_conv_layers(model,  AdaptConv2D, axx_list, 0, 0, layer_count=[0], returned_power = [0], initialize = True)
print('Time to compile cuda extensions: ', time.time()-start)

Time to compile cuda extensions:  142.7679636478424


In [None]:
# measure flops of model and compute 'flops' in every layer

import io
from classification.ptflops import get_model_complexity_info
from classification.ptflops.pytorch_ops import linear_flops_counter_hook
from classification.ptflops.pytorch_ops import conv_flops_counter_hook

#hook our custom axx_layers in the appropriate flop counters, i.e. AdaptConv2D : conv_flops_counter_hook
with torch.cuda.device(0):
    total_macs, total_params, layer_specs = get_model_complexity_info(model, (3, 32, 32),as_strings=False, print_per_layer_stat=True,
                                                          custom_modules_hooks={AdaptConv2D : conv_flops_counter_hook},
                                                          param_units='M', flops_units='MMac',
                                                          verbose=True)

print(f'Computational complexity:  {total_macs/1000000:.2f} MMacs')
print(f'Number of parameters::  {total_params/1000000:.2f} MParams')


Computational complexity:  491.95 MMacs
Number of parameters::  7.84 MParams


## Run model calibration for quantization

Calibrates the quantization parameters

Need to re-run it each time the initial model changes

In [None]:
with torch.no_grad():
    stats = collect_stats(model, calib_data, num_batches=2, device=device)
    amax = compute_amax(model, method="percentile", percentile=99.99, device=device)

    # optional - test different calibration methods
    #amax = compute_amax(model, method="mse")
    #amax = compute_amax(model, method="entropy")

100%|██████████| 2/2 [00:02<00:00,  1.45s/it]
W0202 13:00:28.505932 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.506758 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.507256 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.507720 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.508349 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.508900 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.509424 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.509933 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.510401 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.510946 138988219970176 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:00:28.511501 138988219970176 tensor

stage0.rbr_dense.conv.quantizer         : TensorQuantizer(8bit per-tensor amax=2.1255 calibrator=HistogramCalibrator scale=1.0 quant)
stage0.rbr_dense.conv.quantizer_w       : TensorQuantizer(8bit per-tensor amax=0.3628 calibrator=HistogramCalibrator scale=349.8851318359375 quant)
stage0.rbr_1x1.conv.quantizer           : TensorQuantizer(8bit per-tensor amax=2.1255 calibrator=HistogramCalibrator scale=1.0 quant)
stage0.rbr_1x1.conv.quantizer_w         : TensorQuantizer(8bit per-tensor amax=0.4948 calibrator=HistogramCalibrator scale=256.5318603515625 quant)
stage1.0.rbr_dense.conv.quantizer       : TensorQuantizer(8bit per-tensor amax=1.1008 calibrator=HistogramCalibrator scale=419.8192138671875 quant)
stage1.0.rbr_dense.conv.quantizer_w     : TensorQuantizer(8bit per-tensor amax=0.1891 calibrator=HistogramCalibrator scale=623.5847778320312 quant)
stage1.0.rbr_1x1.conv.quantizer         : TensorQuantizer(8bit per-tensor amax=1.1008 calibrator=HistogramCalibrator scale=419.8192138671875

## Run model evaluation


In [None]:
# set desired approximate multiplier in each layer

#at first, set all layers to have the 8-bit accurate multiplier
axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(conv2d_layers)

# For example, set the first 10 layers to be approximated with a specific multiplier
axx_list[0:10] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 10

returned_power = [0]
replace_conv_layers(model,  AdaptConv2D, axx_list, total_macs, total_params, layer_count=[0], returned_power = returned_power, initialize = False)
print('Power of approximated operations: ', round(returned_power[0], 2), '%')
print('Model compiled. Running evaluation')

# Run evaluation on the validation dataset
top1 = evaluate_cifar10(model, val_data, device = device)

Power of approximated operations:  94.43 %
Model compiled. Running evaluation


100%|██████████| 78/78 [00:20<00:00,  3.76it/s]

20.900561805999985
Accuracy of the network on the 10000 test images: 91.7167 %





## Run model retraining


In [None]:
from classification.train import train_one_epoch

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) # set desired learning rate
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

#one epoch retrain
train_one_epoch(model, criterion, optimizer, calib_data, device, 0, 10)

Epoch: [0]  [ 0/39]  eta: 0:00:40  lr: 0.0001  img/s: 151.10814288754062  loss: 0.0849 (0.0849)  acc1: 97.6562 (97.6562)  acc5: 100.0000 (100.0000)  time: 1.0469  data: 0.1998  max mem: 748
Epoch: [0]  [10/39]  eta: 0:00:10  lr: 0.0001  img/s: 423.8262731265848  loss: 0.0978 (0.0935)  acc1: 96.8750 (96.8750)  acc5: 100.0000 (99.8580)  time: 0.3711  data: 0.0185  max mem: 781
Epoch: [0]  [20/39]  eta: 0:00:06  lr: 0.0001  img/s: 423.4575474414096  loss: 0.0825 (0.0863)  acc1: 96.8750 (97.1726)  acc5: 100.0000 (99.8140)  time: 0.3035  data: 0.0003  max mem: 781
Epoch: [0]  [30/39]  eta: 0:00:02  lr: 0.0001  img/s: 419.30198876594045  loss: 0.0814 (0.0869)  acc1: 97.6562 (97.2530)  acc5: 100.0000 (99.7732)  time: 0.3036  data: 0.0002  max mem: 781
Epoch: [0] Total time: 0:00:12


## Re-run model evaluation

In [None]:
top1 = evaluate_cifar10(model, val_data, device = device)

100%|██████████| 78/78 [00:20<00:00,  3.75it/s]

20.946440060999976
Accuracy of the network on the 10000 test images: 91.9671 %



