In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import glob
import torch
import numpy as np
import pandas as pd

from heqbm.utils import DataDict
from heqbm.utils.geometry import get_RMSD, get_dih_loss
from heqbm.utils.hierarchical_backmapping import HierarchicalBackmapping

torch.set_default_dtype(torch.float32)

In [None]:
backmapping = HierarchicalBackmapping(config_filename="config/backmapping/A2A-protein.yaml")

In [None]:
for k,v in backmapping.mapping.dataset.items():
    print(k, v.shape if isinstance(v, np.ndarray) else len(v))

In [None]:
n_frames = 1

rmsd_data = []
backmapped_u = None
for frame_index in range(0, n_frames):
    backmapping_dataset = backmapping.backmap(frame_index=frame_index, backmap_backbone=True)

    try:
        # Show prediction errors, if ground truth is present #
        data = {}
        if DataDict.ATOM_POSITION in backmapping_dataset:
            fltr_all = np.array([an.split('_')[1] not in [] for an in backmapping_dataset[DataDict.ATOM_NAMES]])
            fltr_bb = np.array([an.split('_')[1] in ["CA", "O", "C", "N"] for an in backmapping_dataset[DataDict.ATOM_NAMES]])
            data.update({
                "frame": frame_index,
                "RMSD All": get_RMSD(backmapping_dataset[DataDict.ATOM_POSITION_PRED], backmapping_dataset[DataDict.ATOM_POSITION], fltr=fltr_all),
                "RMSD BB": get_RMSD(backmapping_dataset[DataDict.ATOM_POSITION_PRED], backmapping_dataset[DataDict.ATOM_POSITION], fltr=fltr_bb),
            })
        if DataDict.BB_PHIPSI in backmapping_dataset:
            data.update({
                "Dih Loss": get_dih_loss(backmapping_dataset[DataDict.BB_PHIPSI_PRED], backmapping_dataset[DataDict.BB_PHIPSI], ignore_zeroes=True),
            })
        rmsd_data.append(data)
    except:
        pass

    atom_filter = None # np.array([an in ["CA", "C", "O", "N"] for an in atomnames])
    backmapped_u = backmapping.to_pdb(
        backmapping_dataset=backmapping_dataset,
        n_frames=n_frames,
        frame_index=frame_index,
        selection=backmapping.config.get("selection", "protein"),
        folder=backmapping.config.get("output_folder"),
        atom_filter=atom_filter,
        previous_u=backmapped_u,
    )

def join_pdb(data_dir, tag):
    i = 1
    files_iterator = glob.glob(f'{data_dir}/**{tag}_*.pdb', recursive=True)
    if len(files_iterator) == 0:
        return
    with open(os.path.join(data_dir, f'multi_{tag}.pdb'), "w") as f_out:
        for filename in files_iterator:
            f_out.write(f"MODEL     {i}\n")
            with open(filename, "r") as f_in:
                txt = f_in.read()
                txt = txt.replace("\nEND\n", "\n")
                f_out.write(txt)
            f_out.write(f"ENDMDL\n")
            i += 1
            os.remove(filename)
        f_out.write(f"END\n")

for tag in ['original_CG', 'final_CG', 'backmapped', 'true']:
    join_pdb(backmapping.config.get("output_folder"), tag)

df = None
if len(rmsd_data) > 0:
    df = pd.DataFrame.from_records(rmsd_data)

In [None]:
df

In [None]:
import nglview as nv
w = nv.show_mdanalysis(backmapped_u)
w._remove_representation()
w.add_representation('cartoon', selection='protein', color='blue')
w.add_representation('licorice', selection='all')
w

Analyse Chi Dihedral angles

In [None]:
import MDAnalysis as mda
from MDAnalysis.analysis.dihedrals import Janin

In [None]:
u = mda.Universe('backmapped/CRF1R/atomistic/protein/multi_true.pdb')
protein = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
janin = Janin(protein).run()
janin.plot(ref=True, marker='.', color='black')

In [None]:
u = mda.Universe('backmapped/CRF1R/atomistic/protein/multi_backmapped.pdb')
protein = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
janin = Janin(protein).run()
janin.plot(ref=True, marker='.', color='black')

In [None]:
u = mda.Universe('output1.pdb')
protein = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
janin = Janin(protein).run()
janin.plot(ref=True, marker='.', color='black')

In [None]:
u = mda.Universe('output2.pdb')
protein = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
janin = Janin(protein).run()
janin.plot(ref=True, marker='.', color='black')

In [None]:
# from matplotlib import pyplot as plt

# resnames = [
#     'ARG', 'ASN',
#     'ASP', 'GLN',
#     'GLU', 'HIE',
#     'HID', 'HIS',
#     'ILE', 'LEU',
#     'LYS', 'MET',
#     'TRP', 'TYR',
# ]

# u_pred = mda.Universe('backmapped/A2A/atomistic/protein/multi_backmapped.pdb')
# u_true = mda.Universe('backmapped/A2A/atomistic/protein/multi_true.pdb')

# fig, axs = plt.subplots(14, 2, figsize=(16, 8*14))

# for resname, (ax_pred, ax_true) in zip(resnames, axs):
#     residue_pred = u_pred.select_atoms(f'protein and resname {resname}')
#     janin_pred = Janin(residue_pred).run()
#     janin_pred.plot(ax=ax_pred, ref=True, marker='.', color='red', )
#     ax_pred.title.set_text(resname)

#     residue_true = u_true.select_atoms(f'protein and resname {resname}')
#     janin_true = Janin(residue_true).run()
#     janin_true.plot(ax=ax_true, ref=True, marker='.', color='black')
#     ax_true.title.set_text(resname)