In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
from pathlib import Path

import torchvision
from torchvision import transforms as T
import torch
# torchvision and cv2 seem to clash (on Jetson)
# import cv2

import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm import tqdm

In [3]:
from onnx_inference import Model
from trt_inference import TRTModel, TRTInferenceBackend

In [4]:
# ssrnet = 'SSRNet_32', Path('assets') / 'ssrnet_dynamic_simplified.onnx'
ssrnet = 'SSRNet_32', Path('assets') / 'ssrnet_bs1.onnx'

In [5]:
transform = T.Compose([
        T.Resize((64, 64)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [6]:
dir_path = Path('assets') / 'megaage' / 'test'

# (Original) ONNX Model Evaluation

In [7]:
model = Model(ssrnet[1])
model.use_gpu()

In [8]:
predictions = []
for img_path in tqdm(sorted(dir_path.glob('*.jpg'))):
    img = Image.open(img_path)
    im = transform(img)
    imgs = torch.stack([im])
    imgs_np = imgs.numpy()
    preds = model.sess.run([model.output_name], {model.input_name: imgs_np})
    predictions.extend(preds)

100%|██████████| 8530/8530 [01:18<00:00, 108.83it/s]


In [9]:
y_onnx = np.stack(predictions).ravel()

In [10]:
gt = np.loadtxt(dir_path.parent / 'list' / 'test_age.txt')

In [11]:
assert gt.shape == y_onnx.shape

In [12]:
MAE_onnx = np.sum(np.abs(y_onnx - gt))/len(gt)

In [13]:
print('Mean Absolute Error:', MAE_onnx)

Mean Absolute Error: 12.799150404868623


# (Optimized) TensorRT Model Evaluation

In [14]:
transform = T.Compose([
        T.Resize((64, 64)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [18]:
engine_path = Path('assets') / 'ssrnet_bs1.trt'
batch_size = 1
trt_model = TRTModel(ssrnet[1], engine_path, model.input_shape[1:])
backend = TRTInferenceBackend(trt_model, batch_size)

In [19]:
predictions = []
for img_path in tqdm(sorted(dir_path.glob('*.jpg'))):
    img = Image.open(img_path)
    im = transform(img)
    imgs = torch.stack([im])
    imgs_np = imgs.numpy()
    preds = backend.run(imgs_np)
    predictions.extend(preds)

100%|██████████| 8530/8530 [01:44<00:00, 81.79it/s] 


In [20]:
y_trt = np.stack(predictions).ravel()

In [21]:
gt = np.loadtxt(dir_path.parent / 'list' / 'test_age.txt')

In [22]:
assert gt.shape == y_trt.shape

In [23]:
MAE_trt = np.sum(np.abs(y_trt - gt))/len(gt)

In [24]:
print('Mean Absolute Error:', MAE_trt)

Mean Absolute Error: 14.412206657159231


# Results

In [28]:
print('MAE ONNX:', MAE_onnx)
print('MAE_TensorRT:', MAE_trt)
print(f'Error increase: {MAE_trt / MAE_onnx * 100 - 100:.1f}%')


MAE ONNX: 12.799150404868623
MAE_TensorRT: 14.412206657159231
Error increase: 12.6%


In [25]:
MAE_between = np.sum(np.abs(y_trt - y_onnx))/len(gt)

In [29]:
MAE_between

23.883264947245017