In [1]:
import torch
from pathlib import Path


from prott5_batch_predictor import SecStructPredictor


root_dir = Path.cwd().parent
model_dir = root_dir / "checkpoints"

sec_struct_model = SecStructPredictor(model_dir=model_dir).load_model(model_dir=model_dir)

In [8]:
import os

def export_sec_struct_to_onnx(sec_struct_model, onnx_file_path=f'{root_dir}/checkpoints/sec_struct_onnx'):
    if not os.path.exists(onnx_file_path):
        os.mkdir(onnx_file_path)

    # Define the dummy input tensor `x` and mask tensor `mask`
    B = 2  # batch size
    N = 505  # sequence length
    C = 1024  # number of input channels/features

    x = torch.randn(B, N, C)

    specific_onnx_file_path = f'{onnx_file_path}/secstruct.onnx'
    # Export the model
    torch.onnx.export(
        sec_struct_model,                               # model being run
        x,                           # model input (or a tuple for multiple inputs)
        specific_onnx_file_path,             # where to save the model
        export_params=True,                  # store the trained parameter weights inside the model file
        opset_version=12,                    # the ONNX version to export the model to
        do_constant_folding=True,            # whether to execute constant folding for optimization
        input_names=['input', 'mask'],       # the model's input names
        output_names=['d3_Yhat', 'd8_Yhat'],             # the model's output names
        dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length', 2: 'embedding_dimension'},
                      'd3_Yhat': {0: 'batch_size'}, 'd8_Yhat': {0: 'batch_size'}}
    )
    print(f"Model has been successfully exported to {specific_onnx_file_path}")

In [9]:
export_sec_struct_to_onnx(sec_struct_model=sec_struct_model)

Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/sec_struct_onnx/secstruct.onnx


# Compare results

In [3]:
from pathlib import Path


root_dir = Path.cwd().parent
output_dir_org = f'{root_dir}/output_sequence_org'
output_dir_onnx = f'{root_dir}/output_sequence_onnx'

with open (f'{output_dir_onnx}/dssp3_pred.txt', 'r') as f:
    sequence_pred3_onnx = f.read()

with open (f'{output_dir_onnx}/dssp8_pred.txt', 'r') as f:
    sequence_pred8_onnx = f.read()

with open (f'{output_dir_org}/dssp3_pred.txt', 'r') as f:
    sequence_pred3_org = f.read()

with open (f'{output_dir_org}/dssp8_pred.txt', 'r') as f:
    sequence_pred8_org = f.read()

with open (f'{output_dir_org}/ids.txt', 'r') as f:
    ids_org = f.read()

with open (f'{output_dir_onnx}/ids.txt', 'r') as f:
    ids_onnx = f.read()

assert sequence_pred3_onnx == sequence_pred3_org, "DSSP3 predictions are NOT identical"
assert sequence_pred8_onnx == sequence_pred8_org, "DSSP8 predictions are NOT identical"