## Imports and helpers


In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from ncamol.models.trainer.trainer import lightning_train_loop
from ncamol.data_prep.voxel_tools.voxel import Voxel_Ligand
from ncamol.visulization import plot_voxel
from ncamol.models.utils import flush_cuda_cache
from ncamol.models.model import LitModel

from Bio import BiopythonWarning
from Bio.PDB import PDBParser
import numpy as np
import torch

warnings.simplefilter("ignore", BiopythonWarning)
torch.backends.cudnn.deterministic = True

In [None]:
small_protein_pdbs = "../data/small_proteins/pdb/"
save_path = "../data/small_proteins/voxelized/"

for path in [save_path, small_protein_pdbs]:
    if not os.path.exists(path):
        os.makedirs(path)

def clear_all():
    import gc
    
    gc.collect()
    torch.cuda.empty_cache()

def prepare_inputs(
    pockets: list[torch.tensor], n_hidden: int = 12
) -> list[torch.tensor]:
    inputs = torch.tensor([])

    for pocket in pockets:
        is_air = torch.max(pocket, dim=0)[0].unsqueeze(0)
        n_hidden = n_hidden
        hidden_channels = (is_air == 1).repeat(n_hidden, 1, 1, 1)

        pocket = (
            torch.cat(
                [(is_air == 0), pocket, (is_air == 1), hidden_channels], dim=0
            )
            .unsqueeze(0)
            .to(torch.float)
        )
        inputs = torch.cat([inputs, pocket], dim=0)
    return inputs


def load_inputs(path="../data/small_proteins/voxelized/"):
    inputs = []
    files = os.listdir(path)
    files.sort()
    for file in files:
        if file.endswith("input.npy"):
            inputs.append(torch.tensor(np.load(f"{path}{file}")[:3, ...]))
    return inputs


def load_targets(path="../data/small_proteins/voxelized/"):
    targets = []
    files = os.listdir(path)
    files.sort()

    for file in files:
        if file.endswith("target.npy"):
            targets.append(torch.tensor(np.load(f"{path}{file}")[:3, ...]))
    return targets

## Data Preparation


### Prepare Protein backbone representation


PDBs were preprocessed in PyMol e.g.:
 * fetch 1ahO
 * show sticks
 * remove solvent | not polymer | hydrogens | sidechain
 * (remove resi 33-37)
 * save PATH/TO/SAVE


In [15]:
proteins = ["1aho", "1sp1", "3nir"]
protein_inputs = ["1ahodel33_37", "1sp1del18_22", "3nirdel35_38"]

for protein, protein_input in zip(proteins, protein_inputs):
    file = f"{small_protein_pdbs}{protein}.pdb"
    parser = PDBParser()
    structure = parser.get_structure(protein, file)[0]

    v_t = Voxel_Ligand(
        structure,
        grid_size=40,
        aggregation="surround",
    )

    voxels = v_t._voxelize()
    np.save(f"{save_path}{protein}_target.npy", voxels)

    file = f"{small_protein_pdbs}{protein_input}.pdb"
    parser = PDBParser()
    structure = parser.get_structure(protein_input, file)[0]

    v = Voxel_Ligand(
        structure,
        grid_size=40,
        aggregation="surround",
        center_vector=v_t.center_vector
    )

    voxels = v._voxelize()
    np.save(f"{save_path}{protein}_input.npy", voxels)


In [None]:
plot_voxel(voxels)

## Train

In [19]:
inputs = load_inputs()
targets = load_targets()
prepared_inputs = prepare_inputs(inputs)
prepared_targets = prepare_inputs(targets)

In [None]:
logging_path = Path("../models/logs/backbone_recon/atom_channels/")
backbone_config = {
    "normal_std": 0.01,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 4,
    "num_hidden_channels": 12,
    "num_epochs": 20000,
    "channel_dims": [34, 34],
    "from_pocket": False,
    "report_interval": 50,
    "logging_path": logging_path,
}

states, losses, model = lightning_train_loop(
    x0=prepared_inputs, target=prepared_targets, **backbone_config
)
