In [8]:
import pickle as pkl
from tape import ProteinBertForValuePredictionFragmentationProsit
%config Completer.use_jedi = False

import onnx

from onnx_tf.backend import prepare
import onnxruntime
import torch
import numpy as np

### Loading PyTorch model

In [9]:
pytorch_model = ProteinBertForValuePredictionFragmentationProsit.from_pretrained("/sdd/prometheus/hcd_maximus_lr_1e-4_d_0.3")

### Load data to test converted models

In [10]:
def pad_sequences(sequences, constant_value=0, dtype=None) -> np.ndarray:
    batch_size = len(sequences)
    shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()

    if dtype is None:
        dtype = sequences[0].dtype

    if isinstance(sequences[0], np.ndarray):
        array = np.full(shape, constant_value, dtype=dtype)
    elif isinstance(sequences[0], torch.Tensor):
        array = torch.full(shape, constant_value, dtype=dtype)

    for arr, seq in zip(array, sequences):
        arrslice = tuple(slice(dim) for dim in seq.shape)
        arr[arrslice] = seq

    return array

In [11]:
results = pkl.load(open("/sdd/prometheus/results.pkl", "rb"))

### Median spectral angle

In [12]:
sequence_arr = np.array([results[1][0]["sequence"]])
sequence = torch.tensor(sequence_arr)

charge_arr = np.array([results[1][0]["charge"]])
charge = torch.tensor(charge_arr)

collision_energy_arr = np.array([results[1][0]["collision_energy"]])
collision_energy = torch.tensor(collision_energy_arr)


input_mask = np.ones_like(sequence)
input_mask_arr = pad_sequences(input_mask, 0)
input_mask = torch.from_numpy(input_mask_arr)

In [21]:
INPUT = (sequence, collision_energy, charge, input_mask)
INP_Name = ["sequence", "charge", "collision_energy", "input_mask"]

In [18]:
o = pytorch_model(sequence, collision_energy, charge, input_mask=input_mask)[0]

In [19]:
real_out = o.cpu().detach().numpy()[0]

### Export PyTorch model

In [22]:
torch.onnx.export(pytorch_model, args=INPUT, f="tape.onnx", input_names=INP_Name,
        output_names=["output1"], export_params=True,opset_version=11, do_constant_folding=True,
                 dynamic_axes={'input_ids': [1], #this means second axis is dynamic       
                                    'input_mask' : [1],
                                    'segment_ids' : [1],
                                    })

### Load model into ONNX-runtime

In [23]:


ort_session = onnxruntime.InferenceSession("tape.onnx")

In [24]:
INP = ["sequence", "charge", "collision_energy", "input_mask"]
ARG = [sequence_arr, collision_energy_arr, charge_arr, input_mask_arr]


In [25]:
ort_inputs = {n:v for n, v in zip(INP,ARG)}
ort_outs = ort_session.run(None, ort_inputs)

In [26]:
onnx_runtime_out = ort_outs[0][0]

In [27]:
np.allclose(real_out, onnx_runtime_out, rtol=1e-02)

True

In [28]:
real_out[~np.isclose(real_out, onnx_runtime_out, rtol=1e-04)]

array([ 1.9282460e-02,  1.7728955e-02,  9.7595900e-04,  1.2260258e-02,
       -2.0688266e-02,  1.7904788e-03,  9.9739060e-05, -1.2320578e-03,
       -8.9403093e-03,  4.0559620e-03, -1.8938631e-03,  2.4975985e-03,
        2.9895268e-04,  3.4641698e-03], dtype=float32)

In [29]:
onnx_runtime_out[~np.isclose(real_out, onnx_runtime_out, rtol=1e-04)]

array([ 1.92847252e-02,  1.77313983e-02,  9.76443291e-04,  1.22627765e-02,
       -2.06859037e-02,  1.79086812e-03,  9.97888856e-05, -1.23066083e-03,
       -8.93931091e-03,  4.05688956e-03, -1.89276040e-03,  2.49802880e-03,
        2.98997387e-04,  3.46497446e-03], dtype=float32)

### Load model to tensorflow

In [46]:
model = onnx.load('tape.onnx')
tf_rep = prepare(model, device='GPU') 

In [47]:
ARG = [sequence_arr, collision_energy_arr, charge_arr, input_mask_arr]

In [48]:
tf_output = tf_rep.run(ARG)[0][0]

In [49]:
tf_output

array([ 5.37016106e+00, -4.10081930e-02,  0.00000000e+00,  9.81307983e+00,
        0.00000000e+00,  0.00000000e+00, -1.23881340e-01,  1.37594953e-01,
       -2.93145934e-03, -1.70865893e-01,  2.27926582e-01,  3.77799273e-02,
        1.59233761e+00,  3.32535028e-01, -3.03387605e-02,  5.49741793e+00,
        1.09856918e-01, -2.83962861e-02, -1.67809629e+00,  4.76672232e-01,
       -3.04797590e-02,  1.93920574e+01,  2.77297616e-01,  4.86301221e-02,
       -2.59120965e+00,  1.70746231e+00, -2.65571792e-02,  1.50965567e+01,
        5.23026466e-01,  9.99324024e-03, -2.87721515e+00,  2.38855209e+01,
        1.07080251e-01,  2.42795992e+00, -4.00280580e-02,  5.08896820e-02,
       -3.15944934e+00,  4.15128860e+01,  4.11418945e-01,  1.54867560e-01,
        5.24484098e-01,  1.53549194e-01, -4.08090973e+00,  8.10036774e+01,
        6.55273080e-01,  1.56922534e-01,  3.68687391e-01, -7.35779703e-02,
       -3.63134956e+00,  3.30631447e+01,  6.96953118e-01, -6.17278516e-02,
        9.77721289e-02,  

In [50]:
real_out

array([ 1.07584782e-01, -1.44136837e-03,  0.00000000e+00,  7.23717734e-02,
        0.00000000e+00,  0.00000000e+00,  2.32030600e-02,  3.12625058e-03,
       -3.04760120e-04,  4.68800664e-02,  5.37447724e-03,  9.82693629e-04,
        3.62140089e-02,  6.47953805e-03, -1.93171436e-04,  7.50481486e-02,
        2.50327517e-03, -6.24939334e-04,  2.33075917e-02,  4.38676495e-03,
       -9.81787685e-04,  5.58439255e-01,  1.82294622e-02,  1.66263001e-03,
        2.34790891e-02,  2.75802407e-02, -1.53056229e-03,  6.27640843e-01,
        2.19078138e-02, -3.40388855e-04,  1.92824602e-02,  2.38126948e-01,
        1.29890000e-03,  1.63167894e-01,  9.12479404e-03,  1.84179016e-03,
        1.77289546e-02,  1.12298667e+00,  1.16544403e-02,  9.75959003e-04,
        8.70637782e-03,  4.46990784e-03,  1.22602582e-02,  2.69511509e+00,
        7.24992678e-02,  6.67252205e-03,  1.05688376e-02, -2.82301707e-03,
       -2.06882656e-02,  1.21580076e+00,  5.49417250e-02,  1.79047883e-03,
        8.62263516e-03,  