In [2]:
###
# File name: Evaluate.ipynb (jupyter notebook)
# Description: evaluate generated sequences with references in test set
# Created on: 2024-05-22 (Wed.)
# Modification History
#   - 2024-05-22 (edited by Gyumin Lee): Add BLEU, GLEU, ROUGE; Implement beam search for comparisons
#   - 2024-05-23 (edited by Gyumin Lee): Add QED, SA_score, NP_score, ADMET
# Version: 0.1
###

# Load data

In [3]:
import numpy as np
import pandas as pd
pd.set_option('display.max_rows',200,'display.max_columns',50)
import csv
import time
import pickle
import os
import sys

In [4]:
root = "/home2/glee/Drug_Discovery_Research"
data_path = os.path.join(root, "data")
result_path = os.path.join(root, "results")
test_set = pd.read_csv(os.path.join(data_path, "splitted", "test.csv"))
if "Unnamed: 0" in test_set.columns: test_set = test_set.drop("Unnamed: 0", axis=1)

In [5]:
test_set

Unnamed: 0,T_seq,C_anc
0,MSVGAMKKGVGRAVGLGGGSGCQATEEDPLPNCGACAPGQGGRRWR...,CCCNCC1=CC=C(C=C1)CCN2C=CC(=CC2=O)OCC3=CC=C(C=...
1,MSVGAMKKGVGRAVGLGGGSGCQATEEDPLPNCGACAPGQGGRRWR...,CN(C)C1=NC(=NC2=CC=CC=C21)NC3CCC(CC3)NC(=O)C4=...
2,MDLQASLLSTGPNASNISDGQDNFTLAGPPPRTRSVSYINIIMPSV...,CC1C2CN(CCC2CC3=C1C4=C(N3)C=CC(=C4)C(F)(F)F)CC...
3,MVNLRNAVHSFLVHLIGLLVWQCDISVSPVAAIVTDIFNTSDGGRF...,C1=CC2=C(C(=C1)O)C(=CN2)C3=CN=C(N=C3N)Br
4,MVNLRNAVHSFLVHLIGLLVWQCDISVSPVAAIVTDIFNTSDGGRF...,C1CN(CCN1CCCCN2C(=O)C3C4CC(C3C2=O)C=C4)C5=C6C=...
...,...,...
8203,ATKAARKSAPATGGVKKPHRYRPGGK,CCN(C1CCC(CC1)N(C)C)C2=CC(=CC(=C2C)C(=O)NCC3C(...
8204,MNFLLSWVHWSLALLLYLHHAKWSQAAPMAEGGGQNHHEVVKFMDV...,C1CCN(C1)CCCCNC(=O)NC2=C(C(=NS2)OCC3=C(C=C(C=C...
8205,HLLDFRKMIRYTTGKEATTSYGAYGCHCGVGGRGAPKAKFLSYKFS...,CCCCCCCCCCCC(=O)C1=C(C(=C(C(=C1O)C2CC(OC3=C2C=...
8206,MSGPTMDHQEPYSVQATAAIASAITFLILFTIFGNALVILAVLTSR...,C1CN(CCN1CC2=CC3=C(C=C2)NC(=O)CO3)C4=CC=C(C=C4)Cl


In [6]:
test_set_per_T = test_set.set_index("T_seq")["C_anc"]

# Beam evaluation

## Extract CT_adj, CTCT_adj for test set

In [8]:
sys.path.append(os.path.join(root, "src"))
from utils import save_hdf5, load_hdf5

In [None]:
## Load adjacency matrices
CT_adj = load_hdf5(os.path.join(data_path, "adjacency_matrices", "sps_CT_adj.h5")).toarray()
CTC_adj = load_hdf5(os.path.join(data_path, "adjacency_matrices", "sps_CTC_adj.h5")).toarray()
CTCT_adj = load_hdf5(os.path.join(data_path, "adjacency_matrices", "sps_CTCT_adj.h5")).toarray()

## Make T-C, T-C-T-C adjacency matrix; Leave 2-hop elements only
only_CTCT = (CTCT_adj.astype(np.int8) - CT_adj.astype(np.int8)).astype(bool)
TC_adj = CT_adj.T
TCTC_adj = CTCT_adj.T
only_TCTC = only_CTCT.T

In [9]:
## Load dictionaries
with open(os.path.join(data_path, "dictionaries", "i_to_C_cid.pickle"), "rb") as f:
    i_to_C_cid = pickle.load(f)
with open(os.path.join(data_path, "dictionaries", "i_to_T_id.pickle"), "rb") as f:
    i_to_T_id = pickle.load(f)
with open(os.path.join(data_path, "dictionaries", "i_to_C_seq.pickle"), "rb") as f:
    i_to_C_seq = pickle.load(f)
with open(os.path.join(data_path, "dictionaries", "i_to_T_seq.pickle"), "rb") as f:
    i_to_T_seq = pickle.load(f)
with open(os.path.join(data_path, "dictionaries", "C_seq_to_i.pickle"), "rb") as f:
    C_seq_to_i = pickle.load(f)
with open(os.path.join(data_path, "dictionaries", "T_seq_to_i.pickle"), "rb") as f:
    T_seq_to_i = pickle.load(f)

## Select beams using IC50

In [10]:
data = pd.read_csv(os.path.join(data_path, "preprocessed", "data_T_truncated_20240520.csv"))
data_to_index = pd.concat([data["T_seq"].apply(lambda x: T_seq_to_i[x]), data["C_seq_can_smiles"].apply(lambda x: C_seq_to_i[x]), data["IC50"]], axis=1).rename(columns={"T_seq": "T_index", "C_seq_can_smiles": "C_index"})
test_set_to_index = pd.concat([test_set["T_seq"].apply(lambda x: T_seq_to_i[x]), test_set["C_anc"].apply(lambda x: C_seq_to_i[x])], axis=1).rename(columns={"T_seq": "T_index", "C_anc": "C_index"})

data_pairs = data_to_index.apply(lambda x: str(x["T_index"].astype(int))+"--"+str(x["C_index"].astype(int)), axis=1).values
test_pairs = test_set_to_index.apply(lambda x: str(x["T_index"])+"--"+str(x["C_index"]), axis=1).values
# ic50_pairs = data_to_index.loc[np.isin(data_pairs, test_pairs)]

In [11]:
def find_pair(src, by="T"):
#     if by=="T": return ic50_pairs[ic50_pairs["T_index"]==src]
#     elif by=="C": return ic50_pairs[ic50_pairs["C_index"]==src]
    if by=="T": return data_to_index[data_to_index["T_index"]==src]
    elif by=="C": return data_to_index[data_to_index["C_index"]==src]

In [12]:
def extract_gtruth(T_input, beam_size=3):
    if T_input is None:
        return (np.array([], dtype=np.int32), np.array([], dtype=np.int32))
    C_anc_candi = find_pair(T_input, "T").sort_values("IC50", ascending=False)["C_index"].values # 입력 T_input(T1)에 대해, IC50 Top K(beam_size)만큼 C_anc(C1) 추출

    C_ancs = np.array([], dtype=np.int32)
    C_poss = np.array([], dtype=np.int32)
    C_ancs_bck = np.array([], dtype=np.int32)

    for C_anc in C_anc_candi:
        binding_pairs = find_pair(C_anc, "C")
        assert T_input in binding_pairs["T_index"].values
        T_repurs = binding_pairs[binding_pairs["T_index"]!=T_input].sort_values("T_index", ascending=False)["T_index"].values

        if len(T_repurs) < 1: # repurposed proteins 없으면 C_anc 까지만 체크
            C_ancs_bck = np.concatenate([C_ancs_bck, [C_anc]])
            continue

        C_pos_candi_bck = []
        C_pos = None
        for T_repur in T_repurs: # 각 C_anc에 대해, IC50 Top 1 T_repur(T2) 선정   
            C_pos_candi = find_pair(T_repur, "T").sort_values("IC50", ascending=False)["C_index"].values # T_repur(T2)에 대해, IC50 Top K(beam_size)만큼 C_pos(C2) 추출
            if len(C_pos_candi) < beam_size:
                if len(C_pos_candi_bck) < len(C_pos_candi): C_pos_candi_bck = C_pos_candi
                continue
            else:
                C_pos = C_pos_candi[:beam_size]
                break
        
        if C_pos is None: C_pos = C_pos_candi_bck

        C_ancs = np.concatenate([C_ancs, [C_anc]])
        C_poss = np.concatenate([C_poss, C_pos])

        if len(C_ancs) >= beam_size: break

    if len(C_ancs) < beam_size: C_ancs = np.concatenate([C_ancs, C_ancs_bck[:beam_size - len(C_ancs)]])
        
    C_ancs_seq = [i_to_C_seq[c] for c in C_ancs]
    C_poss_seq = [i_to_C_seq[c] for c in C_poss]

    return [C_ancs_seq, C_poss_seq]

## Evaluation metric

### Sequence generation

In [13]:
from nltk.translate import bleu, gleu
from rouge_score.rouge_scorer import RougeScorer

In [14]:
def seq_to_index(tseq):
    if tseq in list(T_seq_to_i.keys()):
        return T_seq_to_i[tseq]
    else:
        return None

In [15]:
def cal_gen_metric(references, candidate, metric="bleu", ngram=4, smoothing_function=None):
    if metric=="bleu":
        bleu_weights = [tuple(np.ones(ngram)/ngram)]
        return np.mean([bleu([ref], candidate, weights=bleu_weights, smoothing_function=smoothing_function) for ref in references])
    elif metric=="gleu":
        min_ngram = 1
        max_ngram = ngram
        return np.mean([gleu([ref], candidate, min_len=min_ngram, max_len=max_ngram) for ref in references])
    elif metric=="rouge":
        rouge_scorer = RougeScorer([f'rouge{ngram}'], use_stemmer=True)
        return np.mean([rouge_scorer.score(ref, candidate)[f'rouge{ngram}'].fmeasure for ref in references])
    else:
        print("Not implemeneted")

In [16]:
from nltk.translate.bleu_score import SmoothingFunction
chencherry = SmoothingFunction()

In [None]:
final_results.to_csv(os.path.join(result_path, "eval_results_20240522", f"Total_eval_result.csv"))

In [None]:
final_results

### Drug-likeness

In [21]:
from copy import copy
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import QED
from rdkit.Chem.Descriptors import MolWt
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir,'SA_Score'))
from rdkit.Contrib.SA_Score import sascorer
from rdkit.Contrib.NP_Score import npscorer

In [94]:
NP_fscore = npscorer.readNPModel()
def cal_drug_metric(candidate, metric="QED"):
    m = Chem.MolFromSmiles(candidate)
    if m is None: return np.nan
    if metric=="QED":
        return QED.qed(m)
    elif metric=="SA":
        return sascorer.calculateScore(m)
    elif metric=="NP":
        return npscorer.scoreMol(m, NP_fscore)
    else:
        print("Not implemented")

reading NP model ...
model in


In [85]:
model_names[5]

'sum+mlp+ifft_L2'

In [84]:
res = pd.read_csv(os.path.join(result_path, "gen_results_20240522", f"{model_names[5]}_gen_result.csv"))

In [68]:
m = Chem.MolFromSmiles(sample)

In [47]:
sample_sa = sascorer.calculateScore(m)
sample_np = npscorer.scoreMol(m, np_fscore)

In [39]:
sample_sa

2.1221160204223253

In [48]:
sample_np

-1.1342017834469031

In [51]:
sample_qed = QED.qed(m)

In [52]:
sample_qed

0.5213178889915726

## Evaluation results

In [92]:
model_names = [
    "sum",
    "sum+fft",
    "sum+mlp",
    "sum+fft+mlp",
    "sum+mlp+ifft_L1",
    "sum+mlp+ifft_L2",
    "sum+mlp+ifft_Fro",
    "sum+fft+lpf+ifft_co=2",
    "sum+fft+lpf+ifft_co=4",
    "sum+fft+lpf+ifft_co=6",
    "sum+fft+lpf+ifft_co=8",
]

ngram1, ngram2 = 1, 4
eval_cols = [f'anc/BLUE/{ngram1}-gram', f'pos/BLUE/{ngram1}-gram', 
             f'anc/BLUE/{ngram2}-gram', f'pos/BLUE/{ngram2}-gram', 
             f'anc/GLUE/{ngram1}-gram', f'pos/GLUE/{ngram1}-gram', 
             f'anc/GLUE/{ngram2}-gram', f'pos/GLUE/{ngram2}-gram', 
             f'anc/ROUGE/{ngram1}-gram', f'pos/ROUGE/{ngram1}-gram', 
             f'anc/ROUGE/{ngram2}-gram', f'pos/ROUGE/{ngram2}-gram']

final_results = pd.DataFrame(index=model_names, columns=eval_cols)

In [None]:
for model_name in model_names:
    print(f"Evaluation for {model_name}")
    res = pd.read_csv(os.path.join(result_path, "gen_results_20240522", f"{model_name}_gen_result.csv"))
    res = res[res["gen_compound"].notna()]

    res_to_compare = pd.concat([res, res.apply(lambda x: extract_gtruth(seq_to_index(x["tgt_protein"].replace("\n",""))), axis=1, result_type="expand").rename(columns={0: "C_ancs", 1: "C_poss"})], axis=1)

    res_eval_ngram = res_to_compare.apply(lambda x: 
                                           [cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="bleu", ngram=ngram1, smoothing_function=chencherry.method1), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="bleu", ngram=ngram1, smoothing_function=chencherry.method1), 
                                            cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="bleu", ngram=ngram2, smoothing_function=chencherry.method1), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="bleu", ngram=ngram2, smoothing_function=chencherry.method1), 
                                            cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="gleu", ngram=ngram1), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="gleu", ngram=ngram1), 
                                            cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="gleu", ngram=ngram2), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="gleu", ngram=ngram2), 
                                            cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="rouge", ngram=ngram1), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="rouge", ngram=ngram1),
                                            cal_gen_metric(x["C_ancs"], x["gen_compound"], metric="rouge", ngram=ngram2), 
                                            cal_gen_metric(x["C_poss"], x["gen_compound"], metric="rouge", ngram=ngram2)], axis=1, result_type="expand").rename(
                                                columns={0: f"anc/BLUE/{ngram1}-gram", 1: f"pos/BLUE/{ngram1}-gram", 
                                                         2: f"anc/BLUE/{ngram2}-gram", 3: f"pos/BLUE/{ngram2}-gram", 
                                                         4: f"anc/GLUE/{ngram1}-gram", 5: f"pos/GLUE/{ngram1}-gram",
                                                         6: f"anc/GLUE/{ngram2}-gram", 7: f"pos/GLUE/{ngram2}-gram",
                                                         8: f"anc/ROUGE/{ngram1}-gram", 9: f"pos/ROUGE/{ngram1}-gram",
                                                         10: f"anc/ROUGE/{ngram2}-gram", 11: f"pos/ROUGE/{ngram2}-gram"})

    res_eval = pd.concat([res_to_compare, res_eval_ngram], axis=1)

    res_eval.to_csv(os.path.join(result_path, "eval_results_20240522", f"{model_name}_eval_result.csv"))

    final_results.loc[model_name,:] = res_eval[eval_cols].mean(axis=0)