In [None]:
import os
from datetime import datetime
import torch
from torch.utils.data import DataLoader

from data import MyDataset
from models.resnet18 import ResNet18
from models.swin import SwinTransformer
from models.mae import MaskedAutoEncoder

In [None]:
num_labels = 12
aus = [1,2,4,5,6,9,12,15,17,20,25,26]
batch_size = 512
num_workers = 1
train = False
device = "cpu"
data_root = "../../LibreFace_TestData"
data = "DISFA"
onnx_name = "LibreFace"
test_csv = os.path.join(data_root, data, "labels_intensity_5", "all", "test.csv")
dropout = 0.1
fm_distillation = False
hidden_dim = 128

model_name = "emotionnet_mae"

class SwinConfig:
	def __init__(self):
		self.device = device
		self.dropout = 0.1
		self.num_labels = num_labels

class AU2HeatmapConfig:
	def __init__(self):
		self.data = data
		self.sigma = 10.0
		self.num_labels = num_labels

class DatasetConfig(AU2HeatmapConfig):
	def __init__(self):
		super().__init__()
		self.data = data
		self.data_root = data_root
		self.image_size = 256
		self.crop_size = 224

class ResNet18Config:
	def __init__(self):
		self.fm_distillation = fm_distillation
		self.dropout = dropout
		self.num_labels = num_labels
	
class MaskedAutoEncoderConfig:
	def __init__(self):
		self.fm_distillation = fm_distillation
		self.dropout = dropout
		self.num_labels = num_labels
		self.hidden_dim = hidden_dim


In [None]:
dataset_config = DatasetConfig()
dataset = MyDataset(test_csv, train, dataset_config)
loader = DataLoader(
	dataset=dataset,
	batch_size=batch_size,
	num_workers=num_workers,
	shuffle=train,
	collate_fn=dataset.collate_fn,
	drop_last=train
)

In [None]:
if model_name == "resnet":
    model_config = ResNet18Config()
    model = ResNet18(model_config)
    ckpt_name = os.path.join("resnet_disfa_all", data, "all", "resnet.pt")
elif model_name == "swin":
    model_config = SwinConfig()
    model = SwinTransformer(model_config)
    ckpt_name = os.path.join("swin_checkpoint", data, "0", "swin.pt")
elif model_name == "emotionnet_mae":
    model_config = MaskedAutoEncoderConfig()
    model = MaskedAutoEncoder(model_config)
    ckpt_name = os.path.join("mae_checkpoint", data, "0", "emotionnet_mae.pt")
else:
    assert False

In [None]:
checkpoints = torch.load(ckpt_name, map_location=torch.device(device))["model"]
model.load_state_dict(checkpoints, strict=True)
torch.no_grad()
model.eval()

In [None]:

"""
for images, labels in loader:
	images = images.to(device)
	labels = labels.to(device)
	labels_pred = model(images)
	labels_pred = torch.clamp(labels_pred, min=0.0, max=5.0)
"""

dummy_input = torch.rand((1, 3, 224, 224), device=device)
input_names = [ "image" ]
output_names = [ "AUs" ]
onnx_name = "{0}_{1}_{2}.onnx".format(onnx_name, model_name, datetime.now().strftime("%Y%m%d%H%M%S"))

torch.onnx.export(
	model, 
	dummy_input, 
	onnx_name, 
	verbose=True, 
	input_names=input_names,
	output_names=output_names
)

In [None]:
import onnx
model = onnx.load(onnx_name)
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

In [None]:
import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession(onnx_name)

image, label = next(iter(dataset))
image = image.unsqueeze(dim=0)
image = image.numpy()

label_pred = ort_session.run(
    None,
    {"image": image},
)[0]
label_pred = np.squeeze(label_pred, axis=0)


In [None]:
for i in range(num_labels):
    gt = label[i]
    pred = label_pred[i]
    print(f"AU{aus[i]}:\tdiff={abs(pred - gt):.4f}\tpred={pred:.4f}\tgt={gt:.4f}")
