Placing all of the code from *generate_RMSD_graphs.py*

In [40]:
import argparse
from datetime import datetime
from glob import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import matplotlib
import os.path
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import re

matplotlib.rcParams.update({"figure.facecolor": "white"})

In [19]:
def getPlottingDataFrame(path, col_names, header, delim, usecols, key, unique,exclusive_tag=None, rmsd_good = 2, max_N=9):  # unique is the number of unique receptor-ligand systems that should exist
    # Generates a dataframe used for the graph making functions
    # Each row of dataframe is a cumulation of statistics of all poses before and the current pose
    # Each row has the percentage of Receptor-Ligand Systems with less than 1, 2, and 3 RMSD for 'good1', 'good2', and 'good3' respectively
    initial_df = pd.read_csv(path, header=header, sep=delim, usecols=usecols)
    initial_df.columns = col_names
    initial_df['good'] = (initial_df['rmsd'] < rmsd_good)
    tags = initial_df[key[1]].unique()
    tags.sort()
    new_datafs = []
    new_ranges = []
    final_dataframe = pd.DataFrame(index=list(range(1, max_N+1)),columns=tags)
    for tag in tags:
        if exclusive_tag is not None and tag != exclusive_tag:
            print(tag)
            continue
        df_tagonly = initial_df[initial_df[key[1]] == tag]
        grouped_tagonly = df_tagonly.groupby(key[0])
        print(f"{tag}") 
        assert len(df_tagonly[key[0]].unique()) == unique, f"Doesn't have the right number of systems, should have {unique}, but has {len(df_tagonly[key[0]].unique())}"
        idx = grouped_tagonly.nth(0).index
        maxrange = 9
        rang = list(range(1, max_N+1))
        base = pd.DataFrame(None, index = idx, columns=['good'])
        combin_top_df = pd.DataFrame(None, index = rang, columns=['good']) #non_def_top -> combin_top_df
        top_bools = grouped_tagonly.nth(0)[['good']] #non_def_last -> top_bools
        combin_top_df.loc[1] = [top_bools['good'].mean()*100]
        for r in range(1, maxrange):
            cur_row = base.combine_first(grouped_tagonly.nth(r)[['good']]).fillna(False)
            top_bools = cur_row | top_bools
            combin_top_df.loc[r+1] = [top_bools['good'].mean()*100]
        final_dataframe[tag] = combin_top_df['good']

    return final_dataframe

In [9]:
def filter_csv(subset_file, remove_files=['2017_general.INDEX','Crossdock2020_Lig.txt','Crossdock2020_Prot.txt'],new_suffix="no2017_noCD2020"):                                                     
    subset_csv = pd.read_csv(subset_file,sep=',')
    
    subset_csv['pdbid'] = subset_csv['rec'].apply(lambda x: x.split('/')[-1])  
    for filename in remove_files:      
        with open(filename) as remove_file:      
            remove_recs = remove_file.readlines()      
        remove_recs = [(rec.strip()).upper() for rec in remove_recs]          
        subset_csv = subset_csv[~subset_csv['pdbid'].isin(remove_recs)]                      
        subset_csv = subset_csv[~subset_csv['lig'].isin(remove_recs)]                                  
    subset_csv.drop(['pdbid'],axis=1,inplace=True)
    
    subset_name = f"{subset_file.split('.')[0]}_{new_suffix}.csv"
    subset_csv.to_csv(subset_name,sep=',',index=False)
    
    return subset_name

# Make DataFrame for graphs
Dataframe will be [experiment]x[number_poses] and each 

In [10]:
key = ['rec','tag']
delimiter=','
col_names=['tag', 'rmsd', 'rec']
header='infer'
num_unique=4260

### paths and names

In [5]:
%ls -lthr /home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/

total 1.3G
-rw-r--r-- 1 anm329 anm329  16M Oct  1 13:11 cd_othmodels.csv
-rw-r--r-- 1 anm329 anm329  16M Oct  1 13:11 dense_othmodels.csv
-rw-r--r-- 1 anm329 anm329  16M Oct  1 13:11 gend_othmodels.csv
-rw-rw-r-- 1 anm329 anm329  21M Oct  1 18:04 cd_allmodels.csv
-rw-rw-r-- 1 anm329 anm329  21M Oct  5 10:26 dense_allmodels.csv
-rw-r--r-- 1 anm329 anm329 3.9M Oct  5 10:37 redock_ensemble.csv
-rw-r--r-- 1 anm329 anm329  20M Oct  5 10:37 redock_models.csv
-rw-rw-r-- 1 anm329 anm329  20M Oct  5 12:01 gend_allmodels.csv
-rw-r--r-- 1 anm329 anm329  24M Oct  6 11:01 dense4_cdmodels.csv
-rw-r--r-- 1 anm329 anm329  17M Oct  6 11:01 dense4_densemodels.csv
-rw-r--r-- 1 anm329 anm329  24M Oct  6 11:01 dense4_gendmodels.csv
-rw-r--r-- 1 anm329 anm329  24M Oct  6 11:01 dense4_redockmodels.csv
-rw-r--r-- 1 anm329 anm329 4.4M Oct  6 11:02 new_allens.csv
-rw-rw-r-- 1 anm329 anm329 3.9M Oct  6 12:18 redock_ensemble_comb.csv
-rw-rw-r-- 1 anm329 anm329  17M Oct  6 12:22 all_ensembles.csv
-

In [11]:
basepath='/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/'
datapath='/home/anm329/GNINA1.0/data/redocking/'

# Rescore vs Refine (with Vina)

In [24]:
figure_name = 'rescore_vs_refine'
files = [(f'{basepath}final_def_ensemble_refine_defaults.csv','default_ensemble_refinement_defaults') ]
files.append((f'{basepath}nocnn_11_2.csv',None)) # Vina file
files.append((f'{basepath}final_def_ensemble_rescore_defaults.csv','default_ensemble_rescore_defaults'))
final_dataframe = pd.DataFrame(index=list(range(1,10)))
# list_of_dataframes = []
# list_of_ranges = []
names = ['Default Ensemble Refine','Vina','Default Ensemble Rescore']
big_df = pd.DataFrame(index=list(range(1,10)))
for file,exclusive in files:
    if 'nocnn' not in file:
        use_cols=[0,2,7]
    else:
        use_cols=[0,2,4]
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique,exclusive_tag=exclusive)
    big_df = big_df.join(plot_df)
big_df.columns = names

default_ensemble_refinement_defaults
none_defaults
default_ensemble_rescore_defaults


In [26]:
# big_df.to_csv(f'{datapath}{figure_name}.csv')
final_vina = big_df['Vina']
final_defens = big_df['Default Ensemble Rescore']
final_defens_ref = big_df['Default Ensemble Refine']

# Single Model CSV

In [29]:
files = [(f'{basepath}single_models_11_2.csv',None) ]
figure_name = 'rescore_single_models'
big_df = pd.DataFrame(index=list(range(1,10)))
for file,exclusive in files:
    if 'nocnn' not in file:
        use_cols=[0,2,7]
    else:
        use_cols=[0,2,4]
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique,exclusive_tag=exclusive)
    big_df = big_df.join(plot_df)

crossdock_default2018_0_rescore_defaults
default2017_rescore_defaults
dense_0_rescore_defaults
general_default2018_0_rescore_defaults
redock_default2018_0_rescore_defaults


In [30]:
names = ['Crossdock Default2018','Default2017','Crossdock Dense','General Default2018','Redock Default2018']
big_df.columns = names
big_df['Vina'] = final_vina
big_df['Default Ensemble'] = final_defens

In [9]:
big_df.to_csv(f'{datapath}{figure_name}.csv')

# Ensemble CSV

In [32]:
figure_name = 'rescore_ensembles'
files = [(f'{basepath}ensembles_11_18.csv',None) ]
big_df = pd.DataFrame(index=list(range(1,10)))
for file,exclusive in files:
    if 'nocnn' not in file:
        use_cols=[0,2,7]
    else:
        use_cols=[0,2,4]
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique,exclusive_tag=exclusive)
    big_df = big_df.join(plot_df)

all_ensemble_rescore_defaults
crossdock_default2018_01234_rescore_defaults
default_ensemble_rescore_defaults
dense_01234_rescore_defaults
general_default2018_01234_rescore_defaults
redock_default2018_01234_rescore_defaults


In [33]:
names = ['All Ensemble','Crossdock Default2018 Ensemble','Default Ensemble','Crossdock Dense Ensemble','General Default2018 Ensemble','Redock Default2018 Ensemble']
big_df.columns = names
big_df['Vina'] = final_vina

In [9]:
big_df.to_csv(f'{datapath}{figure_name}.csv')

### No PDBbind2017 and No CD2020

In [35]:
figure_name = 'ensemble_models_no2017_nocd2020'
files = [(f'{basepath}ensembles_11_18.csv',None) ]
files.append((f'{basepath}nocnn_11_2.csv',None)) # Vina file
big_df = pd.DataFrame(index=list(range(1,10)))
for file,exclusive in files:
    file  = filter_csv(file)
    if 'nocnn' not in file:
        use_cols=[0,2,7]
    else:
        use_cols=[0,2,4]
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, 441,exclusive_tag=exclusive)
    big_df = big_df.join(plot_df)

all_ensemble_rescore_defaults
crossdock_default2018_01234_rescore_defaults
default_ensemble_rescore_defaults
dense_01234_rescore_defaults
general_default2018_01234_rescore_defaults
redock_default2018_01234_rescore_defaults
none_defaults


In [36]:
names = ['All Ensemble','Crossdock Default2018 Ensemble','Default Ensemble','Crossdock Dense Ensemble','General Default2018 Ensemble','Redock Default2018 Ensemble','Vina']
big_df.columns = names

In [16]:
big_df.to_csv(f'{datapath}{figure_name}.csv')

Unnamed: 0,All Ensemble,Crossdock Default2018 Ensemble,Default Ensemble,Crossdock Dense Ensemble,General Default2018 Ensemble,Redock Default2018 Ensemble,Vina
1,70.068,60.9977,68.0272,66.6667,63.4921,68.4807,57.1429
2,81.8594,73.6961,81.8594,79.3651,74.1497,79.1383,67.8005
3,84.5805,80.0454,85.2608,82.9932,81.6327,82.5397,73.0159
4,86.6213,83.9002,88.2086,84.8073,85.2608,84.8073,76.8707
5,87.0748,86.3946,89.3424,86.6213,86.8481,85.4875,79.5918
6,87.5283,87.9819,90.0227,87.5283,88.6621,86.6213,81.4059
7,88.4354,89.3424,90.4762,88.2086,89.5692,87.0748,83.4467
8,88.8889,89.3424,90.7029,88.4354,90.0227,87.7551,83.9002
9,89.1156,89.7959,90.9297,88.6621,90.0227,87.9819,84.8073


# Doing all of the Sweeps
Not whole protein and not exhaustiveness

In [46]:
for file in glob(f'{basepath}*sweep*.csv'):
    outname='sweep_'
    use_cols=[0,2,7]
    if 'cnnrot' in file:
        default=0
        outname+='cnnrot'
    elif 'exh' in file:
        continue
    elif 'rmsdf' in file:
        default=1.0
        outname+='rmsdf'
    elif 'aadd' in file:
        default=4
        outname+='autobox_add'
    elif 'nmodes' in file:
        default=9
        outname+='num_modes'
    elif 'mcsaved' in file:
        default=50
        outname+='mcsaved'
    print(f'{file}:{outname}')
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique)
    new_columns = []
    for column in plot_df.columns:
        if len(re.findall(r'[0-9\.]+$',column)):
            new_columns.append(re.findall(r'[0-9\.]+$',column)[0])
        else:
            new_columns.append(default)
    plot_df.columns = [f'DefE_{val}' for val in new_columns]
    print()
#     plot_df.to_csv(f'{datapath}{out_name}.csv')

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/cnnrotsweep_11_2.csv:sweep_cnnrot
default_ensemble_rescore_cnn_rotation1
default_ensemble_rescore_cnn_rotation10
default_ensemble_rescore_cnn_rotation20
default_ensemble_rescore_cnn_rotation5
default_ensemble_rescore_defaults

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/rmsdfsweep_11_2.csv:sweep_rmsdf
defaults
min_rmsd_filter0.5
min_rmsd_filter1.5

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/aaddsweep_11_2.csv:sweep_autobox_add
default_ensemble_rescore_autobox_add2
default_ensemble_rescore_autobox_add6
default_ensemble_rescore_autobox_add8
default_ensemble_rescore_defaults

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/nmodessweep_11_2.csv:sweep_num_modes
default_ensemble_rescore_defaults
default_ensemble_rescore_num_modes100

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/mcsavedsweep_11_2.csv:sweep_mcsaved
default_ensemble_rescore_defaults
default_ensemble_rescore_num_mc_saved100
default_ensemble

### Exhaustiveness Defined Pocket

In [53]:
big_df = pd.DataFrame(index=list(range(1,10)))
for file in glob(f'{basepath}*exh*sweep*.csv'):
    if 'wp' in file:
        continue
    outname='sweep_exhaustiveness'
    default=8
    if 'vina' in file:
        use_cols=[0,2,4]
        base_col = 'Vina_'
    else:
        use_cols=[0,2,7]
        base_col = 'DefE_'
    print(f'{file}:{base_col}')
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique)
    new_columns = []
    for column in plot_df.columns:
        if len(re.findall(r'[0-9\.]+$',column)):
            new_columns.append(re.findall(r'[0-9\.]+$',column)[0])
        else:
            new_columns.append(default)
    plot_df.columns = [f'{base_col}{val}' for val in new_columns]
    big_df = big_df.join(plot_df)

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/final_def_ensemble_exhaustivess_sweep.csv:DefE_
_defaults
_exhaustiveness16
_exhaustiveness4
/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/vina_exh_sweep.csv:Vina_
vina_scoring_exhaustiveness16
vina_scoring_exhaustiveness4


In [54]:
big_df['Vina_8'] = final_vina ## only run this line if your exhaustiveness sweep csv did not include the default vina exhaustiveness

In [55]:
big_df.to_csv(f'{datapath}{out_name}.csv')

Unnamed: 0,DefE_8,DefE_16,DefE_4,Vina_16,Vina_4,Vina_8
1,72.6291,73.4977,70.493,58.0516,57.6056,57.6291
2,81.7371,81.9249,79.1784,67.6056,68.3333,68.1455
3,85.1878,85.4225,82.6526,73.2864,73.6385,73.7089
4,86.9014,87.4413,84.2723,76.9249,77.4178,77.4413
5,87.9343,88.8498,85.0939,79.3897,79.6009,79.8357
6,88.4977,89.6479,85.6338,81.4554,81.1268,81.9249
7,89.1315,90.2817,86.0094,82.6526,82.3005,83.3568
8,89.507,90.7746,86.4085,83.9202,83.2864,84.4601
9,89.9061,91.2207,86.6197,85.1174,83.9671,85.446


## Threshold CNNscore

In [69]:
file = f'{basepath}ensembles_11_18.csv'
gnina_vals = pd.read_csv(file,sep=',')
grouped = gnina_vals.groupby('tag')
final_dataframe = pd.DataFrame(index=list(np.linspace(0,0.99,100)))
for name, group in grouped:
    if 'all' in name:
        col = 'All Ensemble'
    elif 'crossdock_default2018' in name:
        col = 'Crossdock Default2018 Ensemble'
    elif 'general_default2018' in name:
        col = 'General Default2018 Ensemble'
    elif 'redock_default2018' in name:
        col = 'Redock Default2018 Ensemble'
    elif 'dense' in name:
        col = 'Crossdock Dense Ensemble'
    elif 'default_ensemble' in name:
        col='Default Ensemble'
    gv_rec = group.groupby('rec')
    perc_good = []
    frac_left = []
    for thresh in np.linspace(0,0.99,100):
        gt_thresh = gv_rec.nth(0)[gv_rec.nth(0)['cnnscore'] >= thresh]['rmsd'] < 2
        perc_good.append(sum(gt_thresh)/len(gt_thresh)*100)
        frac_left.append(len(gt_thresh)/4260*100)
    final_dataframe[f"{col} Good"] = perc_good
    final_dataframe[f"{col} Left"] = frac_left

Index(['tag', 'molids', 'rmsd', 'cnnscore', 'cnnaffinity', 'minimizedAffinity',
       'pocket', 'rec', 'lig'],
      dtype='object')


In [70]:
# Only need to run this if the default ensemble is not included in the ensemble csv
# assumes that the file only contains the default run of the Default Ensemble
file = f'{basepath}final_def_ensemble_rescore_defaults.csv'
gnina_vals = pd.read_csv(file,sep=',',header=None)
gnina_vals.columns = ['tag', 'molids', 'rmsd', 'cnnscore', 'cnnaffinity', 'minimizedAffinity',
       'pocket', 'rec', 'lig']
perc_good = []
frac_left = []
col='Default Ensemble'
gv_rec = gnina_vals.groupby('rec')
for thresh in np.linspace(0,0.99,100):
    gt_thresh = gv_rec.nth(0)[gv_rec.nth(0)['cnnscore'] >= thresh]['rmsd'] < 2
    perc_good.append(sum(gt_thresh)/len(gt_thresh)*100)
    frac_left.append(len(gt_thresh)/4260*100)
final_dataframe[f"{col} Good"] = perc_good
final_dataframe[f"{col} Left"] = frac_left

In [71]:
final_dataframe.to_csv(f'{datapath}thresh_cnnscore_ensembles.csv')

Unnamed: 0,All Ensemble Good,All Ensemble Left,Crossdock Default2018 Ensemble Good,Crossdock Default2018 Ensemble Left,Default Ensemble Good,Default Ensemble Left,Crossdock Dense Ensemble Good,Crossdock Dense Ensemble Left,General Default2018 Ensemble Good,General Default2018 Ensemble Left,Redock Default2018 Ensemble Good,Redock Default2018 Ensemble Left
0.00,72.488263,100.000000,66.173709,100.000000,72.629108,100.000000,68.826291,100.000000,66.408451,100.000000,70.046948,100.000000
0.01,72.488263,100.000000,66.173709,100.000000,72.629108,100.000000,68.826291,100.000000,66.408451,100.000000,70.046948,100.000000
0.02,72.488263,100.000000,66.173709,100.000000,72.629108,100.000000,68.826291,100.000000,66.408451,100.000000,70.046948,100.000000
0.03,72.488263,100.000000,66.173709,100.000000,72.629108,100.000000,68.826291,100.000000,66.408451,100.000000,70.046948,100.000000
0.04,72.488263,100.000000,66.173709,100.000000,72.629108,100.000000,68.826291,100.000000,66.408451,100.000000,70.046948,100.000000
...,...,...,...,...,...,...,...,...,...,...,...,...
0.95,88.603531,43.873239,84.006462,43.591549,88.140417,49.483568,78.803681,76.525822,81.397638,47.699531,85.283019,49.765258
0.96,89.628681,36.666667,85.069009,37.417840,89.900662,42.535211,80.386379,71.690141,82.375906,42.089202,85.942492,44.084507
0.97,90.724166,28.849765,86.960000,29.342723,90.958532,34.530516,82.285714,65.727700,84.475806,34.929577,87.365761,37.159624
0.98,92.441140,18.943662,90.621336,20.023474,92.263610,24.577465,85.038363,55.070423,86.486486,26.056338,89.058524,27.676056


# Whole protein
This is only one CSV with Vina and Default Ensemble with 8,16,32,64 exhaustiveness 

In [84]:
big_df = pd.DataFrame(index=list(range(1,10)))
for file in glob(f'{basepath}*fullprot*.csv'):
    if 'nocnn' not in file and 'exhaustiveness' not in file:
        continue
    if 'nocnn' in file:
        use_cols=[0,2,4]
        base_col = 'Vina_'
    else:
        use_cols=[0,2,7]
        base_col = 'DefE_'
    default=8
    print(f'{file}:{base_col}')
    plot_df = getPlottingDataFrame(file, col_names, header, delimiter, use_cols, key, num_unique)
    new_columns = []
    for column in plot_df.columns:
        if len(re.findall(r'[0-9\.]+$',column)):
            new_columns.append(re.findall(r'[0-9\.]+$',column)[0])
        else:
            new_columns.append(default)
    plot_df.columns = [f'{base_col}{val}' for val in new_columns]
    big_df = big_df.join(plot_df)

/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/fullprot_nocnn_11_2.csv:Vina_
none_exhaustiveness16
none_exhaustiveness32
none_exhaustiveness64
none_exhaustiveness8
/home/anm329/Docking/cnn_gnina/gnina_out/new_pipeline/fullprot_exhaustiveness_11_17.csv:DefE_
defaults
exhaustiveness16
exhaustiveness32
exhaustiveness64


In [85]:
final_dataframe.to_csv(f'{datapath}whole_ptn_sweep_exhaustiveness.csv')

Unnamed: 0,Vina_16,Vina_32,Vina_64,Vina_8,DefE_8,DefE_16,DefE_32,DefE_64
1,35.1174,37.3474,38.5446,31.0329,38.0282,43.4038,47.4883,50.1408
2,43.2864,47.6761,50.3052,37.9577,44.554,51.6432,57.5822,61.385
3,47.77,52.8169,55.5869,41.4085,46.4789,54.6009,61.7371,65.7746
4,50.3521,56.1502,59.4131,43.0047,47.2066,55.9155,63.615,67.8169
5,51.9953,58.216,61.5962,43.9906,47.6995,56.6667,64.7653,68.7324
6,52.7465,59.3897,63.4038,44.6009,47.8873,57.0423,65.6103,69.4366
7,53.6854,60.3991,64.4366,45.0939,48.0282,57.2066,66.0329,69.7887
8,54.2958,61.0563,65.5634,45.7042,48.1925,57.5587,66.4085,70.1174
9,54.8357,61.8545,66.2911,46.1737,48.2629,57.6995,66.5962,70.3521
