In [None]:
import time
import torch
import numpy as np
import torch.nn as nn

from torch2trt import torch2trt

from model_training import get_loaders_and_classes, get_class_weights
from models.CNN import Model4
from utils import *

In [None]:
import yaml
def read_params(config_path):
    with open(config_path) as yaml_file:
        config = yaml.safe_load(yaml_file)
    return config

config = read_params('settings.yaml')
batch_size = config['batch_size']
num_epochs = config['number_epochs']
# Extracting the training, validation and testing data
compressed_data_path = config['compressed_data_path']
data = decompress_data(compressed_data_path)

# Get data loaders
data_loaders_and_classes = get_loaders_and_classes(data, batch_size)

In [None]:
# Extracting the training, validation and testing data
x_train = data['x_train']
x_val = data['x_val']
x_test = data['x_test']

y_train = data['y_train']
y_val = data['y_val']
y_test = data['y_test']

In [None]:
# Calibrating model
random_idxs_train = np.random.randint(len(x_train), size=100)
random_idxs_val = np.random.randint(len(x_val), size=100)
random_idxs_test = np.random.randint(len(x_test), size=100)

random_examples_train = x_train[random_idxs_train]
random_examples_val = x_val[random_idxs_val]
random_examples_test = x_train[random_idxs_test]

int_8_calib_dataset = torch.from_numpy(np.concatenate((random_examples_train, random_examples_test, random_examples_val)))


In [None]:
x = torch.rand((batch_size, 1, 128, 45)).cuda()
# create some regular pytorch model...
network = Model4()
checkpoint = torch.load('/home/nyasha/environments/masters/Masters-Nyasha/training/model4.pt')
network.load_state_dict(checkpoint)
network.eval().cuda()

# Creating TensorRT models
model_trt = torch2trt(network, [x], max_batch_size=batch_size)
model_trt_fp16 = torch2trt(network, [x], fp16_mode=True, max_batch_size=batch_size)
model_trt_int8 = torch2trt(network, [x], int8_mode=True, max_batch_size=batch_size, int8_calib_dataset=int_8_calib_dataset)


In [None]:
# Initialising training parameters
class_weights = get_class_weights(y_train, 'cuda')
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
def evaluate_model(model_on_device, data_loader, criterion, classes, show_cm=False):

    """Evaluate Performance on test set"""
    model_on_device.eval()  # Turn off gradient computations
    num_batches = len(data_loader)
    correct = 0
    total = 0
    running_loss = 0
    y_tot = torch.empty(0)
    y_pred_tot = torch.empty(0)

    with torch.no_grad():
        t0 = time.time()
        for data in data_loader:
            images, labels = data
            images = images.to('cuda').float()
            labels = labels.to('cuda')
            outputs = model_on_device(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

            labels = labels.cpu()
            predicted = predicted.cpu()

            y_tot = torch.cat((y_tot, labels), 0)
            y_pred_tot = torch.cat((y_pred_tot, predicted), 0)
        t1 = time.time()
    accuracy = 100 * correct / total
    accuracy = 100 * correct / total
    errors = y_pred_tot - y_tot != 0
    y_pred_errors = y_pred_tot[errors]
    y_true_errors = y_tot[errors]

    # Plotting the Confusion Matrix
    if show_cm:
        generate_confusion_matrix(classes, y_tot, y_pred_tot)

    print(f'Time taken on inference is {(t1-t0)/num_batches}')

    return running_loss / num_batches, accuracy, errors, y_pred_errors, y_true_errors


def generate_confusion_matrix(classes, y_tot, y_pred_tot):
    cm = confusion_matrix(y_tot.numpy(), y_pred_tot.numpy())
    num_classes = len(classes)
    np.set_printoptions(precision=4)

    # Coloured confusion matrix
    plt.figure(figsize=(12, 12))
    cm = confusion_matrix(y_tot.numpy(), y_pred_tot.numpy(), normalize="true")
    plt.imshow(cm, cmap=plt.cm.Blues)

    for (i, j), z in np.ndenumerate(cm):
        plt.text(j, i, "{:0.3f}".format(z), ha="center", va="center")

    plt.xticks(range(num_classes))
    plt.yticks(range(num_classes))
    plt.xlabel("Prediction")
    plt.ylabel("True")

    plt.gca().set_xticklabels(classes)
    plt.gca().set_yticklabels(classes)

    plt.title("Normalized Confusion Matrix for the Data")
    plt.colorbar()
    plt.show()

In [None]:
train_loader = data_loaders_and_classes['train_loader']
val_loader = data_loaders_and_classes['val_loader']
test_loader = data_loaders_and_classes['test_loader']
classes = data_loaders_and_classes['classes']

In [None]:
# Pytorch model
_, accuracy, errors, y_pred_errors, y_true_errors = evaluate_model(network, val_loader, criterion, classes, show_cm=False)
print(accuracy)

In [None]:
# TRT model FP32
_, accuracy, errors, y_pred_errors, y_true_errors = evaluate_model(model_trt, val_loader, criterion, classes, show_cm=False)
print(accuracy)

In [None]:
# TRT model FP16
_, accuracy, errors, y_pred_errors, y_true_errors = evaluate_model(model_trt_fp16, val_loader, criterion, classes, show_cm=False)
print(accuracy)

In [None]:
# TRT Model INT8
_, accuracy, errors, y_pred_errors, y_true_errors = evaluate_model(model_trt_int8, val_loader, criterion, classes, show_cm=False)
print(accuracy)

In [None]:
torch.save(network.state_dict(), 'trt_models/pytorch_model.pth')
torch.save(model_trt.state_dict(), 'trt_models/model_trt_fp32.pth')
torch.save(model_trt_fp16.state_dict(), 'trt_models/model_trt_fp16.pth')
torch.save(model_trt_int8.state_dict(), 'trt_models/model_trt_int8.pth')
