# HEqBM Build Training Dataset #

In [None]:
import os
from os.path import basename
import glob
import torch
import numpy as np
from typing import Optional
from heqbm.mapper import HierarchicalMapper
from heqbm.utils import DataDict
from heqbm.utils.plotting import plot_cg

torch.set_default_dtype(torch.float32)

In [None]:
YOUR_PATH_TO_DATA_FOLDER = "/storage_common/angiod/"

config_dict = {
    "PED_TRAIN": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/train"),
        "structure_format": "pdb",
        "selection": "protein",
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/train"),
    },
    "PED_VALID": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/valid"),
        "structure_format": "pdb",
        "selection": "protein",
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/valid"),
    },
    "PED_CA_TRAIN": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein_ca",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/train"),
        "structure_format": "pdb",
        "selection": "protein",
        "frames": 100,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/ca_train"),
    },
    "PED_CA_VALID": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein_ca",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/valid"),
        "structure_format": "pdb",
        "selection": "protein",
        "frames": 100,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/ca_valid_20_frames"),
    },
    "A2A": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/tpr/"),
        "structure_format": "tpr",
        "traj_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/trr"),
        "traj_format": "trr",
        "selection": "protein",
        "frames": 250,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/backmapping/npz/protein/train/"),
    },
    "MEMBRANE": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_membrane",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/tpr/"),
        "structure_format": "tpr",
        "traj_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/trr/membrane"),
        "traj_format": "trr",
        "selection": "resname CHL PC PA OL",
        "frames": 15,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/backmapping/npz/membrane/train/"),
    },
    "POPC": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_membrane_standard",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "POPC/"),
        "structure_format": "gro",
        "traj_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "POPC/"),
        "traj_format": "trr",
        "selection": "resname POPC",
        "frames": 100,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/train/"),
    },
    "ZMA": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_zma",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "LIGANDS/ZMA/atomistic/"),
        "structure_format": "gro",
        "traj_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "LIGANDS/ZMA/atomistic/"), # Could be None, if structure file is for example a multi pdb
        "traj_format": "xtc",
        "selection": "resname ZMA",
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "LIGANDS/ZMA/backmapping/npz/train/"),
    },
    "MiniG": {
        "mapping_root": "./heqbm/data/mappings_hierarchical_protein",
        "structure_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "miniG/prmtop/"),
        "structure_format": "prmtop",
        "traj_folder_in": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "miniG/crd/"),
        "traj_format": "crd",
        "extra_kwargs": {"format": "TRJ"},
        "selection": "protein",
        "frames": 200,
        "npz_folder_out": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "miniG/backmapping/npz/train/"),
    },
}

In [None]:
system = "POPC"
conf = config_dict.get(system)

mapping_root = conf.get("mapping_root")
structure_folder_in = conf.get("structure_folder_in")
structure_format = conf.get("structure_format")
traj_folder_in = conf.get("traj_folder_in")
traj_format = conf.get("traj_format")
npz_folder_out = conf.get("npz_folder_out")

In [None]:
def get_ds(
        filename: str,
        traj_folder_in: Optional[str] = None,
        traj_format: str = "trr",
        selection: str = 'protein',
        keep_backbone: bool = False,
        frame_limit: int = np.inf,
        extra_kwargs: Optional[dict] = None,
    ):
    conf = {
    "keep_hydrogens": False,
    "structure_filename": filename,
    }
    if extra_kwargs is not None:
        conf["extra_kwargs"] = extra_kwargs

    if traj_folder_in is not None:
        for traj_filename in glob.glob(os.path.join(traj_folder_in, f"{basename(filename).split('.')[0]}*.{traj_format}")):
            traj_filenames = conf.get("traj_filenames", [])
            traj_filenames.append(traj_filename)
            conf["traj_filenames"] = traj_filenames

    mapping = HierarchicalMapper(root=mapping_root)
    mapping.map(conf, selection=selection, frame_limit=frame_limit)
    dataset = mapping.dataset

    if not keep_backbone:
        dataset[DataDict.BEAD2ATOM_RELATIVE_VECTORS][:, dataset[DataDict.CA_BEAD_IDCS]] = 0.

    npz_ds = {
        k: v for k, v in dataset.items() if k in [
            DataDict.ATOM_POSITION, DataDict.BEAD_POSITION, DataDict.ATOM_NAMES,
            DataDict.BEAD_NAMES, DataDict.ATOM_TYPES, DataDict.BEAD_TYPES,
            DataDict.BEAD2ATOM_RELATIVE_VECTORS, DataDict.BB_PHIPSI,
            DataDict.LEVEL_IDCS_MASK, DataDict.LEVEL_IDCS_ANCHOR_MASK,
            DataDict.BEAD2ATOM_IDCS, DataDict.CA_NEXT_DIRECTION,
            DataDict.BOND_IDCS, DataDict.ANGLE_IDCS, DataDict.CELL, DataDict.PBC
        ]
    }

    return mapping, npz_ds

In [None]:
os.makedirs(npz_folder_out, exist_ok=True)
for filename in glob.glob(os.path.join(structure_folder_in, f"*.{structure_format}")):
    try:
        filename_out = os.path.join(npz_folder_out, f"{basename(filename).split('.')[0]}.npz")
        # if os.path.isfile(filename_out):
        #     continue
        
        mapping, npz_ds = get_ds(
            filename=filename,
            traj_folder_in=traj_folder_in,
            traj_format=traj_format,
            selection=conf.get("selection"),
            frame_limit=conf.get("frames", None),
            extra_kwargs=conf.get("extra_kwargs", None),
            keep_backbone=True,
        )

        print(filename_out, npz_ds[DataDict.ATOM_POSITION].shape)
        if npz_ds is not None:
            np.savez(filename_out, **npz_ds)
            config_update_text = '''Update the training configuration file with the following snippet (excluding quotation marks):
            \n"\ntype_names:\n'''
            for bt in [x[0] for x in sorted(mapping.bead_types.items(), key=lambda x: x[1])]:
                config_update_text += f'- {bt}\n'
            config_update_text += '"'
            print(config_update_text)
    except TypeError:
        print(f"Skipping file {filename}. Most probably the resid is messed up")

### Show an example mapping ###

In [None]:
for sample_filename in glob.glob(os.path.join(structure_folder_in, f"*.{structure_format}")):
    mapping, npz_ds = get_ds(
        filename=sample_filename,
        traj_folder_in=traj_folder_in,
        traj_format=traj_format,
        selection=conf.get("selection"),
        frame_limit=conf.get("frames", None),
        extra_kwargs=conf.get("extra_kwargs", None),
        keep_backbone=True,
        )
    break

plot_cg(mapping.dataset, frame_index=0)

In [None]:
import nglview as nv
w = nv.show_mdanalysis(mapping.u, default_representation=False)
w.add_ball_and_stick(selection='(not _H) and (12) and not (.CB .CG .OE1 .OE2)')
w.add_ball_and_stick(selection='(.CA) and (12)', color='orange')
w.add_representation('cartoon', selection='1-20', color='aqua', opacity=0.2)
w.center()
w