## Growing Molecules from a single Seed
Belows code contains everything needed to train a 3D NCA to generate molecular representations from a single seed

In [None]:
%load_ext autoreload
%autoreload 2


In [2]:
import warnings
import os
import sys
from pathlib import Path
import pickle
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from ncamol.models.trainer.trainer import lightning_train_loop
from ncamol.data_prep.voxel_tools.voxel import Voxel_Ligand, Voxel_PDB
from ncamol.data_prep.pdb_tools.prepare_pdb import PDB
from ncamol.visulization import plot_voxel
from ncamol.data_prep.utils import prepare_targets, prepare_inputs
from ncamol.models.utils import flush_cuda_cache
from ncamol.models.model import LitModel
from ncamol.data_prep.utils.conformer_utils import (
    generate_xyz_file_from_conformer,
)
from ncamol.data_prep.electron_density.electron_density import Edens_Dataset

from Bio import BiopythonWarning
from Bio.PDB import PDBParser

import numpy as np
import torch
from rdkit import Chem

warnings.simplefilter("ignore", BiopythonWarning)
torch.backends.cudnn.deterministic = True

In [3]:
DATA_DIR = "../data/protein_pockets/"
PDB_PATH = f"{DATA_DIR}complex_structures/"
structure_path = f"{DATA_DIR}ligand_and_pockets/"

STORAGE_DIR_POCKETS = f"{DATA_DIR}pockets/"
STORAGE_DIR_LIGANDS_POCKETS = f"{DATA_DIR}ligand_and_pockets/"
STORAGE_DIR_LIGANDS = f"{DATA_DIR}ligands/"
voxel_storage_path = "../data/protein_pockets/voxelized_pockets/"

for path in [DATA_DIR, PDB_PATH, structure_path,STORAGE_DIR_POCKETS, STORAGE_DIR_LIGANDS_POCKETS, voxel_storage_path, STORAGE_DIR_LIGANDS]:
    Path(path).mkdir(parents=True, exist_ok=True)

## Data Preparation
Extracting binding pockets from PDB, voxelizing PDBs, and calculating Electron Densities/ Electrostatic potentials


###  Atom Channels

ONLY EXECUTE ONCE PER PDB</br>
Retrieves all binding pockets from a pdb file.</br>
 * pdb file will be stored in PDB_PATH </br>
 * pocket pdb (cutoff radius around each) will be stored in STORAGE_DIR_POCKETS </br>
 * pocket + ligand will be stored in STORAGE_DIR_LIGAND_POCKETS
 * ligand pdb will be stored in STORAGE_DIR_LIGANDS


In [None]:
p = PDB(pdb_id="1afk", pdb_path=PDB_PATH)
p.retrieve_binding_pocket(
    storage_dir_pocket=STORAGE_DIR_POCKETS,
    storage_dir_lig_and_pocket=STORAGE_DIR_LIGANDS_POCKETS,
    storage_dir_lig=STORAGE_DIR_LIGANDS,
    cutoff=12,
)

Voxelize pocket/ligands from respective pdb files </br>
Each representation will be a num_atom_channels * (gridsize/voxelsize^3) numpy array

the atom channels are:
{
        "C": 0,
        "N": 1,
        "O": 2,
        "S": 3,
        "P": 4,
        "Cl": 5,
        "other": 6,
    }

In [24]:
pdb_structure = PDBParser().get_structure("1afk", f"{STORAGE_DIR_LIGANDS_POCKETS}1afk_00.pdb")
v = Voxel_PDB(pdb_structure, grid_size=38, voxel_size=1)
pocket = v._voxelize()[0]
pocket[pocket > 0] = 1


ligand_struct = PDBParser().get_structure("1afk_00", f"{STORAGE_DIR_LIGANDS}/1afk_00.pdb")
v_l = Voxel_Ligand(ligand_struct, grid_size=18, voxel_size=0.5)
# v_l.center_vector = v.center_vector
ligand = v_l._voxelize()
ligand[ligand > 0] = 1

In [None]:
vox = plot_voxel(ligand)

In [None]:
vox = plot_voxel(pocket)

###  Ligand ED/ESP
Calcualte Electron Density and Electrostatic Potential of a molecule


In [4]:
inhibitor_design_path = f"{os.pardir}/data/inhibitor_design/pdbs/ligand/"
inhibitor_design_xyz_path = f"{os.pardir}/data/inhibitor_design/xyz/ligand/"
inhibitor_design_storage = (
    f"{os.pardir}/data/inhibitor_design/electron_densities/ligand/"
)


for path in [inhibitor_design_path, inhibitor_design_xyz_path, inhibitor_design_storage]:
    if not os.path.exists(path):
        os.makedirs(path)

Same as above, but only save ligand instead of pocket and pocket+ligand from a pdb

In [None]:
p = PDB(pdb_id="1afk", pdb_path=PDB_PATH)
p.retrieve_ligand(
    storage_dir_lig=inhibitor_design_path
)

Generate xyz files. This is needed for the electron density calculation

In [5]:
ligand_files = os.listdir(inhibitor_design_path)
ligand_names = list(map(lambda x: x.split(".")[0], ligand_files))
ligands = list(
    map(
        lambda x: Chem.MolFromPDBFile(
            f"{inhibitor_design_path}{x}"
        ).GetConformer(),
        ligand_files,
    )
)

generate_xyz_file_from_conformer(
    ligands,
    names=ligand_names,
    save_path=inhibitor_design_xyz_path,
    complex_path=inhibitor_design_path,
)

Calc Edens and ESP. Result will be stored in a subdir of inhibitor_design_storage for each molecule in inhibitor_design_xyz_path

In [None]:
dim = 70
edens_prep = Edens_Dataset(
    file_path=inhibitor_design_xyz_path,
    storage_path=inhibitor_design_storage,
    n_points=dim,
    step_size=0.5,
)
edens_prep._compute_electron_density()

Decorate Electron Density with the ESP and dialate

In [5]:
LIGAND_id = "1afk_00" # select one ligand

file_path_ed = f"{inhibitor_design_storage}{LIGAND_id}/ed.pkl"
file_path_esp = f"{inhibitor_design_storage}{LIGAND_id}/esp.pkl"

with open(file_path_esp, "rb") as f:
    esp = pickle.load(f)

with open(file_path_ed, "rb") as f:
    edens = pickle.load(f)
def preprocess_esp_ed(edens, esp):
    # will half the size of the input
    conv3d = torch.nn.Conv3d(1, 1, 3, stride=2, padding=1, bias=False)
    conv3d.weight.data.fill_(1)

    dense_esp = (
        conv3d(
            torch.tensor(
                esp, dtype=torch.float, requires_grad=False
            ).unsqueeze(0)
        )
        .squeeze()
        .detach()
        .numpy()
    )
    dense_ed = (
        conv3d(
            torch.tensor(edens, dtype=torch.float, requires_grad=False)
            .unsqueeze(0)
            .unsqueeze(0)
        )
        .squeeze()
        .detach()
        .numpy()
    )

    decoreated_ed = dense_esp * np.where(dense_ed < 1e-4, 0, 1)
    return decoreated_ed, dense_esp, dense_ed

In [7]:
ed_esp, dense_esp, dense_ed = preprocess_esp_ed(edens, esp)

dialated = np.zeros((2, *ed_esp.shape))
dialated[0] = np.where(ed_esp > 0, 1, 0)
dialated[1] = np.where(ed_esp < 0, 1, 0)

In [None]:
plot_voxel(dialated)

## Reconstruction

### Grow ligand without pocket [ED]

In [8]:
logging_path = Path("/home/sebastian/code/molnca/ncamol/models/logs/grow_molecule/edesp/")

seed_to_ligand_config = {
    "normal_std": 0.01,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 3,
    "num_hidden_channels": 12,
    "num_epochs": 40_000,
    "channel_dims": [42, 42],
    "from_pocket": False,
    "report_interval": 100,
    "logging_path": logging_path,}

NOTE: It is best to remove most of the white space. Since The ED/ESP representation is a surface only representation this resulting voxel grid is extremly sparse and can lead to traninig instabilities 

In [None]:
targets = prepare_targets(dialated[:,9:27,6:32,4:], bs=2)
inputs = prepare_inputs(targets, seed_coords=np.array(targets.shape[-3:]) // 2, num_categories=seed_to_ligand_config["num_categories"])

inputs.shape, targets.shape

In [None]:
plot_voxel(targets[0, 1:3].detach().numpy())

In [None]:
model = lightning_train_loop(
    x0=inputs, target=targets, **seed_to_ligand_config
)

Inference

In [42]:
from ncamol.models.model.LitModel import LitModel
model = LitModel.load_from_checkpoint("/home/sebastian/code/molnca/ncamol/models/logs/grow_molecule/edesp/atom_channel_reconstruction-v2.ckpt")
# move to gpu
model = model.cuda()
inputs = inputs.cuda()

In [None]:
out = torch.stack([model.forward(inputs, steps=step) for step in range(48, 65)]).detach().cpu().mean(0)
plot_voxel(out[0, 1:3].detach().cpu().numpy() > 0.1)

In [None]:
plot_voxel(targets[0, 1:3].detach().numpy())

### Grow ligand without pocket [Atom Channel] 

In [5]:
logging_path = Path("../models/logs/grow_molecule/atomchannel/")

seed_to_ligand_config = {
    "normal_std": 0.01,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 8, # 7 atom channels + solvent/air channel
    "num_hidden_channels": 12,
    "num_epochs": 20_000,
    "channel_dims": [42, 42],
    "from_pocket": False,
    "report_interval": 100,
    "logging_path": logging_path,
    "loss": ["mse", "iou"],
    "batch_size": 4,
}

In [25]:
targets = prepare_targets(ligand, bs=4, num_hidden=seed_to_ligand_config["num_hidden_channels"])
inputs = prepare_inputs(targets, seed_coords=np.array(targets.shape[-3:]) // 2)

In [None]:
model = lightning_train_loop(
    x0=inputs, target=targets, **seed_to_ligand_config
)

Inference

In [28]:
from ncamol.models.model.LitModel import LitModel
model = LitModel.load_from_checkpoint("/home/sebastian/code/molnca/ncamol/models/logs/grow_molecule/atomchannel/atom_channel_reconstruction-v4.ckpt")

model = model.cuda()
inputs = inputs.cuda()

In [None]:
out = torch.stack([model.forward(inputs, steps=step).detach().cpu() for step in range(40, 72)]).mean(0)

In [None]:
plot_voxel(targets[0, 1:8].detach().cpu().numpy())

In [None]:
plot_voxel(out[0, 1:8].detach().cpu().numpy() > 0.1)