# Comparing Training of Quantized and unquantized Model
In this Notebook we compare the Training of different Models.
All Models are Autoencoders, they differ in quantization. 
The 'normal' Model
The 'creator' model
The 'brevitas' model

In [None]:
import os
import numpy as np
from torch import Tensor
import matplotlib.pyplot as plt
from elasticai.creator.nn.fixed_point.quantization import quantize

from denspp.offline.yaml_handler import YamlConfigHandler
from denspp.offline.dnn.dataset.autoencoder import prepare_training
from denspp.offline.dnn.dnn_handler import ConfigMLPipeline, DefaultSettings_MLPipe
from denspp.offline.dnn.pytorch_config_data import ConfigDataset, DefaultSettingsDataset
from denspp.offline.dnn.pytorch_config_model import ConfigPytorch, DefaultSettingsTrainMSE
from denspp.offline.dnn.plots.plot_dnn import results_training

In [None]:
# --- Load Configs
yaml_handler = YamlConfigHandler(DefaultSettings_MLPipe, 'config', 'Config_DNN')
dnn_handler = yaml_handler.get_class(ConfigMLPipeline)

yaml_data = YamlConfigHandler(DefaultSettingsDataset, 'config', f'ConfigAE_Dataset')
config_data = yaml_data.get_class(ConfigDataset)

In [None]:
# --- Get Dataset
dataset = prepare_training(settings=config_data, do_classification=False,
                           mode_train_ae=0, noise_std=0.01)
data_inference_test = dataset.__getitem__(4)

## Comparing normal/quantized Model

### Training

In [None]:
default_train = DefaultSettingsTrainMSE
default_train.model_name = 'CompareDNN_Autoencoder_v1_Torch'
yaml_nn = YamlConfigHandler(default_train, 'config', f'ConfigAE_Training')
config_train = yaml_nn.get_class(ConfigPytorch)

dnn_handler.do_plot  = False

In [None]:
from datetime import datetime, date
from denspp.offline.dnn.pytorch_pipeline import do_train_autoencoder

path4vhdl = f'vhdl/run_{date.today()}_{datetime.now().strftime("%H%M")}'
model_name_compare_wBN = ['CompareDNN_Autoencoder_v1_Torch', 'CompareDNN_Autoencoder_v1_Creator']
model_name_compare_woBN = ['CompareDNN_Autoencoder_woBN_v1_Torch', 'CompareDNN_Autoencoder_woBN_v1_Creator']

In [None]:
train_model_with_batchnorm = True
if train_model_with_batchnorm:
    used_models = model_name_compare_wBN
else:
    used_models = model_name_compare_woBN

model_stats_torch = dict()
config_train.model_name = used_models[0]
used_model_torch = config_train.get_model()
model_stats_torch['metrics'], model_stats_torch['data_result'], model_stats_torch['path2folder'] = do_train_autoencoder(
    config_ml=dnn_handler, config_data=config_data, config_train=config_train,
    used_dataset=dataset, used_model=used_model_torch, calc_custom_metrics=['dsnr_all', 'ptq_loss'], print_results=False, ptq_validation_do=True, ptq_quant_lvl=[12, 8]
)

In [None]:
model_stats_creator = dict()
config_train.model_name = used_models[1]
used_model_creator = config_train.get_model()
model_stats_creator['metrics'], model_stats_creator['data_result'], model_stats_creator['path2folder'] = do_train_autoencoder(
    config_ml=dnn_handler, config_data=config_data, config_train=config_train,
    used_dataset=dataset, used_model=used_model_creator, calc_custom_metrics=['dsnr_all'], print_results=False
)

## Plotting Results

In [None]:
used_first_fold = [key for key in model_stats_torch["metrics"].keys()][0]
results_training(
    path=model_stats_torch["path2folder"], cl_dict=model_stats_torch["data_result"]['cl_dict'], feat=model_stats_torch["data_result"]['feat'],
    yin=model_stats_torch["data_result"]['input'], ypred=model_stats_torch["data_result"]['pred'], ymean=dataset.get_mean_waveforms,
    yclus=model_stats_torch["data_result"]['valid_clus'], snr=model_stats_torch["metrics"][used_first_fold]['dsnr_all'],
    show_plot=dnn_handler.do_block
)

used_first_fold = [key for key in model_stats_creator["metrics"].keys()][0]
results_training(
    path=model_stats_creator["path2folder"], cl_dict=model_stats_creator["data_result"]['cl_dict'], feat=model_stats_creator["data_result"]['feat'],
    yin=model_stats_creator["data_result"]['input'], ypred=model_stats_creator["data_result"]['pred'], ymean=dataset.get_mean_waveforms,
    yclus=model_stats_creator["data_result"]['valid_clus'], snr=model_stats_creator["metrics"][used_first_fold]['dsnr_all'],
    show_plot=dnn_handler.do_block
)

### Model Comparison

In [None]:
# --- Extract metrics
model_loss_train_torch = model_stats_torch['metrics']['fold_000']['loss_train']
model_loss_valid_torch = model_stats_torch['metrics']['fold_000']['loss_valid']
model_loss_valid_ptq = model_stats_torch['metrics']['fold_000']['ptq_loss']

model_loss_train_creator = model_stats_creator['metrics']['fold_000']['loss_train']
model_loss_valid_creator = model_stats_creator['metrics']['fold_000']['loss_valid']

# --- Plotting
fig, ax = plt.subplots()
epochs_ite = np.array([idx+1 for idx in range(len(model_loss_train_torch))])
ax.plot(epochs_ite, model_loss_train_torch, label='FP32, Training', linestyle='solid', marker='.', color='blue')
ax.plot(epochs_ite, model_loss_valid_torch, label='FP32, Validation', linestyle='dotted', marker='.', color='blue')
ax.plot(epochs_ite, model_loss_train_creator, label='QAT, Training', linestyle='solid', marker='v', color='red')
ax.plot(epochs_ite, model_loss_valid_creator, label='QAT, Validation', linestyle='dotted', marker='v', color='red')
ax.plot(epochs_ite, model_loss_valid_ptq, label='PTQ, Validation', linestyle='dotted', marker='s', color='green')

font = {'size': 15}

ax.grid()
ax.legend(fontsize=font['size'])
ax.margins(0)
ax.set_yscale('log')
ax.set_xlabel('Epoch', fontdict=font)
ax.set_ylabel('Loss', fontdict=font)
ax.set_title(label='Performance Comparison (FP vs. QAT (FxP) vs. PTQ (FxP))', fontdict=font)

# Save Plot in runs Folder
folder_name = f'../runs/comparisons'
os.makedirs(folder_name, exist_ok=True)
fig.savefig(f'{folder_name}/{datetime.now().strftime("%Y%m%d_%H%M%S")}_train_valid_loss.svg', format='svg')

## Testing Inference

In [None]:
# --- Load model
used_model_creator.eval()
bit_config = used_model_creator.bit_config

model_test_input = Tensor(data_inference_test['in'])
model_test_input_quant = quantize(model_test_input, total_bits=bit_config[0], frac_bits=bit_config[1])
print(f"Quantized Input = {model_test_input}")

model_test_output = used_model_creator(model_test_input_quant)
print(f"Output = {model_test_output[0]}")

In [None]:
from fxpmath import Fxp

def value_to_binary(x : Tensor, total_bits: int, frac_bits: int) -> str:
    val = Fxp(float(x), signed=True, n_word=total_bits, n_frac=frac_bits)
    return val.bin()
    
def tensor_to_vhdl_vector(X: Tensor, total_bits: int, frac_bits: int) -> str:
    std_logic_vector : str = "("
    for idx, val in enumerate(X):
        std_logic_vector += "\""
        std_logic_vector += value_to_binary(val, total_bits, frac_bits)
        std_logic_vector += "\","
    std_logic_vector = std_logic_vector[:-1]
    std_logic_vector += ")"
    return std_logic_vector

print(f"Input={tensor_to_vhdl_vector(model_test_input_quant, bit_config[0], bit_config[1])}")
print(model_test_output[0].flatten())
print(f"Output={tensor_to_vhdl_vector(model_test_output[0].flatten(), bit_config[0], bit_config[1])}={model_test_output[0].flatten()}")

### Just first layer

In [None]:
q_input_layer_0 = quantize(model_test_input, total_bits=bit_config[0], frac_bits=bit_config[1])
print(f"Quantized Input = {q_input_layer_0}")

output_layer_0 = used_model_creator.forward_first_layer(q_input_layer_0)
print(f"Output = {output_layer_0}")

In [None]:
print(f"Input={tensor_to_vhdl_vector(q_input_layer_0, total_bits=bit_config[0], frac_bits=bit_config[1])}")
print(f"Output={tensor_to_vhdl_vector(output_layer_0, total_bits=bit_config[0], frac_bits=bit_config[1])}")