# HEqBM Build Training Dataset #

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np

from os.path import dirname
from build_dataset import build_npz_dataset

In [None]:
YOUR_PATH_TO_DATA_FOLDER = "/storage_common/angiod/"

config_dict = {
    "zma.train": {
        "mapping": "zma",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/Omar"),
        "inputformat": "gro",
        "inputtraj": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/Omar"),
        "traj_format": "xtc",
        "selection": "resname ZMA",
        "trajslice": ":500",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/Omar/backmapping/npz/train/zma.npz"),
        "isatomistic": True,
    },
    "chc.train": {
        "mapping": "chc",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "paolo/chc/CHC"),
        "inputformat": "pdb",
        "inputtraj": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "paolo/chc/CHC"),
        "traj_format": "xtc",
        "selection": "all",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "paolo/chc/backmapping/npz/train/"),
        "isatomistic": True,
    },
    "membrane.martini3.train": {
        "mapping": "martini3.membrane",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/charmmgui/train/"),
        "inputformat": "pdb",
        # "inputtraj": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/ali/"),
        # "traj_format": "xtc",
        "selection": "resname POPC",
        #"trajslice": ":",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/backmapping/martini3/train/"),
    },
    "membrane.martini3.valid": {
        "mapping": "martini3.membrane",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/charmmgui/valid/"),
        "inputformat": "pdb",
        # "inputtraj": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/ali/"),
        # "traj_format": "xtc",
        "selection": "resname POPC",
        #"trajslice": ":",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "membrane/backmapping/martini3/valid/"),
    },
    "A2A.martini3.train": {
        "mapping": "martini3",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/tpr/a2a.tpr"),
        "inputtraj": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/trr/"),
        "trajformat": "trr",
        "selection": "protein",
        "trajslice": ":50",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "A2A/backmapping/npz/martini3/train/"),
    },
    "PDB6K.martini3.train": {
        "mapping": "martini3",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/pdb.6k/augment"),
        "filter": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/set/targets.train.pdb.2.9k"),
        "inputformat": "pdb",
        "selection": "protein",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/backmapping/npz/martini3.2.9k/train"),
    },
    "PDB6K.martini3.valid": {
        "mapping": "martini3",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/pdb.6k/augment"),
        "filter": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/set/targets.valid.pdb.72"),
        "inputformat": "pdb",
        "selection": "protein",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/backmapping/npz/martini3.2.9k/valid"),
    },
    "PDB6K.CA.train": {
        "mapping": "ca",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/pdb.6k/augment"),
        "filter": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/set/targets.train.pdb.2.9k"),
        "inputformat": "pdb",
        "selection": "protein",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/backmapping/npz/ca.2.9k/trainnew"),
    },
    "PDB6K.CA.valid": {
        "mapping": "ca",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/pdb.6k/augment"),
        "filter": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/set/targets.valid.pdb.72"),
        "inputformat": "pdb",
        "selection": "protein",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PDB6K/backmapping/npz/ca.2.9k/valid"),
    },
    "PED.CA.train": {
        "mapping": "ca",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/train"),
        "inputformat": "pdb",
        "selection": "protein",
        "trajslice": ":1000",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/ca.train"),
    },
    "PED.CA.valid": {
        "mapping": "ca",
        "input": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/pdb/valid"),
        "inputformat": "pdb",
        "selection": "protein",
        "output": os.path.join(YOUR_PATH_TO_DATA_FOLDER, "PED/backmapping/npz/ca.valid"),
    },
}

In [None]:
system = "zma.train"

In [None]:
for npz_dataset, config in build_npz_dataset(config_dict.get(system), skip_if_existent=False):
    if npz_dataset is not None:
        os.makedirs(dirname(config["output"]), exist_ok=True)
        np.savez(config["output"], **npz_dataset)

In [None]:
for k,v in npz_dataset.items():
    print(k, v.shape)