# Benchmarking docking and scoring methods with PLEX
## Overview
In this notebook we are running two docking models against the PDBBind benchmark. 
* equibind
* diffdock

We compare the performance of each method using the commonly used RSMD metric for 3D ligand position. 
Taking these models a step further, we combine their pose prediction capability with existing, ML-based, scoring functions, such as ODDT.

## Requirements
In order to run this notebook, you will need: 
* PLEX installed on your device
* PDBBind benchmark data downloaded from [Stärk et al.](https://zenodo.org/record/6408497)
* PDBBind affinity data downloaded from the official [website](https://pdbbind.oss-cn-hangzhou.aliyuncs.com/download/PDBbind_v2020_plain_text_index.tar.gz)

## Learn more 
Head to our [docs](docs.labdao.xyz) to learn more about how to install, use, and contribute to PLEX.

## PLEX setup

In [None]:
import os
import sys
import importlib

# this can disapear once plex is a pip package
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import plex.sdk
importlib.reload(plex.sdk)

os.environ["PLEX_ACCESS_TOKEN"] = "mellon"
os.environ["PLEX_ENV"] = "stage"

## Running equibind
### Generating IO objects for the PDBBind benchmark

In [None]:
import csv
import os
import json

def create_pdbind_io_dict(csv_path):
    io_data = []
    
    with open(csv_path, 'r') as csvfile:
        csvreader = csv.DictReader(csvfile)
        
        for row in csvreader:
            protein_path = os.path.join("/home/ubuntu/", row['protein_path'])
            ligand_path = os.path.join("/home/ubuntu/", row['ligand_description'])
            
            if not os.path.exists(protein_path) or not os.path.exists(ligand_path):
                print(f"Skipping row {row['complex_name']} due to missing file(s).")
                continue
            
            entry = {
                "tool": "tools/equibind.json",
                "inputs": {
                    "protein": {
                        "class": "File",
                        "filepath": protein_path
                    },
                    "small_molecule": {
                        "class": "File",
                        "filepath": ligand_path
                    }
                },
                "outputs": {
                    "best_docked_small_molecule": {
                        "class": "File",
                        "filepath": ""
                    },
                    "protein": {
                        "class": "File",
                        "filepath": ""
                    }
                },
                "state": "created",
                "errMsg": ""
            }
            
            io_data.append(entry)
    
    return io_data

# Example usage
csv_path = '/home/ubuntu/datasets/diffdock_testdata.csv'
io_sig = create_pdbind_io_dict(csv_path)


In [None]:
from plex.sdk import run_plex

run_plex(io_sig, concurrency=2); # remove semicolon to display outputs

### Run statistics

In [None]:
# generating statistics on the success rate of the runs
import json
import pandas as pd

def get_state_counts(json_filepath):
    # Load the JSON data from the file
    with open(json_filepath, 'r') as f:
        data = json.load(f)
    
    # Extract the "state" and "errMsg" values from each JSON object
    state_errMsg_list = [{'state': item['state'], 'errMsg': item['errMsg']} for item in data]
    
    # Convert the list of dictionaries to a Pandas DataFrame
    df = pd.DataFrame(state_errMsg_list)
    
    # Count the occurrences of each unique "state" and "errMsg" combination
    counts_df = df.groupby(['state', 'errMsg']).size().reset_index(name='count')
    
    return counts_df, df

# Example usage
json_filepath = '/home/ubuntu/plex/0e1b24c5-870e-4a58-9b61-a302cecbbcd0/io.json'
state_counts_df, complete_df = get_state_counts(json_filepath)
print(state_counts_df)

In [None]:
complete_df[complete_df['state'] == 'failed']

### Resubmitting failed tasks

In [None]:
def resubmit_failed_states(json_filepath):
    # Load the JSON data from the file
    with open(json_filepath, 'r') as f:
        data = json.load(f)
    
    # Filter the JSON list to include only entries with a failed state
    failed_entries = [entry for entry in data if entry['state'] == 'failed']
    
    # Create the io_sig object for each failed entry
    io_sig = []
    for entry in failed_entries:
        # Extract the relevant information from the JSON entry
        tool = entry['tool']
        inputs = entry['inputs']
        outputs = entry['outputs']
        state = 'created'  # Set the state to 'created' for resubmission
        errMsg = ''
        
        # Create a new entry for the io_sig object
        new_entry = {
            'tool': tool,
            'inputs': inputs,
            'outputs': outputs,
            'state': state,
            'errMsg': errMsg
        }
        
        # Append the new entry to the io_sig object
        io_sig.append(new_entry)
    
    return io_sig

# Example usage
json_filepath = '/home/ubuntu/plex/0e1b24c5-870e-4a58-9b61-a302cecbbcd0/io.json'
io_sig = resubmit_failed_states(json_filepath)


In [None]:
from plex.sdk import run_plex

run_plex(io_sig, concurrency=6)

In [None]:
print(complete_df)

In [None]:
run_plex(io_sig, concurrency=6)

## Benchmarking predicted Binding Pose

## Benchmarking predicted Binding Affinity

### Preparing PDBBind Affinity data

In [41]:
# 
!wget https://bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link/INDEX_general_PL_data.2020

--2023-05-04 04:16:35--  https://bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link/INDEX_general_PL_data.2020
Resolving bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link (bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link)... 209.94.90.1, 2602:fea2:2::1
Connecting to bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link (bafybeicl4suczrx7ql2aayegfz6fjg4fs5kplkl2eaj4n4ieheimscrzoi.ipfs.dweb.link)|209.94.90.1|:443... ^C
     PDB_code Kd/Ki
3zzf     2.20    //
3gww     2.46    //
1w8l     1.80    //
3fqa     2.35    //
1zsb     2.00    //




  df = pd.read_csv(plain_text_file, delim_whitespace=True, skiprows=6, header=None, names=column_names, error_bad_lines=False)
b'Skipping line 2430: expected 8 fields, saw 9\nSkipping line 3421: expected 8 fields, saw 9\nSkipping line 3508: expected 8 fields, saw 9\nSkipping line 8591: expected 8 fields, saw 9\nSkipping line 9201: expected 8 fields, saw 9\nSkipping line 17133: expected 8 fields, saw 9\nSkipping line 17383: expected 8 fields, saw 9\nSkipping line 17434: expected 8 fields, saw 9\nSkipping line 17850: expected 8 fields, saw 9\nSkipping line 18069: expected 8 fields, saw 9\nSkipping line 18293: expected 8 fields, saw 9\nSkipping line 18306: expected 8 fields, saw 9\nSkipping line 18368: expected 8 fields, saw 9\nSkipping line 18749: expected 8 fields, saw 9\nSkipping line 19393: expected 8 fields, saw 9\n'


In [42]:
plain_text_file = 'INDEX_general_PL_data.2020'

# Read the plain text file into a pandas DataFrame
column_names = ["PDB_code", "resolution", "release_year", "-logKd/Ki", "Kd/Ki", "reference", "ligand_name"]
df = pd.read_csv(plain_text_file, delim_whitespace=True, skiprows=6, header=None, names=column_names, error_bad_lines=False)

# Select the first and fifth columns
selected_columns = df[["PDB_code", "Kd/Ki"]]

# Display the selected data
print(df.head())

     PDB_code  resolution  release_year   -logKd/Ki Kd/Ki reference  \
3zzf     2.20        2012          0.40    Ki=400mM    //  3zzf.pdf   
3gww     2.46        2009          0.45  IC50=355mM    //  3gwu.pdf   
1w8l     1.80        2004          0.49    Ki=320mM    //  1w8l.pdf   
3fqa     2.35        2009          0.49  IC50=320mM    //  3fq7.pdf   
1zsb     2.00        1996          0.60    Kd=250mM    //  1zsb.pdf   

     ligand_name  
3zzf       (NLG)  
3gww       (SFX)  
1w8l       (1P3)  
3fqa   (GAB&PMP)  
1zsb       (AZM)  




  df = pd.read_csv(plain_text_file, delim_whitespace=True, skiprows=6, header=None, names=column_names, error_bad_lines=False)
b'Skipping line 2430: expected 8 fields, saw 9\nSkipping line 3421: expected 8 fields, saw 9\nSkipping line 3508: expected 8 fields, saw 9\nSkipping line 8591: expected 8 fields, saw 9\nSkipping line 9201: expected 8 fields, saw 9\nSkipping line 17133: expected 8 fields, saw 9\nSkipping line 17383: expected 8 fields, saw 9\nSkipping line 17434: expected 8 fields, saw 9\nSkipping line 17850: expected 8 fields, saw 9\nSkipping line 18069: expected 8 fields, saw 9\nSkipping line 18293: expected 8 fields, saw 9\nSkipping line 18306: expected 8 fields, saw 9\nSkipping line 18368: expected 8 fields, saw 9\nSkipping line 18749: expected 8 fields, saw 9\nSkipping line 19393: expected 8 fields, saw 9\n'
