# TMbed Model
Exporting the TMbed Model to ONNX

Previous issue: The Pytorch function "unfold" in Predictor.forward is not supported by ONNX and had to be implemented manually.

In [12]:
import os
import torch
from pathlib import Path

root_dir = Path.cwd().parent
model_dir = root_dir / "checkpoints"


def export_TMbed_to_onnx(tmbed_model, model_index=0, onnx_file_path=f'{root_dir}/checkpoints/tmbed_onnx'):
    if not os.path.exists(onnx_file_path):
        os.mkdir(onnx_file_path)
    for index, single_tmbed_model in enumerate(tmbed_model.model):
        single_tmbed_model.eval()

        # Define the dummy input tensor `x` and mask tensor `mask`
        B = 2  # batch size
        N = 505  # sequence length
        C = 1024  # number of input channels/features

        x = torch.randn(B, N, C)
        mask = torch.ones(B, N)  # Mask tensor with shape (B, N). All ones means no masking

        specific_onnx_file_path = f'{onnx_file_path}/cv_{index}.onnx'
        # Export the model
        torch.onnx.export(
            single_tmbed_model,                               # model being run
            (x, mask),                           # model input (or a tuple for multiple inputs)
            specific_onnx_file_path,             # where to save the model
            export_params=True,                  # store the trained parameter weights inside the model file
            opset_version=12,                    # the ONNX version to export the model to
            do_constant_folding=True,            # whether to execute constant folding for optimization
            input_names=['input', 'mask'],       # the model's input names
            output_names=['output'],             # the model's output names
            dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length', 2: 'embedding_dim'},
                          'mask': {0: 'batch_size', 1: 'sequence_length'},# variable length axes
                          'output': {0: 'batch_size'}}
        )
        print(f"Model has been successfully exported to {specific_onnx_file_path}")

In [13]:
from prott5_batch_predictor import TMbed

tmbed = TMbed(model_dir=model_dir)
export_TMbed_to_onnx(tmbed_model=tmbed)




Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/tmbed_onnx/cv_0.onnx




Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/tmbed_onnx/cv_1.onnx




Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/tmbed_onnx/cv_2.onnx




Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/tmbed_onnx/cv_3.onnx




Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/tmbed_onnx/cv_4.onnx


# Test: TMbed onnx Model
This is where we test whether the normally loaded model and the ONNX model produce the same results.

In [14]:
from pathlib import Path

root_dir = Path.cwd().parent
output_dir_org = f'{root_dir}/output_tmbed_original'
output_dir_onnx = f'{root_dir}/output_tmbed_onnx'
with open (f'{output_dir_onnx}/ids.txt', 'r') as f:
    ids_onnx = f.read()

with open (f'{output_dir_onnx}/membrane_tmbed.txt', 'r') as f:
    membrane_tmbed_onnx = f.read()

with open (f'{output_dir_org}/ids.txt', 'r') as f:
    ids_org = f.read()

with open (f'{output_dir_org}/membrane_tmbed.txt', 'r') as f:
    membrane_tmbed_org = f.read()

assert ids_onnx == ids_org, "IDs of nnx tmbed model and original tmbed model output are NOT identical!"
assert membrane_tmbed_onnx == membrane_tmbed_org, "Membrane output of onnx tmbed model and original tmbed model are NOT identical!"