In [4]:
import pandas as pd
import keras
from keras.models import load_model
import numpy as np
from tqdm import tqdm

In [5]:
model_path = './models/basic_c_and_n_regression_hydrophobic_only.h5'

df = pd.read_csv('./data/csv/test_acc10_300.csv')

In [6]:
df.head()

Unnamed: 0,lig_id,pro_id,dest,score
0,2701,2701,./processed_data/test_acc10/2701_pro_2701_lig.npy,1
1,2701,2702,./processed_data/test_acc10/2702_pro_2701_lig.npy,0
2,2701,2703,./processed_data/test_acc10/2703_pro_2701_lig.npy,0
3,2701,2704,./processed_data/test_acc10/2704_pro_2701_lig.npy,0
4,2701,2705,./processed_data/test_acc10/2705_pro_2701_lig.npy,0


In [9]:
# group by ligand
# for each group, compute top10 accuracy manually
def test_model(model_path):
    all_largest_ids = []
    all_match_ids = []
    model = load_model(model_path)
    matches = 0
    total = float(len(df.lig_id.unique()))
    debug = False

    for lig_id, grp in tqdm(df.groupby('lig_id')):
        grp.reset_index(inplace=True)

        match_id = grp.index[grp['pro_id']==lig_id]

        # load all features
        probabilities = []
        n_channels = 2
        dims = (24,24,24)
        X = np.empty((len(grp), *dims, n_channels))
        for row in grp.itertuples():
            X[row[0],] = np.load(row.dest)

        probs = model.predict(X, batch_size=200)
        probs = probs.flatten()

        largest_first = list(reversed(np.argsort(probs).tolist()))
        top10 = largest_first[:10]

        if debug:
            print('ligid: ', lig_id)
            print('match_id: ', match_id)
            print('probs: ', probs)
            print('top_10: ', top10)
            print()
            break

        if match_id in top10:
            matches += 1
        all_match_ids.append(match_id)
        all_largest_ids.append(largest_first)

    print(matches/total)    
    return all_match_ids, all_largest_ids

In [10]:
all_match, all_largest = test_model(model_path)


  0%|          | 0/300 [00:00<?, ?it/s][A
  0%|          | 1/300 [00:01<06:21,  1.27s/it][A
  1%|          | 2/300 [00:04<08:41,  1.75s/it][A
  1%|          | 3/300 [00:07<11:36,  2.35s/it][A
  1%|▏         | 4/300 [00:11<13:53,  2.82s/it][A
  2%|▏         | 5/300 [00:15<15:14,  3.10s/it][A
  2%|▏         | 6/300 [00:19<16:14,  3.31s/it][A
  2%|▏         | 7/300 [00:23<16:56,  3.47s/it][A
  3%|▎         | 8/300 [00:26<17:18,  3.56s/it][A
  3%|▎         | 9/300 [00:30<17:35,  3.63s/it][A
  3%|▎         | 10/300 [00:34<17:37,  3.65s/it][A
  4%|▎         | 11/300 [00:38<17:47,  3.69s/it][A
  4%|▍         | 12/300 [00:42<18:01,  3.75s/it][A
  4%|▍         | 13/300 [00:46<18:07,  3.79s/it][A
  5%|▍         | 14/300 [00:49<17:57,  3.77s/it][A
  5%|▌         | 15/300 [00:53<18:00,  3.79s/it][A
  5%|▌         | 16/300 [00:57<17:51,  3.77s/it][A
  6%|▌         | 17/300 [01:00<17:41,  3.75s/it][A
  6%|▌         | 18/300 [01:04<17:38,  3.75s/it][A
  6%|▋         | 19/300 [01:0

0.8866666666666667


In [12]:
paired = list(zip(all_match, all_largest))

acc_3_matches = 0
total = len(paired)
for match, largest in paired:
    top3 = largest[:3]
    if match in top3:
        acc_3_matches +=1
print(acc_3_matches/total)

0.7
