# Run Pocket2Mol on a PDB file

colab by [@btnaughton](https://twitter.com/btnaughton)

In [None]:
#@title PDB + SMILES input

PDB_ID = '' #@param {type:"string"}
LIG_ID = '' #@param {type:"string"}
CHAIN = '' #@param {type:"string"}

# markdown Download a tar file containing all results?
# download_results = True #@param {type:"boolean"}

## Install prerequisites

In [None]:
!pip install ipython-autotime --quiet
%load_ext autotime
!pip install rdkit biopython pyyaml easydict tensorboard lmdb gdown prody pypdb --quiet #  replaced python-lmdb with lmdb

In [None]:
import torch
print(f"torch version {torch.__version__}")
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet
!pip install git+https://github.com/pyg-team/pytorch_geometric.git  --quiet


## Install Pocket2Mol

In [None]:
!git clone https://github.com/pengxingang/Pocket2Mol --quiet

In [None]:
%cd /content/Pocket2Mol/data
!test -f crossdocked_pocket10.tar.gz || gdown 10KGuj15mxOJ2FBsduun2Lggzx0yPreEU && tar -xzf crossdocked_pocket10.tar.gz
!test -f split_by_name.pt || gdown 1mycOKpphVBQjxEbpn1AwdpQs8tNVbxKY
!ls -l

In [None]:
%cd /content/Pocket2Mol/ckpt
!test -f pretrained_Pocket2Mol.pt || gdown 1WaoEj9RDG4VEcyHEmgsjbh958txm1W6x
!ls -l

## Utility PDB functions

In [None]:
import os
import requests
import time
from random import random

def download_pdb_file(pdb_id: str) -> str:
    """Download pdb file as a string from rcsb.org"""
    PDB_DIR ="/tmp/pdb/"
    os.makedirs(PDB_DIR, exist_ok=True)

    # url or pdb_id
    if pdb_id.startswith('http'):
        url = pdb_id
        filename = url.split('/')[-1]
    elif pdb_id.endswith(".pdb"):
        return pdb_id
    else:
        if pdb_id.startswith("AF"):
            url = f"https://alphafold.ebi.ac.uk/files/{pdb_id}-model_v3.pdb"
        else:
            url = f"http://files.rcsb.org/view/{pdb_id}.pdb"
        filename = f'{pdb_id}.pdb'

    cache_path = os.path.join(PDB_DIR, filename)
    if os.path.exists(cache_path):
        return cache_path

    pdb_req = requests.get(url)
    pdb_req.raise_for_status()
    open(cache_path, 'w').write(pdb_req.text)
    return cache_path

In [None]:
from io import StringIO
import os
import sys
from typing import Iterable

import pandas as pd
from prody import parsePDB, writePDB, writePDBStream
from rdkit import Chem
from rdkit.Chem import AllChem
import requests


LIGAND_EXPO_FILENAME = "Components-smiles-stereo-oe.smi"
LIGAND_EXPO_URL = f"http://ligand-expo.rcsb.org/dictionaries/{LIGAND_EXPO_FILENAME}"

def _read_ligand_expo():
    """
    Read Ligand Expo data, try to find a file called
    Components-smiles-stereo-oe.smi in the current directory.
    If you can't find the file, grab it from the RCSB (archived in gs://hx-brian 2023-06-11)
    :return: Ligand Expo as a dictionary with ligand id as the key
    """
    if not os.path.exists(LIGAND_EXPO_FILENAME):
        with open(LIGAND_EXPO_FILENAME, 'wb') as out:
            r = requests.get(LIGAND_EXPO_URL, allow_redirects=True)
            out.write(r.content)

    df = pd.read_csv(LIGAND_EXPO_FILENAME, sep="\t",
                     header=None,
                     names=["SMILES", "ID", "Name"])

    df.set_index("ID", inplace=True)

    return df.to_dict()


def _get_pdb_components(pdb_id):
    """
    Split a protein-ligand pdb into protein and ligand components
    :param pdb_id:
    :return:
    """
    pdb = parsePDB(pdb_id)
    protein = pdb.select('protein')
    ligand = pdb.select('not protein and not water')
    return protein, ligand


def _process_ligand(ligand, res_name, expo_dict,
                    chain=None):
    """
    Add bond orders to a pdb ligand
    1. Select the ligand component with name "res_name"
    2. Get the corresponding SMILES from the Ligand Expo dictionary
    3. Create a template molecule from the SMILES in step 2
    4. Write the PDB file to a stream
    5. Read the stream into an RDKit molecule
    6. Assign the bond orders from the template from step 3
    :param ligand: ligand as generated by prody
    :param res_name: residue name of ligand to extract
    :param expo_dict: dictionary with LigandExpo
    :return: molecule with bond orders assigned
    """

    # If you include all chains then the SDF includes multiple molecules
    # and it looks messed up
    if chain is None:
        print("No chain given, defaulting to chain A. "
              "Not specifying a chain can result in multiple molecules combined into one SDF file", file=sys.stderr)
        chain = "A"

    output = StringIO()
    sub_mol = ligand.select(f"resname {res_name} and chain {chain}")
    if sub_mol is None:
        print(f"sub_mol is None for {res_name}")
        return None

    sub_smiles = expo_dict['SMILES'][res_name]
    print("smiles:", sub_smiles, file=sys.stderr)

    template = AllChem.MolFromSmiles(sub_smiles)
    if template is None:
        print(f"template is None for {sub_smiles}. Returning None.", file=sys.stderr)
        return None

    writePDBStream(output, sub_mol)
    pdb_string = output.getvalue()
    rd_mol = AllChem.MolFromPDBBlock(pdb_string)
    new_mol = AllChem.AssignBondOrdersFromTemplate(template, rd_mol)

    return new_mol, sub_smiles


def _write_pdb(protein, pdb_name,
               output_pdb_name=None):
    """
    Write a prody protein to a pdb file
    :param protein: protein object from prody
    :param pdb_name: base name for the pdb file
    :return: None
    """
    output_pdb_name = output_pdb_name or f"{pdb_name}_protein.pdb"
    writePDB(f"{output_pdb_name}", protein)
    print(f"wrote pdb: {output_pdb_name}")
    return output_pdb_name


def _write_sdf(new_mol, pdb_name:str, res_name:str,
               output_sdf_name:str|None=None) -> str:
    """
    Write an RDKit molecule to an SD file
    :param new_mol:
    :param pdb_name:
    :param res_name:
    :return:
    """
    output_sdf_name = output_sdf_name or f"{pdb_name}_{res_name}_ligand.sdf"
    writer = Chem.SDWriter(f"{output_sdf_name}")
    writer.write(new_mol)
    print(f"wrote ligand sdf: {output_sdf_name}\n")
    return output_sdf_name


def extract_ligands(pdb_name:str,
                    ligand_names:Iterable[str]|None=None,
                    chains:Iterable[str]|None=None,
                    output_pdb_name:str|None=None,
                    output_sdf_name:str|None=None) -> tuple[str, list[str], list[str]]:
    """
    Read Ligand Expo data, split pdb into protein and ligands,
    write protein pdb, write ligand sdf files
    :param pdb_name: id from the pdb, doesn't need to have an extension
    :return:
    """
    if chains is not None:
        assert ligand_names is not None, "chains requires ligand_names"
        assert len(chains) == len(ligand_names), "chains and ligand_names must be the same length"

    # ----------------------------
    # First write out protein part
    #
    df_dict = _read_ligand_expo()
    protein_sel, ligand_sel = _get_pdb_components(pdb_name)
    # write out the pdb with no ligands
    out_pdb_file = _write_pdb(protein_sel, pdb_name, output_pdb_name=output_pdb_name)

    # ----------------------------
    # Then write out ligands
    #
    res_name_list = list(set(ligand_sel.getResnames()))
    out_sdf_files = []
    out_sdf_smiles = []

    for res_name in res_name_list:
        if ligand_names is not None and res_name not in ligand_names:
            continue

        if chains is not None:
            chain = chains[ligand_names.index(res_name)]
        else:
            chain = None

        new_mol, new_mol_smiles = _process_ligand(ligand_sel, res_name, df_dict, chain)
        if new_mol is None:
            print(f"_process_ligand failed for {res_name}. Skipping")
            continue

        out_sdf_files.append(_write_sdf(new_mol, pdb_name, res_name, output_sdf_name=output_sdf_name))
        out_sdf_smiles.append(new_mol_smiles)

    return out_pdb_file, out_sdf_files, out_sdf_smiles


def extract_ligand(pdb_name:str, ligand_name:str,
                   chain=None,
                   output_pdb_name:str|None=None,
                   output_sdf_name:str|None=None) -> tuple[str, str, str]:
    """extract_ligands wrapper for a single ligand"""
    out_pdb_file, out_sdf_files, out_sdf_smileses = extract_ligands(pdb_name, [ligand_name],
                                                                    [chain] if chain is not None else None,
                                                                    output_pdb_name,
                                                                    output_sdf_name)

    # add a title
    for out_sdf_file, out_sdf_smiles in zip(out_sdf_files, out_sdf_smileses):
      lines = open(out_sdf_file).readlines()
      lines[0] = f"{ligand_name}\t{out_sdf_smiles}\n"
      open(out_sdf_file, 'w').write(''.join(lines))

    return out_pdb_file, out_sdf_files[0], out_sdf_smileses[0]


def extract_all_ligands(pdb_name, lig_id=None):
    """
    Read Ligand Expo data, split pdb into protein and ligands,
    write protein pdb, write ligand sdf files
    :param pdb_name: id from the pdb, doesn't need to have an extension
    :return:
    """
    protein, ligand = _get_pdb_components(pdb_name)
    output_pdb_name = _write_pdb(protein, pdb_name)

    res_name_list = list(set(ligand.getResnames()))
    df_dict = _read_ligand_expo()

    output_sdf_names = []
    for res in res_name_list:
        if lig_id is not None and res != lig_id: continue

        print(ligand, res, str(df_dict)[:20])
        new_mol, new_mol_smiles = _process_ligand(ligand, res, df_dict)
        print("new_mol", new_mol)
        output_sdf_name = _write_sdf(new_mol, pdb_name, res)
        # add a title
        lines = open(output_sdf_name).readlines()
        lines[0] = f"{res}\t{new_mol_smiles}\n"
        open(output_sdf_name, 'w').write(''.join(lines))
        output_sdf_names.append(output_sdf_name)

    return output_pdb_name, output_sdf_names

In [None]:
%cd /content/Pocket2Mol

if PDB_ID == '':
  print("## Running example: ")
  PDB_ID = "7S15"
  LIG_ID = "82L"
  CHAIN = "R"

pdb_file = download_pdb_file(PDB_ID)
pdb_nohet_file, pdb_het_file, pdb_het_smiles = extract_ligand(pdb_file, LIG_ID, chain=CHAIN)

Extract centroid from SDF file

In [None]:
import numpy as np

all_coords = []
for line in open(pdb_het_file).readlines()[4:]:
  if len(line.split()) < 4: continue
  *coords_str, atom = line.split()[:4]
  if atom == "C":
    all_coords.append([float(x) for x in coords_str])

sdf_centroid = np.array(all_coords).mean(axis=0)
print(sdf_centroid)

In [None]:
%cd /content/Pocket2Mol

# example
#!python sample_for_pdb.py --pdb_path "4yhj.pdb" --center " 32.0,28.0,36.0" # e.g., 4yhj

# replace seed with new seed every time
from random import randint
num_samples = 100
!sed -i -E s/"seed: [0-9]+"/"seed: {randint(1, 10000)}"/ configs/sample_for_pdb.yml
!sed -i -E s/"num_samples: [0-9]+"/"num_samples: {num_samples}"/ configs/sample_for_pdb.yml

centroid_str = '" ' + ','.join(str(x) for x in sdf_centroid) + '"'
!python sample_for_pdb.py --pdb_path {pdb_nohet_file} --center {centroid_str}

## install and run gnina to get affinities

In [None]:
%cd /content/Pocket2Mol
!wget https://sourceforge.net/projects/smina/files/smina.static/download --quiet -O smina && chmod +x smina
!wget https://github.com/gnina/gnina/releases/download/v1.0.3/gnina --quiet -O gnina && chmod +x gnina

In [None]:
import glob
from itertools import chain as ichain
import re

import pandas as pd
from tqdm.auto import tqdm

from rdkit import Chem
from rdkit.Chem import Descriptors

rows = []
output_dir = sorted(glob.glob("outputs/sample_for_pdb*"))[-1]
smileses = [l.strip() for l in open(f"{output_dir}/SMILES.txt")]

for sdf_file in tqdm(list(ichain([pdb_het_file], glob.glob(f"{output_dir}/SDF/*.sdf")))):

  if "_ligand" in sdf_file:
    print(sdf_file)
    _, smiles = open(sdf_file).readlines()[0].strip().split('\t')
  else:
    smiles_num = int(re.findall(fr"{output_dir}/SDF/(.+)\.sdf", sdf_file)[0])
    smiles = smileses[smiles_num]

  mol_wt = Descriptors.ExactMolWt(Chem.MolFromSmiles(smiles))

  scored_stdout = !/content/Pocket2Mol/gnina --score_only -r "{pdb_nohet_file}" -l "{sdf_file}"
  scored_affinity = re.findall(r"Affinity:\s*([\-\.\d+]+)", '\n'.join(scored_stdout))[0]
  minimized_stdout = !/content/Pocket2Mol/gnina --local_only --minimize -r "{pdb_nohet_file}" -l "{sdf_file}" --autobox_ligand "{sdf_file}" --autobox_add 2
  minimized_affinity = re.findall(r"Affinity:\s*([\-\.\d+]+)", '\n'.join(minimized_stdout))[0]
  rows.append((pdb_file.split('/')[-1], sdf_file.split('/')[-1], smiles, float(scored_affinity), float(minimized_affinity), mol_wt))


df_aff = (pd.DataFrame(rows, columns=["pdb", "sdf", "smiles", "scored_affinity", "minimized_affinity", "mol_wt"])
            .assign(scored_lig_eff = lambda df: df.scored_affinity / df.mol_wt)
            .assign(minimized_lig_eff = lambda df: df.scored_affinity / df.mol_wt)
            .sort_values("minimized_lig_eff")
)

with pd.option_context('display.max_colwidth', None, 'display.max_rows', None, 'display.max_columns', None):
  display(df_aff.head(10))

## Visualize top hit with Py3DMol

In [None]:
!pip install py3dmol --upgrade

In [None]:
top_hit = df_aff.sort_values("minimized_affinity").loc[lambda df: ~df.sdf.str.contains("_ligand")].iloc[0]
display(pd.DataFrame(top_hit))
top_sdf_file = f'{output_dir}/SDF/{top_hit.sdf}'

In [None]:
import py3Dmol

resid_hover = """
function(atom,viewer) {
    if(!atom.label) {
        atom.label = viewer.addLabel(atom.chain+" "+atom.resn+" "+atom.resi,
            {position: atom, backgroundColor: 'mintcream', fontColor:'black', fontSize:12});
    }
}"""
unhover_func = """
function(atom,viewer) {
    if(atom.label) {
        viewer.removeLabel(atom.label);
        delete atom.label;
    }
}"""

view = py3Dmol.view(width=800, height=800)
view.setCameraParameters({'fov': 35, 'z': 100});

# top hit for any pdb file and any smiles
#top_hit = df_results.sort_values("diffdock_confidence", ascending=False).iloc[0]
#print("top hit:")
#display(top_hit)

# add sdf
view.addModel(open(top_sdf_file).read(), "sdf")
view.setStyle({"model": 0}, {'stick':{"color":"#ff0000"}})
view.setViewStyle({"model": 0}, {'style':'outline','color':'black','width':0.1})
view.zoomTo();

# add pdb
view.addModel(open(pdb_file).read(), "pdb");
view.setStyle({"model": 1}, {"cartoon":{"color":"spectrum"}})
view.setStyle({"model": 1, "hetflag":True}, {'stick':{"color":"spectrum"}})

model = view.getModel()
model.setHoverable({}, True, resid_hover, unhover_func)

view