## Imports and Helpers

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import warnings
import os
import sys
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from ncamol.models.trainer.trainer import lightning_train_loop
from ncamol.data_prep.voxel_tools.voxel import Voxel_Ligand, Voxel_PDB
from ncamol.data_prep.pdb_tools.prepare_pdb import PDB
from ncamol.visulization import plot_voxel

from Bio import BiopythonWarning
from Bio.PDB import PDBParser

import torch
from torch import nn

warnings.simplefilter("ignore", BiopythonWarning)
torch.backends.cudnn.deterministic = True

In [3]:
DATA_DIR = "../data/protein_pockets/"
PDB_PATH = f"{DATA_DIR}complex_structures/"
structure_path = f"{DATA_DIR}ligand_and_pockets/"

STORAGE_DIR_POCKETS = f"{DATA_DIR}pockets/"
STORAGE_DIR_LIGANDS_POCKETS = f"{DATA_DIR}ligand_and_pockets/"
STORAGE_DIR_LIGANDS = f"{DATA_DIR}ligands/"
voxel_storage_path = "../data/protein_pockets/voxelized_pockets/"

for path in [DATA_DIR, PDB_PATH, structure_path,STORAGE_DIR_POCKETS, STORAGE_DIR_LIGANDS_POCKETS, voxel_storage_path, STORAGE_DIR_LIGANDS]:
    Path(path).mkdir(parents=True, exist_ok=True)


def clear_all():
    import gc
    
    gc.collect()
    torch.cuda.empty_cache()


def prepare_inputs(
    pockets: list[torch.tensor], n_hidden: int = 20
) -> list[torch.tensor]:
    inputs = torch.tensor([])

    for pocket in pockets:
        is_air = torch.where(
            torch.max(pocket, dim=0)[0].unsqueeze(0) == 0, 1, 0
        )
        is_alive = nn.functional.max_pool3d(
            (is_air == 0).float(),
            kernel_size=1,
            stride=1,
        ).int()
        n_hidden = n_hidden
        hidden_channels = (is_air == 1).repeat(n_hidden, 1, 1, 1)

        pocket = (
            torch.cat([is_air, pocket, is_alive, hidden_channels], dim=0)
            .unsqueeze(0)
            .to(torch.float)
        )
        inputs = torch.cat([inputs, pocket], dim=0)
    return inputs


def prepare_outputs(pockets, ligands):
    outputs = torch.tensor([])
    for pocket, ligand in zip(pockets, ligands):
        from_pocket_target = torch.clip(pocket + ligand, 0, 1).unsqueeze(0)
        outputs = torch.cat([outputs, from_pocket_target], dim=0)
    return outputs

## Data Preparation
Prepare protein/ binding pocket and ligand in an atom_channel representation


### Protein/Ligand Atom Channel

In [None]:
p = PDB(pdb_id="1f8c", pdb_path=PDB_PATH)
p.retrieve_binding_pocket(
    storage_dir_pocket=STORAGE_DIR_POCKETS,
    storage_dir_lig_and_pocket=STORAGE_DIR_LIGANDS_POCKETS,
    storage_dir_lig=STORAGE_DIR_LIGANDS,
    cutoff=12,
)

In [14]:
grid_size = 36 # in Angstrom
step = 1 # in Angstrom
# resulution of the voxel grid is grid_size/step

pdb_structure = PDBParser().get_structure("1f8c", f"{STORAGE_DIR_POCKETS}1f8c_05.pdb")
v = Voxel_PDB(pdb_structure, grid_size=grid_size, voxel_size=step)
pocket = v._voxelize()[0]
pocket[pocket > 0] = 1

ligand_struct = PDBParser().get_structure("1f8c_05", f"{STORAGE_DIR_LIGANDS}/1f8c_05.pdb")
v_l = Voxel_Ligand(ligand_struct, grid_size=grid_size, voxel_size=step)
v_l.center_vector = v.center_vector
ligand = v_l._voxelize()
ligand[ligand > 0] = 1

In [None]:
vox = plot_voxel(ligand)

In [None]:
vox = plot_voxel(pocket)

In [None]:
pocket = torch.tensor(pocket)
ligand = torch.tensor(ligand)

inputs = prepare_inputs([pocket])
targets = prepare_outputs([pocket], [ligand])

inputs.shape, targets.shape

In [None]:
# # ids = ["184l_00", "1azm_00", "1fkn_00"]

# # ligands = [
# #     torch.load(f"{VOXEL_STORAGE_PATH}/{id}_ligand.pt").to_dense()[
# #         :-1, 3:33, 3:33, 3:33
# #     ]
# #     for id in ids
# # ]
# # pockets = [
# #     torch.load(f"{VOXEL_STORAGE_PATH}/{id}_protein.pt").to_dense()[
# #         :-1, 3:33, 3:33, 3:33
# #     ]
# #     for id in ids
# # ]

# inputs = prepare_inputs(pockets)
# targets = prepare_outputs(pockets, ligands)

# inputs.shape, targets.shape

## Train

In [23]:
# dtype of the inputs and targets
inputs.dtype, targets.dtype

# make float32
inputs = inputs.to(torch.float32)
targets = targets.to(torch.float32)

In [None]:
logging_path = Path("/home/sebastian/code/molnca/ncamol/models/logs/grow_in_pocket/atom_channels/")

multiple_pocket_to_ligand_config_new_loss = {
    "normal_std": 0.04,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 8,
    "num_hidden_channels": 20,
    "num_epochs": 40_000,
    "channel_dims": [56, 56],
    "from_pocket": True,
    "report_interval": 100,
    "logging_path": logging_path,
}

model = lightning_train_loop(
    x0=inputs,
    target=targets,
    **multiple_pocket_to_ligand_config_new_loss,
)
