In [None]:
import os
import glob
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
import h5py
import random

from totalface_cpu.model_zoo.model_common import load_onnx
from totalface_cpu.model_zoo.get_models import get_detection_model
from totalface_cpu.face.get_result import get_detection

In [None]:
def normalization(rgb_img,mean_list=[0.485, 0.456, 0.406],std_list=[0.229, 0.224, 0.225]):
    MEAN = 255 * np.array(mean_list)
    STD = 255 * np.array(std_list)
    rgb_img = rgb_img.transpose(-1, 0, 1)
    norm_img = (rgb_img - MEAN[:, None, None]) / STD[:, None, None]
    
    return norm_img

In [None]:
def read_image(path,bbox,mean,std,resize=128,prefix='',ori_return=False):
    img = cv2.imread(prefix+path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    cropped = img[bbox[1]:bbox[3],bbox[0]:bbox[2]]

    input_img = cv2.resize(cropped,(resize,resize),interpolation=cv2.INTER_CUBIC)
    input_img = normalization(input_img,mean,std)
    input_img = np.transpose(input_img,(1,2,0))
    
    if ori_return:
        return input_img, img
    else:
        return input_img

In [None]:
# detection model load
detection_name = "scrfd"
detection_path = "scrfd_10g_bnkps.onnx"

# Set parameters
detection_thresh = 0.5
detection_height_min=0 

# if load multiple trt, load_multi=True, default False
detection_model = get_detection_model(detection_name,detection_path,load_multi=False)

In [None]:
# model load 
model_path = "./pretrained/anti-spoof-mn3.onnx"
model = load_onnx.Onnx_session(model_path,input_mean=0.0, input_std=1.0,output_sort=True)
# input (1,3,128,128)
# output (1,2)

# result label
pred_dict={0:'real',1:'fake'}

In [None]:
# image param
mean=[0.5931, 0.4690, 0.4229]
std=[0.2471, 0.2214, 0.2157]
resize=128

In [None]:
img_path = "~"

In [None]:
img = cv2.cvtColor(cv2.imread(img_path),cv2.COLOR_BGR2RGB)

In [None]:
# detection
faces = get_detection(detection_name,detection_model,img,thresh=detection_thresh,height_min=detection_height_min,input_size=(640,640))

if len(faces)==1:
    face = faces[0]
elif len(faces)<1:
    print("not detected..")
else:
    for face in faces:
        if face.max_flag:
            break
            
bbox = face['bbox'].astype(np.int32)

In [None]:
# input
input_img = read_image(path,bbox,mean,std,resize=resize,prefix=prefix)

In [None]:
# run
out = model(input_img)[0][0] # prob (real,fake)
pred_idx = np.argmax(out)

In [None]:
print(out) # real, fake
print("gt: {} / pred: {}".format(label,pred_dict[pred_idx]))

In [None]:
plt.imshow(img)