In [None]:
from insightface.app import FaceAnalysis
import tensorflow as tf
from tensorflow.keras import Input, Model
import onnx
import onnxruntime
import sclblonnx as so
import numpy as np
from dataclasses import dataclass,field
from typing import List,Any

In [None]:
@dataclass
class Config:
    example_output: Any
    input_std: float
    input_mean: float
    fmc: int
    feat_stride_fpn: List[int]
    use_kps: bool
    threshold: float
    size: int
    iou: float
    num_anchors:int
    max_output_size: int = field(default=100)

    @staticmethod
    def from_insightface(size, app):
        image = np.zeros((1, 3, size, size), dtype="float32")

        app.prepare(ctx_id=0, det_size=(size, size))
        det = app.det_model
        outputs = det.session.run(det.output_names, {det.input_name: image})
        print(size)
        return Config(
            example_output=outputs,
            input_std=det.input_std,
            input_mean=det.input_mean,
            fmc=det.fmc,
            feat_stride_fpn=det._feat_stride_fpn,
            use_kps=det.use_kps,
            threshold=det.det_thresh,
            size=size,
            num_anchors=det._num_anchors,
            iou=det.nms_thresh,
        )


In [None]:
def distance2bbox(points, distance):
    x1 = points[:, 0] - distance[:, 0]
    y1 = points[:, 1] - distance[:, 1]
    x2 = points[:, 0] + distance[:, 2]
    y2 = points[:, 1] + distance[:, 3]
    return tf.stack([x1, y1, x2, y2], axis=-1)


def distance2kps(points, distance):
    c_range = tf.range(0, tf.shape(distance)[1], 2)
    px = points[:, 0]
    px = tf.reshape(tf.repeat(px, 5), (-1, 5))
    px += tf.gather(distance, indices=c_range, axis=-1)
    py = points[:, 1]
    py = tf.reshape(tf.repeat(py, 5), (-1, 5))
    py += tf.gather(distance, indices=c_range + 1, axis=-1)
    return tf.stack([px, py], axis=-1)


def build_postprocessing(config, size_adjust=True):
    inputs = list(
        map(
            lambda x: Input(x[1].shape, batch_size=1, name=f"input_{x[0]}"),
            enumerate(config.example_output),
        )
    )

    horizontal_padding_input = Input((1,), batch_size=1, name="horizontal_padding")
    vertical_padding_input = Input((1,), batch_size=1, name="vertical_padding")
    ratio_input = Input((1,), batch_size=1, name="ratio")

    outputs = list(map(lambda x: tf.squeeze(x, 0), inputs))

    scores_list = tf.reshape((), (0, 1))
    bboxes_list = tf.reshape((), (0, 4))
    kps_list = tf.reshape((), (0, 5, 2))

    for idx, stride in enumerate(config.feat_stride_fpn):
        scores = outputs[idx]
        bbox_preds = outputs[idx + config.fmc]
        bbox_preds = bbox_preds * stride
        height = config.size // stride
        width = config.size // stride
        anchor_centers = tf.stack(
            tf.meshgrid(tf.range(height), tf.range(width), indexing="ij")[::-1], axis=-1
        )
        anchor_centers = tf.cast(anchor_centers, tf.float32)
        anchor_centers = tf.reshape(anchor_centers * stride, (-1, 2))

        if config.num_anchors > 1:
            anchor_centers = tf.reshape(
                tf.stack([anchor_centers] * config.num_anchors, axis=1), (-1, 2)
            )

        pos_inds = tf.squeeze(scores >= config.threshold)
        bboxes = distance2bbox(anchor_centers, bbox_preds)
        pos_scores = scores[pos_inds]
        pos_bboxes = bboxes[pos_inds]
        scores_list = tf.concat([scores_list, pos_scores], 0)
        bboxes_list = tf.concat([bboxes_list, pos_bboxes], 0)

        if config.use_kps:
            kps_preds = outputs[idx + config.fmc * 2] * stride
            kpss = distance2kps(anchor_centers, kps_preds)
            kpss = kpss[pos_inds]
            kps_list = tf.concat([kps_list, kpss], 0)

    good = tf.image.non_max_suppression(
        bboxes_list,
        tf.squeeze(scores_list),
        max_output_size=config.max_output_size,
        iou_threshold=config.iou,
    )

    scores = tf.gather(scores_list, indices=good, name="scores")
    bboxes = tf.gather(bboxes_list, indices=good, name="bboxes")
    key_points = tf.gather(kps_list, indices=good, name="keypoints")
    if size_adjust:
        vertical_padding = tf.squeeze(vertical_padding_input, axis=0)
        horizontal_padding = tf.squeeze(horizontal_padding_input, axis=0)
        ratio = tf.squeeze(ratio_input, axis=0)
        padding = tf.concat([horizontal_padding, vertical_padding], 0)
        key_points = ratio * key_points - padding * ratio
        bbox_padding = tf.concat([padding, padding], 0)
        bboxes = ratio * bboxes - bbox_padding * ratio

        return Model(
            [*inputs, ratio_input, horizontal_padding_input, vertical_padding_input],
            [scores, bboxes, key_points],
        )
    else:
        return Model(inputs, [scores, bboxes, key_points])


In [None]:
def build_preprocessing(config, size_adjust=True):
    input_img = Input((None, None, 3), batch_size=1, dtype=tf.uint8)
    img = tf.squeeze(input_img, name="to_remove")
    img = tf.cast(img, tf.float32)

    img = tf.cast(img, tf.float32)
    img = (img - config.input_mean) * (1 / config.input_std)

    if size_adjust:
        height = tf.shape(img)[0]
        width = tf.shape(img)[1]

        vertical_padding = tf.math.maximum(
            (config.size - tf.cast((height / width) * config.size, tf.int32)) / 2,
            0,
        )

        vertical_padding = tf.cast(vertical_padding, tf.float32)
        vertical_padding = tf.reshape(vertical_padding, (1, 1))

        horizontal_padding = tf.math.maximum(
            (config.size - tf.cast((width / height) * config.size, tf.int32)) / 2,
            0,
        )

        horizontal_padding = tf.cast(horizontal_padding, tf.float32)
        horizontal_padding = tf.reshape(horizontal_padding, (1, 1))

        ratio = tf.math.maximum(height, width) / config.size
        ratio = tf.cast(ratio, tf.float32, name="ratio")
        ratio = tf.reshape(ratio, (1, 1))

        img = tf.image.resize_with_pad(img, config.size, config.size)

    img = tf.transpose(img, [2, 0, 1])
    img = tf.expand_dims(img, 0, name="img")

    if size_adjust:
        return Model(input_img, [img, ratio, horizontal_padding, vertical_padding])
    return Model(input_img, img)


In [None]:
import tf2onnx

In [None]:
name = "buffalo_m"
size = 320
app = FaceAnalysis(name=name, allowed_modules=["detection"])
config = Config.from_insightface(size, app)
det = app.det_model

size_adjust = False
save_intermidiate = True

post_processing = build_postprocessing(config, size_adjust=size_adjust)
onnx.save(
    tf2onnx.convert.from_keras(post_processing, oppset=11)[0], "post_processing.onnx"
)
post_processing = so.graph_from_file("post_processing.onnx")
so.rename_output(post_processing, "tf.compat.v1.gather_6", "scores")

if size_adjust:
    so.rename_output(post_processing, "tf.math.subtract_6", "key_points")
    so.rename_output(post_processing, "tf.math.subtract_7", "bboxes")
else:
    so.rename_output(post_processing, "tf.compat.v1.gather_7", "bboxes")
    so.rename_output(post_processing, "tf.compat.v1.gather_8", "key_points")
if save_intermidiate:
    so.graph_to_file(post_processing, "post_processing.onnx")

pre_processing = build_preprocessing(config, size_adjust=size_adjust)
onnx.save(
    tf2onnx.convert.from_keras(pre_processing, oppset=11)[0], "pre_processing.onnx"
)

pre_processing = so.graph_from_file("pre_processing.onnx")
so.rename_output(pre_processing, "tf.expand_dims", "img")

if size_adjust:
    so.rename_output(pre_processing, "tf.reshape_2", "ratio")
    so.rename_output(pre_processing, "tf.reshape_1", "horizontal_padding")
    so.rename_output(pre_processing, "tf.reshape", "vertical_padding")
so.replace_output(pre_processing, "img", "FLOAT", [1, 3, size, size])

if save_intermidiate:
    so.graph_to_file(pre_processing, "pre_processing.onnx")

model = so.graph_from_file(det.model_file)
so.replace_input(model, "input.1", "FLOAT", [1, 3, size, size])

for output in det.session.get_outputs():
    sq = so.node("Unsqueeze", [output.name], ["unsquezed" + output.name], axes=[0])
    so.add_node(model, sq)
    so.delete_output(model, output.name)
    so.add_output(model, "unsquezed" + output.name, "FLOAT", [1, *output.shape])

if save_intermidiate:
    so.graph_to_file(model, "model.onnx")

with_preprocessing = so.merge(
    pre_processing, model, io_match=[("img", "input.1")], complete=False
)

if save_intermidiate:
    so.graph_to_file(with_preprocessing, "with_preprocessing.onnx")
m = onnxruntime.InferenceSession("with_preprocessing.onnx")

if size_adjust:
    match = [
        ("horizontal_padding", "horizontal_padding"),
        ("vertical_padding", "vertical_padding"),
        ("ratio", "ratio"),
        *[(x.name, f"input_{i}") for i, x in enumerate(m.get_outputs()[3:])],
    ]
else:
    match = [(x.name, f"input_{i}") for i, x in enumerate(m.get_outputs())]

final = so.merge(with_preprocessing, post_processing, io_match=match, complete=False)
if size_adjust:
    so.replace_input(final, "input_1", "UINT8", ["?", "?", 3])
else:
    so.replace_input(final, "input_1", "UINT8", [size, size, 3])
so.rename_input(final, "input_1", "image")

for elem in final.node:
    if "to_remove" in elem.name:
        elem.op_type = "Identity"
        break
else:
    raise ValueError("Failed to remove squeze")


so.graph_to_file(
    final,
    "models/"
    + name
    + ("_with_size_adjust_" if size_adjust else "_")
    + str(size)
    + ".onnx",
)
