In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import pandas as pd

from heqbm.utils import DataDict
from heqbm.utils.geometry import get_RMSD, get_dih_loss
from heqbm.utils.pdbFixer import joinPDBs
from heqbm.backmapping.hierarchical_backmapping import HierarchicalBackmapping

from plotting import plot_cg_impl

torch.set_default_dtype(torch.float32)

In [None]:
backmapping = HierarchicalBackmapping(config_filename="config/backmapping/A2A/del.yaml")

In [None]:
plot_cg_impl(
    dataset=backmapping.mapping.dataset,
    frame_index=0,
    bead_filter=None, # np.arange(0,12),
    atom_filter=None, # np.arange(52),
    residue_filter=None,
)

In [None]:
for k,v in backmapping.mapping.dataset.items():
    print(k, v.shape if isinstance(v, np.ndarray) else len(v))

In [None]:
frame_idcs = range(0, min(9999, len(backmapping.mapping.dataset[DataDict.BEAD_POSITION])))
n_frames = max(frame_idcs) + 1

rmsd_data = []
backmapped_u = None
for frame_index in frame_idcs:
    backmapping_dataset = backmapping.backmap(frame_index=frame_index, optimize_backbone=True)

    try:
        # Show prediction errors, if ground truth is present #
        data = {}
        if DataDict.ATOM_POSITION in backmapping_dataset:
            fltr_all = np.array([an.split('_')[1] not in [] for an in backmapping_dataset[DataDict.ATOM_NAMES]])
            fltr_bb = np.array([an.split('_')[1] in ["CA", "O", "C", "N"] for an in backmapping_dataset[DataDict.ATOM_NAMES]])
            data.update({
                "frame": frame_index,
                "RMSD All": get_RMSD(backmapping_dataset[DataDict.ATOM_POSITION_PRED], backmapping_dataset[DataDict.ATOM_POSITION], fltr=fltr_all),
                "RMSD BB": get_RMSD(backmapping_dataset[DataDict.ATOM_POSITION_PRED], backmapping_dataset[DataDict.ATOM_POSITION], fltr=fltr_bb),
            })
        if DataDict.BB_PHIPSI in backmapping_dataset:
            data.update({
                "Dih Loss": get_dih_loss(backmapping_dataset[DataDict.BB_PHIPSI_PRED], backmapping_dataset[DataDict.BB_PHIPSI], ignore_zeroes=True),
            })
        rmsd_data.append(data)
    except:
        pass

    # atom_filter = np.array([an in ["CA", "C", "O", "N"] for an in atomnames])
    atom_filter = None
    backmapped_u = backmapping.to_pdb(
        backmapping_dataset=backmapping_dataset,
        n_frames=n_frames,
        frame_index=frame_index,
        selection=backmapping.config.get("selection", "protein"),
        folder=backmapping.config.get("output_folder"),
        atom_filter=atom_filter,
        previous_u=backmapped_u,
    )

for tag in ['original_CG', 'final_CG', 'backmapped', 'true']:
    joinPDBs(backmapping.config.get("output_folder"), tag)

df = None
if len(rmsd_data) > 0:
    df = pd.DataFrame.from_records(rmsd_data)

In [None]:
from simtk import openmm
import os
print(os.path.dirname(openmm.version.openmm_library_path), openmm.version.openmm_library_path)

In [None]:
import nglview as nv
w = nv.show_mdanalysis(backmapped_u)
w._remove_representation()
w.add_representation('cartoon', selection='protein', color='blue')
w.add_representation('licorice', selection='all')
w

In [5]:
import numpy as np
positions = np.array([[1.,0.,0], [0.,2.,0.], [1.,1.,0.], [2., 2., 0.]])

distance_vectors = positions[None, ...] - positions[:, None]

_idx_i, _idx_j = [], []
dist = np.linalg.norm(distance_vectors, axis=-1)
for i in range(0, len(dist)):
    for j in range(i+1, len(dist)):
        _idx_i.append(i)
        _idx_j.append(j)
        _idx_j.append(i)
        _idx_i.append(j)
_idx_i = np.array(_idx_i)
_idx_j = np.array(_idx_j)

offset = np.zeros(len(positions, 3), dtype=np.float32)

In [18]:
_idx_i

array([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3])

In [19]:
_idx_j

array([1, 0, 2, 0, 3, 0, 2, 1, 3, 1, 3, 2])