# Uni-Mol Pocket Representation

**Licenses**

Copyright (c) DP Technology.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

**Citations**

Please cite the following papers if you use this notebook:

- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. "[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)"
ChemRxiv (2022)

### Download pretrained pocket weights, and CASF-2016 data.

In [None]:
%%bash
pocket_data_url='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/CASF-2016.tar.gz'
pocket_weight_url='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt'
wget -q ${pocket_data_url}
tar -xzf "CASF-2016.tar.gz"
wget -q ${pocket_weight_url}

### Read pocket information from CASF-2016 and save it to a .lmdb file

In [None]:
import os
import pandas as pd
import lmdb
from biopandas.pdb import PandasPdb
from tqdm import tqdm
import pickle
import re
import json
import glob

In [None]:
CASF_PATH = "CASF-2016"
main_atoms = ["N", "CA", "C", "O", "H"]

def load_from_CASF(pdb_id):
    try:
        pdb_path = os.path.join(CASF_PATH, "casf2016", pdb_id + "_protein.pdb")
        pmol = PandasPdb().read_pdb(pdb_path)
        pocket_residues = json.load(
            open(os.path.join(CASF_PATH, "casf2016.pocket.json"))
        )[pdb_id]
        return pmol, pocket_residues
    except:
        print("Currently not support parsing pdb and pocket info from local files.")

def normalize_atoms(atom):
    return re.sub("\d+", "", atom)

def parser(pdb_id):
    pmol, pocket_residues = load_from_CASF(pdb_id)
    pname = pdb_id
    pro_atom = pmol.df["ATOM"]
    pro_hetatm = pmol.df["HETATM"]

    pro_atom["ID"] = pro_atom["chain_id"].astype(str) + pro_atom[
        "residue_number"
    ].astype(str)
    pro_hetatm["ID"] = pro_hetatm["chain_id"].astype(str) + pro_hetatm[
        "residue_number"
    ].astype(str)

    pocket = pd.concat(
        [
            pro_atom[pro_atom["ID"].isin(pocket_residues)],
            pro_hetatm[pro_hetatm["ID"].isin(pocket_residues)],
        ],
        axis=0,
        ignore_index=True,
    )

    pocket["normalize_atom"] = pocket["atom_name"].map(normalize_atoms)
    pocket = pocket[pocket["normalize_atom"] != ""]
    patoms = pocket["atom_name"].apply(normalize_atoms).values.tolist()
    pcoords = [pocket[["x_coord", "y_coord", "z_coord"]].values]
    side = [0 if a in main_atoms else 1 for a in patoms]
    residues = (
        pocket["chain_id"].astype(str) + pocket["residue_number"].astype(str)
    ).values.tolist()

    return pickle.dumps(
        {
            "atoms": patoms,
            "coordinates": pcoords,
            "side": side,
            "residue": residues,
            "pdbid": pname,
        },
        protocol=-1,
    )

def write_lmdb(pdb_id_list, job_name, outpath="./results"):
    os.makedirs(outpath, exist_ok=True)
    outputfilename = os.path.join(outpath, job_name + ".lmdb")
    try:
        os.remove(outputfilename)
    except:
        pass
    env_new = lmdb.open(
        outputfilename,
        subdir=False,
        readonly=False,
        lock=False,
        readahead=False,
        meminit=False,
        max_readers=1,
        map_size=int(10e9),
    )
    txn_write = env_new.begin(write=True)
    for i, pdb_id in tqdm(enumerate(pdb_id_list)):
        inner_output = parser(pdb_id)
        txn_write.put(f"{i}".encode("ascii"), inner_output)
    txn_write.commit()
    env_new.close()

In [None]:
job_name = 'get_pocket_repr'   # replace to your custom name
data_path = './results'  # replace to your data path
weight_path='pocket_pre_220816.pt'  # replace to your ckpt path
only_polar=0  # no h
dict_name='dict_coarse.txt'
batch_size=16
results_path=data_path   # replace to your save path
casf_collect = os.listdir(os.path.join(CASF_PATH, "casf2016"))
casf_collect = list(set([item[:4] for item in casf_collect]))
casf_collect.remove('3qgy')
write_lmdb(casf_collect, job_name=job_name, outpath=data_path)

### Infer from pretrained pocket ckpt

In [None]:
# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES="0" before the command.
!cp ../example_data/pocket/$dict_name $data_path
!CUDA_VISIBLE_DEVICES="0" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \
       --results-path $results_path \
       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
       --task unimol_pocket --loss unimol_infer --arch unimol_base \
       --path $weight_path \
       --dict-name $dict_name \
       --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer

### Read .pkl and save results to .csv

In [None]:
def get_csv_results(predict_path, results_path):
    predict = pd.read_pickle(predict_path)
    pdb_id_list, mol_repr_list, atom_repr_list, pair_repr_list = [], [], []
    for batch in predict:
        sz = batch["bsz"]
        for i in range(sz):
            pdb_id_list.append(batch["data_name"][i])
            mol_repr_list.append(batch["mol_repr_cls"][i])
            atom_repr_list.append(batch['atom_repr'][i])
            pair_repr_list.append(batch["pair_repr"][i])
    predict_df = pd.DataFrame({"pdb_id": pdb_id_list, "mol_repr": mol_repr_list, "atom_repr": atom_repr_list, "pair_repr": pair_repr_list})
    print(predict_df.head(1),predict_df.info())
    predict_df.to_csv(results_path+'/mol_repr.csv',index=False)

pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]
get_csv_results(pkl_path, results_path)