In [1]:
import os
import sys
import tqdm
import pandas as pd

In [2]:
REBADD_LIB_PATH = os.path.abspath(os.pardir)
if REBADD_LIB_PATH not in sys.path:
    sys.path = [REBADD_LIB_PATH] + sys.path

from rebadd.evaluate import evaluate_sr_nov_div

In [3]:
filepath_ref = os.path.join(os.pardir, 'data', 'chembl', 'chembl_test_full.csv')
df_ref = pd.read_csv(filepath_ref)
referece_smiles_iter = df_ref.loc[:,'smiles'].values.tolist()

print(len(referece_smiles_iter))

315


In [4]:
input_dir = 'outputs_4_calculate_properties'

filenames = [f'smi_after.csv.{num:04d}' for num in range(50,550,50)]

## gsk3b_jnk3_qed_sa

In [5]:
frames = []

for modelname in ['gsk3_jnk3_qed_sa']:
    
    for filename in filenames:
        
        filepath = os.path.join(input_dir, modelname, filename)
        
        df = pd.read_csv(filepath)
        
        df = df.loc[:,('smiles', 'gsk3b', 'jnk3', 'qed', 'sa')]
        df.loc[:,'model'] = modelname
        df.loc[:,'checkpoint'] = filename.split('.')[-1]
        
        frames.append(df)
        
        
data = []

for df in tqdm.tqdm(frames):

    s_sr, s_nov, s_div = evaluate_sr_nov_div(df, referece_smiles_iter, 'gsk3b_jnk3_qed_sa')
    
    data.append({'SR':s_sr, 'Nov':s_nov, 'Div':s_div, 'Model':df.loc[0,'model'], 'Ckpt':df.loc[0,'checkpoint']})
    

df_records = pd.DataFrame(data)
df_records.loc[:,'HMean'] = (df_records.loc[:,'SR'] * df_records.loc[:,'Nov'] * df_records.loc[:,'Div']) ** 0.333

df_records

100%|██████████| 10/10 [00:08<00:00,  1.16it/s]


Unnamed: 0,SR,Nov,Div,Model,Ckpt,HMean
0,0.0012,1.0,0.687687,gsk3_jnk3_qed_sa,50,0.09402
1,0.485,0.882812,0.695527,gsk3_jnk3_qed_sa,100,0.668062
2,0.84,0.690909,0.674231,gsk3_jnk3_qed_sa,150,0.731653
3,0.941,0.557292,0.662038,gsk3_jnk3_qed_sa,200,0.70308
4,0.966,0.492228,0.647412,gsk3_jnk3_qed_sa,250,0.675479
5,0.9794,0.427466,0.639243,gsk3_jnk3_qed_sa,300,0.644713
6,0.987,0.403785,0.632903,gsk3_jnk3_qed_sa,350,0.632122
7,0.9778,0.350972,0.628829,gsk3_jnk3_qed_sa,400,0.600123
8,0.9802,0.320435,0.61851,gsk3_jnk3_qed_sa,450,0.579479
9,0.944,0.324324,0.624371,gsk3_jnk3_qed_sa,500,0.576374


## gsk3b_jnk3

In [6]:
frames = []

for modelname in ['gsk3_jnk3']:
    
    for filename in filenames:
        
        filepath = os.path.join(input_dir, modelname, filename)
        
        df = pd.read_csv(filepath)
        
        df = df.loc[:,('smiles', 'gsk3b', 'jnk3', 'qed', 'sa')]
        df.loc[:,'model'] = modelname
        df.loc[:,'checkpoint'] = filename.split('.')[-1]
        
        frames.append(df)
        
        
data = []

for df in tqdm.tqdm(frames):

    s_sr, s_nov, s_div = evaluate_sr_nov_div(df, referece_smiles_iter, 'gsk3b_jnk3')
    
    data.append({'SR':s_sr, 'Nov':s_nov, 'Div':s_div, 'Model':df.loc[0,'model'], 'Ckpt':df.loc[0,'checkpoint']})
    

df_records = pd.DataFrame(data)
df_records.loc[:,'HMean'] = (df_records.loc[:,'SR'] * df_records.loc[:,'Nov'] * df_records.loc[:,'Div']) ** 0.333

df_records

100%|██████████| 10/10 [00:31<00:00,  3.20s/it]


Unnamed: 0,SR,Nov,Div,Model,Ckpt,HMean
0,0.0028,1.0,0.817049,gsk3_jnk3,50,0.132033
1,0.573,0.870923,0.701732,gsk3_jnk3,100,0.705105
2,0.9166,0.655214,0.679703,gsk3_jnk3,150,0.742034
3,0.9756,0.338715,0.66854,gsk3_jnk3,200,0.604825
4,0.9884,0.225213,0.657288,gsk3_jnk3,250,0.527277
5,0.99,0.224155,0.653336,gsk3_jnk3,300,0.525677
6,0.9798,0.381492,0.632205,gsk3_jnk3,350,0.618542
7,0.9776,0.519804,0.623299,gsk3_jnk3,400,0.681918
8,0.9512,0.70166,0.683954,gsk3_jnk3,450,0.770173
9,0.9326,0.678707,0.689564,gsk3_jnk3,500,0.75876


## gsk3b

In [7]:
frames = []

for modelname in ['gsk3']:
    
    for filename in filenames:
        
        filepath = os.path.join(input_dir, modelname, filename)
        
        df = pd.read_csv(filepath)
        
        df = df.loc[:,('smiles', 'gsk3b', 'jnk3', 'qed', 'sa')]
        df.loc[:,'model'] = modelname
        df.loc[:,'checkpoint'] = filename.split('.')[-1]
        
        frames.append(df)
        
        
data = []

for df in tqdm.tqdm(frames):

    s_sr, s_nov, s_div = evaluate_sr_nov_div(df, referece_smiles_iter, 'gsk3b')
    
    data.append({'SR':s_sr, 'Nov':s_nov, 'Div':s_div, 'Model':df.loc[0,'model'], 'Ckpt':df.loc[0,'checkpoint']})
    

df_records = pd.DataFrame(data)
df_records.loc[:,'HMean'] = (df_records.loc[:,'SR'] * df_records.loc[:,'Nov'] * df_records.loc[:,'Div']) ** 0.333

df_records

100%|██████████| 10/10 [00:02<00:00,  3.55it/s]


Unnamed: 0,SR,Nov,Div,Model,Ckpt,HMean
0,0.059,1.0,0.836114,gsk3,50,0.36712
1,0.912,0.989967,0.765699,gsk3,100,0.884324
2,0.997,0.992084,0.670578,gsk3,150,0.872216
3,0.998,0.991848,0.654743,gsk3,200,0.865522
4,1.0,0.995,0.644544,gsk3,250,0.862494
5,1.0,0.997512,0.636851,gsk3,300,0.859774
6,1.0,0.995025,0.63274,gsk3,350,0.857209
7,0.992,0.987539,0.657482,gsk3,400,0.863734
8,0.995,1.0,0.676967,gsk3,450,0.876706
9,0.974,1.0,0.715845,gsk3,500,0.886839


## jnk3

In [8]:
frames = []

for modelname in ['jnk3']:
    
    for filename in filenames:
        
        filepath = os.path.join(input_dir, modelname, filename)
        
        df = pd.read_csv(filepath)
        
        df = df.loc[:,('smiles', 'gsk3b', 'jnk3', 'qed', 'sa')]
        df.loc[:,'model'] = modelname
        df.loc[:,'checkpoint'] = filename.split('.')[-1]
        
        frames.append(df)
        
        
data = []

for df in tqdm.tqdm(frames):

    s_sr, s_nov, s_div = evaluate_sr_nov_div(df, referece_smiles_iter, 'jnk3')
    
    data.append({'SR':s_sr, 'Nov':s_nov, 'Div':s_div, 'Model':df.loc[0,'model'], 'Ckpt':df.loc[0,'checkpoint']})
    

df_records = pd.DataFrame(data)
df_records.loc[:,'HMean'] = (df_records.loc[:,'SR'] * df_records.loc[:,'Nov'] * df_records.loc[:,'Div']) ** 0.333

df_records

100%|██████████| 10/10 [00:08<00:00,  1.14it/s]


Unnamed: 0,SR,Nov,Div,Model,Ckpt,HMean
0,0.029,0.931034,0.79585,jnk3,50,0.278369
1,0.836,0.544156,0.661279,jnk3,100,0.670314
2,0.979,0.395377,0.614073,jnk3,150,0.619745
3,0.995,0.396005,0.603135,jnk3,200,0.619708
4,0.995,0.652838,0.64898,jnk3,250,0.75003
5,0.997,0.756072,0.66799,jnk3,300,0.795748
6,0.99,0.890481,0.722655,jnk3,350,0.860588
7,0.975,0.845511,0.719256,jnk3,400,0.840256
8,0.958,0.896008,0.757476,jnk3,450,0.866452
9,0.932,0.873224,0.754817,jnk3,500,0.850221
