# Model evaluation and re-training with TransAxx on ImageNet dataset

In this notebook you can evaluate different approximate multipliers on various models.
You can also retrain the model for further accuracy improvement

**Note**:
* Currently, the quantization bitwidth supported is 8bit and supported layers are Conv2d and Linear

* Please make sure you have run the installation steps first

* This example notebook approximates Linear layers 

In [None]:
from classification.utils import *
device = 'cuda'

## Load dataset

Provide an ImageNet dataset with the structure below (Use the 'imagenet_data' path)

imagenet_data/

└── val/

└── train_tiny/

**Note**: 'val' is the validation dataset and 'train_tiny' should be a small train dataset used for calibration purposes. Change batch size if needed.


In [None]:
val_data, calib_data = imagenet_data_loader('datasets/imagenet_data', batch_size=128)

## Select a pretrained model

In [None]:
import timm 

model_name = 'vit_small_patch16_224'

model = timm.create_model(model_name, pretrained=True).to(device)

In [None]:
# to avoid frequent downloading of the weights the following commands might be handy

#torch.save(model, 'models/' + model_name + '.pth')
#model = timm.create_model(model_name).to(device)
#model.load_state_dict(torch.load('models/' + model_name + '.pth'))


## Optional: Evaluate default model


In [None]:
top1 = evaluate_cifar10(model, val_data, device = device)

## Initialize model with axx layers


In [None]:
# get linear layers to approximate
linear_layers = [(name, module) for name, module in model.named_modules() if (isinstance(module, torch.nn.Linear) or  isinstance(module, AdaPT_Linear)) and ("head" not in name and "reduction" not in name)]

In [None]:
len(linear_layers)

In [None]:
# Initialize model with all required approximate multipliers for axx layers. 
# No explicit assignment needed; this step JIT compiles all upcoming multipliers

axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(linear_layers)

axx_list[1:2] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 1

start = time.time()
replace_linear_layers(model,  AdaPT_Linear, axx_list, 0, 0, layer_count=[0], returned_power = [0], initialize = True)  
print('Time to compile cuda extensions: ', time.time()-start)

In [None]:
# measure flops of model and compute 'flops' in every layer

import io
from classification.ptflops import get_model_complexity_info
from classification.ptflops.pytorch_ops import linear_flops_counter_hook
from classification.ptflops.pytorch_ops import conv_flops_counter_hook

#hook our custom axx_layers in the appropriate flop counters, i.e. AdaPT_Linear : linear_flops_counter_hook
with torch.cuda.device(0):
    total_macs, total_params, layer_specs = get_model_complexity_info(model, (3, 224, 224),as_strings=False, print_per_layer_stat=True,
                                                          custom_modules_hooks={AdaPT_Linear : linear_flops_counter_hook}, 
                                                          param_units='M', flops_units='MMac',
                                                          verbose=True)

print(f'Computational complexity:  {total_macs/1000000:.2f} MMacs')
print(f'Number of parameters::  {total_params/1000000:.2f} MParams')


## Run model calibration for quantization

Calibrates the quantization parameters 

Need to re-run it each time the initial model changes

In [None]:
with torch.no_grad():
    stats = collect_stats(model, calib_data, num_batches=2, device=device)
    amax = compute_amax(model, method="percentile", percentile=99.99, device=device)
    
    # optional - test different calibration methods
    #amax = compute_amax(model, method="mse")
    #amax = compute_amax(model, method="entropy")

## Run model evaluation


In [None]:
# set desired approximate multiplier in each layer

#at first, set all layers to have the 8-bit accurate multiplier
axx_list = [{'axx_mult' : 'mul8s_acc', 'axx_power' : 1.0, 'quant_bits' : 8, 'fake_quant' : False}]*len(linear_layers)

# For example, set the first 5 layers to be approximated with a specific multiplier 
axx_list[0:5] = [{'axx_mult' : 'mul8s_1L2H', 'axx_power' : 0.7082, 'quant_bits' : 8, 'fake_quant' : False}] * 5

returned_power = [0]
replace_linear_layers(model,  AdaPT_Linear, axx_list, total_macs, total_params, layer_count=[0], returned_power = returned_power, initialize = False)  
print('Power of approximated operations: ', round(returned_power[0], 2), '%')
print('Model compiled.')

criterion = torch.nn.CrossEntropyLoss().to(device)
# Run evaluation on the validation dataset
top1, top5 = evaluate_imagenet(model, val_data, criterion, print_freq=1000, device = device)

## Run model retraining


In [None]:
from classification.train import train_one_epoch

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) # set desired learning rate

#one epoch retrain
train_one_epoch(model, criterion, optimizer, calib_data, device, 0, 10)

## Re-run model evaluation

In [None]:
top1, top5 = evaluate_imagenet(model, val_data, criterion, print_freq=1000, device = device)