In [28]:
import pickle

import argparse
import gzip
import multiprocessing as mp
import os
import pickle
import random

import lmdb
import numpy as np
import pandas as pd
import rdkit
import rdkit.Chem.AllChem as AllChem
import torch
from tqdm import tqdm
from biopandas.mol2 import PandasMol2
from biopandas.pdb import PandasPdb
from rdkit import Chem, RDLogger
from rdkit.Chem.MolStandardize import rdMolStandardize


def write_lmdb(data, lmdb_path):
    #resume

    env = lmdb.open(lmdb_path, subdir=False, readonly=False, lock=False, readahead=False, meminit=False, map_size=1099511627776)
    num = 0
    with env.begin(write=True) as txn:
        for d in tqdm(data):
            txn.put(str(num).encode('ascii'), pickle.dumps(d))
            num += 1

def read_lmdb(lmdb_path):
    env = lmdb.open(
        lmdb_path,
        subdir=False,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False,
        max_readers=256,
    )
    txn = env.begin()
    keys = list(txn.cursor().iternext(values=False))
    out_list = []
    i=0
    count=0
    for idx in tqdm(keys):
        datapoint_pickled = txn.get(idx)
        data = pickle.loads(datapoint_pickled)
        out_list.append(data)
        
    env.close()
    return out_list, keys

In [10]:
from rdkit.Chem import GetPeriodicTable

# Create an instance of the periodic table
pt = GetPeriodicTable()

def get_element_name(element_list):
    element_name = [pt.GetElementSymbol(i) for i in element_list]
    return element_name

print(get_element_name([1, 6, 8, 7, 1, 6, 8, 7]))

['H', 'C', 'O', 'N', 'H', 'C', 'O', 'N']


In [None]:
path = "/nfs/data/targetdiff_data/crossdocked_v1.1_rmsd1.0_pocket10_processed_final-001.lmdb"
data, keys = read_lmdb(path)
import pickle
split_path = "/nfs/data/targetdiff_data/crossdocked_pocket10_pose_split.pt"
split = torch.load(split_path)
print(max(split["train"]))

train_data = []
test_data = []

for idx in split["train"]:
    train_data.append(data[idx])
for idx in split["test"]:
    test_data.append(data[idx])

print(len(train_data))
print(len(test_data))

write_lmdb(train_data, "/nfs/data/targetdiff_data/train.lmdb")
write_lmdb(test_data, "/nfs/data/targetdiff_data/test.lmdb")
    

In [34]:
from rdkit.Chem import GetPeriodicTable
pt = GetPeriodicTable()
train_path = "/nfs/data/targetdiff_data/train.lmdb"

train_data, keys = read_lmdb(train_path)

new_train = []

for data in train_data:
    pocket_atoms = data["protein_atom_name"]
    pocket_coordinates = data["protein_pos"]
    ligand_atom_atomics = data["ligand_element"].detach().cpu().numpy()
    ligand_atom_atomics = [int(i) for i in ligand_atom_atomics]
    ligand_atoms = [pt.GetElementSymbol(i) for i in ligand_atom_atomics]
    ligand_coordinates = data["ligand_pos"]
    smiles = data["ligand_smiles"]
    protein_name = data["protein_filename"]
    ligand_name = data["ligand_filename"]

    d = {
        "pocket_atoms": pocket_atoms,
        "pocket_coordinates": pocket_coordinates,
        "ligand_atoms": ligand_atoms,
        "ligand_coordinates": ligand_coordinates,
        "smiles": smiles,
        "protein_name": protein_name,
        "ligand_name": ligand_name,
    }
    new_train.append(d)


write_lmdb(new_train, "/nfs/data/targetdiff_data/train_processed.lmdb")

test_path = "/nfs/data/targetdiff_data/test.lmdb"

test_data, keys = read_lmdb(test_path)

new_test = []

for data in test_data:
    pocket_atoms = data["protein_atom_name"]
    pocket_coordinates = data["protein_pos"]
    ligand_atom_atomics = data["ligand_element"].detach().cpu().numpy()
    ligand_atom_atomics = [int(i) for i in ligand_atom_atomics]
    ligand_atoms = [pt.GetElementSymbol(i) for i in ligand_atom_atomics]
    ligand_coordinates = data["ligand_pos"]
    smiles = data["ligand_smiles"]
    protein_name = data["protein_filename"]
    ligand_name = data["ligand_filename"]

    d = {
        "pocket_atoms": pocket_atoms,
        "pocket_coordinates": pocket_coordinates,
        "ligand_atoms": ligand_atoms,
        "ligand_coordinates": ligand_coordinates,
        "smiles": smiles,
        "protein_name": protein_name,
        "ligand_name": ligand_name,
    }
    new_test.append(d)

write_lmdb(new_test, "/nfs/data/targetdiff_data/test_processed.lmdb")


  0%|          | 0/99990 [00:00<?, ?it/s]

422
torch.Size([422, 3])
26
torch.Size([26, 3])
COc1cc(C(C)(C)C#Cc2c(C)nc(N)nc2N)cc(OC)c1OC
DYR_STAAU_2_158_0/4xe6_X_rec_3fqc_55v_lig_tt_docked_4_pocket10.pdb
DYR_STAAU_2_158_0/4xe6_X_rec_3fqc_55v_lig_tt_docked_4.sdf



