# Export BindEmbed21DL to onnx

In [1]:
import os
import torch
from pathlib import Path

root_dir = Path.cwd().parent
model_dir = root_dir / "checkpoints"


def export_bindembeddl_to_onnx(bind_model, onnx_file_path=f'{root_dir}/checkpoints/bindpredict_onnx'):
    if not os.path.exists(onnx_file_path):
        os.mkdir(onnx_file_path)
    for index, single_bind_model in enumerate(bind_model):
        single_bind_model.eval()

        # Define the dummy input tensor `x` and mask tensor `mask`
        B = 2  # batch size
        N = 5  # sequence length
        C = 1024  # number of input channels/features

        x = torch.randn(B, N, C)
        x_transposed = torch.permute(x, (0,2,1))

        specific_onnx_file_path = f'{onnx_file_path}/cv_{index}.onnx'
        # Export the model
        torch.onnx.export(
            single_bind_model,                               # model being run
            x_transposed,                           # model input (or a tuple for multiple inputs)
            specific_onnx_file_path,             # where to save the model
            export_params=True,                  # store the trained parameter weights inside the model file
            opset_version=12,                    # the ONNX version to export the model to
            do_constant_folding=True,            # whether to execute constant folding for optimization
            input_names=['input', 'mask'],       # the model's input names
            output_names=['output'],             # the model's output names
            dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length', 2: 'embedding_dim'},
                          'mask': {0: 'batch_size', 1: 'sequence_length'},# variable length axes
                          'output': {0: 'batch_size'}}
        )
        print(f"Model has been successfully exported to {specific_onnx_file_path}")

In [2]:
from prott5_batch_predictor import BindPredict
bind_model = BindPredict(model_dir=model_dir).model
export_bindembeddl_to_onnx(bind_model=bind_model)

Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/bindembeddl_onnx/cv_0.onnx
Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/bindembeddl_onnx/cv_1.onnx
Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/bindembeddl_onnx/cv_2.onnx
Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/bindembeddl_onnx/cv_3.onnx
Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/bindembeddl_onnx/cv_4.onnx


In [6]:
from pathlib import Path


root_dir = Path.cwd().parent
output_dir_org = f'{root_dir}/output_bind_org'
output_dir_onnx = f'{root_dir}/output_bind_onnx'

with open (f'{output_dir_org}/binding_bindEmbed_metal_pred.txt', 'r') as f:
    metal_org = f.read()
with open (f'{output_dir_org}/binding_bindEmbed_nucleic_pred.txt', 'r') as f:
    nucleic_org = f.read()
with open (f'{output_dir_org}/binding_bindEmbed_small_pred.txt', 'r') as f:
    small_org = f.read()

with open (f'{output_dir_onnx}/binding_bindEmbed_metal_pred.txt', 'r') as f:
    metal_onnx = f.read()
with open (f'{output_dir_onnx}/binding_bindEmbed_nucleic_pred.txt', 'r') as f:
    nucleic_onnx = f.read()
with open (f'{output_dir_onnx}/binding_bindEmbed_small_pred.txt', 'r') as f:
    small_onnx = f.read()

with open (f'{output_dir_org}/ids.txt', 'r') as f:
    ids_org = f.read()

with open (f'{output_dir_onnx}/ids.txt', 'r') as f:
    ids_onnx = f.read()

if ids_onnx == ids_org:
    print("IDs of nnx conservation model and original conservation model output identical!")
if metal_org == metal_onnx:
    print("Metal predictions are identical")
else:
    print("Metal predictions are NOT identical")
if nucleic_org == nucleic_onnx:
    print("Nucleic predictions are identical")
else:
    print("Nucleic predictions are NOT identical")
if small_org == small_onnx:
    print("Small predictions are identical")
else:
    print("Small predictions are NOT identical")

IDs of nnx conservation model and original conservation model output identical!
Metal predictions are identical
Nucleic predictions are identical
Small predictions are identical
