In [1]:
import time
import csv
import onnxruntime as rt
import numpy as np
from pathlib import Path
import json
from meta_deepFRI.DeepFRI.deepfrier.Predictor import Predictor
from meta_deepFRI.DeepFRI.deepfrier.utils import seq2onehot
from meta_deepFRI.utils.bio_utils import protein_letters_1to3
from tqdm import tqdm

# write a wrapper to record time of function execution 
def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time() - start
        return result, end
    return wrapper

@timeit
def predict_tf(pred: Predictor, seq=None, cmap=None):
    if cmap is None:
        result_tf = pred.model(seq, training=False)
    else:
        result_tf = pred.model([cmap, seq], training=False)
    return result_tf

@timeit
def predict_onnx(session, seq, cmap=None):
    inputDetails = session.get_inputs()
    if cmap is None:
        result_onnx = session.run(None, {inputDetails[0].name: seq.astype(np.float32)})
    else:
        result_onnx = session.run(None, {inputDetails[0].name: cmap.astype(np.float32), inputDetails[1].name: seq.astype(np.float32)})
    
    return result_onnx

def generate_random_protein(prot_length):
    aa_array = np.random.choice(list(protein_letters_1to3.keys()), size=(1, prot_length))
    return "".join(list(*aa_array))

2023-04-28 10:40:46.260998: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/binutils/2.37/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/Perl/5.32.0-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/DB/18.1.40-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/libreadline/8.0-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/gettext/0.21-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/ncurses/6.2-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/libxml2/2.9.10-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules/modules/software/XZ/5.2.5-GCCcore-10.2.0/lib:/nfs/nas22/fs2201/biol_micro_unix_modules

In [88]:
model_path =  "/nfs/nas22/fs2202/biol_micro_sunagawa/Projects/EAN/PROPHAGE_REFSEQ_EAN/scratch/databases/meta_deepfri_data/newest_models/trained_models"
with open(model_path + "/model_config.json") as f:
    config = json.loads(f.read())

inference_times = []

for net_type in ["cnn", "gcn"]:
        for mode in ["ec", "mf", "bp", "cc"]:
           
            tf_model_prefix = str(Path(model_path) / config[net_type]["models"][mode].split("/")[-1])
            model_name = Path(tf_model_prefix).name
            print(model_name)
            pred = Predictor(model_prefix = tf_model_prefix, gcn=config[net_type]["gcn"])
            session = rt.InferenceSession(f'../onnx_deepfri_models/{model_name}.onnx', 
                                          providers=['CPUExecutionProvider'])

            # Testing GCNs
            for seed in tqdm(range(1000), total=100):
                np.random.seed(seed)
                # zero out cmap
                cmap = None
                
                # generate random protein and contact map
                prot_len = np.random.randint(60, 1000, size=1)[0]
                seq = generate_random_protein(prot_length=prot_len)
                one_hot = seq2onehot(seq)
                one_hot = one_hot.reshape(1, *one_hot.shape)
                
                if net_type == "gcn":
                    cmap = np.random.randint(0, 2, size=(1, prot_len, prot_len), dtype=int)
                    
                tf_pred, tf_time = predict_tf(pred, seq=one_hot, cmap=cmap)
                onnx_pred, onnx_time = predict_onnx(session, seq=one_hot, cmap=cmap)

                # Skip if predictions all nan
                if np.isnan(onnx_pred[0]).all():
                    continue

                # Compare predictions from TF and ONNX
                assert np.isclose(tf_pred, onnx_pred, atol=10e-5).all()

                # Record inference time
                inference_times.append([prot_len, net_type, mode, tf_time, onnx_time])

            print(f"{model_name} passed the test!")

# write inference to csv
with open("inference_times.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["prot_len", "net_type", "mode", "tf_time", "onnx_time"])
    writer.writerows(inference_times)

DeepCNN-MERGED_enzyme_commission


100%|██████████| 100/100 [00:07<00:00, 13.08it/s]


DeepCNN-MERGED_enzyme_commission passed the test!
DeepCNN-MERGED_molecular_function


100%|██████████| 100/100 [00:07<00:00, 14.10it/s]


DeepCNN-MERGED_molecular_function passed the test!
DeepCNN-MERGED_biological_process


100%|██████████| 100/100 [00:08<00:00, 11.25it/s]


DeepCNN-MERGED_biological_process passed the test!
DeepCNN-MERGED_cellular_component


100%|██████████| 100/100 [00:07<00:00, 14.25it/s]


DeepCNN-MERGED_cellular_component passed the test!
DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_ec


100%|██████████| 100/100 [01:56<00:00,  1.17s/it]


DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_ec passed the test!
DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_mf


100%|██████████| 100/100 [01:59<00:00,  1.19s/it]


DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_mf passed the test!
DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_bp


100%|██████████| 100/100 [02:00<00:00,  1.20s/it]


DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_bp passed the test!
DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_cc


100%|██████████| 100/100 [01:58<00:00,  1.18s/it]

DeepFRI-MERGED_GraphConv_gcd_512-512-512_fcd_1024_ca_10.0_cc passed the test!



