In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
from heqbm.utils.pdbFixer import joinPDBs
from heqbm.backmapping.hierarchical_backmapping import HierarchicalBackmapping

torch.set_default_dtype(torch.float32)

In [None]:
args_dict = {
    # "config": "config/backmapping/PDB6K-martini3-bbcommon-geqmodel.yaml",
    # "config": "config/backmapping/PDB6K-CA-geqmodel.yaml",
    # "config": "config/backmapping/MEM-martini3-geqmodel.yaml",
    "config": "config/backmapping/A2A-martini3-geqmodel.yaml",
    "device": "cuda:1",
    "isatomistic": True,
}
args_dict = {
    "mapping": "zma",

    "input": "/storage_common/angiod/A2A/Omar/prot_zma.gro",
    "inputtraj": "/storage_common/angiod/A2A/Omar/md_all_prot_mol.xtc",
    "trajslice": "3000:3001",

    "selection": "resname ZMA",

    "output": "backmapped/ZMA/atomistic/",

    "model": "config/training/ZMA-geqmodel.yaml",
    "modelweights": "best_model.pth",
    "isatomistic": True,
}
# args_dict = {
#     "mapping": "zma",

#     "input": "/storage_common/angiod/A2A/Vince/11-blyr.pdb",
#     # "inputtraj": "/storage_common/angiod/A2A/Vince/11-blyr.gro",
#     "trajslice": "480:481",

#     "selection": "resname ZMA",

#     "output": "backmapped/ZMA/CG/",

#     "model": "config/training/ZMA-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": False,
# }
# args_dict = {
#     "mapping": "martini3",

#     "input": "/storage_common/angiod/A2A/Vince/11-blyr.pdb",
#     # "inputtraj": "/storage_common/angiod/A2A/Vince/11-blyr.gro",
#     "trajslice": ":1",

#     "selection": "protein",

#     "output": "backmapped/A2A/CG/",

#     "model": "config/training/A2A-martini3-bbcommon-geqmodel.yaml",
#     "modelweights": "best_model.pth",
#     "isatomistic": False,
# }
backmapping = HierarchicalBackmapping(args_dict=args_dict)

In [None]:
backmapped_filenames, backmapped_minimised_filenames, true_filenames, cg_filenames = [], [], [], []
for mapping in backmapping.map():
    _backmapped_filenames, _backmapped_minimised_filenames, _true_filenames, _cg_filenames = backmapping.backmap(mapping, optimise_backbone=False)
    backmapped_filenames.extend(_backmapped_filenames)
    backmapped_minimised_filenames.extend(_backmapped_minimised_filenames)
    true_filenames.extend(_true_filenames)
    cg_filenames.extend(_cg_filenames)

for tag in ['backmapped']:
    joinPDBs(backmapping.config.get("output"), tag)

In [None]:
import MDAnalysis as mda
from MDAnalysis.analysis import align, rms

def show(backmapped_filenames, cg_filenames, true_filenames):
    backmapped_u = mda.Universe(*backmapped_filenames)
    cg_u = mda.Universe(*cg_filenames)
    # backmapped_minimised_u = mda.Universe(*backmapped_minimised_filenames)

    ref_u = None
    if len(true_filenames) > 0:
        ref_u = mda.Universe(*true_filenames)

        # aligner = align.AlignTraj(
        #             backmapped_u,  # mobile
        #             ref_u,         # reference
        #             select='all', # selection to operate on
        #             in_memory=True,
        #             match_atoms=True,
        #         ).run()
        # aligned_rmsd = rms.rmsd(backmapped_u.atoms.positions, ref_u.atoms.positions, superposition=False)
        # print(aligned_rmsd)
    
    from nglview import NGLWidget
    w = NGLWidget(representations=None)
    w._remove_representation()
    w.add_trajectory(cg_u)
    w.add_trajectory(backmapped_u)
    # w.add_trajectory(backmapped_minimised_u)
    # if ref_u is not None:
    #     w.add_trajectory(ref_u)
    
    # w.add_representation('spacefill', radius=.5, selection='.RE .BB', color='pink')
    # w.add_representation('spacefill', radius=.5, selection='.SC1 .SC2 .SC3 .SC4 .SC5', color='yellow')
    w.add_representation('spacefill', radius=.5, selection='not (_H _C _N _O)', color='pink')
    w.add_representation('licorice',  selection='_H _C _N _O')
    return w

show(backmapped_filenames, cg_filenames, true_filenames)

In [None]:
show(backmapped_filenames, cg_filenames, true_filenames)