## Imports and Helpers

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from ncamol.models.trainer.trainer import train_loop_ext_force
import pickle
from Bio import BiopythonWarning
import numpy as np
import torch
import os
from pathlib import Path
from ncamol.data_prep.electron_density.electron_density import Edens_Dataset
from rdkit import Chem
from rdkit.Chem import AllChem
from ncamol.visulization import plot_voxel
from ncamol.data_prep.utils import prepare_targets
from ncamol.models.utils import flush_cuda_cache
from ncamol.models.model import LitModel, LitModel_external_force
from ncamol.models.dataset.multiconf import MultiConfDataset

warnings.simplefilter("ignore", BiopythonWarning)
torch.backends.cudnn.deterministic = True

In [None]:
isomer_path = "../data/cis_trans/mol_files/"
isommer_edens_storage = "../data/cis_trans/electron_density/"

for path in [isomer_path, isommer_edens_storage]:
    if not os.path.exists(path):
        os.makedirs(path)

def preprocess_esp_ed(edens, esp):
    # will half the size of the input
    conv3d = torch.nn.Conv3d(1, 1, 3, stride=1, padding=1, bias=False)
    conv3d.weight.data.fill_(1)

    dense_esp = (
        conv3d(
            torch.tensor(
                esp, dtype=torch.float, requires_grad=False
            ).unsqueeze(0)
        )
        .squeeze()
        .detach()
        .numpy()
    )
    dense_ed = (
        conv3d(
            torch.tensor(edens, dtype=torch.float, requires_grad=False)
            .unsqueeze(0)
            .unsqueeze(0)
        )
        .squeeze()
        .detach()
        .numpy()
    )

    decoreated_ed = dense_esp * np.where(dense_ed < 0.0001, 0, 1)
    return decoreated_ed, dense_esp, dense_ed

def load_ed_esp(file_path):
    with open(file_path + "ed.pkl", "rb") as f:
        edens = pickle.load(f)
    with open(file_path + "esp.pkl", "rb") as f:
        esp = pickle.load(f)
    return edens, esp

## Data Preparation
Computing Electron Density & Electro Static Potential
as well as conformers

### Ligand ED/ESP
and conformer generation


In [3]:
def save_molecule_conf(molecule, path):
    molecule = Chem.RemoveHs(molecule)
    write = Chem.SDWriter(path)
    write.write(molecule)
    write.close()


In [None]:
mol_cis = "Oc1cc(O)cc(/C=C\c2cc(O)cc(O)c2)c1"
mol_trans = "Oc1cc(O)cc(\C=C\c2cc(O)cc(O)c2)c1"

molH_cis = Chem.AddHs(Chem.MolFromSmiles(mol_cis))
molH_trans = Chem.AddHs(Chem.MolFromSmiles(mol_trans))


AllChem.EmbedMolecule(molH_cis)
AllChem.MMFFOptimizeMolecule(molH_cis)

AllChem.EmbedMolecule(molH_trans)
save_path = "../data/cis_trans/mol_files/"
AllChem.rdmolfiles.MolToXYZFile(molH_cis, save_path + "cis.xyz")


In [None]:
save_molecule_conf(molH_trans, save_path + "trans_conf.sdf")


In [4]:
dim = 30
edens_prep = Edens_Dataset(
    file_path=isomer_path,
    storage_path=isommer_edens_storage,
    n_points=dim,
    step_size=1.5,
)
edens_prep._compute_electron_density()

In [5]:
edens_cis, esp_cis = load_ed_esp(isommer_edens_storage + "cis/")
edens_trans, esp_trans = load_ed_esp(isommer_edens_storage + "trans/")

In [None]:
conv3d = torch.nn.Conv3d(1, 1, 3, stride=1, padding=1, bias=False)
conv3d.weight.data.fill_(1)


dense_esp_cis = (
    conv3d(
        torch.tensor(
            esp_cis, dtype=torch.float, requires_grad=False
        ).unsqueeze(0)
    )
    .squeeze()
    .detach()
    .numpy()
)

dense_esp_trans = (
    conv3d(
        torch.tensor(
            esp_trans, dtype=torch.float, requires_grad=False
        ).unsqueeze(0)
    )
    .squeeze()
    .detach()
    .numpy()
)

decorated_ed_cis = np.expand_dims(
    np.where(edens_cis < 0.0001, 0, 1) * dense_esp_cis, axis=0
)[:, 5:25, 5:25, 9:22]
decorated_ed_trans = np.expand_dims(
    np.where(edens_trans < 0.0001, 0, 1) * dense_esp_trans, axis=0
)[:, 5:25, 5:25, 9:22]

np.expand_dims(np.where(edens_cis < 0.0001, 0, 1) * dense_esp_cis, axis=0)[
    :, 5:25, 5:25, 9:22
].sum() == np.expand_dims(
    np.where(edens_cis < 0.0001, 0, 1) * dense_esp_cis, axis=0
).sum()

np.expand_dims(np.where(edens_trans < 0.0001, 0, 1) * dense_esp_trans, axis=0)[
    :, 5:25, 5:25, 9:22
].sum() == np.expand_dims(
    np.where(edens_trans < 0.0001, 0, 1) * dense_esp_trans, axis=0
).sum()

In [8]:
channeled_ed_esp_cis = np.concatenate(
    [
        np.where(decorated_ed_cis < 0, 1, 0),
        np.where(decorated_ed_cis > 0, 1, 0),
    ]
)

channeled_ed_esp_trans = np.concatenate(
    [
        np.where(decorated_ed_trans < 0, 1, 0),
        np.where(decorated_ed_trans > 0, 1, 0),
    ]
)

plot_voxel(channeled_ed_esp_trans)

## Train

### Pretrain

In [18]:
cis_input = prepare_targets(channeled_ed_esp_cis, bs=1, num_hidden=24)
trans_input = prepare_targets(channeled_ed_esp_trans, bs=1, num_hidden=24)


In [None]:
ds = MultiConfDataset(cis_input, trans_input, pretrain=True)
logging_path = Path("../models/logs/cis_trans/edens/")
isomer_config = {
    "normal_std": 0.01,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 3,
    "num_hidden_channels": 24,
    "num_epochs": 2500,
    "channel_dims": [56, 56],
    "report_interval": 100,
    "logging_path": logging_path,
    "pretrain": True
}

model = train_loop_ext_force(data=ds, **isomer_config)

### Finetune

In [None]:
ds_ft = MultiConfDataset(cis_input, trans_input, pretrain=False)


isomer_config_finetune = {
    "normal_std": 0.01,
    "learning_rate": 0.001,
    "alive_threshold": 0.1,
    "cell_fire_rate": 0.5,
    "steps": [48, 64],
    "num_categories": 3,
    "num_hidden_channels": 24,
    "num_epochs": 2500,
    "channel_dims": [56, 56],
    "report_interval": 100,
    "logging_path": logging_path,
    "checkpoint": logging_path / "ckpt.pt"
}

states, losses, model = train_loop_ext_force(
    data=ds_ft,
    **isomer_config_finetune,
    model=model,
    at_epoch=2500
)
