In [1]:
import os
from datetime import datetime
import torch
from torch.utils.data import DataLoader

from data import MyDataset
from models.swin import SwinTransformer
from models.resnet18 import ResNet18

In [2]:
num_labels = 12
aus = [1,2,4,5,6,9,12,15,17,20,25,26]
batch_size = 512
num_workers = 1
train = False
device = "cpu"
data_root = "../data"
data = "DISFA"
onnx_name = "FaceAU"
test_csv = os.path.join(data_root, data, "labels_intensity_5", "all", "test.csv")
dropout = 0.1
fm_distillation = False

class SwinConfig:
	def __init__(self):
		self.device = device
		self.dropout = 0.1
		self.num_labels = num_labels

class AU2HeatmapConfig:
	def __init__(self):
		self.data = data
		self.sigma = 10.0
		self.num_labels = num_labels

class DatasetConfig(AU2HeatmapConfig):
	def __init__(self):
		super().__init__()
		self.data = data
		self.data_root = data_root
		self.image_size = 256
		self.crop_size = 224

class ResNet18Config:
	def __init__(self):
		self.fm_distillation = fm_distillation
		self.dropout = dropout
		self.num_labels = num_labels


In [3]:
dataset_config = DatasetConfig()
dataset = MyDataset(test_csv, train, dataset_config)
loader = DataLoader(
	dataset=dataset,
	batch_size=batch_size,
	num_workers=num_workers,
	shuffle=train,
	collate_fn=dataset.collate_fn,
	drop_last=train
)

In [4]:
"""
# Swin
model_config = SwinConfig()
model = SwinTransformer(model_config)
ckpt_name = os.path.join("swin_checkpoint", data, "0", "swin.pt")
"""

# ResNet18
model_config = ResNet18Config()
model = ResNet18(model_config)
ckpt_name = os.path.join("resnet_disfa_all", data, "all", "resnet.pt")



In [5]:
checkpoints = torch.load(ckpt_name, map_location=torch.device(device))["model"]
model.load_state_dict(checkpoints, strict=True)
torch.no_grad()
model.eval()

ResNet18(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [6]:

"""
for images, labels in loader:
	images = images.to(device)
	labels = labels.to(device)
	labels_pred = model(images)
	labels_pred = torch.clamp(labels_pred, min=0.0, max=5.0)
"""

dummy_input = torch.rand((1, 3, 224, 224), device=device)
input_names = [ "image" ]
output_names = [ "AUs" ]
onnx_name = onnx_name + datetime.now().strftime("%Y%m%d%H%M%S") + ".onnx"

torch.onnx.export(
	model, 
	dummy_input, 
	onnx_name, 
	verbose=True, 
	input_names=input_names,
	output_names=output_names
)

verbose: False, log level: Level.ERROR



In [7]:
import onnx
model = onnx.load(onnx_name)
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

graph torch_jit (
  %image[FLOAT, 1x3x224x224]
) initializers (
  %classifier.0.weight[FLOAT, 128x512]
  %classifier.0.bias[FLOAT, 128]
  %classifier.2.weight[FLOAT, 128]
  %classifier.2.bias[FLOAT, 128]
  %classifier.2.running_mean[FLOAT, 128]
  %classifier.2.running_var[FLOAT, 128]
  %classifier.4.weight[FLOAT, 12x128]
  %classifier.4.bias[FLOAT, 12]
  %onnx::Conv_211[FLOAT, 64x3x7x7]
  %onnx::Conv_212[FLOAT, 64]
  %onnx::Conv_214[FLOAT, 64x64x3x3]
  %onnx::Conv_215[FLOAT, 64]
  %onnx::Conv_217[FLOAT, 64x64x3x3]
  %onnx::Conv_218[FLOAT, 64]
  %onnx::Conv_220[FLOAT, 64x64x3x3]
  %onnx::Conv_221[FLOAT, 64]
  %onnx::Conv_223[FLOAT, 64x64x3x3]
  %onnx::Conv_224[FLOAT, 64]
  %onnx::Conv_226[FLOAT, 128x64x3x3]
  %onnx::Conv_227[FLOAT, 128]
  %onnx::Conv_229[FLOAT, 128x128x3x3]
  %onnx::Conv_230[FLOAT, 128]
  %onnx::Conv_232[FLOAT, 128x64x1x1]
  %onnx::Conv_233[FLOAT, 128]
  %onnx::Conv_235[FLOAT, 128x128x3x3]
  %onnx::Conv_236[FLOAT, 128]
  %onnx::Conv_238[FLOAT, 128x128x3x3]
  %onnx::Conv

In [8]:
import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession(onnx_name)

image, label = next(iter(dataset))
image = image.unsqueeze(dim=0)
image = image.numpy()

label_pred = ort_session.run(
    None,
    {"image": image},
)[0]
label_pred = np.squeeze(label_pred, axis=0)


In [9]:
for i in range(num_labels):
    gt = label[i]
    pred = label_pred[i]
    print(f"AU{aus[i]}:\tdiff={abs(pred - gt):.4f}\tpred={pred:.4f}\tgt={gt:.4f}")


AU1:	diff=0.0046	pred=0.0046	gt=0.0000
AU2:	diff=0.0028	pred=0.0028	gt=0.0000
AU4:	diff=0.0008	pred=0.0008	gt=0.0000
AU5:	diff=0.0007	pred=0.0007	gt=0.0000
AU6:	diff=0.0154	pred=0.0154	gt=0.0000
AU9:	diff=0.0015	pred=0.0015	gt=0.0000
AU12:	diff=1.6542	pred=0.3458	gt=2.0000
AU15:	diff=0.0011	pred=0.0011	gt=0.0000
AU17:	diff=0.0018	pred=0.0018	gt=0.0000
AU20:	diff=0.0070	pred=0.0070	gt=0.0000
AU25:	diff=0.0149	pred=0.0149	gt=0.0000
AU26:	diff=0.0100	pred=0.0100	gt=0.0000
