In [81]:
import insightface
import onnxruntime
import glob
import os.path as osp
import numpy as np
from skimage import transform as trans
import cv2

In [80]:
import time
prev_time = None
def timepoint(text):
    global prev_time
    if prev_time:
        elapsed_time = time.time() - prev_time
        print(f"{text}: {int(elapsed_time*1000)}ms")
    prev_time = time.time()

In [82]:
# face_align
arcface_dst = np.array(
    [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
     [41.5493, 92.3655], [70.7299, 92.2041]],
    dtype=np.float32)

def estimate_norm(lmk, image_size=112, mode='arcface'):
    assert lmk.shape == (5, 2)
    assert image_size%112==0 or image_size%128==0
    if image_size%112==0:
        ratio = float(image_size)/112.0
        diff_x = 0
    else:
        ratio = float(image_size)/128.0
        diff_x = 8.0*ratio
    dst = arcface_dst * ratio
    dst[:,0] += diff_x
    tform = trans.SimilarityTransform()
    tform.estimate(lmk, dst)
    M = tform.params[0:2, :]
    return M

def norm_crop2(img, landmark, image_size=112, mode='arcface'):
    M = estimate_norm(landmark, image_size, mode)
    warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
    return warped, M

def norm_crop(img, landmark, image_size=112, mode='arcface'):
    warped, _ = norm_crop2(img, landmark, image_size, mode)
    return warped

In [98]:
from numpy.linalg import norm

class ArcFaceONNX:
    def __init__(self):
        self.session = onnxruntime.InferenceSession("buffalo_l/w600k_r50.onnx", providers=['CPUExecutionProvider'])
        self.input_mean = 127.5
        self.input_std = 127.5
        self.input_names = [input.name for input in self.session.get_inputs()]
        self.output_names = [output.name for output in self.session.get_outputs()]

        input_cfg = self.session.get_inputs()[0]
        input_shape = input_cfg.shape
        self.input_size = tuple(input_shape[2:4][::-1])
        self.input_name = self.input_names[0]

    def get_feat(self, imgs):
        if not isinstance(imgs, list):
            imgs = [imgs]
        blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, self.input_size,
                                        (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
        return self.session.run(self.output_names, {self.input_name: blob})[0]    
    
    def get_embedding(self, img, face_kps):
        aimg = norm_crop(img, landmark=face_kps, image_size=self.input_size[0])
        return self.get_feat(aimg).flatten()

def embedding_norm(embedding):
    if embedding is None:
        return None
    return norm(embedding)

def normed_embedding(embedding):
    if embedding is None:
        return None
    return embedding / embedding_norm(embedding)

-1.337291

In [65]:
def distance2kps(points, distance, max_shape=None):
    """Decode distance prediction to bounding box.

    Args:
        points (Tensor): Shape (n, 2), [x, y].
        distance (Tensor): Distance from the given point to 4
            boundaries (left, top, right, bottom).
        max_shape (tuple): Shape of the image.

    Returns:
        Tensor: Decoded bboxes.
    """
    preds = []
    for i in range(0, distance.shape[1], 2):
        px = points[:, i%2] + distance[:, i]
        py = points[:, i%2+1] + distance[:, i+1]
        if max_shape is not None:
            px = px.clamp(min=0, max=max_shape[1])
            py = py.clamp(min=0, max=max_shape[0])
        preds.append(px)
        preds.append(py)
    return np.stack(preds, axis=-1)

def distance2bbox(points, distance, max_shape=None):
    x1 = points[:, 0] - distance[:, 0]
    y1 = points[:, 1] - distance[:, 1]
    x2 = points[:, 0] + distance[:, 2]
    y2 = points[:, 1] + distance[:, 3]
    if max_shape is not None:
        x1 = x1.clamp(min=0, max=max_shape[1])
        y1 = y1.clamp(min=0, max=max_shape[0])
        x2 = x2.clamp(min=0, max=max_shape[1])
        y2 = y2.clamp(min=0, max=max_shape[0])
    return np.stack([x1, y1, x2, y2], axis=-1)

class RetinaFace:
    def __init__(self):
        self.session = onnxruntime.InferenceSession("buffalo_l/det_10g.onnx", providers=['CPUExecutionProvider'])
        self.input_mean = 127.5
        self.input_std = 127.5
        self.input_names = [input.name for input in self.session.get_inputs()]
        self.output_names = [output.name for output in self.session.get_outputs()]
        self.nms_thresh = 0.4
        self.det_thresh = 0.5
        self.input_name = self.input_names[0]
        self.center_cache = {}

        if len(self.output_names)==6:
            self.fmc = 3
            self._feat_stride_fpn = [8, 16, 32]
            self._num_anchors = 2
        elif len(self.output_names)==9:
            self.fmc = 3
            self._feat_stride_fpn = [8, 16, 32]
            self._num_anchors = 2
            self.use_kps = True
        elif len(self.output_names)==10:
            self.fmc = 5
            self._feat_stride_fpn = [8, 16, 32, 64, 128]
            self._num_anchors = 1
        elif len(self.output_names)==15:
            self.fmc = 5
            self._feat_stride_fpn = [8, 16, 32, 64, 128]
            self._num_anchors = 1
            self.use_kps = True

    def nms(self, dets):
        thresh = self.nms_thresh
        x1 = dets[:, 0]
        y1 = dets[:, 1]
        x2 = dets[:, 2]
        y2 = dets[:, 3]
        scores = dets[:, 4]

        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        order = scores.argsort()[::-1]

        keep = []
        while order.size > 0:
            i = order[0]
            keep.append(i)
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1 + 1)
            h = np.maximum(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

            inds = np.where(ovr <= thresh)[0]
            order = order[inds + 1]

        return keep
    
    def forward(self, img, threshold):
        scores_list = []
        bboxes_list = []
        kpss_list = []
        input_size = tuple(img.shape[0:2][::-1])
        blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
        net_outs = self.session.run(self.output_names, {self.input_name : blob})

        input_height = blob.shape[2]
        input_width = blob.shape[3]
        fmc = self.fmc
        for idx, stride in enumerate(self._feat_stride_fpn):
            scores = net_outs[idx]
            bbox_preds = net_outs[idx+fmc]
            bbox_preds = bbox_preds * stride
            kps_preds = net_outs[idx+fmc*2] * stride
            height = input_height // stride
            width = input_width // stride
            K = height * width
            key = (height, width, stride)
            if key in self.center_cache:
                anchor_centers = self.center_cache[key]
            else:
                anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)

                anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
                if self._num_anchors>1:
                    anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
                if len(self.center_cache)<100:
                    self.center_cache[key] = anchor_centers

            pos_inds = np.where(scores>=threshold)[0]
            bboxes = distance2bbox(anchor_centers, bbox_preds)
            pos_scores = scores[pos_inds]
            pos_bboxes = bboxes[pos_inds]
            scores_list.append(pos_scores)
            bboxes_list.append(pos_bboxes)
            if self.use_kps:
                kpss = distance2kps(anchor_centers, kps_preds)
                kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
                pos_kpss = kpss[pos_inds]
                kpss_list.append(pos_kpss)
        return scores_list, bboxes_list, kpss_list

    def detect(self, img, input_size = None, max_num=0, metric='default'):
        assert input_size is not None or self.input_size is not None
        input_size = self.input_size if input_size is None else input_size

        im_ratio = float(img.shape[0]) / img.shape[1]
        model_ratio = float(input_size[1]) / input_size[0]
        if im_ratio>model_ratio:
            new_height = input_size[1]
            new_width = int(new_height / im_ratio)
        else:
            new_width = input_size[0]
            new_height = int(new_width * im_ratio)
        det_scale = float(new_height) / img.shape[0]
        resized_img = cv2.resize(img, (new_width, new_height))
        det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
        det_img[:new_height, :new_width, :] = resized_img

        scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)

        scores = np.vstack(scores_list)
        scores_ravel = scores.ravel()
        order = scores_ravel.argsort()[::-1]
        bboxes = np.vstack(bboxes_list) / det_scale
        kpss = np.vstack(kpss_list) / det_scale
        pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
        pre_det = pre_det[order, :]
        keep = self.nms(pre_det)
        det = pre_det[keep, :]

        kpss = kpss[order,:,:]
        kpss = kpss[keep,:,:]

        if max_num > 0 and det.shape[0] > max_num:
            area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
                                                    det[:, 1])
            img_center = img.shape[0] // 2, img.shape[1] // 2

            offsets = np.vstack([
                (det[:, 0] + det[:, 2]) / 2 - img_center[1],
                (det[:, 1] + det[:, 3]) / 2 - img_center[0]
            ])
            offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)

            if metric == 'max':
                values = area
            else:
                values = area - offset_dist_squared * 2.0  # some extra weight on the centering
            bindex = np.argsort(values)[::-1]  # some extra weight on the centering
            bindex = bindex[0:max_num]
            kpss = kpss[bindex, :]
        return kpss

retinaface = RetinaFace()
print(len(retinaface.session.get_outputs()))
kpss = retinaface.detect(cv2.imread("suzu.jpg"), input_size=(640, 640))
kpss[0]

array([[202.74078, 244.44377],
       [313.35687, 236.37285],
       [268.4746 , 308.40735],
       [210.28693, 345.5235 ],
       [317.6929 , 337.33698]], dtype=float32)

In [99]:
import onnx
import onnxruntime
from onnx import numpy_helper

import time
import numpy as np
import onnxruntime
import cv2
import onnx
from onnx import numpy_helper

class INSwapper:
    def __init__(self):
        model_file = 'models/inswapper_128.onnx'
        model = onnx.load(model_file)
        self.emap = numpy_helper.to_array(model.graph.initializer[-1])
        self.session = onnxruntime.InferenceSession(model_file, providers=['CPUExecutionProvider'])

        self.input_mean = 0.0
        self.input_std = 255.0
        self.input_names = [input.name for input in self.session.get_inputs()]
        self.output_names = [output.name for output in self.session.get_outputs()]

        self.output_shape = self.session.get_outputs()[0].shape
        self.input_shape = self.session.get_inputs()[0].shape
        self.input_size = tuple(self.input_shape[2:4][::-1])

    def swap(self, img, source_face_normed_embedding, target_face_kps):

        timepoint("start")
        aimg, M = norm_crop2(img, target_face_kps, self.input_size[0])
        blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
                                    (self.input_mean, self.input_mean, self.input_mean), swapRB=True)

        timepoint("norm_crop2")

        latent = source_face_normed_embedding.reshape((1,-1))
        latent = np.dot(latent, self.emap)
        latent /= np.linalg.norm(latent)
        timepoint("latent")

        pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
        timepoint("session.run")

        #print(latent.shape, latent.dtype, pred.shape)
        img_fake = pred.transpose((0,2,3,1))[0]
        bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
        timepoint("clip")

        target_img = img
        fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
        fake_diff = np.abs(fake_diff).mean(axis=2)
        timepoint("np.abs")

        fake_diff[:2,:] = 0
        fake_diff[-2:,:] = 0
        fake_diff[:,:2] = 0
        fake_diff[:,-2:] = 0
        IM = cv2.invertAffineTransform(M)
        timepoint("invertAffineTransform")

        img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
        bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
        img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
        fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
        timepoint("warpAffine")

        img_white[img_white>20] = 255
        fthresh = 10
        fake_diff[fake_diff<fthresh] = 0
        fake_diff[fake_diff>=fthresh] = 255
        img_mask = img_white
        mask_h_inds, mask_w_inds = np.where(img_mask==255)
        mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
        mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
        mask_size = int(np.sqrt(mask_h*mask_w))
        k = max(mask_size//10, 10)
        #k = max(mask_size//20, 6)
        #k = 6
        timepoint("maxmax")

        kernel = np.ones((k,k),np.uint8)
        img_mask = cv2.erode(img_mask, kernel, iterations = 1)
        kernel = np.ones((2,2),np.uint8)
        fake_diff = cv2.dilate(fake_diff, kernel, iterations = 1)
        k = max(mask_size//20, 5)
        timepoint("cv2")
        #k = 3
        #k = 3
        kernel_size = (k, k)
        blur_size = tuple(2*i+1 for i in kernel_size)
        img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
        k = 5
        kernel_size = (k, k)
        blur_size = tuple(2*i+1 for i in kernel_size)
        fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
        img_mask /= 255
        fake_diff /= 255
        #img_mask = fake_diff

        timepoint("GaussianBlur")

        img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
        fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
        fake_merged = fake_merged.astype(np.uint8)
        cv2.imwrite("result.jpg", fake_merged)

In [100]:
source_img = cv2.imread("suzu.jpg")
target_img = cv2.imread("org.jpg")
retinaface = RetinaFace()
arcface = ArcFaceONNX()
swapper = INSwapper()

2023-06-24 13:13:49.747330 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'buff2fs'. It is not used by any node and should be removed from the model.


In [None]:
face_analyzer = insightface.app.FaceAnalysis(name='buffalo_l')
face_analyzer.prepare(ctx_id=0, det_size=(640, 640))

source_face = face_analyzer.get(source_img)
target_face = face_analyzer.get(target_img)

swapper.swap(target_img, target_normed_embedding, target_kps)

In [107]:
source_kps = retinaface.detect(source_img, input_size=(640, 640))[0]
target_kps = retinaface.detect(target_img, input_size=(640, 640))[0]

source_embedding = arcface.get_embedding(source_img, source_kps)
source_normed_embedding = normed_embedding(source_embedding)

swapper.swap(target_img, source_normed_embedding, target_kps)

start: 174073ms
norm_crop2: 0ms
latent: 4ms
session.run: 1065ms
clip: 0ms
np.abs: 0ms
invertAffineTransform: 0ms
warpAffine: 1ms
maxmax: 1ms
cv2: 0ms
GaussianBlur: 1ms
