In [None]:
from google.colab import drive
drive.mount("/mnt")

In [None]:
!mkdir -p /mnt/MyDrive/workspace/
%cd "/mnt/MyDrive/workspace/"

In [None]:
!git clone https://github.com/changmg/transaxx.git

In [None]:
%cd "/mnt/MyDrive/workspace/transaxx"

In [None]:
!uv pip install -r requirements.txt

In [None]:
import torch
print(torch.__version__)

# Model evaluation and re-training with TransAxx on CIFAR10 dataset

In this notebook you can evaluate different approximate multipliers on various models.
You can also retrain the model for further accuracy improvement

**Note**:
* Currently, the quantization bitwidth supported is 8bit and supported layers are Conv2d and Linear

* Please make sure you have run the installation steps first

* This example notebook approximates Conv2d layers

In [1]:
from classification.utils import *
device = 'cuda'

## Load dataset

Set your path for the CIFAR10 dataset

'calib dataset' is created from a 10% sample of train data for calibration purposes


In [2]:
val_data, calib_data = cifar10_data_loader(data_path="./datasets/cifar10_data", batch_size=128)

Files already downloaded and verified
Files already downloaded and verified


## Select a pretrained model

In [3]:
# an example repo with cifar10 models. you can use your own (ref: https://github.com/chenyaofo/pytorch-cifar-models)
model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar10_repvgg_a0', pretrained=True).to(device)

Using cache found in /home/chang/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


## Optional: Evaluate default model


In [4]:
top1 = evaluate_cifar10(model, val_data, device = device)

100%|██████████| 78/78 [00:02<00:00, 33.76it/s]

2.481016346000615
Accuracy of the network on the 10000 test images: 94.3209 %





## Initialize model with axx layers


In [5]:
# get conv2d layers to approximate
conv2d_layers = [(name, module) for name, module in model.named_modules() if (isinstance(module, torch.nn.Conv2d) or isinstance(module, AdaptConv2D)) and ("head" not in name and "reduction" not in name)]

In [6]:
len(conv2d_layers)

44

In [7]:
# Initialize model with all required approximate multipliers for axx layers. 
# No explicit assignment needed; this step JIT compiles all upcoming multipliers

axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(conv2d_layers)
axx_list[3:4] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 1

start = time.time()
replace_conv_layers(model,  AdaptConv2D, axx_list, 0, 0, layer_count=[0], returned_power = [0], initialize = True)  
print('Time to compile cuda extensions: ', time.time()-start)

CUDA Compute Architecture: sm_89


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compu

In [8]:
# measure flops of model and compute 'flops' in every layer

import io
from classification.ptflops import get_model_complexity_info
from classification.ptflops.pytorch_ops import linear_flops_counter_hook
from classification.ptflops.pytorch_ops import conv_flops_counter_hook

#hook our custom axx_layers in the appropriate flop counters, i.e. AdaptConv2D : conv_flops_counter_hook
with torch.cuda.device(0):
    total_macs, total_params, layer_specs = get_model_complexity_info(model, (3, 32, 32),as_strings=False, print_per_layer_stat=True,
                                                          custom_modules_hooks={AdaptConv2D : conv_flops_counter_hook}, 
                                                          param_units='M', flops_units='MMac',
                                                          verbose=True)

print(f'Computational complexity:  {total_macs/1000000:.2f} MMacs')
print(f'Number of parameters::  {total_params/1000000:.2f} MParams')




  return F.conv2d(quant_input, quant_weight, bias, stride, padding, dilation, groups)


Computational complexity:  491.95 MMacs
Number of parameters::  7.84 MParams


## Run model calibration for quantization

Calibrates the quantization parameters 

Need to re-run it each time the initial model changes

In [9]:
with torch.no_grad():
    stats = collect_stats(model, calib_data, num_batches=2, device=device)
    amax = compute_amax(model, method="percentile", percentile=99.99, device=device)
    
    # optional - test different calibration methods
    #amax = compute_amax(model, method="mse")
    #amax = compute_amax(model, method="entropy")

100%|██████████| 2/2 [00:03<00:00,  1.60s/it]
W0202 13:15:54.835047 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.835716 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.836232 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.836864 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.837194 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.837452 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.837690 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.837910 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.838751 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.838996 135905308418688 tensor_quantizer.py:173] Disable HistogramCalibrator
W0202 13:15:54.839264 135905308418688 tensor

stage0.rbr_dense.conv.quantizer         : TensorQuantizer(8bit per-tensor amax=2.1255 calibrator=HistogramCalibrator scale=1.0 quant)
stage0.rbr_dense.conv.quantizer_w       : TensorQuantizer(8bit per-tensor amax=0.3628 calibrator=HistogramCalibrator scale=349.8851318359375 quant)
stage0.rbr_1x1.conv.quantizer           : TensorQuantizer(8bit per-tensor amax=2.1255 calibrator=HistogramCalibrator scale=1.0 quant)
stage0.rbr_1x1.conv.quantizer_w         : TensorQuantizer(8bit per-tensor amax=0.4948 calibrator=HistogramCalibrator scale=256.5318603515625 quant)
stage1.0.rbr_dense.conv.quantizer       : TensorQuantizer(8bit per-tensor amax=1.1008 calibrator=HistogramCalibrator scale=419.8192138671875 quant)
stage1.0.rbr_dense.conv.quantizer_w     : TensorQuantizer(8bit per-tensor amax=0.1891 calibrator=HistogramCalibrator scale=623.5847778320312 quant)
stage1.0.rbr_1x1.conv.quantizer         : TensorQuantizer(8bit per-tensor amax=1.1008 calibrator=HistogramCalibrator scale=419.8192138671875

## Run model evaluation


In [10]:
# set desired approximate multiplier in each layer

#at first, set all layers to have the 8-bit accurate multiplier
axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(conv2d_layers)

# For example, set the first 10 layers to be approximated with a specific multiplier 
axx_list[0:10] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 10

returned_power = [0]
replace_conv_layers(model,  AdaptConv2D, axx_list, total_macs, total_params, layer_count=[0], returned_power = returned_power, initialize = False)  
print('Power of approximated operations: ', round(returned_power[0], 2), '%')
print('Model compiled. Running evaluation')

# Run evaluation on the validation dataset
top1 = evaluate_cifar10(model, val_data, device = device)

CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compute Architecture: sm_89
CUDA Compu

100%|██████████| 78/78 [00:30<00:00,  2.52it/s]

31.05098044100032
Accuracy of the network on the 10000 test images: 91.7167 %





## Run model retraining


In [16]:
from classification.train import train_one_epoch

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) # set desired learning rate
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

#one epoch retrain
train_one_epoch(model, criterion, optimizer, calib_data, device, 0, 10)

Epoch: [0]  [ 0/39]  eta: 0:00:32  lr: 0.0001  img/s: 179.65246524378043  loss: 0.0849 (0.0849)  acc1: 97.6562 (97.6562)  acc5: 100.0000 (100.0000)  time: 0.8401  data: 0.1276  max mem: 756
Epoch: [0]  [10/39]  eta: 0:00:13  lr: 0.0001  img/s: 306.99301065584785  loss: 0.0969 (0.0937)  acc1: 96.8750 (96.8750)  acc5: 100.0000 (99.8580)  time: 0.4616  data: 0.0118  max mem: 779
Epoch: [0]  [20/39]  eta: 0:00:08  lr: 0.0001  img/s: 297.4863353338605  loss: 0.0829 (0.0863)  acc1: 96.8750 (97.1726)  acc5: 100.0000 (99.8140)  time: 0.4226  data: 0.0002  max mem: 779
Epoch: [0]  [30/39]  eta: 0:00:03  lr: 0.0001  img/s: 302.8545756703608  loss: 0.0817 (0.0869)  acc1: 97.6562 (97.2530)  acc5: 100.0000 (99.7732)  time: 0.4228  data: 0.0001  max mem: 779
Epoch: [0] Total time: 0:00:16


## Re-run model evaluation

In [17]:
top1 = evaluate_cifar10(model, val_data, device = device)

100%|██████████| 78/78 [00:29<00:00,  2.66it/s]

29.45351870100012
Accuracy of the network on the 10000 test images: 91.9271 %



