In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
from heqbm.utils.pdbFixer import joinPDBs
from heqbm.backmapping.hierarchical_backmapping import HierarchicalBackmapping

torch.set_default_dtype(torch.float32)

In [None]:
args_dict = {
    "config": "config/backmapping/PDB6K-martini3-bbcommon-geqmodel.yaml",
}
# args_dict = {
#     "mapping": "zma",
#     "input": "/storage_common/angiod/A2A/Omar/prot_zma.gro",
#     "inputtraj": "/storage_common/angiod/A2A/Omar/md_all_prot_mol.xtc",
#     "trajslice": "4005:4006",
#     "selection": "resname ZMA",
#     "output": "backmapped/ZMA/atomistic/",
#     "model": "config/training/ZMA-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": True,
# }
# args_dict = {
#     "mapping": "zma",
#     "input": "/storage_common/angiod/A2A/Vince/11-daniele.pdb",
#     "selection": "resname ZMA",
#     "output": "backmapped/A2A-ZMA-POPC/ZMA/",
#     "model": "config/training/ZMA-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": False,
# }
# args_dict = {
#     "mapping": "martini3",
#     "input": "/storage_common/angiod/A2A/Vince/11-daniele.pdb",
#     "selection": "protein",
#     "output": "backmapped/A2A-ZMA-POPC/A2A/",
#     "model": "config/training/A2A-martini3-bbcommon-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": False,
# }
args_dict = {
    "mapping": "martini3.membrane",
    "input": "/storage_common/angiod/A2A/Vince/11-daniele.pdb",
    "selection": "resname POPC",
    "trajslice": "69:70",
    "output": "backmapped/A2A-ZMA-POPC/POPC/",
    "model": "config/training/MEM-martini3-geqmodel.yaml",
    "modelweights": "best_model.pth",
    "isatomistic": False,
}
# args_dict = {
#     "mapping": "chc",
#     "input": "/storage_common/angiod/paolo/chc/CHC/CHC_complete.pdb",
#     "inputtraj": "/storage_common/angiod/paolo/chc/CHC/md10_npt_10ps.xtc",
#     "trajslice": ":1",
#     "selection": "resname CHC",
#     "output": "backmapped/CHC/atomistic/",
#     "model": "config/training/CHC-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": True,
# }
# args_dict = {
#     "mapping": "chc",
#     "input": "/storage_common/angiod/LIGANDS/CHC/CG/chc_only.gro",
#     "inputtraj": "/storage_common/angiod/LIGANDS/CHC/CG/chc_only_nopbc.xtc",
#     "trajslice": ":1",
#     "selection": "resname CHC",
#     "output": "backmapped/CHC/atomistic/",
#     "model": "config/training/CHC-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": False,
# }

args_dict.update({
    "device": "cuda:0",
    "batch_max_atoms": 3000,
})
backmapping = HierarchicalBackmapping(args_dict=args_dict)

In [None]:
backmapped_filenames, backmapped_minimised_filenames, true_filenames, cg_filenames = [], [], [], []
for mapping in backmapping.map():
    _backmapped_filenames, _backmapped_minimised_filenames, _true_filenames, _cg_filenames = backmapping.backmap(
        mapping, optimise_backbone=False
    )
    backmapped_filenames.extend(_backmapped_filenames)
    backmapped_minimised_filenames.extend(_backmapped_minimised_filenames)
    true_filenames.extend(_true_filenames)
    cg_filenames.extend(_cg_filenames)

for tag in ['backmapped']:
    joinPDBs(backmapping.config.get("output"), tag)

In [None]:
import MDAnalysis as mda

def show(backmapped_filenames, cg_filenames, true_filenames):
    backmapped_u = mda.Universe(*backmapped_filenames)
    cg_u = mda.Universe(*cg_filenames)
    # backmapped_minimised_u = mda.Universe(*backmapped_minimised_filenames)

    ref_u = None
    if len(true_filenames) > 0:
        ref_u = mda.Universe(*true_filenames)

        from MDAnalysis.analysis import align, rms
        aligner = align.AlignTraj(
                    backmapped_u,  # mobile
                    ref_u,         # reference
                    select='all',  # selection to operate on
                    in_memory=True,
                    match_atoms=True,
                ).run()
        aligned_rmsd = rms.rmsd(backmapped_u.atoms.positions, ref_u.atoms.positions, superposition=False)
        print(aligned_rmsd)

        merged = mda.Merge(backmapped_u.select_atoms("protein"), ref_u.atoms)
    
    from nglview import NGLWidget
    import nglview as nv
    
    w = NGLWidget(representations=None)
    # w._remove_representation()
    w.add_trajectory(cg_u)
    w.add_trajectory(backmapped_u)
    # w.add_trajectory(backmapped_minimised_u)
    # if ref_u is not None:
    #     w.add_trajectory(ref_u)
    
    # w.add_representation('spacefill', radius=.5, selection='.RE .BB', color='pink')
    # w.add_representation('spacefill', radius=.5, selection='.SC1 .SC2 .SC3 .SC4 .SC5', color='yellow')
    # w.add_representation('spacefill', radius=.5, selection='not (_H _C _N _O)', color='pink')
    w.add_representation('licorice',  selection='protein')
    return w

show(backmapped_filenames, cg_filenames, true_filenames)