In [None]:
import onnxruntime
from config import load_config
from data_prep import create_data_loader
from model import load_model_for_inference
from quantize import convert_to_onnx, dynamic_quantize, evaluate_onnx_model
from utils import plot_confusion_matrix, plot_roc_curve, plot_precision_recall

In [None]:
# Load config
config = load_config('configs/config.yaml')
quantize_config = load_config('configs/quantization_config.yaml')

# Load data 
val_loader = create_data_loader(config["valpath"], 
                                config["val_label_col"],
                                config['tokenizer_model'],
                                config['max_length'],
                                config['batch_size'],
                                shuffle=False) 

# Load model
model = load_model_for_inference(config)

# Convert model to ONNX
convert_to_onnx(model, config, quantize_config)

# Dynamic quantization
dynamic_quantize(quantize_config)

# Evaluate quantized model
session = onnxruntime.InferenceSession(quantize_config['quantized_onnx_path'], 
                                       providers=["CPUExecutionProvider"])
accuracy, f1, all_val_labels, all_val_preds = evaluate_onnx_model(session, val_loader)

In [None]:
# Plot confusion matrix
plot_confusion_matrix(all_val_labels, all_val_preds, classes=[str(i) for i in range(8)])

# Plot precision-recall curve
plot_precision_recall(all_val_labels, all_val_preds)

# Plot ROC curve
plot_roc_curve(all_val_labels, all_val_preds)