# Convert our BiRefNet weights to onnx format.

> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.

> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample  
> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb

+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example.

### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om

In [None]:
import torch


weights_file = 'BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth'  # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
with open('config.py') as fp:
    file_lines = fp.read()
if 'swin_v1_tiny' in weights_file:
    print('Set `swin_v1_tiny` as the backbone.')
    file_lines = file_lines.replace(
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][6]
        ''',
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][3]
        ''',
    )
    with open('config.py', mode="w") as fp:
        fp.write(file_lines)
else:
    file_lines = file_lines.replace(
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][3]
        ''',
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][6]
        ''',
    )
    with open('config.py', mode="w") as fp:
        fp.write(file_lines)

In [None]:
from utils import check_state_dict
from models.birefnet import BiRefNet


birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load('./{}'.format(weights_file), map_location=device)
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

torch.set_float32_matmul_precision(['high', 'highest'][0])

birefnet.to(device)
_ = birefnet.eval()

# Process deform_conv2d in the conversion to ONNX

In [None]:
from torchvision.ops.deform_conv import DeformConv2d
import deform_conv2d_onnx_exporter

# register deform_conv2d operator
deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()

def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):
    input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)

    input_layer_names = ['input_image']
    output_layer_names = ['output_image']

    torch.onnx.export(
        net,
        input,
        file_name,
        verbose=False,
        opset_version=17,
        input_names=input_layer_names,
        output_names=output_layer_names,
    )
convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)

# Load ONNX weights and do the inference.

In [None]:
from PIL import Image
from torchvision import transforms


transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

imagepath = './Helicopter-HR.jpg'
image = Image.open(imagepath)
input_images = transform_image(image).unsqueeze(0).to(device)
input_images_numpy = input_images.cpu().numpy()

In [None]:
import onnxruntime
import matplotlib.pyplot as plt


providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
onnx_session = onnxruntime.InferenceSession(
    weights_file.replace('.pth', '.onnx'),
    providers=providers
)
input_name = onnx_session.get_inputs()[0].name
print(onnxruntime.get_device(), onnx_session.get_providers())

In [None]:
from time import time
import matplotlib.pyplot as plt

time_st = time()
pred_onnx = torch.tensor(
    onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]
).squeeze(0).sigmoid().cpu()
print(time() - time_st)

plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()

In [None]:
with torch.no_grad():
    preds = birefnet(input_images)[-1].sigmoid().cpu()
plt.imshow(preds.squeeze(), cmap='gray'); plt.show()

In [None]:
diff = abs(preds - pred_onnx)
print('sum(diff):', diff.sum())
plt.imshow((diff).squeeze(), cmap='gray'); plt.show()

# Efficiency Comparison between .pth and .onnx

In [None]:
%%timeit
with torch.no_grad():
    preds = birefnet(input_images)[-1].sigmoid().cpu()

In [None]:
%%timeit
pred_onnx = torch.tensor(
    onnx_session.run(None, {input_name: input_images_numpy})[-1]
).squeeze(0).sigmoid().cpu()