In [None]:
import os
import math
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import scipy
import numpy as np
import ast
import statistics
from sklearn.cluster import KMeans
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold

from sbdd_bench.sbdd_analysis.constants import BLIND_SET_POCKET_IDS, BLIND_SET_PDBS_TO_POCKET_IDS, TASK2_FUNC_TO_PDBS, TASK3_FUNC_TO_PDBS, ALL_PDBS, DPOCKET_TYPES
from sbdd_bench.sbdd_analysis import combine_outputs
from sbdd_bench.sbdd_analysis.task_metrics import task2Metrics, task3Metrics, task4Metrics
from sbdd_bench.sbdd_analysis.eval_metrics import recreatedPLIP

import warnings
import json
import random
from rdkit import RDLogger                        

RDLogger.DisableLog('rdApp.*') 
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
warnings.filterwarnings('ignore')

## STEP 1: SUMMARY

In [2]:
# Get output task DataFrame - assess failure rate per-task
# Same as in run_analysis.py
def get_output_dfs_per_model(inference_out_dirs, inp_task_files, model_name):
    """
    Parameters
    ----------
        inference_out_dir: Dictionary
            Paths to inference files for tasks 1, 2, 3 in format {"1": Path, "2": Path, "3": Path}
        inp_task_file: Dictionary
            Paths to input task files 1, 2, 3 in format {"1": Path, "2": Path, "3": Path}
        model_name: string
            Pocket2Mol, LigBuilderV3, DiffSBDD, or AutoGrow4
            
    Returns
    -------
        List of task1, task2, task3 output DataFrames
    """
    
    all_task_dfs = []
    
    for task_num in sorted(list(inp_task_files.keys()), key=lambda x: int(x)):
        task_output_df = combine_outputs.getOutputDf(
            output_dir=inference_out_dirs[task_num],
            inp_task_file=inp_task_files[task_num],
            model_name=model_name,
            ).output_df
        
        all_task_dfs.append(task_output_df)
        
    return all_task_dfs
    
def get_task_summ_tables(task_dict, task_output_df, num_mols_to_sample=10000):
    """
    Generates summary statistics on number of generated compounds, successful PDBs, and unsuccessful PDBs
    
    Parameters
    ----------
        task_dict: Dictionary
            Tasks dictionary containing PDB IDs and pocket/function identifiers
        task_output_df: Pandas DataFrame
            Output DataFrame for a particular task run
        num_mols_to_sample: int
            Number of samples asked for from generator, default 10000
            
    Returns
    -------
        summ_table: Pandas DataFrame containing summarised run information
    """
    
    task_rows = []
    
    for id_ in list(task_dict.keys()):
        pdb_list = [i.strip().strip("'") for i in task_dict[id_].split(",")]
        
        if pdb_list[0] == "not 7gax":
            pdb_list = ['4zh8','7lay','4yq3','4l7s','3s92','4yjp','4yq0','4yqs','4yqn','4yqr','4yqo','4yju','4y79','4yq1','6ddi','4ypy','4yjv','4yqi','4y76','4yq5',
                        '4y6d','4yjo','4y7b','4yqg','1fbz','1zyu','4yqt','7mrd','4yqp','4yqc','4y71','3lxn','4yqj','4yq9','4zha','5wo4','4yqa','4y7a','4yjr','4bq3',
                        '4yqb','7usg','4yqk','4af3','4yq7','4yq8','7q6h','5ig6','4ypz','4yql','4yq2','4yjt','4yq4','4yqd','7l99','4yq6','4yjq','4yjs','4ypx','7q7k',
                        '5d9f','4ypw','4yqq']
        
        failed_pdbs = []
        successful_pdbs = []
        num_gen_mols = 0
        perc_gen_mols = 0
        
        for pdb in pdb_list:
            pdb_spec_df = task_output_df[task_output_df["PDB ID"] == pdb]

            failed_pdb_files = [str(i) for i in pdb_spec_df["mol_pred"] if "Errno" in str(i)]

            if len(failed_pdb_files) != 0:
                failed_pdbs.append(pdb)
            else:
                successful_pdbs.append(pdb)

            gen_mols = [i for i in pdb_spec_df["mol_pred"] if "Errno" not in str(i)]
            num_gen_mols += len(gen_mols)
            
        if len(successful_pdbs) == 0:
            perc_gen = num_gen_mols
        else:
            perc_gen = num_gen_mols/(num_mols_to_sample*len(successful_pdbs))
        task_rows.append([id_, ",".join(successful_pdbs), ",".join(failed_pdbs), num_gen_mols, perc_gen])

    return pd.DataFrame(task_rows, columns=["ID", "Successful PDBs", "Failed PDBs", "Number of generated molecules", "Average percentage of generated samples per PDB"])
    
def perc_passed_posebusters_moses(task_dir):
    """
    Get summary of tasks that passed PoseBusters and MOSES metrics
    
    Parameters
    ----------
        task_dir: str
            Path to output directory containing task assessments for a given method
    
    Returns
    -------
        Per-PDB and per-protein dataFrames for tasks 1, 2, and 3
    """
    
    task_output_files = [
            os.path.join(task_dir, i)
            for i in os.listdir(task_dir)
            if i.endswith(".csv") and "aggregated" in i
        ]
    
    task1_df = pd.DataFrame()
    task2_df = pd.DataFrame()
    task3_df = pd.DataFrame()
    
    task1_proteins = list(BLIND_SET_PDBS_TO_POCKET_IDS.keys())
    task2_proteins = []
    for key, val in TASK2_FUNC_TO_PDBS.items():
        task2_proteins.extend([i.strip().strip("'") for i in val.split(",")])
    task3_proteins = set(ALL_PDBS) - (
        set(task1_proteins).union(set(task2_proteins))
    )
    
    # Open each task file
    for task_output_file in task_output_files:
        nonetask = False
        task_id = os.path.basename(task_output_file).split("aggregated_scores_")[-1].split(".csv")[0]
        
        # Some files may be empty if that protein failed to run
        try:
            task_df = pd.read_csv(task_output_file)
        except:
            nonetask = True
            task_df = pd.DataFrame([[None, None]], columns=["mol_cond", "mol_true"])
            
        task_col_list = task_df.columns.tolist()
        
        task_df["task_id"] = task_id
        
        if not nonetask:
            # Get base names of mol_cond and mol_true
            task_df["mol_cond"] = task_df["mol_cond"].apply(lambda x: os.path.basename(x))
            task_df["mol_true"] = task_df["mol_true"].apply(lambda x: os.path.basename(x))
        
        if "task1" in os.path.basename(task_output_file):
            task1_df = pd.concat([task1_df, task_df])
        elif "task2" in os.path.basename(task_output_file):
            task2_df = pd.concat([task2_df, task_df])
        elif "task3" in os.path.basename(task_output_file):
            task3_df = pd.concat([task3_df, task_df])
            
    task1_summ_df = pd.DataFrame()
    task2_summ_df = pd.DataFrame()
    task3_summ_df = pd.DataFrame()
    
    if len(task1_df) > 0:
        task1_summ_df = task1_df.groupby("task_id").mean(numeric_only=True).reset_index()
    if len(task2_df) > 0:
        task2_summ_df = task2_df.groupby("task_id").mean(numeric_only=True).reset_index()
    if len(task3_df) > 0:
        task3_summ_df = task3_df.groupby("task_id").mean(numeric_only=True).reset_index()
        
    return task1_df, task1_summ_df, task2_df, task2_summ_df, task3_df, task3_summ_df
    

In [None]:
inference_out_dirs = {
    "1": "<output_dir_to_task1>",
    "2": "<output_dir_to_task2>",
    "3": "<output_dir_to_task3>"
}

inp_task_files = {
    "1": "<input_files_to_csv_task1>",
    "2": "<input_files_to_csv_task2>",
    "3": "<input_files_to_csv_task3>"
}

model_name = "<model_name>" # Pocket2Mol, DiffSBDD, AutoGrow4, LigBuilderV3

task_dir="<dir_of_summary_aggregated_and_spec_scores>"

task1_output_df, task2_output_df, task3_output_df = get_output_dfs_per_model(inference_out_dirs, inp_task_files, model_name)

task1_summ_table = get_task_summ_tables(task_dict=BLIND_SET_POCKET_IDS, task_output_df=task1_output_df)
task2_summ_table = get_task_summ_tables(task_dict=TASK2_FUNC_TO_PDBS, task_output_df=task2_output_df)
task3_summ_table = get_task_summ_tables(task_dict=TASK3_FUNC_TO_PDBS, task_output_df=task3_output_df)

task1_summ_table["Task"] = "1"
task2_summ_table["Task"] = "2"
task3_summ_table["Task"] = "3"

task_table = pd.concat([task1_summ_table, task2_summ_table, task3_summ_table])

task_table_grouped = task_table.groupby("Task").mean(numeric_only=True).reset_index()
task_table_grouped.rename(columns={'Number of generated molecules': 'Average number of generated molecules'}, inplace=True)

task1_df, task1_summ_df, task2_df, task2_summ_df, task3_df, task3_summ_df = perc_passed_posebusters_moses(task_dir=task_dir)

task1_summ_df["Task"] = "1"
task2_summ_df["Task"] = "2"
task3_summ_df["Task"] = "3"

task_table_moses_pb = pd.concat([task1_summ_df, task2_summ_df, task3_summ_df])

task_table_moses_pb_grouped = task_table_moses_pb.groupby("Task").mean(numeric_only=True).reset_index()

all_cols = task_table_moses_pb_grouped.columns.tolist()
moses_cols = []
pb_cols = []

for col in all_cols:
    if "Task" in col:
        moses_cols.append(col)
        pb_cols.append(col)
    elif "moses" in col:
        moses_cols.append(col)
    else:
        pb_cols.append(col)

In [None]:
task_moses = task_table_moses_pb_grouped[moses_cols]
task_pb = task_table_moses_pb_grouped[pb_cols]


## STEP 2: FILTER AND STATS OF FILTER

In [None]:
pains = pd.read_csv('wehi_pains.csv',
                     names=['smarts', 'names'])

def molgen_filters(cmpd_df, return_original=False):
    """
    Uses molgen filters (without hard and soft SMARTS) to get reasonable compounds
    
    Parameters
    ----------
        cmpd_df: Pandas DataFrame
            Contains a mol_pred column of generated sdf files
        return_original: Boolean
            Optional argument to return the original data without filtering, otherwise filters
    
    Returns
    -------
        cleaned_cmpd_df: Pandas DataFrame with an extra column on passed filters
    """
    
    # SMARTS patterns
    allene_smarts = Chem.MolFromSmarts("[$([C;R]=[C]=[C]),$([C]=[C;R]=[C])]=[C]=[C]")
    
    # Fused rings patterns from Advanced Drug Delivery Reviews, 54, 2002, 255-271
    systems_with_4_fused_rings_1 = Chem.MolFromSmarts("*~&@*(~&@*~&@*(~&@*)~&@*~&@*)~&@*~&@*(~&@*)~&@*~&@*")
    systems_with_4_fused_rings_2 = Chem.MolFromSmarts("*(~&@*)~&@*(~&@*(~&@*~&@*)~&@*(~&@*)~&@*(~&@*~&@*)~&@*(~&@*~&@*)~&@*)~&@*~&@*")
    systems_with_4_fused_rings_3 = Chem.MolFromSmarts("*(~&@*)(~&@*(~&@*(~&@*(~&@*)~&@*)~&@*(~&@*)~&@*)~&@*(-&@*)~&@*)~&@*")
    
    molgen_passed_failed = []
    smiles = []
    num_aromatic_rings = []
    
    for idx, row in cmpd_df.iterrows():
        sdf_file = row["mol_pred"]
        
        molgen_filter = "Fail"
        
        try:
            # This will be an empty list if RDKit cannot convert it to a mol object
            mol = [mol for mol in Chem.SDMolSupplier(sdf_file) if mol != None]

            if len(mol) != 0:
                mol = mol[0]
                # Check mol gen filters
                systems_with_4_fused_rings_1_hit = mol.GetSubstructMatches(systems_with_4_fused_rings_1)
                systems_with_4_fused_rings_2_hit = mol.GetSubstructMatches(systems_with_4_fused_rings_2)
                systems_with_4_fused_rings_3_hit = mol.GetSubstructMatches(systems_with_4_fused_rings_3)
                allene_hit = mol.GetSubstructMatches(allene_smarts)

                smiles.append(Chem.MolToSmiles(mol))
                num_aromatic_rings.append(Chem.rdMolDescriptors.CalcNumAromaticRings(mol))

                if (len(allene_hit) == 0) & (len(systems_with_4_fused_rings_1_hit) == 0) & (len(systems_with_4_fused_rings_2_hit) == 0) & (len(systems_with_4_fused_rings_3_hit) == 0):
                    molgen_filter = "Pass"
                
            else:
                smiles.append(None)
                num_aromatic_rings.append(None)
        except:
            smiles.append(None)
            num_aromatic_rings.append(None)
                
        molgen_passed_failed.append(molgen_filter)
        
    cleaned_cmpd_df = cmpd_df.copy()
    cleaned_cmpd_df["Molgen_extra_check"] = molgen_passed_failed
    cleaned_cmpd_df["SMILES"] = smiles
    cleaned_cmpd_df["num_aromatic_rings"] = num_aromatic_rings
    
    if return_original:
        return cleaned_cmpd_df
    else:
        return cleaned_cmpd_df[(cleaned_cmpd_df["moses_filters"] == True) & (cleaned_cmpd_df["Molgen_extra_check"] == "Pass")]

MOSES_MCFs = {
            1:"[#6]=&!@[#6]-[#6]#[#7]",
            2:"[#6]=&!@[#6]-[#16](=[#8])=[#8]",
            3:"[#6]=&!@[#6&!H0]-&!@[#6](=[#8])-&!@[#7]",
            4:"[H]C([H])([#6])[F,Cl,Br,I]",
            5:"[#6]1-[#8]-[#6]-1",
            6:"[#6]-[#7]=[#6]=[#8]",
            7:"[#6&!H0]=[#8]",
            8:"[#6](=&!@[#7&!H0])-&!@[#6,#7,#8,#16]",
            9:"[#6]1-[#7]-[#6]-1",
            10:"[#6]~&!@[#7]~&!@[#7]~&!@[#6]",
            11:"[#7]=&!@[#7]",
            12:"[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#8]-1",
            13:"[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#16]-1",
            14:"[#17,#35,#53]-c(:*):[!#1!#6]:*",
            15:"[H][#7]([H])-[#6]-1=[#6]-[#6]=[#6]-[#6]=[#6]-1",
            16:"[#16]~[#16]",
            17:"[#7]~&!@[#7]~&!@[#7]", # Azide filter
            18:"[#7]-&!@[#6&!H0&!H1]-&!@[#7]",
            19:"[#6&!H0](-&!@[#8])-&!@[#8]",
            20:"[#35].[#35].[#35]",
            21:"[#17].[#17].[#17].[#17]",
            22:"[#9].[#9].[#9].[#9].[#9].[#9].[#9]"
             }

def get_MOSES_MCF_flags(cmpd_df):
    
    cleaned_cmpd_df = cmpd_df.copy()
    
    MOSES_MCFs_flags = {i: [] for i in range(1,23)}
    
    # 23 MOSES filters
    for i in range(1,23):
        MOSES_SMARTS = Chem.MolFromSmarts(MOSES_MCFs[i])
        
        MOSES_flag_list = MOSES_MCFs_flags[i]
    
        for idx, row in cmpd_df.iterrows():
            sdf_file = row["mol_pred"]

            molgen_filter = "Fail"

            try:
                # This will be an empty list if RDKit cannot convert it to a mol object
                mol = [mol for mol in Chem.SDMolSupplier(sdf_file) if mol != None]

                if len(mol) != 0:
                    mol = mol[0]
                    # Check mol gen filters
                    hit = mol.GetSubstructMatches(MOSES_SMARTS)

                    if (len(hit) == 0):
                        MOSES_flag_list.append(True)
                    else:
                        MOSES_flag_list.append(False)
                else:
                    MOSES_flag_list.append(False)
            except:
                MOSES_flag_list.append(False)
                    
        cleaned_cmpd_df["MOSES_filter_num_"+str(i)] = MOSES_flag_list
    
    return cleaned_cmpd_df

def get_MOSES_PAINS_flags(cmpd_df):
    cleaned_cmpd_df = cmpd_df.copy()
    
    MOSES_PAINS_flags = {i: [] for i in range(len(pains))}
    
    # 23 MOSES filters
    for i, row in pains.iterrows():
        MOSES_SMARTS = Chem.MolFromSmarts(row["smarts"])
        
        MOSES_flag_list = MOSES_PAINS_flags[i]
    
        for idx, row in cmpd_df.iterrows():
            sdf_file = row["mol_pred"]

            molgen_filter = "Fail"

            try:
                # This will be an empty list if RDKit cannot convert it to a mol object
                mol = [mol for mol in Chem.SDMolSupplier(sdf_file) if mol != None]

                if len(mol) != 0:
                    mol = mol[0]
                    # Check mol gen filters
                    N3_hit = mol.GetSubstructMatches(MOSES_SMARTS)

                    if (len(N3_hit) == 0):
                        MOSES_flag_list.append(True)
                    else:
                        MOSES_flag_list.append(False)
                else:
                    MOSES_flag_list.append(False)
            except:
                MOSES_flag_list.append(False)
                    
        cleaned_cmpd_df["MOSES_filter_PAINS_num_"+str(i)] = MOSES_flag_list
    
    return cleaned_cmpd_df

def check_posebusters(cmpd_df):
    pb_dict_cols = {'perc_pass_bond_lengths':[],
 'perc_pass_bond_angles':[],
 'perc_pass_internal_steric_clash':[],
 'perc_passed_internal_energy':[],
 'perc_pass_dist_to_prot':[],
 'perc_pass_dist_to_water':[],
 'perc_pass_vol_overlap_with_prot':[]}
    
    cleaned_cmpd_df = cmpd_df.copy()
    
    for idx, row in cmpd_df.iterrows():
        for pb_check in list(pb_dict_cols.keys()):
            if row[pb_check] > 90:
                pb_dict_cols[pb_check] = True
            else:
                pb_dict_cols[pb_check] = False
                
    return pdb_dict_cols

def check_charge_flag(cmpd_df):
    cleaned_cmpd_df = cmpd_df.copy()
    MOSES_charge_flag = []

    for idx, row in cmpd_df.iterrows():
        sdf_file = row["mol_pred"]

        try:
            # This will be an empty list if RDKit cannot convert it to a mol object
            mol = [mol for mol in Chem.SDMolSupplier(sdf_file) if mol != None]

            if len(mol) != 0:
                mol = mol[0]
                h_mol = Chem.AddHs(mol)
                if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()):
                    MOSES_charge_flag.append(False)
                else:
                    MOSES_charge_flag.append(True)
            else:
                MOSES_charge_flag.append(False)
        except:
            MOSES_charge_flag.append(False)

    cleaned_cmpd_df["MOSES_filter_charge"] = MOSES_charge_flag
    
    return cleaned_cmpd_df

def filter_cmpds_get_stats(cmpd_df, run_moses_breakdown=False, return_stats=False):
    """
    Runs MolGen filters and performs additional checks on which MOSES filters are failing
    """
    
    df_molgen = molgen_filters(cmpd_df, return_original=True)
    
    # Print number and percentage of compounds passing PoseBusters checks
    perc_pass_bond_lengths = round((df_molgen['bond_lengths'].tolist().count(True))/len(df_molgen['bond_lengths'])*100, 2)
    perc_pass_bond_angles = round((df_molgen['bond_angles'].tolist().count(True))/len(df_molgen['bond_angles'])*100, 2)
    perc_pass_internal_steric_clash = round((df_molgen['internal_steric_clash'].tolist().count(True))/len(df_molgen['internal_steric_clash'])*100, 2)
    perc_passed_internal_energy = round((df_molgen['internal_energy'].tolist().count(True))/len(df_molgen['internal_energy'])*100, 2)
    perc_pass_dist_to_prot = round((df_molgen['minimum_distance_to_protein'].tolist().count(True))/len(df_molgen['minimum_distance_to_protein'])*100, 2)
    perc_pass_dist_to_water = round((df_molgen['minimum_distance_to_waters'].tolist().count(True))/len(df_molgen['minimum_distance_to_waters'])*100, 2)
    perc_pass_vol_overlap_with_prot = round((df_molgen['volume_overlap_with_protein'].tolist().count(True))/len(df_molgen['volume_overlap_with_protein'])*100, 2)
    
    print("Number and percentage of compounds passing bond length PoseBusters check:", str((df_molgen['bond_lengths'].tolist().count(True))) + " " + str(perc_pass_bond_lengths) + "%")
    print("Number and percentage of compounds passing bond angle PoseBusters check:", str((df_molgen['bond_angles'].tolist().count(True))) + " " + str(perc_pass_bond_angles) + "%")
    print("Number and percentage of compounds passing internal steric clash PoseBusters check:", str((df_molgen['internal_steric_clash'].tolist().count(True))) + " " + str(perc_pass_internal_steric_clash) + "%")
    print("Number and percentage of compounds passing internal energy PoseBusters check:", str((df_molgen['internal_energy'].tolist().count(True))) + " " + str(perc_passed_internal_energy) + "%")
    print("Number and percentage of compounds passing distance to protein PoseBusters check:", str((df_molgen['minimum_distance_to_protein'].tolist().count(True))) + " " + str(perc_pass_dist_to_prot) + "%")
    print("Number and percentage of compounds passing distance to water PoseBusters check:", str((df_molgen['minimum_distance_to_waters'].tolist().count(True))) + " " + str(perc_pass_dist_to_water) + "%")
    print("Number and percentage of compounds passing overlap with protein PoseBusters check:", str((df_molgen['volume_overlap_with_protein'].tolist().count(True))) + " " + str(perc_pass_vol_overlap_with_prot) + "%")
    
    # Print number and percentage of compounds passing MolGen checks (allene and fused ring checks)
    df_molgen_extra_pass = df_molgen[df_molgen["Molgen_extra_check"] == "Pass"]
    perc_molgen_extra_pass = round((len(df_molgen_extra_pass)/len(df_molgen))*100, 2)
    print("Number and percentage of compounds passing allene and fused ring checks:", str(len(df_molgen_extra_pass)) + " " + str(perc_molgen_extra_pass) + "%")
    
    # Print number and percentage of compounds passing MOSES checks
    df_molgen_MOSES_pass = df_molgen[df_molgen["moses_filters"] == True]
    perc_molgen_MOSES_pass = round((len(df_molgen_MOSES_pass)/len(df_molgen))*100, 2)
    print("Number and percentage of compounds passing MOSES checks:", str(len(df_molgen_MOSES_pass)) + " " + str(perc_molgen_MOSES_pass) + "%")

    perc_MOSES_MCF_pass = None
    perc_MOSES_PAINS_pass = None
    perc_MOSES_charge_pass = None
    
    if run_moses_breakdown or perc_molgen_MOSES_pass < 50:
        # Check MCF filters 
        df_molgen_MOSES_MCF = get_MOSES_MCF_flags(cmpd_df)
        MOSES_MCF_cols_df = df_molgen_MOSES_MCF.filter(regex="MOSES_filter_num_")
        MOSES_MCF_cols_df_passed = MOSES_MCF_cols_df.all(axis=1)
        perc_MOSES_MCF_pass = round((MOSES_MCF_cols_df_passed.tolist().count(True)/len(MOSES_MCF_cols_df_passed))*100, 2)
        print("Number and percentage of compounds passing MOSES MCF checks:", str(MOSES_MCF_cols_df_passed.tolist().count(True)) + " " + str(perc_MOSES_MCF_pass) + "%")
        
        # Check PAINS filters
        df_molgen_MOSES_PAINS = get_MOSES_PAINS_flags(cmpd_df)
        MOSES_PAINS_cols_df = df_molgen_MOSES_PAINS.filter(regex="MOSES_filter_PAINS_num_")
        MOSES_PAINS_cols_df_passed = MOSES_PAINS_cols_df.all(axis=1)
        perc_MOSES_PAINS_pass = round((MOSES_PAINS_cols_df_passed.tolist().count(True)/len(MOSES_PAINS_cols_df_passed))*100, 2)
        print("Number and percentage of compounds passing MOSES PAINS checks:", str(MOSES_PAINS_cols_df_passed.tolist().count(True)) + " " + str(perc_MOSES_PAINS_pass) + "%")
        
        # Check charges
        df_molgen_MOSES_charges = check_charge_flag(cmpd_df)
        MOSES_charges_pass = df_molgen_MOSES_charges[df_molgen_MOSES_charges["MOSES_filter_charge"] == True]
        perc_MOSES_charge_pass = round((len(MOSES_charges_pass)/len(df_molgen_MOSES_charges))*100, 2)
        print("Number and percentage of compounds passing MOSES charge check:", str(len(MOSES_charges_pass)) + " " + str(perc_MOSES_charge_pass) + "%")
    
    if not return_stats:
        return df_molgen[(df_molgen["moses_filters"] == True) & (df_molgen["Molgen_extra_check"] == "Pass")]
    else:
        return df_molgen[(df_molgen["moses_filters"] == True) & (df_molgen["Molgen_extra_check"] == "Pass")], perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot


In [None]:
# EXAMPLE
df_test_task2_ITK_selective = pd.read_csv("<dir_rof_summary_aggregated_and_spec_scores_for_a_method>/task2_spec_scores_ITK_selectivity.csv")
df_test_task2_ITK_selective_clean = filter_cmpds_get_stats(df_test_task2_ITK_selective)


## STEP 3: TASK ANALYSIS

## TASK 1: STATS AND PLOTS

In [None]:
def not_na(inp_list):
    return [i for i in inp_list if i is not None]

class task1StatsPlots():
    def __init__(self, task_outputs_dir, output_plots_dir=None):
        self.output_dir = task_outputs_dir
        self.output_plots_dir = output_plots_dir
        
        if not os.path.exists(self.output_plots_dir):
            os.makedirs(self.output_plots_dir)
    
    def prot1(self, df):
        print("Running task 1 protein 1 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df["Input structure"] = df["mol_cond"].apply(lambda x: "Mutated" if "W170S" in x else "Wild type")
        df["WT_diff"] = df["gen_min_dist_to_WT_res"] - df["true_min_dist_to_WT_res"]
        df["mut_diff"] = df["gen_min_dist_to_mut_res"] - df["true_min_dist_to_mut_res"]

        df_wt_inp = df[df["Input structure"] == "Wild type"]
        df_mut_inp = df[df["Input structure"] == "Mutated"]

        test_wt_inp = scipy.stats.ttest_ind(df_wt_inp[df_wt_inp["WT_diff"].notna()]["WT_diff"], df_wt_inp[df_wt_inp["mut_diff"].notna()]["mut_diff"], equal_var=False, nan_policy='omit')
        test_mut_inp = scipy.stats.ttest_ind(df_mut_inp[df_mut_inp["WT_diff"].notna()]["WT_diff"], df_mut_inp[df_mut_inp["mut_diff"].notna()]["mut_diff"], equal_var=False, nan_policy='omit')

        passed_wt_inp_check = False
        if test_wt_inp.pvalue < 0.01:
            passed_wt_inp_check = True
        else:
            comments += "Step1/3: Did not pass WT distance t-test;"

        passed_mut_inp_check = False
        if test_mut_inp.pvalue < 0.01:
            passed_mut_inp_check = True
        else:
            comments += "Step2/3: Did not pass mut distance t-test;"

        passed = False
        if len(df_wt_inp) > 10 and len(df_mut_inp) > 10:
            if passed_wt_inp_check == True and passed_mut_inp_check == True:
                passed = True
        else:
            comments += "Step3/3: Did not have sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df[df["Input structure"] == "Wild type"], x="gen_min_dist_to_mut_res", label="Distance to mutated Ser")
        sns.kdeplot(df[df["Input structure"] == "Wild type"], x="gen_min_dist_to_WT_res", label="Distance to wild-type Trp")
        plt.axvline(4.425715, color='r', label="True minimum distance")
        plt.title("Given wild-type structure as input")
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot1_given_WT.png"))
        
        fig = plt.figure()
        sns.kdeplot(df[df["Input structure"] == "Mutated"], x="gen_min_dist_to_mut_res", label="Distance to mutated Ser")
        sns.kdeplot(df[df["Input structure"] == "Mutated"], x="gen_min_dist_to_WT_res", label="Distance to wild-type Trp")
        plt.axvline(5.784789, color='r', label="True minimum distance")
        plt.title("Given mutated structure as input")
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        plt.savefig(os.path.join(self.output_plots_dir, "task1_prot1_given_mut.png"))
    
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot2(self, df):
        print("Running task 1 protein 2 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df_1w6k_inp = df[df["PDB ID"] == "1w6k"]
        df_1w6j_inp = df[df["PDB ID"] == "1w6j"]

        # T-tests to compare whether distances are NOT significant between generated molecule distance differences for each residue [this assesses whether the input structure has an effect on generated interactions]
        task1_2_t_test_gen_455 = scipy.stats.ttest_ind(df_1w6k_inp["gen_min_dist_to_1w6k_res_455"], df_1w6j_inp["gen_min_dist_to_1w6j_res_455"], equal_var=False, nan_policy='omit')
        task1_2_t_test_gen_237 = scipy.stats.ttest_ind(df_1w6k_inp["gen_min_dist_to_1w6k_res_237"], df_1w6j_inp["gen_min_dist_to_1w6j_res_237"], equal_var=False, nan_policy='omit')
        task1_2_t_test_gen_233 = scipy.stats.ttest_ind(df_1w6k_inp["gen_min_dist_to_1w6k_res_233"], df_1w6j_inp["gen_min_dist_to_1w6j_res_233"], equal_var=False, nan_policy='omit')
        task1_2_t_test_gen_524 = scipy.stats.ttest_ind(df_1w6k_inp["gen_min_dist_to_1w6k_res_524"], df_1w6j_inp["gen_min_dist_to_1w6j_res_524"], equal_var=False, nan_policy='omit')

        passed = False
        # If difference between generated and crystal structures is NOT significant, passes check as we want these to be different for residues 237, 233, and 524
        # Initial check for residue 455 which should technically be the same
        if len(df_1w6k_inp) > 10 and len(df_1w6j_inp) > 10:
            if task1_2_t_test_gen_455.pvalue >= 0.01:
                if task1_2_t_test_gen_237.pvalue >= 0.01:
                    if task1_2_t_test_gen_233.pvalue >= 0.01:
                        if task1_2_t_test_gen_524.pvalue >= 0.01:
                            passed = True
                        else:
                            comments += "Step5/5: Did not pass residue 524 t-test distance check;"
                    else:
                        comments += "Step4/5: Did not pass residue 233 t-test distance check;"
                else:
                    comments += "Step3/5: Did not pass residue 237 t-test distance check;"
            else:
                comments += "Step2/5: Did not pass residue 455 t-test distance check;"
        else:
            comments += "Step1/5: Did not have sufficient data points;"
                
        # Plots
        # Given 1w6k
        fig = plt.figure()
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6k_res_237", label="Residue 237 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6j_res_237", label="Residue 237 1w6j", linestyle="--")
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6k_res_233", label="Residue 233 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6j_res_233", label="Residue 233 1w6j", linestyle="--")
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6k_res_524", label="Residue 524 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6k"], x="gen_min_dist_to_1w6j_res_524", label="Residue 524 1w6j", linestyle="--")
        plt.axvline(df["true_min_dist_to_1w6k_res_237"].iloc[0], color='r', label="True distance 237")
        plt.axvline(df["true_min_dist_to_1w6k_res_233"].iloc[0], color='b', label="True distance 233")
        plt.axvline(df["true_min_dist_to_1w6k_res_524"].iloc[0], color='g', label="True distance 524")
        plt.title("Given 1w6k as input")
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot2_given_1w6k.png"))
        
        # 1w6j
        fig = plt.figure()
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6k_res_237", label="Residue 237 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6j_res_237", label="Residue 237 1w6j", linestyle="--")
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6k_res_233", label="Residue 233 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6j_res_233", label="Residue 233 1w6j", linestyle="--")
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6k_res_524", label="Residue 524 1w6k")
        sns.kdeplot(df[df["PDB ID"] == "1w6j"], x="gen_min_dist_to_1w6j_res_524", label="Residue 524 1w6j", linestyle="--")
        plt.axvline(df["true_min_dist_to_1w6j_res_237"].iloc[0], color='r', label="True distance 237")
        plt.axvline(df["true_min_dist_to_1w6j_res_233"].iloc[0], color='r', label="True distance 237")
        plt.axvline(df["true_min_dist_to_1w6j_res_524"].iloc[0], color='r', label="True distance 237")
        plt.title("Given 1w6j as input")
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot2_given_1w6j.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
        
    def prot3(self, df):
        print("Running task 1 protein 3 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df_5cqh_inp = df[(df["PDB ID"] == "5cqh") & (df["gen_true_sim_5cqh"].notna())]
        df_5cqi_inp = df[(df["PDB ID"] == "5cqi") & (df["gen_true_sim_5cqh"].notna())]
        t_test_gen = scipy.stats.ttest_ind(df_5cqh_inp["gen_true_sim_5cqh"], df_5cqi_inp["gen_true_sim_5cqh"], equal_var=False, nan_policy='omit')

        # Check percentage of interactions re-created for hydrophobic, hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            
        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
        
        passed = False
        # If the similarity between crystal and generated ligands is NOT significant given different protein structures
        if len(df_5cqh_inp) > 10 and len(df_5cqi_inp) > 10:
            if t_test_gen.pvalue >= 0.01:
                # And most interactions are re-created
                if statistics.mean(hydrophobic_perc) > 0.5:
                    if statistics.mean(hbond_perc) > 0.5:
                        passed = True
                    else:
                        comments += "Step4/4: Did not pass hbond re-creation check;"
                else:
                    comments += "Step3/4: Did not pass hydrophobic re-creation check;"
            else:
                comments += "Step2/4: Did not pass similarity to crystal check;"
        else:
            comments += "Step1/4: Did not have sufficient data points;"
        
        fig = plt.figure()
        sns.kdeplot(df, x="gen_true_sim_5cqh", hue="PDB ID")
        plt.title("Given apo 5cqi and holo 5cqh")
        plt.xlabel("Similarity to crystal ligand")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot3_sims.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot4(self, df):
        print("Running task 1 protein 4 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df["Input structure"] = df["mol_cond"].apply(lambda x: "with Fe" if "Fe" in x else "without Fe")
        df = df[df["gen_min_dist_to_Fe"].notna()]

        # These should be statistically significantly different
        t_test_gen = scipy.stats.ttest_ind(df[df["Input structure"] == "with Fe"]["gen_min_dist_to_Fe"], df[df["Input structure"] == "without Fe"]["gen_min_dist_to_Fe"], equal_var=False, nan_policy='omit')

        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df[df["Input structure"] == "with Fe"].iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
    
        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
    
        passed = False
        if len(df) > 10:
            if t_test_gen.pvalue < 0.01:
                if statistics.mean(hydrophobic_perc) > 0.5:
                    if statistics.mean(hbond_perc) > 0.5:
                        passed = True
                    else:
                        comments += "Step4/4: Did not pass hbond re-creation check;"
                else:
                    comments += "Step3/4: Did not pass hydrophobic re-creation check;"
            else:
                comments += "Step2/4: Did not pass t-test with Fe and without Fe;"
        else:
            comments += "Step1/4: Did not have sufficient data points;"

        fig = plt.figure()
        sns.kdeplot(df, hue="Input structure", x="gen_min_dist_to_Fe")
        plt.axvline(4.391704, color='r', label="True minimum distance")
        plt.xlabel("Minimum distance to Fe")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot4_Fe_distance.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot5(self, df):
        print("Running task 1 protein 5 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])

        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
        
        passed = False
        if len(df) > 10:
            if statistics.mean(hydrophobic_perc) > 0.5:
                if statistics.mean(hbond_perc) > 0.5:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass hbond re-creation check;"
            else:
                comments += "Step2/3: Did not pass hydrophobic re-creation check;"
        else:
            comments += "Step1/3: Did not have sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, hue="PDB ID", x="gen_true_sim_3hig")
        plt.title("Compared to 3hig crystal ligand")
        plt.xlabel("Similarity")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot5_3hig_sim.png"))
        
        fig = plt.figure()
        sns.kdeplot(df, hue="PDB ID", x="gen_true_sim_3hii")
        plt.title("Compared to 3hii crystal ligand")
        plt.xlabel("Similarity")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot5_3hig_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
        
    def prot6(self, df):
        print("Running task 1 protein 6 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df["Input structure"] = df["mol_cond"].apply(lambda x: "6ryo" if "6ryo" in x else "6ryp")
        
        df_6ryo_inp = df[df["Input structure"] == "6ryo"]
        df_6ryp_inp = df[df["Input structure"] == "6ryp"]

        # T-tests to compare whether distances are NOT significant between generated molecule distance [this assesses whether the input structure has an effect on generated interactions]
        # Basically, it should find the same binding pocket regardless of input structure
        t_test_gen_6ryo = scipy.stats.ttest_ind(df_6ryo_inp[df_6ryo_inp["6ryo_gen_mol_min_dist"].notna()]["6ryo_gen_mol_min_dist"], df_6ryo_inp[df_6ryo_inp["6ryp_gen_mol_min_dist"].notna()]["6ryp_gen_mol_min_dist"], equal_var=False, nan_policy='omit')
        t_test_gen_6ryp = scipy.stats.ttest_ind(df_6ryp_inp[df_6ryp_inp["6ryo_gen_mol_min_dist"].notna()]["6ryo_gen_mol_min_dist"], df_6ryp_inp[df_6ryp_inp["6ryp_gen_mol_min_dist"].notna()]["6ryp_gen_mol_min_dist"], equal_var=False, nan_policy='omit')

        passed = False
        if len(df_6ryo_inp) > 10 and len(df_6ryp_inp) > 10:
            if t_test_gen_6ryo.pvalue >= 0.01:
                if t_test_gen_6ryp.pvalue >= 0.01:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass 6ryp distance t-test;"
            else:
                comments += "Step2/3: Did not pass 6ryo distance t-test;"
        else:
            comments += "Step1/3: Did not have sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df[df["Input structure"] == "6ryo"], x="6ryo_gen_mol_min_dist", label="Distance to 6ryo")
        sns.kdeplot(df[df["Input structure"] == "6ryo"], x="6ryp_gen_mol_min_dist", label="Distance to 6ryp")
        plt.title("Given 6ryo input structure")
        plt.xlabel("Minimum distance of generated molecule to crystal ligand COM")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot6_6ryo_dist.png"))
        
        fig = plt.figure()
        sns.kdeplot(df[df["Input structure"] == "6ryp"], x="6ryp_gen_mol_min_dist", label="Distance to 6ryp")
        sns.kdeplot(df[df["Input structure"] == "6ryp"], x="6ryo_gen_mol_min_dist", label="Distance to 6ryo")
        plt.title("Given 6ryp input structure")
        plt.xlabel("Minimum distance of generated molecule to crystal ligand COM")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot6_6ryp_dist.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot7(self, df):
        print("Running task 1 protein 7 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hbonds
        hbond_perc = []

        for idx, row in df.iterrows():
            if str(row["PLIP"]) != "nan":
                interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
                hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            else:
                hbond_perc.append(None)
                
        hbond_perc = not_na(hbond_perc)

        df_4l9i = df[df["PDB ID"] == "4l9i"]
        df_4mk0 = df[df["PDB ID"] == "4mk0"]

        # Check whether compounds don't show clashes
        # Percentage of structures with clashes
        perc_of_clashes_4l9i = df_4l9i[df_4l9i['number_clashes'].notna()]['number_clashes'].mean()
        perc_of_clashes_4mk0 = df_4mk0[df_4mk0['number_clashes'].notna()]['number_clashes'].mean()
        passed_clash_metric = perc_of_clashes_4l9i < 0.5 and perc_of_clashes_4l9i < 0.5
        
        passed = False
        if len(df_4l9i) > 10 and len(df_4mk0) > 10:
            if statistics.mean(hbond_perc) > 0.5:
                if passed_clash_metric:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass clash metrics;"
            else:
                comments += "Step2/3: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/3: Did not have sufficient data points;"

        # Plot
        fig = plt.figure()
        sns.kdeplot(df, x="sim_gen_4l9i", hue="PDB ID")
        plt.xlim((0, 1))
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot7_4l9i_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot8(self, df):
        print("Running task 1 protein 8 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])

        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
        
        passed = False
        if len(df) > 10:
            if statistics.mean(hbond_perc) > 0.5:
                if statistics.mean(hydrophobic_perc) > 0.5:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass hydrophobic re-creation check;"
            else:
                comments += "Step2/3: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
            
        # Plot
        fig = plt.figure()
        sns.kdeplot(df, x="sim_gen_4q6r")
        plt.xlim((0, 1))
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot8_4q6r_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot9(self, df):
        print("Running task 1 protein 9 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check if hbond and hydrophobic interactions re-created
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna())]
        
        df_1pl6 = df[df["PDB ID"] == "1pl6"]
        df_1pl6_h2o = df_1pl6[df_1pl6["mol_cond"].str.contains("water")]
        df_1pl6_NAD = df_1pl6[df_1pl6["mol_cond"].str.contains("NAD")]
        df_1pl6_Zn = df_1pl6[df_1pl6["mol_cond"].str.contains("Zn")]
        df_1pl6_no_h2o = df_1pl6[~df_1pl6["mol_cond"].str.contains("water")]
        df_1pl6_no_NAD = df_1pl6[~df_1pl6["mol_cond"].str.contains("NAD")]
        df_1pl6_no_Zn = df_1pl6[~df_1pl6["mol_cond"].str.contains("Zn")]
        
        df_1pl7 = df[df["PDB ID"] == "1pl7"]
        df_1pl7_h2o = df_1pl7[df_1pl7["mol_cond"].str.contains("water")]
        df_1pl7_NAD = df_1pl7[df_1pl7["mol_cond"].str.contains("NAD")]
        df_1pl7_Zn = df_1pl7[df_1pl7["mol_cond"].str.contains("Zn")]
        df_1pl7_no_h2o = df_1pl7[~df_1pl7["mol_cond"].str.contains("water")]
        df_1pl7_no_NAD = df_1pl7[~df_1pl7["mol_cond"].str.contains("NAD")]
        df_1pl7_no_Zn = df_1pl7[~df_1pl7["mol_cond"].str.contains("Zn")]
        
        # Check for 1pl6
        passed_1pl6 = False
        if len(df_1pl6_h2o) > 10 and len(df_1pl6_NAD) > 10 and len(df_1pl6_Zn) > 10:
            if statistics.mean(df_1pl6_h2o["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_h2o["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_h2o["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_h2o["hydrophobic_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_1pl6_NAD["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_NAD["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_NAD["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_NAD["hydrophobic_frac_recreated"].tolist()) > 0.5:
                    if statistics.mean(df_1pl6_Zn["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_Zn["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_Zn["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl6_no_Zn["hydrophobic_frac_recreated"].tolist()) > 0.5:
                        passed_1pl6 = True
                    else:
                        comments += "Step4/4: Did not pass 1pl6 Zn and no Zn hbond and hydrophobic interactions check;"
                else:
                    comments += "Step3/4: Did not pass 1pl6 NAD and no NAD hbond and hydrophobic interactions check;"
            else:
                comments += "Step2/4: Did not pass 1pl6 H2O and no H2O hbond and hydrophobic interactions check;"
        else:
            comments += "Step1/4: Did not contain sufficient 1pl6 data;"
            
        # Check for 1pl7 (H-bond only, easier task for apo protein)
        passed_1pl7 = False
        if len(df_1pl7_h2o) > 10 and len(df_1pl7_NAD) > 10 and len(df_1pl7_Zn) > 10:
            if statistics.mean(df_1pl7_h2o["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl7_no_h2o["hbond_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_1pl7_NAD["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl7_no_NAD["hbond_frac_recreated"].tolist()) > 0.5:
                    if statistics.mean(df_1pl7_Zn["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_1pl7_no_Zn["hbond_frac_recreated"].tolist()) > 0.5:
                        passed_1pl7 = True
                    else:
                        comments += "Did not pass 1pl7 Zn and no Zn hbond interactions check;"
                else:
                    comments += "Did not pass 1pl7 NAD and no NAD hbond interactions check;"
            else:
                comments += "Did not pass 1pl7 H2O and no H2O hbond interactions check;"
        else:
            comments += "Did not contain sufficient 1pl7 data;"
            
        passed = False
        if passed_1pl6 and passed_1pl7:
            passed = True
        
        # Plots of distances of generated ligands to cofactors (NAD, Zn, H2O)
        fig = plt.figure()
        sns.kdeplot(df_1pl6, x="gen_min_dist_to_H2O_1pl6", label="1pl6")
        sns.kdeplot(df_1pl7, x="gen_min_dist_to_H2O_1pl7", label="1pl7")
        plt.axvline(2.943564, color='r', label="True minimum distance")
        plt.xlabel("Minimum distance of generated molecule to H2O")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot9_1pl6.png"))

        sns.kdeplot(df_1pl6, x="gen_min_dist_to_Zn_1pl6", label="1pl6")
        sns.kdeplot(df_1pl7, x="gen_min_dist_to_Zn_1pl7", label="1pl7")
        plt.axvline(2.087968, color='r', label="True minimum distance")
        plt.xlabel("Minimum distance of generated molecule to Zn")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot9_1pl7.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot10(self, df):
        print("Running task 1 protein 10 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []
        saltbridge_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            saltbridge_perc.append(interaction_dict_literal["saltbridge"]["interaction_frac"])

        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
        saltbridge_perc = not_na(saltbridge_perc)
        
        passed = False
        if len(df[df["sim_gen_6r7d"].notna()]) > 10:
            if statistics.mean(hbond_perc) > 0.5:
                if statistics.mean(hydrophobic_perc) > 0.5:
                    if statistics.mean(saltbridge_perc) > 0.5:
                        passed = True
                    else:
                        comments += "Step4/4: Did not pass saltbridge re-creation check;"
                else:
                    comments += "Step3/4: Did not pass hydrophobic re-creation check;"
            else:
                comments += "Step2/4: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/4: Did not contain sufficient data points;"
            
        # Plots for similarity
        fig = plt.figure()
        sns.kdeplot(df, x="sim_gen_6r7d")
        plt.xlabel("Similarity to crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot10_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot11(self, df):
        print("Running task 1 protein 11 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []
        saltbridge_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            saltbridge_perc.append(interaction_dict_literal["saltbridge"]["interaction_frac"])

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        df["saltbridge_frac_recreated"] = saltbridge_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna()) & (df["saltbridge_frac_recreated"].notna())]
        
        df["Zn"] = df["mol_cond"].apply(lambda x: "With" if "Zn" in x else "Without")
        df_apo = df[(df["PDB ID"] == "1itq") & (df["gen_min_dist_to_Zn"].notna())]
        df_holo = df[(df["PDB ID"] == "1itu") & (df["gen_min_dist_to_Zn"].notna())]
        # These should be statistically significantly different with and without Zn, it should take it into account
        t_test_gen_apo = scipy.stats.ttest_ind(df_apo[df_apo["Zn"] == "With"]["gen_min_dist_to_Zn"], df_apo[df_apo["Zn"] == "Without"]["gen_min_dist_to_Zn"], equal_var=False, nan_policy='omit')
        t_test_gen_holo = scipy.stats.ttest_ind(df_holo[df_holo["Zn"] == "With"]["gen_min_dist_to_Zn"], df_holo[df_holo["Zn"] == "Without"]["gen_min_dist_to_Zn"], equal_var=False, nan_policy='omit')
        
        apo_test_interactions = statistics.mean(df_apo["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo["saltbridge_frac_recreated"].tolist()) > 0.5
        holo_test_interactions = statistics.mean(df_holo["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo["saltbridge_frac_recreated"].tolist()) > 0.5 

        passed = False
        if len(df_apo) > 10 and len(df_holo) > 10:
            if t_test_gen_apo.pvalue < 0.01:
                if t_test_gen_holo.pvalue < 0.01:
                    if apo_test_interactions:
                        if holo_test_interactions:
                            passed = True
                        else:
                            comments += "Step5/5: Did not pass holo hydrophobic, hbond, saltbridge re-creation checks;"
                    else:
                        comments += "Step4/5: Did not pass apo hydrophobic, hbond, saltbridge re-creation checks;"
                else:
                    comments += "Step3/5: Did not pass holo Zn no Zn distance check;"
            else:
                comments += "Step2/5: Did not pass apo Zn no Zn distance check;"
        else:
            comments += "Step1/5: Did not contain sufficient data points;"

        # Plots
        fig = plt.figure()
        sns.kdeplot(df_holo, x="gen_min_dist_to_Zn", hue="Zn")
        plt.xlabel("Minimum distance to Zn")
        plt.title("With holo 1itu")
        plt.axvline(3.431756, color='r', label="True minimum distance")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot11_holo.png"))

        fig = plt.figure()
        sns.kdeplot(df_apo, x="gen_min_dist_to_Zn", hue="Zn")
        plt.xlabel("Minimum distance to Zn")
        plt.title("With apo 1itq")
        plt.axvline(3.431756, color='r', label="True minimum distance")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot11_apo.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot12(self, df):
        print("Running task 1 protein 12 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        df["Input structure"] = df["mol_cond"].apply(lambda x: "with Zn" if "Zn" in x else "without Zn")
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []
        saltbridge_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            saltbridge_perc.append(interaction_dict_literal["saltbridge"]["interaction_frac"])
            
        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        df["saltbridge_frac_recreated"] = saltbridge_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna()) & (df["saltbridge_frac_recreated"].notna())]
        
        df_with_zn = df[(df["Input structure"] == "with Zn") & (df["gen_dist_to_Zn_2v77"].notna())]
        df_without_zn = df[(df["Input structure"] == "without Zn") & (df["gen_dist_to_Zn_2v77"].notna())]
        
        passed_interactions_with_zn = statistics.mean(df_with_zn["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_with_zn["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_with_zn["saltbridge_frac_recreated"].tolist()) > 0.5
        passed_interactions_without_zn = statistics.mean(df_without_zn["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_without_zn["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_without_zn["saltbridge_frac_recreated"].tolist()) > 0.5
        
        # t-test to compare if addition of Zn affects distance to Zn (it should)
        t_test_gen = scipy.stats.ttest_ind(df_with_zn["gen_dist_to_Zn_2v77"], df_without_zn["gen_dist_to_Zn_2v77"], equal_var=False, nan_policy='omit')
        
        passed = False
        if len(df_with_zn) > 10 and len(df_without_zn) > 10:
            if passed_interactions_with_zn:
                if passed_interactions_without_zn:
                    if t_test_gen.pvalue < 0.01:
                        passed = True
                    else:
                        comments += "Step4/4: Did not pass distance to Zn t-test check;"
                else:
                    comments += "Step3/4: Did not pass hydrohobic, hbond, saltbridge re-creations without Zn check;"
            else:
                comments += "Step2/4: Did not pass hydrohobic, hbond, saltbridge re-creations with Zn check;"
        else:
            comments += "Step1/4: Did not contain sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, hue="Input structure", x="gen_dist_to_Zn_2v77")
        plt.xlabel("Distance to Zn")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot12_dist.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot13(self, df):
        print("Running task 1 protein 13 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df = df[(df["sim_gen_2c9b"].notna()) & (df["sim_gen_2c9d"].notna()) & (df["sim_gen_2c92"].notna()) & (df["sim_gen_2c94"].notna()) & (df["sim_gen_2c97"].notna())]
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []
        
        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna())]
        
        df_2c9d = df[df["PDB ID"] == "2c9d"]
        df_2c9b = df[df["PDB ID"] == "2c9b"]
        df_2c92 = df[df["PDB ID"] == "2c92"]
        df_2c94 = df[df["PDB ID"] == "2c94"]
        df_2c97 = df[df["PDB ID"] == "2c97"]
        
        passed = False
        if len(df_2c9d) > 10 and len(df_2c9b) > 10 and len(df_2c92) > 10 and len(df_2c94) > 10 and len(df_2c97) > 10:
            # Plots for similarity
            fig = plt.figure()
            sns.kdeplot(df, x="sim_gen_2c9b", hue="PDB ID")
            plt.xlabel("Similarity to 2c9b")
            plt.xlim((0,1))
            fig.savefig(os.path.join(self.output_plots_dir, "task1_prot13_sim_2c9b.png"))

            fig = plt.figure()
            sns.kdeplot(df, x="sim_gen_2c9d", hue="PDB ID")
            plt.xlabel("Similarity to 2c9d")
            plt.xlim((0,1))
            fig.savefig(os.path.join(self.output_plots_dir, "task1_prot13_sim_2c9d.png"))

            fig = plt.figure()
            sns.kdeplot(df, x="sim_gen_2c92", hue="PDB ID")
            plt.xlabel("Similarity to 2c92")
            plt.xlim((0,1))
            fig.savefig(os.path.join(self.output_plots_dir, "task1_prot13_sim_2c92.png"))

            fig = plt.figure()
            sns.kdeplot(df, x="sim_gen_2c94", hue="PDB ID")
            plt.xlabel("Similarity to 2c94")
            plt.xlim((0,1))
            fig.savefig(os.path.join(self.output_plots_dir, "task1_prot13_sim_2c94.png"))

            fig = plt.figure()
            sns.kdeplot(df, x="sim_gen_2c97", hue="PDB ID")
            plt.xlabel("Similarity to 2c97")
            plt.xlim((0,1))
            fig.savefig(os.path.join(self.output_plots_dir, "task1_prot13_sim_2c97.png"))
            
            if statistics.mean(df_2c9d["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_2c9d["hbond_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_2c9b["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_2c9b["hbond_frac_recreated"].tolist()) > 0.5:
                    if statistics.mean(df_2c92["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_2c92["hbond_frac_recreated"].tolist()) > 0.5:
                        if statistics.mean(df_2c94["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_2c94["hbond_frac_recreated"].tolist()) > 0.5:
                            if statistics.mean(df_2c97["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_2c97["hbond_frac_recreated"].tolist()) > 0.5:
                                passed = True
                            else:
                                comments += "Step6/6: Did not pass hydrophobic and hbond re-creations for 2c97 check;"
                        else:
                            comments += "Step5/6: Did not pass hydrophobic and hbond re-creations for 2c94 check;"
                    else:
                        comments += "Step4/6: Did not pass hydrophobic and hbond re-creations for 2c92 check;"
                else:
                    comments += "Step3/6: Did not pass hydrophobic and hbond re-creations for 2c9b check;"
            else:
                comments += "Step2/6: Did not pass hydrophobic and hbond re-creations for 2c9d check;"
        else:
            comments += "Step1/6: Did not contain sufficient data points;"
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot14(self, df):
        print("Running task 1 protein 14 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        x = df[df["5jwc_gen_mol_com_dist"].notna()]["5jwc_gen_mol_com_dist"]

        kmeans = KMeans(n_clusters=2).fit(np.asarray(x).reshape(len(x),1))
        cluster_centers = kmeans.cluster_centers_
        
        # Check if clusters are at least 5 Angstroms apart (should be about 15 Angstroms difference)
        passed = False
        if len(df) > 10:
            if abs(cluster_centers[0]-cluster_centers[1]) > 5:
                passed = True
            else:
                comments += "Step2/2: Did not pass 2 clusters check for binding sites;"
        else:
            comments += "Step1/2: Did not contain sufficient data points;"
            
        # Plot
        fig = plt.figure()
        sns.kdeplot(df, hue="PDB ID", x="5jwc_gen_mol_com_dist")
        plt.xlabel("Distance to 5jwc crystal ligand")
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot14_dist.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot15(self, df):
        print("Running task 1 protein 15 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df = df[(df["6z80_gen_sim"].notna()) & (df["6z85_gen_sim"].notna())]
        
        hbond_perc = []
        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            
        hbond_perc = not_na(hbond_perc)
            
        passed = False
        if len(df) > 10:
            if statistics.mean(hbond_perc) > 0.5:
                passed = True
            else:
                comments += "Step2/2: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/2: Did not contain sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="6z80_gen_sim", hue="PDB ID")
        plt.xlabel("Similarity to 6z80")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot15_sim_6z80.png"))

        fig = plt.figure()
        sns.kdeplot(df, x="6z85_gen_sim", hue="PDB ID")
        plt.xlabel("Similarity to 6z85")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot15_sim_6z85.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot16(self, df):
        print("Running task 1 protein 16 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        x_allosteric = df[df["5yjw_allosteric_gen_mol_com_dist"].notna()]["5yjw_allosteric_gen_mol_com_dist"]
        x_active = df[df["5yjw_active_gen_mol_com_dist"].notna()]["5yjw_active_gen_mol_com_dist"]

        kmeans_allosteric = KMeans(n_clusters=2).fit(np.asarray(x_allosteric).reshape(len(x_allosteric),1))
        cluster_centers_allosteric = kmeans_allosteric.cluster_centers_
        
        kmeans_active = KMeans(n_clusters=2).fit(np.asarray(x_allosteric).reshape(len(x_active),1))
        cluster_centers_active = kmeans_active.cluster_centers_
        
        # Check if clusters are at least 5 Angstroms apart (should be about 15 Angstroms difference)
        passed_pocket_finding = False
        if abs(cluster_centers_allosteric[0]-cluster_centers_allosteric[1]) > 5 and abs(cluster_centers_active[0]-cluster_centers_active[1]) > 5:
            passed_pocket_finding = True
            
        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            
        hydrophobic_perc = not_na(hydrophobic_perc)
        hbond_perc = not_na(hbond_perc)
            
        passed_interactions = False
        if statistics.mean(hbond_perc) > 0.5 and statistics.mean(hydrophobic_perc) > 0.5:
            passed_interactions = True
            
        # Check if distances to COM of ligand < 10 A mean (easy task)
        passed_pocket_distance = False
        if statistics.mean(df["5yjy_gen_mol_com_dist"].tolist()) < 5:
            passed_pocket_distance = True
            
        passed = False
        if len(x_allosteric) > 10 and len(x_active) > 10:
            if passed_pocket_finding:
                if passed_interactions:
                    if passed_pocket_distance:
                        passed = True
                    else:
                        comments += "Step4/4: Did not pass generated COM to crystal < 5 check;"
                else:
                    comments += "Step3/4: Did not pass hbond and hydrophobic re-creations check;"
            else:
                comments += "Step2/4: Did not pass active and allosteric pocket finding checks;"
        else:
            comments += "Step1/4: Did not contain sufficient data points;"
            
        # Plot
        fig = plt.figure()
        sns.kdeplot(df[df["PDB ID"] == "5yjw"], x="5yjw_allosteric_gen_mol_com_dist", label="Allosteric pocket")
        sns.kdeplot(df[df["PDB ID"] == "5yjw"], x="5yjw_active_gen_mol_com_dist", label="Active pocket")
        plt.xlabel("Distances to binding pockets")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot16_dist.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot17(self, df):
        print("Running task 1 protein 17 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df["SF4"] = df["mol_cond"].apply(lambda x: "With" if "SF4" in x else "Without")
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []
        saltbridge_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            saltbridge_perc.append(interaction_dict_literal["saltbridge"]["interaction_frac"])

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        df["saltbridge_frac_recreated"] = saltbridge_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna()) & (df["saltbridge_frac_recreated"].notna())]
        
        df_apo = df[df["PDB ID"] == "6xig"]
        df_holo = df[df["PDB ID"] == "6xi9"]
        
        hydrophobic_pass = False
        hbond_pass = False
        saltbridge_pass = False
        
        if len(df_apo[df_apo["SF4"] == "With"]) > 10 and len(df_apo[df_apo["SF4"] == "Without"]) > 10 and len(df_holo[df_holo["SF4"] == "With"]) > 10 and len(df_holo[df_holo["SF4"] == "Without"]) > 10:
            if statistics.mean(df_apo[df_apo["SF4"] == "With"]["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo[df_apo["SF4"] == "Without"]["hydrophobic_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_holo[df_holo["SF4"] == "With"]["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo[df_holo["SF4"] == "Without"]["hydrophobic_frac_recreated"].tolist()) > 0.5:
                    hydrophobic_pass = True
                else:
                    comments += "Step2/7: Did not pass holo with and without SF4 hydrophobic re-creation check;"
            else:
                comments += "Step1/7: Did not pass apo with and without SF4 hydrophobic re-creation check;"
                
            if statistics.mean(df_apo[df_apo["SF4"] == "With"]["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo[df_apo["SF4"] == "Without"]["hbond_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_holo[df_holo["SF4"] == "With"]["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo[df_holo["SF4"] == "Without"]["hbond_frac_recreated"].tolist()) > 0.5:
                    hbond_pass = True
                else:
                    comments += "Step4/7: Did not pass holo with and without SF4 hbond re-creation check;"
            else:
                comments += "Step3/7: Did not pass apo with and without SF4 hbond re-creation check;"

            if statistics.mean(df_apo[df_apo["SF4"] == "With"]["saltbridge_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo[df_apo["SF4"] == "Without"]["saltbridge_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df_holo[df_holo["SF4"] == "With"]["saltbridge_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo[df_holo["SF4"] == "Without"]["saltbridge_frac_recreated"].tolist()) > 0.5:
                    saltbridge_pass = True
                else:
                    comments += "Step7/7: Did not pass holo with and without SF4 saltbridge re-creation check;"
            else:
                comments += "Step6/7: Did not pass apo with and without SF4 saltbridge re-creation check;"
        else:
            comments += "Step5/7: Did not contain sufficient data points;"

        passed = False
        if hydrophobic_pass and hbond_pass and saltbridge_pass:
            passed = True
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df[df["PDB ID"] == "6xi9"], x="6xi9_gen_sim", hue="SF4")
        plt.xlabel("Similarity to crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot17_sim_6xi9.png"))

        fig = plt.figure()
        sns.kdeplot(df[df["PDB ID"] == "6xig"], x="6xi9_gen_sim", hue="SF4")
        plt.xlabel("Similarity to crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot17_sim_6xig.png"))

        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot18(self, df):
        print("Running task 1 protein 18 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic, hbond, and saltbridge
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
            hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
            hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna())]
        
        hydrophobic_pass = False
        hbond_pass = False
        if len(df[df["PDB ID"] == "1kqb"]) > 10 and len(df[df["PDB ID"] == "1kqd"]) > 10 and len(df[df["PDB ID"] == "1kqc"]) > 10:
            if statistics.mean(df[df["PDB ID"] == "1kqb"]["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df[df["PDB ID"] == "1kqc"]["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df[df["PDB ID"] == "1kqd"]["hydrophobic_frac_recreated"].tolist()) > 0.5:
                hydrophobic_pass = True
            
            if statistics.mean(df[df["PDB ID"] == "1kqb"]["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df[df["PDB ID"] == "1kqc"]["hbond_frac_recreated"].tolist()) > 0.5 and statistics.mean(df[df["PDB ID"] == "1kqd"]["hbond_frac_recreated"].tolist()) > 0.5:
                hbond_pass = True
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
            
        passed = False
        if hydrophobic_pass:
            if hbond_pass:
                passed = True
            else:
                comments += "Step3/3: Did not pass hbond re-creation check;"
        else:
            comments += "Step2/3: Did not pass hydrophobic re-creation check;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="1kqc_gen_sim", hue="PDB ID")
        plt.xlabel("Similarity to crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot18_sim_1kqc.png"))

        fig = plt.figure()
        sns.kdeplot(df, x="1kqb_gen_sim", hue="PDB ID")
        plt.xlabel("Similarity to crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot18_sim_1kqb.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot19(self, df):
        print("Running task 1 protein 19 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        df["Ca"] = df["mol_cond"].apply(lambda x: "With" if "Ca" in x else "Without")
        df["H2O"] = df["mol_cond"].apply(lambda x: "With" if "water" in x else "Without")
        
        passed = False
        if len(df) > 10:
            if statistics.mean(df[df["Ca"] == "With"]["hasAcidic"].tolist()) > statistics.mean(df[df["Ca"] == "Without"]["hasAcidic"].tolist()):
                passed = True
            else:
                comments += "Step2/2: Did not pass having more acidic groups with Ca;"
        else:
            comments += "Step1/2: Did not contain sufficient data points;"
        
        fig = plt.figure()
        sns.kdeplot(df[df["H2O"] == "With"], x="gen_mol_vol", label="With H2O")
        sns.kdeplot(df[df["H2O"] == "Without"], x="gen_mol_vol", label="Without H2O")
        plt.xlabel("Volume of generated ligand")
        plt.axvline(252.664, color="r", label="4kxb volume")
        plt.axvline(112.776, color="g", label="4kxc volume")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot19_vol.png"))

        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def prot20(self, df):
        print("Running task 1 protein 20 assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            try:
                interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
                hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
                hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            except:
                hydrophobic_perc.append(None)
                hbond_perc.append(None)

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[(df["hydrophobic_frac_recreated"].notna()) & (df["hbond_frac_recreated"].notna())]
        
        df_apo = df[df["PDB ID"] == "4yt3"]
        df_holo = df[df["PDB ID"] == "5iki"]
        
        apo_pass = False
        holo_pass = False
        if len(df_apo) > 10 and len(df_holo) > 10:
            if statistics.mean(df_apo["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_apo["hbond_frac_recreated"].tolist()) > 0.5:
                apo_pass = True
        
            if statistics.mean(df_holo["hydrophobic_frac_recreated"].tolist()) > 0.5 and statistics.mean(df_holo["hbond_frac_recreated"].tolist()) > 0.5:
                holo_pass = True
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
                
        passed = False
        if apo_pass:
            if holo_pass:
                passed = True
            else:
                comments += "Step3/3: Did not pass holo hydrophobic and hbond re-creation check;"
        else:
            comments += "Step2/3: Did not pass apo hydrophobic and hbond re-creation check;"
            
        # Plot
        fig = plt.figure()
        sns.kdeplot(df, x="5iki_gen_sim", hue="PDB ID")
        plt.xlabel("Similartiy to 5iki crystal ligand")
        plt.xlim((0,1))
        fig.savefig(os.path.join(self.output_plots_dir, "task1_prot20_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
        
def run_task1_analysis(task_outputs_dir, output_plots_dir):
    all_task_spec_output_files = [os.path.join(task_outputs_dir, i) for i in os.listdir(task_outputs_dir) if "spec" in i and "task1" in i]
    task1_stats_init = task1StatsPlots(task_outputs_dir=task_outputs_dir, output_plots_dir=output_plots_dir)
        
    # Join into a DataFrame, add None rows for missing values
    final_df = pd.DataFrame()
    
    prot_ids = []
    task_ids = []
    comments_list = []
    passed_list = []
    perc_molgen_extra_pass_list = []
    perc_molgen_MOSES_pass_list = []
    perc_MOSES_MCF_pass_list = []
    perc_MOSES_PAINS_pass_list = []
    perc_MOSES_charge_pass_list = []
    perc_pass_bond_lengths_list = []
    perc_pass_bond_angles_list = []
    perc_pass_internal_steric_clash_list = []
    perc_passed_internal_energy_list = []
    perc_pass_dist_to_prot_list = []
    perc_pass_dist_to_water_list = []
    perc_pass_vol_overlap_with_prot_list = []

    for prot_id in range(1, 21):
        prot_ids.append(prot_id)
        task_ids.append(1)

        filtered_list = list(filter(lambda x: x.split(".csv")[0].split("_")[-1] == str(prot_id), all_task_spec_output_files))
        
        if len(filtered_list) != 0:
            print("Run", prot_id)
            filepath = filtered_list[0]
            df = pd.read_csv(filepath)
            
            try:
                if "moses_filters" in df.columns and len(df) >= 100:
                    func = getattr(task1_stats_init, "prot"+str(prot_id))
                    comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = func(df)
                    comments_list.append(comments)
                    passed_list.append(passed)
                    perc_molgen_extra_pass_list.append(perc_molgen_extra_pass)
                    perc_molgen_MOSES_pass_list.append(perc_molgen_MOSES_pass)
                    perc_MOSES_MCF_pass_list.append(perc_MOSES_MCF_pass)
                    perc_MOSES_PAINS_pass_list.append(perc_MOSES_PAINS_pass)
                    perc_MOSES_charge_pass_list.append(perc_MOSES_charge_pass)
                    perc_pass_bond_lengths_list.append(perc_pass_bond_lengths)
                    perc_pass_bond_angles_list.append(perc_pass_bond_angles)
                    perc_pass_internal_steric_clash_list.append(perc_pass_internal_steric_clash)
                    perc_passed_internal_energy_list.append(perc_passed_internal_energy)
                    perc_pass_dist_to_prot_list.append(perc_pass_dist_to_prot)
                    perc_pass_dist_to_water_list.append(perc_pass_dist_to_water)
                    perc_pass_vol_overlap_with_prot_list.append(perc_pass_vol_overlap_with_prot)

                else:
                    comments_list.append("Analysis not run or less than 100 compounds generated;")
                    passed_list.append(None)
                    perc_molgen_extra_pass_list.append(None)
                    perc_molgen_MOSES_pass_list.append(None)
                    perc_MOSES_MCF_pass_list.append(None)
                    perc_MOSES_PAINS_pass_list.append(None)
                    perc_MOSES_charge_pass_list.append(None)
                    perc_pass_bond_lengths_list.append(None)
                    perc_pass_bond_angles_list.append(None)
                    perc_pass_internal_steric_clash_list.append(None)
                    perc_passed_internal_energy_list.append(None)
                    perc_pass_dist_to_prot_list.append(None)
                    perc_pass_dist_to_water_list.append(None)
                    perc_pass_vol_overlap_with_prot_list.append(None)
            except Exception as e:
                comments_list.append("Could not run task test, {0};".format(e))
                passed_list.append(None)
                perc_molgen_extra_pass_list.append(None)
                perc_molgen_MOSES_pass_list.append(None)
                perc_MOSES_MCF_pass_list.append(None)
                perc_MOSES_PAINS_pass_list.append(None)
                perc_MOSES_charge_pass_list.append(None)
                perc_pass_bond_lengths_list.append(None)
                perc_pass_bond_angles_list.append(None)
                perc_pass_internal_steric_clash_list.append(None)
                perc_passed_internal_energy_list.append(None)
                perc_pass_dist_to_prot_list.append(None)
                perc_pass_dist_to_water_list.append(None)
                perc_pass_vol_overlap_with_prot_list.append(None)
        else:
            comments_list.append("Generation and/or analysis not run;")
            passed_list.append(None)
            perc_molgen_extra_pass_list.append(None)
            perc_molgen_MOSES_pass_list.append(None)
            perc_MOSES_MCF_pass_list.append(None)
            perc_MOSES_PAINS_pass_list.append(None)
            perc_MOSES_charge_pass_list.append(None)
            perc_pass_bond_lengths_list.append(None)
            perc_pass_bond_angles_list.append(None)
            perc_pass_internal_steric_clash_list.append(None)
            perc_passed_internal_energy_list.append(None)
            perc_pass_dist_to_prot_list.append(None)
            perc_pass_dist_to_water_list.append(None)
            perc_pass_vol_overlap_with_prot_list.append(None)
    
    final_df["prot_id"] = prot_ids
    final_df["task"] = task_ids
    final_df["comments"] = comments_list
    final_df["passed_task"] = passed_list
    final_df["perc_molgen_filter_pass"] = perc_molgen_extra_pass_list
    final_df["perc_MOSES_pass"] = perc_molgen_MOSES_pass_list
    final_df["perc_MOSES_MCF_pass"] = perc_MOSES_MCF_pass_list
    final_df["perc_MOSES_PAINS_pass"] = perc_MOSES_PAINS_pass_list
    final_df["perc_MOSES_charges_pass"] = perc_MOSES_charge_pass_list
    final_df["perc_pass_bond_lengths"] = perc_pass_bond_lengths_list
    final_df["perc_pass_bond_angles"] = perc_pass_bond_angles_list
    final_df["perc_pass_internal_steric_clash"] = perc_pass_internal_steric_clash_list
    final_df["perc_passed_internal_energy"] = perc_passed_internal_energy_list
    final_df["perc_pass_dist_to_prot"] = perc_pass_dist_to_prot_list
    final_df["perc_pass_dist_to_water"] = perc_pass_dist_to_water_list
    final_df["perc_pass_vol_overlap_with_prot"] = perc_pass_vol_overlap_with_prot_list
    
    return final_df
    

In [None]:
# EXAMPLE
final_df_diffsbdd_task1 = run_task1_analysis(task_outputs_dir="<dir_rof_summary_aggregated_and_spec_scores_for_DiffSBDD>", output_plots_dir="<dir_to_write_output_plots>")
#final_df_diffsbdd_task1.to_csv("<path_to_desired_csv>", index=False)


## TASK 2: STATS AND PLOTS

In [None]:
class task2StatsPlots():
    def __init__(self, task_outputs_dir, output_plots_dir=None):
        self.output_dir = task_outputs_dir
        self.output_plots_dir = output_plots_dir
        
        if not os.path.exists(self.output_plots_dir):
            os.makedirs(self.output_plots_dir)
            
    def get_hydrophobic_hbond_perc_recreations(self, df):
        # Check percentage of interactions re-created for hydrophobic and hbond
        hydrophobic_perc = []
        hbond_perc = []

        for idx, row in df.iterrows():
            try:
                interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
                hydrophobic_perc.append(interaction_dict_literal["hydrophobic"]["interaction_frac"])
                hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            except:
                hydrophobic_perc.append(None)
                hbond_perc.append(None)

        df["hydrophobic_frac_recreated"] = hydrophobic_perc
        df["hbond_frac_recreated"] = hbond_perc
        
        return df
    
    def ITK_LCK_AurB(self, df):
        print("Running task 2 ITK selectivity assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        ITK_selectivity_PDBs = {"4l7s": "ITK", "1fbz": "LCK", "4af3": "AurB"}
        df["Protein"] = df["PDB ID"].apply(lambda x: ITK_selectivity_PDBs[x])
        
        # Compare each active-inactive distribution using t-test
        # Similarities not that important - re-created interactions are
        ITK_df = df[df["PDB ID"].isin(["4l7s"])]
        LCK_df = df[df["PDB ID"].isin(["1fbz"])]
        AurB_df = df[df["PDB ID"].isin(["4af3"])]
        
        # Check percentage of interactions re-created for hbond
        ITK_hbond_frac = ITK_df["recreated_H_bond_ITK_4l7s"].tolist().count(1)/len(ITK_df["recreated_H_bond_ITK_4l7s"])
        LCK_hbond_frac = LCK_df["recreated_H_bond_LCK_1fbz"].tolist().count(1)/len(LCK_df["recreated_H_bond_LCK_1fbz"])
        AurB_hbond_frac = AurB_df["recreated_H_bond_AurB_4af3"].tolist().count(1)/len(AurB_df["recreated_H_bond_AurB_4af3"])
        
        ITK_AurB_val419_t_test = scipy.stats.ttest_ind(ITK_df["min_dist_gen_to_ITK_Val419"], AurB_df["com_dist_gen_to_ITK_Val419"], equal_var=False, nan_policy='omit')
        
        passed = False
        if len(ITK_df) > 10 and len(LCK_df) > 10 and len(AurB_df) > 10:
            if ITK_hbond_frac > 0.5 and LCK_hbond_frac > 0.5 and AurB_hbond_frac > 0.5:
                if ITK_AurB_val419_t_test < 0.01:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass Val419 ITK AurB residue distance t-test check;"
            else:
                comments += "Step2/3: Did not pass hbond ITK, LCK, AurB re-creation check;"
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
        
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="2d_active_sim", hue="Protein", linestyle="--", common_norm=False)
        sns.kdeplot(df, x="2d_inactive_sim", hue="Protein", common_norm=False)
        plt.title("Similarities of generated compounds to active and inactive sets where dashed lines are inactive similarities, bold are active")
        sns.set(style="white",rc={'legend.frameon':False})
        plt.xlabel("Similarity to ChEMBL sets")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task2_ITK_active_inactive_sim.png"))
        
        fig = plt.figure()
        df_bar = pd.DataFrame({'ITK': [ITK_df["recreated_H_bond_ITK_4l7s"].tolist().count(0)/len(ITK_df["recreated_H_bond_ITK_4l7s"]), 
                              ITK_df["recreated_H_bond_ITK_4l7s"].tolist().count(1)/len(ITK_df["recreated_H_bond_ITK_4l7s"])],
                   'LCK': [LCK_df["recreated_H_bond_LCK_1fbz"].tolist().count(0)/len(LCK_df["recreated_H_bond_LCK_1fbz"]),
                            LCK_df["recreated_H_bond_LCK_1fbz"].tolist().count(1)/len(LCK_df["recreated_H_bond_LCK_1fbz"])],
                   'AurB': [AurB_df["recreated_H_bond_AurB_4af3"].tolist().count(0)/len(AurB_df["recreated_H_bond_AurB_4af3"]),
                            AurB_df["recreated_H_bond_AurB_4af3"].tolist().count(1)/len(AurB_df["recreated_H_bond_AurB_4af3"])]},
                  index=[False,True])
        ax = df_bar.plot(kind='bar', stacked=False, rot=0)
        plt.ylabel("Fraction")
        plt.xlabel("H-bond interaction recreated")
        plt.legend(frameon=False)
        fig = ax.get_figure()
        fig.savefig(os.path.join(self.output_plots_dir, "task2_ITK_interactions.png"))
        
        fig = plt.figure()
        sns.kdeplot(ITK_df, x="min_dist_gen_to_ITK_Val419", label="ITK Val419 ", common_norm=True)
        sns.kdeplot(AurB_df, x="com_dist_gen_to_ITK_Val419", label="AurB Val419", common_norm=True)
        sns.kdeplot(AurB_df, x="min_dist_gen_to_AurB_Leu138", label="AurB Leu138", common_norm=True)
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        sns.set(style="white",rc={'legend.frameon':False})
        fig.savefig(os.path.join(self.output_plots_dir, "task2_ITK_dist_to_residue_Val419.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
        
    def pan_JAK(self, df):
        print("Running task 2 pan JAK assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        JAK_PDBs = {"5wo4": "JAK1", "7q7k": "JAK2", "7q6h": "JAK3", "3lxn": "TYK2"}
        df["Protein"] = df["PDB ID"].apply(lambda x: JAK_PDBs[x])
        
        JAK1_df = df[df["PDB ID"].isin(["5wo4"])]
        JAK2_df = df[df["PDB ID"].isin(["7q7k"])]
        JAK3_df = df[df["PDB ID"].isin(["7q6h"])]
        TYK2_df = df[df["PDB ID"].isin(["3lxn"])]
        
        # Check percentage of interactions re-created for hbond
        JAK1_hbond_frac = JAK1_df["recreated_H_bond_JAK1_5wo4"].tolist().count(1)/len(JAK1_df["recreated_H_bond_JAK1_5wo4"])
        JAK2_hbond_frac = JAK2_df["recreated_H_bond_JAK2_7q7k"].tolist().count(1)/len(JAK2_df["recreated_H_bond_JAK2_7q7k"])
        JAK3_hbond_frac = JAK3_df["recreated_H_bond_JAK3_7q6h"].tolist().count(1)/len(JAK3_df["recreated_H_bond_JAK3_7q6h"])
        TYK2_hbond_frac = TYK2_df["recreated_H_bond_TYK2_3lxn"].tolist().count(1)/len(TYK2_df["recreated_H_bond_TYK2_3lxn"])
        
        passed = False
        if len(JAK1_df) > 10 and len(JAK2_df) > 10 and len(JAK3_df) > 10 and len(TYK2_df) > 10:
            if JAK1_hbond_frac > 0.5 and JAK2_hbond_frac > 0.5 and JAK3_hbond_frac > 0.5 and TYK2_hbond_frac > 0.5:
                passed = True
            else:
                comments += "Step3/2: Did not pass hbond JAK1, JAK2, JAK3, TYK2 re-creation check;"
        else:
            comments += "Step1/2: Did not contain sufficient data points;"
        
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="2d_active_sim", hue="Protein", linestyle="--", common_norm=False)
        sns.kdeplot(df, x="2d_inactive_sim", hue="Protein", common_norm=False)
        plt.title("Similarities of generated compounds to active and inactive sets where dashed lines are inactive similarities, bold are active")
        plt.xlabel("Similarity to ChEMBL sets")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task2_JAK_active_inactive_sim.png"))
        
        fig = plt.figure()
        sns.kdeplot(JAK1_df, x="com_dist_to_JAK1_Leu959", label="JAK1 Leu959", common_norm=True)
        sns.kdeplot(JAK2_df, x="com_dist_to_JAK2_Leu932", label="JAK2 Leu932", common_norm=True)
        sns.kdeplot(JAK3_df, x="com_dist_to_JAK3_Leu905", label="JAK3 Leu905", common_norm=True)
        sns.kdeplot(TYK2_df, x="com_dist_to_TYK2_Val981", label="TYK2 Val981", common_norm=True)
        plt.xlabel("Minimum distance to residue")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task2_JAK_COM_dist_to_residue.png"))
        
        fig = plt.figure()
        df_bar = pd.DataFrame({'JAK1': [JAK1_df["recreated_H_bond_JAK1_5wo4"].tolist().count(0)/len(JAK1_df["recreated_H_bond_JAK1_5wo4"]), 
                              JAK1_df["recreated_H_bond_JAK1_5wo4"].tolist().count(1)/len(JAK1_df["recreated_H_bond_JAK1_5wo4"])],
                   'JAK2': [JAK2_df["recreated_H_bond_JAK2_7q7k"].tolist().count(0)/len(JAK2_df["recreated_H_bond_JAK2_7q7k"]),
                            JAK2_df["recreated_H_bond_JAK2_7q7k"].tolist().count(1)/len(JAK2_df["recreated_H_bond_JAK2_7q7k"])],
                   'JAK3': [JAK3_df["recreated_H_bond_JAK3_7q6h"].tolist().count(0)/len(JAK3_df["recreated_H_bond_JAK3_7q6h"]),
                            JAK3_df["recreated_H_bond_JAK3_7q6h"].tolist().count(1)/len(JAK3_df["recreated_H_bond_JAK3_7q6h"])],
                    'TYK2': [TYK2_df["recreated_H_bond_TYK2_3lxn"].tolist().count(0)/len(TYK2_df["recreated_H_bond_TYK2_3lxn"]),
                            TYK2_df["recreated_H_bond_TYK2_3lxn"].tolist().count(1)/len(TYK2_df["recreated_H_bond_TYK2_3lxn"])]},
                  index=[False,True])
        ax = df_bar.plot(kind='bar', stacked=False, rot=0)
        plt.ylabel("Fraction")
        plt.xlabel("H-bond interaction recreated")
        plt.legend(frameon=False)
        fig = ax.get_figure()
        fig.savefig(os.path.join(self.output_plots_dir, "task2_JAK_interactions.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def pan_BET(self, df):
        print("Running task 2 pan BET assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        BET_PDBs = {"6ddi": "BRD2", "5ig6": "BRD2", "7lay": "BRD3", "3s92": "BRD3", "4qb3": "BRD4", "7usg": "BRD4", "7mrd": "BRDT", "7l99": "BRDT"}
        df["Protein"] = df["PDB ID"].apply(lambda x: BET_PDBs[x])
        
        BRD2_df = df[df["Protein"] == "BRD2"]
        BRD3_df = df[df["Protein"] == "BRD3"]
        BRD4_df = df[df["Protein"] == "BRD4"]
        BRDT_df = df[df["Protein"] == "BRDT"]
        
        # Check percentage of interactions re-created for hbond
        BRD2_hbond_frac = BRD2_df["recreated_Asn_bond"].tolist().count(1)/len(BRD2_df["recreated_Asn_bond"])
        BRD3_hbond_frac = BRD3_df["recreated_Asn_bond"].tolist().count(1)/len(BRD3_df["recreated_Asn_bond"])
        BRD4_hbond_frac = BRD4_df["recreated_Asn_bond"].tolist().count(1)/len(BRD4_df["recreated_Asn_bond"])
        BRDT_hbond_frac = BRDT_df["recreated_Asn_bond"].tolist().count(1)/len(BRDT_df["recreated_Asn_bond"])
        
        passed = False
        if len(BRD2_df) > 10 and len(BRD3_df) > 10 and len(BRD4_df) > 10 and len(BRDT_df) > 10:
            if BRD2_hbond_frac > 0.5 and BRD3_hbond_frac > 0.5 and BRD4_hbond_frac > 0.5 and BRDT_hbond_frac > 0.5:
                if statistics.mean(BRD2_df["min_WPF_dist"].tolist()) <= 5 and statistics.mean(BRD3_df["min_WPF_dist"].tolist()) <= 5 and statistics.mean(BRD4_df["min_WPF_dist"].tolist()) <= 5 and statistics.mean(BRDT_df["min_WPF_dist"].tolist()) <= 5:
                    if statistics.mean(BRD2_df["min_ZA_dist"].tolist()) <= 5 and statistics.mean(BRD3_df["min_ZA_dist"].tolist()) <= 5 and statistics.mean(BRD4_df["min_ZA_dist"].tolist()) <= 5 and statistics.mean(BRDT_df["min_ZA_dist"].tolist()) <= 5:
                        passed = True
                    else:
                        comments += "Step4/4: Did not created compounds < 5 A from ZA channel"
                else:
                    comments += "Step3/4: Did not create compounds < 5 A from WPF motif"
            else:
                comments += "Step2/4: Did not pass hbond BRD2, BRD3, BRD4, BRDT re-creation check;"
        else:
            comments += "Step1/4: Did not contain sufficient data points;"
        
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="2d_active_sim", hue="Protein", linestyle="--", common_norm=False)
        sns.kdeplot(df, x="2d_inactive_sim", hue="Protein", common_norm=False)
        plt.title("Similarities of generated compounds to active and inactive sets where dashed lines are inactive similarities, bold are active")
        sns.set(style="white",rc={'legend.frameon':False})
        plt.xlabel("Similarity to ChEMBL sets")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task2_BET_active_inactive_sim.png"))
        
        fig = plt.figure()
        df_bar = pd.DataFrame({'BRD2': [BRD2_df["recreated_Asn_bond"].tolist().count(0)/len(BRD2_df["recreated_Asn_bond"]), 
                              BRD2_df["recreated_Asn_bond"].tolist().count(1)/len(BRD2_df["recreated_Asn_bond"])],
                   'BRD3': [BRD3_df["recreated_Asn_bond"].tolist().count(0)/len(BRD3_df["recreated_Asn_bond"]),
                            BRD3_df["recreated_Asn_bond"].tolist().count(1)/len(BRD3_df["recreated_Asn_bond"])],
                   'BRD4': [BRD4_df["recreated_Asn_bond"].tolist().count(0)/len(BRD4_df["recreated_Asn_bond"]),
                            BRD4_df["recreated_Asn_bond"].tolist().count(1)/len(BRD4_df["recreated_Asn_bond"])],
                    'BRDT': [BRDT_df["recreated_Asn_bond"].tolist().count(0)/len(BRDT_df["recreated_Asn_bond"]),
                            BRDT_df["recreated_Asn_bond"].tolist().count(1)/len(BRDT_df["recreated_Asn_bond"])]},
                  index=[False,True])
        ax = df_bar.plot(kind='bar', stacked=False, rot=0)
        plt.ylabel("Fraction")
        plt.xlabel("Asn bond interaction recreated")
        plt.legend(frameon=False)
        fig = ax.get_figure()
        fig.savefig(os.path.join(self.output_plots_dir, "task2_BET_interactions.png"))
        
        fig = plt.figure()
        sns.kdeplot(df, x="min_WPF_dist", hue="Protein", common_norm=False)
        plt.xlabel("Distance to WPF motif")
        fig.savefig(os.path.join(self.output_plots_dir, "task2_BET_min_WPF_dist.png"))
        
        fig = plt.figure()
        sns.kdeplot(df, x="min_ZA_dist", hue="Protein", common_norm=False)
        plt.xlabel("Distance to ZA channel")
        fig.savefig(os.path.join(self.output_plots_dir, "task2_BET_min_ZA_dist.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
    def shik(self, df):
        print("Running task 2 SHIK assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        shik_PDBs = {"1zyu": "SHIK"}
        df["Protein"] = df["PDB ID"].apply(lambda x: shik_PDBs[x])
        
        passed = False
        if len(df) > 10:
            if statistics.mean(df["recreated_H_bond_Asp34"].tolist()) > 0.5:
                if statistics.mean(df["recreated_Arg58"].tolist()) > 0.5:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass Arg58 re-creation check;"
            else:
                comments += "Step2/3: Did not pass Asp32 hbond re-creation check;"
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
            
        # Plots
        fig = plt.figure()
        sns.kdeplot(df, x="2d_active_sim", hue="Protein", linestyle="--", common_norm=False)
        sns.kdeplot(df, x="2d_inactive_sim", hue="Protein", common_norm=False)
        plt.title("Similarities of generated compounds to active and inactive sets where dashed lines are inactive similarities, bold are active")
        sns.set(style="white",rc={'legend.frameon':False})
        plt.xlabel("Similarity to ChEMBL sets")
        plt.xlim((0, 1.0))
        fig.savefig(os.path.join(self.output_plots_dir, "task2_SHIK_active_inactive_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
def run_task2_analysis(task_outputs_dir, output_plots_dir):
    all_task_spec_output_files = [os.path.join(task_outputs_dir, i) for i in os.listdir(task_outputs_dir) if "spec" in i and "task2" in i]
    task2_stats_init = task2StatsPlots(task_outputs_dir=task_outputs_dir, output_plots_dir=output_plots_dir)
        
    # Join into a DataFrame, add None rows for missing values
    final_df = pd.DataFrame()
    
    prot_ids = []
    task_ids = []
    comments_list = []
    passed_list = []
    perc_molgen_extra_pass_list = []
    perc_molgen_MOSES_pass_list = []
    perc_MOSES_MCF_pass_list = []
    perc_MOSES_PAINS_pass_list = []
    perc_MOSES_charge_pass_list = []
    perc_pass_bond_lengths_list = []
    perc_pass_bond_angles_list = []
    perc_pass_internal_steric_clash_list = []
    perc_passed_internal_energy_list = []
    perc_pass_dist_to_prot_list = []
    perc_pass_dist_to_water_list = []
    perc_pass_vol_overlap_with_prot_list = []

    for prot_id in ["ITK", "JAK", "BET", "shik"]:
        prot_ids.append(prot_id)
        task_ids.append(2)

        filtered_list = list(filter(lambda x: x.split(".csv")[0].split("_")[-2] == str(prot_id), all_task_spec_output_files))
        
        if len(filtered_list) != 0:
            print("Run", prot_id)
            filepath = filtered_list[0]
            df = pd.read_csv(filepath)
            
            try:
                if "moses_filters" in df.columns and len(df) >= 100:
                    if prot_id == "ITK":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task2_stats_init.ITK_LCK_AurB(df)
                    elif prot_id == "JAK":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task2_stats_init.pan_JAK(df)
                    elif prot_id == "BET":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task2_stats_init.pan_BET(df)
                    elif prot_id == "shik":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass, perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task2_stats_init.shik(df)
                        
                    comments_list.append(comments)
                    passed_list.append(passed)
                    perc_molgen_extra_pass_list.append(perc_molgen_extra_pass)
                    perc_molgen_MOSES_pass_list.append(perc_molgen_MOSES_pass)
                    perc_MOSES_MCF_pass_list.append(perc_MOSES_MCF_pass)
                    perc_MOSES_PAINS_pass_list.append(perc_MOSES_PAINS_pass)
                    perc_MOSES_charge_pass_list.append(perc_MOSES_charge_pass)
                    perc_pass_bond_lengths_list.append(perc_pass_bond_lengths)
                    perc_pass_bond_angles_list.append(perc_pass_bond_angles)
                    perc_pass_internal_steric_clash_list.append(perc_pass_internal_steric_clash)
                    perc_passed_internal_energy_list.append(perc_passed_internal_energy)
                    perc_pass_dist_to_prot_list.append(perc_pass_dist_to_prot)
                    perc_pass_dist_to_water_list.append(perc_pass_dist_to_water)
                    perc_pass_vol_overlap_with_prot_list.append(perc_pass_vol_overlap_with_prot)

                else:
                    comments_list.append("Analysis not run or less than 100 compounds generated;")
                    passed_list.append(None)
                    perc_molgen_extra_pass_list.append(None)
                    perc_molgen_MOSES_pass_list.append(None)
                    perc_MOSES_MCF_pass_list.append(None)
                    perc_MOSES_PAINS_pass_list.append(None)
                    perc_MOSES_charge_pass_list.append(None)
                    perc_pass_bond_lengths_list.append(None)
                    perc_pass_bond_angles_list.append(None)
                    perc_pass_internal_steric_clash_list.append(None)
                    perc_passed_internal_energy_list.append(None)
                    perc_pass_dist_to_prot_list.append(None)
                    perc_pass_dist_to_water_list.append(None)
                    perc_pass_vol_overlap_with_prot_list.append(None)
            except Exception as e:
                comments_list.append("Could not run task test, {0};".format(e))
                passed_list.append(None)
                perc_molgen_extra_pass_list.append(None)
                perc_molgen_MOSES_pass_list.append(None)
                perc_MOSES_MCF_pass_list.append(None)
                perc_MOSES_PAINS_pass_list.append(None)
                perc_MOSES_charge_pass_list.append(None)
                perc_pass_bond_lengths_list.append(None)
                perc_pass_bond_angles_list.append(None)
                perc_pass_internal_steric_clash_list.append(None)
                perc_passed_internal_energy_list.append(None)
                perc_pass_dist_to_prot_list.append(None)
                perc_pass_dist_to_water_list.append(None)
                perc_pass_vol_overlap_with_prot_list.append(None)
        else:
            comments_list.append("Generation and/or analysis not run;")
            passed_list.append(None)
            perc_molgen_extra_pass_list.append(None)
            perc_molgen_MOSES_pass_list.append(None)
            perc_MOSES_MCF_pass_list.append(None)
            perc_MOSES_PAINS_pass_list.append(None)
            perc_MOSES_charge_pass_list.append(None)
            perc_pass_bond_lengths_list.append(None)
            perc_pass_bond_angles_list.append(None)
            perc_pass_internal_steric_clash_list.append(None)
            perc_passed_internal_energy_list.append(None)
            perc_pass_dist_to_prot_list.append(None)
            perc_pass_dist_to_water_list.append(None)
            perc_pass_vol_overlap_with_prot_list.append(None)
    
    final_df["prot_id"] = prot_ids
    final_df["task"] = task_ids
    final_df["comments"] = comments_list
    final_df["passed_task"] = passed_list
    final_df["perc_molgen_filter_pass"] = perc_molgen_extra_pass_list
    final_df["perc_MOSES_pass"] = perc_molgen_MOSES_pass_list
    final_df["perc_MOSES_MCF_pass"] = perc_MOSES_MCF_pass_list
    final_df["perc_MOSES_PAINS_pass"] = perc_MOSES_PAINS_pass_list
    final_df["perc_MOSES_charges_pass"] = perc_MOSES_charge_pass_list
    final_df["perc_pass_bond_lengths"] = perc_pass_bond_lengths_list
    final_df["perc_pass_bond_angles"] = perc_pass_bond_angles_list
    final_df["perc_pass_internal_steric_clash"] = perc_pass_internal_steric_clash_list
    final_df["perc_passed_internal_energy"] = perc_passed_internal_energy_list
    final_df["perc_pass_dist_to_prot"] = perc_pass_dist_to_prot_list
    final_df["perc_pass_dist_to_water"] = perc_pass_dist_to_water_list
    final_df["perc_pass_vol_overlap_with_prot"] = perc_pass_vol_overlap_with_prot_list
    
    return final_df

In [None]:
# EXAMPLE
final_df_diffsbdd_task2 = run_task2_analysis(task_outputs_dir="<dir_rof_summary_aggregated_and_spec_scores_for_DiffSBDD>", output_plots_dir="<dir_to_write_output_plots>")
#final_df_diffsbdd_task2.to_csv("<path_to_desired_csv>", index=False)


## TASK 3: STATS AND PLOTS

In [None]:
class task3StatsPlots():
    def __init__(self, task_outputs_dir, output_plots_dir=None):
        self.output_dir = task_outputs_dir
        self.output_plots_dir = output_plots_dir
        
        if not os.path.exists(self.output_plots_dir):
            os.makedirs(self.output_plots_dir)
            
    def COVIDMoonshot(self, df):
        print("Running task 3 COVID-19 Moonshot VS assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        # Re-created H-bond interactions
        hbond_perc = []

        for idx, row in df.iterrows():
            try:
                interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
                hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            except:
                hbond_perc.append(None)
                
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[df["hbond_frac_recreated"].notna()]
        
        passed = False
        if len(df) > 10:
            if statistics.mean(df["hbond_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df["2d_hits_ave_sim"].tolist()) > 0.5:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass 2d Tanimoto sim > 0.5 check;"
            else:
                comments += "Step2/3: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
        
        # Plots
        # COVID-19 Moonshot
        fig = plt.figure()
        sns.kdeplot(data=df, x="2d_hits_ave_sim", label="Similarity to 2D hits", common_norm=False)
        sns.kdeplot(data=df, x="3d_crystal_sim", label="Similarity to 3D crystal", common_norm=False)
        plt.xlabel("Similarities of generated compounds")
        plt.xlim((0, 1.0))
        plt.title("2D similarities in bold, 3D in dashed")
        plt.legend()
        fig.savefig(os.path.join(self.output_plots_dir, "task3_COVID19_Moonshot_2d_3d_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
        
    def CSAR2014(self, df):
        print("Running task 3 CSAR 2014 VS assessment")
        
        comments = ""
        
        # Filter out bad compounds and get statistics
        df, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = filter_cmpds_get_stats(df, return_stats=True)
        
        CSAR_PDBs_BY_PROT = {'FXa': ['4zh8','4y6d','4zha','4y71','4y76','4y79','4y7a','4y7b'],
                    'SYK': ['4yjo','4yjp','4yjq','4yjr','4yjs','4yjt','4yju','4yjv'],
                    'TrMD': ['4ypw','4ypx','4ypy','4ypz','4yq0','4yq1','4yq2','4yq3',
                             '4yq4','4yq5','4yq6','4yq7','4yq8','4yq9','4yqa','4yqb',
                             '4yqc','4yqd','4yqg','4yqi','4yqj','4yqk','4yql','5d9f',
                             '4yqn','4yqo','4yqp','4yqq','4yqr','4yqs','4yqt']}
        
        prot_name = []
        for idx, row in df.iterrows():
            if row["PDB ID"] in CSAR_PDBs_BY_PROT['FXa']:
                prot_name.append('FXa')
            elif row["PDB ID"] in CSAR_PDBs_BY_PROT['SYK']:
                prot_name.append('SYK')
            elif row["PDB ID"] in CSAR_PDBs_BY_PROT['TrMD']:
                prot_name.append('TrMD')

        df['Protein'] = prot_name
        
        # Re-created H-bond interactions
        hbond_perc = []

        for idx, row in df.iterrows():
            try:
                interaction_dict_literal = json.loads(row["PLIP"].replace("'", '"').replace("None", "null").replace("nan", "null"))
                hbond_perc.append(interaction_dict_literal["hbond"]["interaction_frac"])
            except:
                hbond_perc.append(None)
                
        df["hbond_frac_recreated"] = hbond_perc
        
        df = df[df["hbond_frac_recreated"].notna()]
        
        passed = False
        if len(df) > 10:
            if statistics.mean(df["hbond_frac_recreated"].tolist()) > 0.5:
                if statistics.mean(df["2d_hits_ave_sim"].tolist()) > 0.5:
                    passed = True
                else:
                    comments += "Step3/3: Did not pass 2d Tanimoto sim > 0.5 check;"
            else:
                comments += "Step2/3: Did not pass hbond re-creation check;"
        else:
            comments += "Step1/3: Did not contain sufficient data points;"
        
        # Plots
        # 2D and 3D similarity to hits and crystal ligand
        # CSAR 2014
        fig = plt.figure()
        sns.kdeplot(data=df, x="3d_crystal_sim", hue="Protein", common_norm=False, linestyle='--') 
        sns.kdeplot(data=df, x="2d_hits_ave_sim", hue="Protein", common_norm=False)
        plt.xlabel("Similarities of generated compounds")
        plt.xlim((0, 1.0))
        plt.title("2D similarities in bold, 3D in dashed")
        fig.savefig(os.path.join(self.output_plots_dir, "task3_CSAR_2d_3d_sim.png"))
        
        return comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot
    
def run_task3_analysis(task_outputs_dir, output_plots_dir):
    all_task_spec_output_files = [os.path.join(task_outputs_dir, i) for i in os.listdir(task_outputs_dir) if "spec" in i and "task3" in i]
    task3_stats_init = task3StatsPlots(task_outputs_dir=task_outputs_dir, output_plots_dir=output_plots_dir)
        
    # Join into a DataFrame, add None rows for missing values
    final_df = pd.DataFrame()
    
    prot_ids = []
    task_ids = []
    comments_list = []
    passed_list = []
    perc_molgen_extra_pass_list = []
    perc_molgen_MOSES_pass_list = []
    perc_MOSES_MCF_pass_list = []
    perc_MOSES_PAINS_pass_list = []
    perc_MOSES_charge_pass_list = []
    perc_pass_bond_lengths_list = []
    perc_pass_bond_angles_list = []
    perc_pass_internal_steric_clash_list = []
    perc_passed_internal_energy_list = []
    perc_pass_dist_to_prot_list = []
    perc_pass_dist_to_water_list = []
    perc_pass_vol_overlap_with_prot_list = []

    for prot_id in ["moonshot", "2014"]:
        print(prot_id)
        prot_ids.append(prot_id)
        task_ids.append(3)

        filtered_list = list(filter(lambda x: x.split(".csv")[0].split("_")[-1] == str(prot_id), all_task_spec_output_files))
        
        if len(filtered_list) != 0:
            print("Run", prot_id)
            filepath = filtered_list[0]
            df = pd.read_csv(filepath)
            
            try:
                if "moses_filters" in df.columns and len(df) >= 100:
                    if prot_id == "moonshot":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task3_stats_init.COVIDMoonshot(df)
                    elif prot_id == "2014":
                        comments, passed, perc_molgen_extra_pass, perc_molgen_MOSES_pass, perc_MOSES_MCF_pass, perc_MOSES_PAINS_pass, perc_MOSES_charge_pass,  perc_pass_bond_lengths, perc_pass_bond_angles, perc_pass_internal_steric_clash, perc_passed_internal_energy, perc_pass_dist_to_prot, perc_pass_dist_to_water, perc_pass_vol_overlap_with_prot = task3_stats_init.CSAR2014(df)
                        
                    comments_list.append(comments)
                    passed_list.append(passed)
                    perc_molgen_extra_pass_list.append(perc_molgen_extra_pass)
                    perc_molgen_MOSES_pass_list.append(perc_molgen_MOSES_pass)
                    perc_MOSES_MCF_pass_list.append(perc_MOSES_MCF_pass)
                    perc_MOSES_PAINS_pass_list.append(perc_MOSES_PAINS_pass)
                    perc_MOSES_charge_pass_list.append(perc_MOSES_charge_pass)
                    perc_pass_bond_lengths_list.append(perc_pass_bond_lengths)
                    perc_pass_bond_angles_list.append(perc_pass_bond_angles)
                    perc_pass_internal_steric_clash_list.append(perc_pass_internal_steric_clash)
                    perc_passed_internal_energy_list.append(perc_passed_internal_energy)
                    perc_pass_dist_to_prot_list.append(perc_pass_dist_to_prot)
                    perc_pass_dist_to_water_list.append(perc_pass_dist_to_water)
                    perc_pass_vol_overlap_with_prot_list.append(perc_pass_vol_overlap_with_prot)

                else:
                    comments_list.append("Analysis not run or less than 100 compounds generated;")
                    passed_list.append(None)
                    perc_molgen_extra_pass_list.append(None)
                    perc_molgen_MOSES_pass_list.append(None)
                    perc_MOSES_MCF_pass_list.append(None)
                    perc_MOSES_PAINS_pass_list.append(None)
                    perc_MOSES_charge_pass_list.append(None)
                    perc_pass_bond_lengths_list.append(None)
                    perc_pass_bond_angles_list.append(None)
                    perc_pass_internal_steric_clash_list.append(None)
                    perc_passed_internal_energy_list.append(None)
                    perc_pass_dist_to_prot_list.append(None)
                    perc_pass_dist_to_water_list.append(None)
                    perc_pass_vol_overlap_with_prot_list.append(None)
            except Exception as e:
                comments_list.append("Could not run task test, {0};".format(e))
                passed_list.append(None)
                perc_molgen_extra_pass_list.append(None)
                perc_molgen_MOSES_pass_list.append(None)
                perc_MOSES_MCF_pass_list.append(None)
                perc_MOSES_PAINS_pass_list.append(None)
                perc_MOSES_charge_pass_list.append(None)
                perc_pass_bond_lengths_list.append(None)
                perc_pass_bond_angles_list.append(None)
                perc_pass_internal_steric_clash_list.append(None)
                perc_passed_internal_energy_list.append(None)
                perc_pass_dist_to_prot_list.append(None)
                perc_pass_dist_to_water_list.append(None)
                perc_pass_vol_overlap_with_prot_list.append(None)
        else:
            comments_list.append("Generation and/or analysis not run;")
            passed_list.append(None)
            perc_molgen_extra_pass_list.append(None)
            perc_molgen_MOSES_pass_list.append(None)
            perc_MOSES_MCF_pass_list.append(None)
            perc_MOSES_PAINS_pass_list.append(None)
            perc_MOSES_charge_pass_list.append(None)
            perc_pass_bond_lengths_list.append(None)
            perc_pass_bond_angles_list.append(None)
            perc_pass_internal_steric_clash_list.append(None)
            perc_passed_internal_energy_list.append(None)
            perc_pass_dist_to_prot_list.append(None)
            perc_pass_dist_to_water_list.append(None)
            perc_pass_vol_overlap_with_prot_list.append(None)
    
    final_df["prot_id"] = prot_ids
    final_df["task"] = task_ids
    final_df["comments"] = comments_list
    final_df["passed_task"] = passed_list
    final_df["perc_molgen_filter_pass"] = perc_molgen_extra_pass_list
    final_df["perc_MOSES_pass"] = perc_molgen_MOSES_pass_list
    final_df["perc_MOSES_MCF_pass"] = perc_MOSES_MCF_pass_list
    final_df["perc_MOSES_PAINS_pass"] = perc_MOSES_PAINS_pass_list
    final_df["perc_MOSES_charges_pass"] = perc_MOSES_charge_pass_list
    final_df["perc_pass_bond_lengths"] = perc_pass_bond_lengths_list
    final_df["perc_pass_bond_angles"] = perc_pass_bond_angles_list
    final_df["perc_pass_internal_steric_clash"] = perc_pass_internal_steric_clash_list
    final_df["perc_passed_internal_energy"] = perc_passed_internal_energy_list
    final_df["perc_pass_dist_to_prot"] = perc_pass_dist_to_prot_list
    final_df["perc_pass_dist_to_water"] = perc_pass_dist_to_water_list
    final_df["perc_pass_vol_overlap_with_prot"] = perc_pass_vol_overlap_with_prot_list
    
    return final_df
        

In [None]:
# EXAMPLE
final_df_diffsbdd_task3 = run_task3_analysis(task_outputs_dir="<dir_rof_summary_aggregated_and_spec_scores_for_DiffSBDD>", output_plots_dir="<dir_to_write_output_plots>")
#final_df_diffsbdd_task3.to_csv("<path_to_desired_csv>", index=False)


## TASK 4: POST-HOC ASSESSMENT

In [None]:
# TASK 1 OUTPUTS
task1_output_pocket2mol = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task1_Pocket2Mol>",
        inp_task_file="<input_files_to_csv_task1_Pocket2Mol>",
        model_name="Pocket2Mol", calc_all=False
    ).output_df

task1_output_diffsbdd = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task1_DiffSBDD>",
        inp_task_file="<input_files_to_csv_task1_DiffSBDD>",
        model_name="DiffSBDD", calc_all=False
    ).output_df

task1_output_ligbuilder = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task1_LigBuilder>",
        inp_task_file="<input_files_to_csv_task1_LigBuilder>",
        model_name="LigBuilderV3", calc_all=False
    ).output_df


task1_output_ag4 = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task1_AutoGrow4>",
        inp_task_file="<input_files_to_csv_task1_AutoGrow4",
        model_name="AutoGrow4", calc_all=False
    ).output_df

In [None]:
# TASK 2 OUTPUTS
task2_output_pocket2mol = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task2_Pocket2Mol>",
        inp_task_file="<input_files_to_csv_task2_Pocket2Mol>",
        model_name="Pocket2Mol", calc_all=False
    ).output_df

task2_output_diffsbdd = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task2_DiffSBDD>",
        inp_task_file="<input_files_to_csv_task2_DiffSBDD>",
        model_name="DiffSBDD", calc_all=False
    ).output_df

task2_output_ligbuilder = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task2_LigBuilder>",
        inp_task_file="<input_files_to_csv_task2_LigBuilder>",
        model_name="LigBuilderV3", calc_all=False
    ).output_df


task2_output_ag4 = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task2_AutoGrow4>",
        inp_task_file="<input_files_to_csv_task2_AutoGrow4",
        model_name="AutoGrow4", calc_all=False
    ).output_df

In [None]:
# TASK 3 OUTPUTS
task3_output_pocket2mol = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task3_Pocket2Mol>",
        inp_task_file="<input_files_to_csv_task3_Pocket2Mol>",
        model_name="Pocket2Mol", calc_all=False
    ).output_df

task3_output_diffsbdd = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task3_DiffSBDD>",
        inp_task_file="<input_files_to_csv_task3_DiffSBDD>",
        model_name="DiffSBDD", calc_all=False
    ).output_df

task3_output_ligbuilder = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task3_LigBuilder>",
        inp_task_file="<input_files_to_csv_task3_LigBuilder>",
        model_name="LigBuilderV3", calc_all=False
    ).output_df


task3_output_ag4 = combine_outputs.getOutputDf(
        output_dir="<output_dir_to_task3_AutoGrow4>",
        inp_task_file="<input_files_to_csv_task3_AutoGrow4",
        model_name="AutoGrow4", calc_all=False
    ).output_df

In [None]:
def sample_df_per_prot_group(df, n_samples):
    df_sampled = pd.DataFrame()

    df_grouped = df[['PDB ID', 'mol_pred']].groupby('PDB ID').count().reset_index()

    # For each PDB group, check if enough samples, otherwise get max possible
    for pdb in df_grouped["PDB ID"].tolist():
        pdb_df = df[df["PDB ID"] == pdb]

        if len(pdb_df) < n_samples:
            df_sample = pdb_df
        else:
            df_sample = pdb_df.sample(n=n_samples, random_state=42)

        df_sampled = pd.concat([df_sampled, df_sample])

    return df_sampled
    
def sample_tasks(task1_df, task2_df, task3_df, n_samples=50):
    """
    Parameters
    ----------
        task1_df, task1_df, task1_df: Pandas DataFrames
            Pandas DataFrame containing outputs for each task
        n_samples: integer
            Number of samples to get per task output
    """

    task1_df["Task"] = 1
    task2_df["Task"] = 2
    task3_df["Task"] = 3

    task1_df = task1_df[
        ~task1_df["mol_pred"].astype(str).str.contains("Errno")
    ]
    task2_df = task2_df[
        ~task2_df["mol_pred"].astype(str).str.contains("Errno")
    ]
    task3_df = task3_df[
        ~task3_df["mol_pred"].astype(str).str.contains("Errno")
    ]

    task1_df = sample_df_per_prot_group(task1_df, n_samples=n_samples)
    task2_df = sample_df_per_prot_group(task2_df, n_samples=n_samples)
    task3_df = sample_df_per_prot_group(task3_df, n_samples=n_samples)

    task_df = pd.concat([task1_df, task2_df])
    task_df = pd.concat([task_df, task3_df])

    return task_df


In [None]:
p2m_sampled_task_df = sample_tasks(task1_output_pocket2mol, task2_output_pocket2mol, task3_output_pocket2mol)
diffsbdd_sampled_task_df = sample_tasks(task1_output_diffsbdd, task2_output_diffsbdd, task3_output_diffsbdd)
ag4_sampled_task_df = sample_tasks(task1_output_ag4, task2_output_ag4, task3_output_ag4)
ligbuilder_sampled_task_df = sample_tasks(task1_output_ligbuilder, task2_output_ligbuilder, task3_output_ligbuilder)


In [None]:
from rdkit.Chem import AllChem, ChemicalFeatures, Crippen, EState, rdMolDescriptors
from rdkit import RDConfig
from pathlib import Path
import numpy as np
import collections
from collections import Counter

feature_factory = AllChem.BuildFeatureFactory(str(Path(RDConfig.RDDataDir) / "BaseFeatures.fdef"))

def get_pharmacophores(mol):
    features = [feature.GetFamily() for feature in feature_factory.GetFeaturesForMol(mol)]
    feature_frequency = collections.Counter(features)
    
    return feature_frequency

def get_tpsa_pharmacophores(input_df, model_name):
    # tPSA, H-bond donors, H-bond acceptors
    tpsa_list = []
    molecule_feature_frequencies = []
    
    pharmacophores = pd.DataFrame()
    
    for idx, row in input_df.iterrows():
        mol_pred = row["mol_pred"]
        
        try:
            mol = [mol for mol in Chem.SDMolSupplier(mol_pred) if mol != None][0]
            tpsa_contribs = np.mean(rdMolDescriptors._CalcTPSAContribs(mol)) # Take an average across all atoms
            tpsa_list.append(tpsa_contribs)
            
            molecule_feature_frequencies.append(get_pharmacophores(mol))
            
            
        except Exception as e:
            tpsa_list.append(None)
            print(e)
        
    input_df["tPSA"] = tpsa_list
    input_df["Model"] = model_name
    
    feature_frequencies_df = pd.DataFrame(molecule_feature_frequencies).fillna(0).astype(int)
    feature_frequencies_df["Model"] = model_name
    
    return input_df, feature_frequencies_df


In [None]:
binding_moad_data = pd.read_csv("<dir_to_repo>/SBDD-benchmarking/Benchmarking_Tasks/Task0/BindingMOAD_filtered_set_for_retraining_with_split_4_Pocket2Mol.csv")

ring_info = []
ring_sizes = []
num_aromatic_rings = []

for idx, row in binding_moad_data.iterrows():
    mol = Chem.MolFromSmiles(row["SMILES"])
    
    if mol is not None:
        # Ring info
        ringobj = mol.GetRingInfo()
        ring_info.append(ringobj.BondRings())
        ring_sizes.append(ringobj.NumRings())
        num_aromatic_rings.append(Chem.rdMolDescriptors.CalcNumAromaticRings(mol))
    else:
        ring_info.append(None)
        ring_sizes.append(None)
        num_aromatic_rings.append(None)
        
binding_moad_data["ring_info"] = ring_info
binding_moad_data["ring_size"] = ring_sizes
binding_moad_data["num_aromatic_rings"] = num_aromatic_rings

molecular_feat_freq_bindingmoad = []

for idx, row in binding_moad_data.iterrows():
    smiles = row["SMILES"]
    try:
        mol = Chem.MolFromSmiles(smiles)

        molecular_feat_freq_bindingmoad.append(get_pharmacophores(mol))
    except Exception as e:
        print(e)
    
bindingmoad_pharmacophores = pd.DataFrame(molecular_feat_freq_bindingmoad).fillna(0).astype(int)
bindingmoad_pharmacophores["Model"] = "BindingMOAD"

In [None]:
p2m_sampled_task_df_tpsa, p2m_pharmacophores = get_tpsa_pharmacophores(p2m_sampled_task_df, "Pocket2Mol")
diffsbdd_sampled_task_df_tpsa, diffsbdd_pharmacophores = get_tpsa_pharmacophores(diffsbdd_sampled_task_df, "DiffSBDD")
ag4_sampled_task_df_tpsa, ag4_pharmacophores = get_tpsa_pharmacophores(ag4_sampled_task_df, "AutoGrow4")
ligbuilder_sampled_task_df_tpsa, ligbuilder_pharmacophores = get_tpsa_pharmacophores(ligbuilder_sampled_task_df, "LigBuilderV3")

In [None]:
p2m_pharmacophores["Model"] = "Pocket2Mol"
diffsbdd_pharmacophores["Model"] = "DiffSBDD"
ag4_pharmacophores["Model"] = "AutoGrow4"
ligbuilder_pharmacophores["Model"] = "LigBuilderV3"


In [None]:
def get_counts(inp_list, minrange, maxrange):
    dict_counts = {}
    
    for i in range(minrange, maxrange):
        dict_counts[i] = inp_list.count(i)/len(inp_list)
        
    return dict_counts

donor_pharm_df = pd.DataFrame({'BindingMOAD': get_counts(bindingmoad_pharmacophores["Donor"].tolist(), 0, 10).values(),
                   'Pocket2Mol': get_counts(p2m_pharmacophores["Donor"].tolist(), 0, 10).values(),
                   'DiffSBDD': get_counts(diffsbdd_pharmacophores["Donor"].tolist(), 0, 10).values(),
                    'AutoGrow4': get_counts(ag4_pharmacophores["Donor"].tolist(), 0, 10).values(),
                    'LigBuilderV3': get_counts(ligbuilder_pharmacophores["Donor"].tolist(), 0, 10).values()},
                  index=get_counts(bindingmoad_pharmacophores["Donor"].tolist(), 0, 10).keys())

acceptor_pharm_df = pd.DataFrame({'BindingMOAD': get_counts(bindingmoad_pharmacophores["Acceptor"].tolist(), 0, 10).values(),
                   'Pocket2Mol': get_counts(p2m_pharmacophores["Acceptor"].tolist(), 0, 10).values(),
                   'DiffSBDD': get_counts(diffsbdd_pharmacophores["Acceptor"].tolist(), 0, 10).values(),
                    'AutoGrow4': get_counts(ag4_pharmacophores["Acceptor"].tolist(), 0, 10).values(),
                    'LigBuilderV3': get_counts(ligbuilder_pharmacophores["Acceptor"].tolist(), 0, 10).values()},
                  index=get_counts(bindingmoad_pharmacophores["Acceptor"].tolist(), 0, 10).keys())

hydrophobe_pharm_df = pd.DataFrame({'BindingMOAD': get_counts(bindingmoad_pharmacophores["Hydrophobe"].tolist(), 0, 10).values(),
                   'Pocket2Mol': get_counts(p2m_pharmacophores["Hydrophobe"].tolist(), 0, 10).values(),
                   'DiffSBDD': get_counts(diffsbdd_pharmacophores["Hydrophobe"].tolist(), 0, 10).values(),
                    'AutoGrow4': get_counts(ag4_pharmacophores["Hydrophobe"].tolist(), 0, 10).values(),
                    'LigBuilderV3': get_counts(ligbuilder_pharmacophores["Hydrophobe"].tolist(), 0, 10).values()},
                  index=get_counts(bindingmoad_pharmacophores["Hydrophobe"].tolist(), 0, 10).keys())

aromatic_pharm_df = pd.DataFrame({'BindingMOAD': get_counts(bindingmoad_pharmacophores["Aromatic"].tolist(), 0, 6).values(),
                   'Pocket2Mol': get_counts(p2m_pharmacophores["Aromatic"].tolist(), 0, 6).values(),
                   'DiffSBDD': get_counts(diffsbdd_pharmacophores["Aromatic"].tolist(), 0, 6).values(),
                    'AutoGrow4': get_counts(ag4_pharmacophores["Aromatic"].tolist(), 0, 6).values(),
                    'LigBuilderV3': get_counts(ligbuilder_pharmacophores["Aromatic"].tolist(), 0, 6).values()},
                  index=get_counts(bindingmoad_pharmacophores["Aromatic"].tolist(), 0, 6).keys())



In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(15,8))
ax1, ax2, ax3, ax4 = axes.flatten()

donor_pharm_df.plot(kind="bar", stacked=False, rot=0, ax=ax1, legend=False)
acceptor_pharm_df.plot(kind="bar", stacked=False, rot=0, ax=ax2, legend=True)
hydrophobe_pharm_df.plot(kind="bar", stacked=False, rot=0, ax=ax3, legend=False)
aromatic_pharm_df.plot(kind="bar", stacked=False, rot=0, ax=ax4, legend=False)

ax1.set_xlabel('Donor')
ax2.set_xlabel('Acceptor')
ax3.set_xlabel('Hydrophobe')
ax4.set_xlabel('Aromatic')

ax1.set_xlim((-0.5))
ax2.set_xlim((-0.5))
ax3.set_xlim((-0.5))
ax4.set_xlim((-0.5))

ax1.set_ylabel('Fraction')
ax3.set_ylabel('Fraction')
ax2.set_ylabel('')
ax4.set_ylabel('')

sns.move_legend(ax2, "upper right", bbox_to_anchor=(1.4, 1), frameon=False)

plt.tight_layout()
plt.show()

In [None]:
combined_props_models = pd.concat([p2m_sampled_task_df_tpsa,diffsbdd_sampled_task_df_tpsa])
combined_props_models = pd.concat([combined_props_models,ag4_sampled_task_df_tpsa])
combined_props_models = pd.concat([combined_props_models,ligbuilder_sampled_task_df_tpsa])
combined_props_models = combined_props_models.reset_index()


In [None]:
sns.set(rc={'figure.figsize':(10,8), 'legend.frameon':False})
sns.set_style("whitegrid", {"axes.edgecolor": ".0", "axes.facecolor":"none", 'axes.grid' : False})

sns.kdeplot(data=combined_props_models, x="tPSA", common_norm=True, hue="Model", legend=True)
plt.xlim((0))
plt.tight_layout()
plt.show()

In [None]:
class task4Properties:
    def __init__(self, task123_outdir, model_name):
        self.task123_outdir = task123_outdir
        self.model_name = model_name
        self.aggregated_files = [os.path.join(self.task123_outdir, i) for i in os.listdir(self.task123_outdir) if "aggregated" in i]
        
    def run(self):
        final_df = pd.DataFrame()
        
        for file in self.aggregated_files:
            try:
                df = pd.read_csv(file)
                df["Task"] = file.split("/")[-1].split(".csv")[0].split("_")[0]
                df["Task test ID"] = file.split("/")[-1].split(".csv")[0].split("aggregated_scores_")[-1]

                final_df = pd.concat([final_df, df])
            except:
                print("Not run for file", file)
                
        final_df = final_df.reset_index()
        final_df["Model"] = self.model_name
            
        return final_df

In [None]:
p2m_props = task4Properties(task123_outdir="<output_dir_to_task1_2_3>", model_name="Pocket2Mol").run()
diffsbdd_props = task4Properties(task123_outdir="<output_dir_to_task1_2_3>", model_name="DiffSBDD").run()
ag4_props = task4Properties(task123_outdir="<output_dir_to_task1_2_3>", model_name="AutoGrow4").run()
ligbuilder_props = task4Properties(task123_outdir="<output_dir_to_task1_2_3>", model_name="LigBuilderV3").run()

combined_props_models = pd.concat([p2m_props,diffsbdd_props])
combined_props_models = pd.concat([combined_props_models,ag4_props])
combined_props_models = pd.concat([combined_props_models,ligbuilder_props])
combined_props_models = combined_props_models.reset_index()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(10,8))
ax1, ax2, ax3, ax4 = axes.flatten()

sns.set(rc={'figure.figsize':(10,8), 'legend.frameon':False})
sns.set_style("whitegrid", {"axes.edgecolor": ".0", "axes.facecolor":"none", 'axes.grid' : False})

sns.kdeplot(data=combined_props_models, x="moses_logp_div", common_norm=True, hue="Model", ax=ax1, legend=False)
sns.kdeplot(data=combined_props_models, x="moses_sa_div", common_norm=True, hue="Model", ax=ax2, legend=True)
sns.kdeplot(data=combined_props_models, x="moses_qed_div", common_norm=True, hue="Model", ax=ax3, legend=False)
sns.kdeplot(data=combined_props_models, x="moses_mw_div", common_norm=True, hue="Model", ax=ax4, legend=False)

ax1.set_xlabel('logP Wasserstein distance')
ax2.set_xlabel('SAScore Wasserstein distance')
ax3.set_xlabel('QED Wasserstein distance')
ax4.set_xlabel('MW Wasserstein distance')

ax1.set_xlim((0))
ax2.set_xlim((0))
ax3.set_xlim((0))
ax4.set_xlim((0))

ax1.set_ylabel('Density')
ax3.set_ylabel('Density')
ax2.set_ylabel('')
ax4.set_ylabel('')

sns.move_legend(ax2, "upper right", bbox_to_anchor=(1.6, 1))

plt.xlim((0))
plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,6))
sns.set(rc={'figure.figsize':(8,6), 'legend.frameon':False})
sns.set_style("whitegrid", {"axes.edgecolor": ".0", "axes.facecolor":"none", 'axes.grid' : False})

sns.kdeplot(data=combined_props_models, x="ave_nSPS", common_norm=True, hue="Model")
plt.xlabel("Average nSPS")
plt.show()
