In [None]:
import torch
from torch import nn
import torch.nn.functional as F

import torch.onnx
import onnx
from onnx import helper
from onnx_tf.backend import prepare

import timm

import warnings
warnings.filterwarnings("ignore")

In [35]:
class ConvNeXtArcFace(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super(ConvNeXtArcFace, self).__init__()
        self.convnext = timm.create_model(model_name, pretrained=pretrained)
        self.convnext.reset_classifier(num_classes=0, global_pool='avg')
      
    def forward(self, x):
        x = self.convnext.forward_features(x) # 
        x = F.avg_pool2d(x, 7).flatten(1)
        return x

In [36]:
# ckpt = torch.load("../checkpoints/model.pth")
# model_state_dict = ckpt['model_state_dict']
model = ConvNeXtArcFace(model_name="mobilenetv4_conv_small", embedding_size=960)
# model.load_state_dict(model_state_dict)
model.eval()
print()

ConvNeXtArcFace(
  (convnext): MobileNetV3(
    (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): ConvBnAct(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU(inplace=True)
          )
          (aa): Identity()
          (drop_path): Identity()
        )
        (1): ConvBnAct(
          (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): 

In [37]:
dummy_input = torch.randn(1, 3, 224, 224)
dummy_output = model(dummy_input)
onnx_model_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path)

In [38]:
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)

In [39]:
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

np.testing.assert_allclose(to_numpy(dummy_output), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


In [40]:
onnx_model = onnx.load(onnx_model_path)

name_map = {"input.1": "input_1"}

new_inputs = []
for inp in onnx_model.graph.input:
    if inp.name in name_map:
        new_inp = helper.make_tensor_value_info(name_map[inp.name],
                                                inp.type.tensor_type.elem_type,
                                                [dim.dim_value for dim in inp.type.tensor_type.shape.dim])
        new_inputs.append(new_inp)
    else:
        new_inputs.append(inp)

onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)

for node in onnx_model.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name in name_map:
            node.input[i] = name_map[input_name]

onnx.save(onnx_model, 'model.onnx')

In [42]:
tf_rep = prepare(onnx_model)
tf_model_path = "model.pb"
tf_rep.export_graph(tf_model_path)

INFO:absl:Function `__call__` contains input name(s) x, y with unsupported characters which will be renamed to transpose_132_x, add_44_y in the SavedModel.
INFO:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: model.pb\assets


INFO:tensorflow:Assets written to: model.pb\assets
INFO:absl:Writing fingerprint to model.pb\fingerprint.pb


In [1]:
!wsl tensorflowjs_converter --input_format=tf_saved_model model.pb  model

2024-07-02 02:33:10.507924: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-02 02:33:10.716988: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-02 02:33:10.816641: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-02 02:33:10.817251: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-02 02:33:10.966320: I tensorflow/core/platform/cpu_feature_gua

In [None]:
torch.save({'model_state_dict': model.state_dict()}, "model.pth")