# Export Conservation Model

In [2]:
import torch
from pathlib import Path


root_dir = Path.cwd().parent
model_dir = root_dir / "checkpoints"


In [2]:
import os

def export_ConservationModel_to_onnx(conservation_model, onnx_file_path=f'{root_dir}/checkpoints/conservation_onnx'):
    if not os.path.exists(onnx_file_path):
        os.mkdir(onnx_file_path)

    B = 1  # Batch size
    L = 10  # Protein length
    F = 1024  # Number of features
    x = torch.randn(B, L, F)

    # Export the model
    torch.onnx.export(
        conservation_model,                               # model being run
        x,                           # model input (or a tuple for multiple inputs)
        f"{onnx_file_path}/conservation.onnx",             # where to save the model
        export_params=True,                  # store the trained parameter weights inside the model file
        opset_version=12,                    # the ONNX version to export the model to
        do_constant_folding=True,            # whether to execute constant folding for optimization
        input_names=['input'],       # the model's input names
        output_names=['output'],             # the model's output names
        dynamic_axes={'input': {0: 'batch_size', 1: 'protein_length', 2: 'embedding_dimension'},
                      'output': {0: 'batch_size', 1: 'protein_length'}},
    )
    print(f"Model has been successfully exported to {onnx_file_path}")

In [3]:
from prott5_batch_predictor import ConservationPredictor

conservation_model = ConservationPredictor(model_dir=model_dir, use_onnx=False)
export_ConservationModel_to_onnx(conservation_model=conservation_model.model)

Model has been successfully exported to /Users/pschloetermann/IdeaProjects/Biocentral_ohne_original/pgp/checkpoints/conservation_onnx


# Test Conservation onnx model output

In [3]:
output_dir_org = f'{root_dir}/output_conservation_original'
output_dir_onnx = f'{root_dir}/output_conservation_onnx'
with open (f'{output_dir_onnx}/ids.txt', 'r') as f:
    ids_onnx = f.read()

with open (f'{output_dir_onnx}/conservation_pred.txt', 'r') as f:
    conservation_pred_onnx = f.read()

with open (f'{output_dir_org}/ids.txt', 'r') as f:
    ids_org = f.read()

with open (f'{output_dir_org}/conservation_pred.txt', 'r') as f:
    conservation_pred_org = f.read()

assert ids_onnx == ids_org, "IDs of onnx conservation model and original conservation model output are NOT identical!"
assert conservation_pred_onnx == conservation_pred_org, "Membrane output of onnx conservation model and original conservation model are NOT identical!"