In [None]:
import os
import glob
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis.rms import RMSD
from os.path import basename
from MDAnalysis.analysis.dihedrals import Ramachandran, Janin

from scipy.stats import norm
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

In [None]:
ROOT_TARGET = "backmapped/A2A/CG/protein/minimized/"
ROOT_REF = "backmapped/A2A/atomistic/protein/"

# ROOT_TARGET = "/home/angiod@usi.ch/GenZProt/scripts/result_model_seed_12345_PED00218_CA_trace"
# ROOT_REF = "/home/angiod@usi.ch/GenZProt/scripts/result_model_seed_12345_PED00218_CA_trace"

# Compute Steric Clashes #

In [None]:
TAG_TARGET = "backmapped"

for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    u = mda.Universe(trg_filename)
    protein = u.select_atoms(f'protein and not (type H)')
    clash_ratios = []
    for ts in u.trajectory:
        pos = protein.positions
        dist = np.linalg.norm(pos[None, ...] - pos[:, None, ...], axis=-1)
        dist = dist[np.triu_indices(len(dist), k=1)]
        clash_ratio = dist[dist < 1.2].shape[0] / dist[dist < 5.].shape[0]
        clash_ratios.append(clash_ratio)
    clash_ratios = np.array(clash_ratios)
    print(clash_ratios.mean(), clash_ratios.var())

# Compute RMSD distribution #

In [None]:
TAG_TARGET = "backmapped_fixed_"
TAG_REF = "true_fixed_"

# TAG_TARGET = "sample_traj"
# TAG_REF = "true"

def get_rmsd(ref_frame=0):
    rmsd = []
    for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
        src_filename = os.path.join(ROOT_REF, basename(trg_filename).replace(TAG_TARGET, TAG_REF))
        u = mda.Universe(trg_filename)
        ref = mda.Universe(src_filename)

        R = RMSD(u, ref,
                select="protein and backbone and not (resname ACE NME or name OXT or type H)",
                groupselections=[
                    "protein and (not backbone) and not (resname ACE NME or name OXT or type H)",
                ],
                ref_frame=ref_frame,
        )
        R.run()
        frame_rmsd = R.rmsd[:, 2:]
        rmsd.append(frame_rmsd)
    return np.concatenate(rmsd, axis=0).T

def plot_rmsd_distribution(rmsd):
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    ax.hist(rmsd[0], density=False, bins=20, histtype='step', linewidth=1, edgecolor='orange',fill=False, label="BB")
    ax.hist(rmsd[1], density=False, bins=20, histtype='step', linewidth=1, edgecolor='c',fill=False, label="Side-Chains")

    BB, SC = rmsd[:, np.argmax(rmsd[0])]
    # ax.vlines(BB, 0, 8, colors='k', linestyles='--', label='')
    # ax.vlines(SC, 0, 8, colors='c', linestyles='--', label='')

    mu, std = norm.fit(rmsd[0])
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    p_BB = norm.pdf(x, mu, std)
    ax.plot(x, p_BB, 'k', linewidth=.5)

    mu, std = norm.fit(rmsd[1])
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    p_BB = norm.pdf(x, mu, std)
    ax.plot(x, p_BB, 'c', linewidth=.5)

    ax.legend(loc="best")
    ax.set_xlabel(r"RMSD ($\AA$)")
    ax.set_ylabel("# of Frames")
    plt.title(f"RMSD Distribution for {os.path.join(ROOT_TARGET, TAG_TARGET)}")
    plt.show()
    return fig

# for i in range(20):
#     print(get_rmsd(ref_frame=i).min(axis=1))

rmsd = get_rmsd(ref_frame=0)
fig = plot_rmsd_distribution(rmsd)
# fig.savefig("rmsd_all_CORE_LID_NMP_ref1AKE.pdf")

print(f"Mean RMSD: {rmsd.mean(axis=1)}")
print(f"STD: {rmsd.std(axis=1)}")
print(f"ID of frame with Max RMSD: {np.argmax(rmsd[0])}")

# Plot Dihedrals #

In [None]:
def plot_rama_janin(dih_list, dih_list_true, title: str = None, ref_back: bool = True):
    for (rama, janin), (rama_true, janin_true) in zip(dih_list, dih_list_true):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 20), facecolor='white')
        if title is not None:
            fig.suptitle(title)
        ax1.set_title(r'Ramachandran Plot', loc='center', fontstyle='oblique', fontsize='medium')
        rama.plot(ax=ax1, ref=ref_back, marker='x', color='black', s=10)
        if rama_true is not None:
            rama_true_scatter = rama_true.results.angles.reshape(np.prod(rama_true.results.angles.shape[:2]), 2)
            ax1.scatter(rama_true_scatter[:, 0], rama_true_scatter[:, 1], marker='X', s=20, alpha=0.2, color='black', edgecolors='red', linewidth=2)
        
        ax2.set_title(r'$\chi_1$-$\chi_2$ Distribution', loc='center', fontstyle='oblique', fontsize='medium')
        janin.plot(ax=ax2, ref=ref_back, marker='x', color='black', s=10)
        if janin_true is not None:
            janin_true_scatter = janin_true.results.angles.reshape(np.prod(janin_true.results.angles.shape[:2]), 2)
            ax2.scatter(janin_true_scatter[:, 0], janin_true_scatter[:, 1], marker='X', s=20, alpha=0.2, color='black', edgecolors='red', linewidth=2)
        plt.show()

def compute_dihedrals(filename: str):
    u = mda.Universe(filename)
    protein = u.select_atoms(f'protein')
    rama = Ramachandran(protein).run()
    protein_chi = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
    janin = Janin(protein_chi).run()

    return rama, janin

def normalize(x):
    x = x / 180 * np.pi
    if np.any(x > np.pi):
        x -= np.pi
    return x

def plot_distribution(dih_list, dih_list_true, dataset: str, figsize=(40, 5), bins=60):
        csfont = {'fontname':'Comic Sans MS'}

        fig = plt.figure(figsize=(12, 6), facecolor='white')
        ax1 = plt.subplot(1,2,1)
        # ax2 = plt.subplot(2,2,2)
        # ax3 = plt.subplot(2,2,4)

        plt.subplots_adjust(
            left=0.1,
            bottom=0.1,
            right=0.9,
            top=0.9,
            wspace=0.02,
            hspace=0.02,
        )

        phi_psi = normalize(np.concatenate([dih[0].results.angles.reshape(np.prod(dih[0].results.angles.shape[:2]), 2) for dih in dih_list], axis=0))
        phi_psi_true = normalize(np.concatenate([dih[0].results.angles.reshape(np.prod(dih[0].results.angles.shape[:2]), 2) for dih in dih_list_true], axis=0))
        
        ax1.set_title(f'{dataset} - Ramachandran Plot', **csfont)
        sns.kdeplot(
            x=phi_psi[:, 0],
            y=phi_psi[:, 1],
            cmap=sns.color_palette(f"blend:#EEE,{sns.color_palette().as_hex()[0]}", as_cmap=True),
            fill=True, thresh=0.05, ax=ax1, levels=10, bw=0.18
        )
        sns.kdeplot(
            x=phi_psi[:, 0],
            y=phi_psi[:, 1],
            color=sns.color_palette()[0],
            fill=False, thresh=0.05, ax=ax1, levels=10, linewidths=0.1, bw=0.18
        )
        sns.kdeplot(
            x=phi_psi_true[:, 0],
            y=phi_psi_true[:, 1],
            color=sns.color_palette()[1],
            fill=False, thresh=0.005, ax=ax1, levels=10, linewidths=0.5, bw=0.18
        )
        ax1.set_xlim(xmin=-np.pi, xmax=np.pi)
        ax1.set_ylim(ymin=-np.pi, ymax=np.pi)
        ax1.set_xlabel('Phi [rad]')
        ax1.set_ylabel('Psi [rad]')
        ax1.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        ax1.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi$/2', r'$\pi$'])
        ax1.set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        ax1.set_yticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])

        chi1_chi2 = normalize(np.concatenate([dih[1].results.angles.reshape(np.prod(dih[1].results.angles.shape[:2]), 2) for dih in dih_list], axis=0))
        chi1_chi2_true = normalize(np.concatenate([dih[1].results.angles.reshape(np.prod(dih[1].results.angles.shape[:2]), 2) for dih in dih_list_true], axis=0))
        
        chi1_data = np.concatenate([chi1_chi2[:, 0], chi1_chi2_true[:, 0]])
        # df = pd.DataFrame(
        #     {
        #         "Angle [rad]": chi1_data,
        #         "": np.array(
        #             ["Backmapped"]*len(chi1_chi2) + ["Original"]*len(chi1_chi2_true)
        #             ),
        #     }
        # )
        # sns.histplot(data=df,
        #             x="Angle [rad]",
        #             bins=bins,
        #             hue="",
        #             stat = "probability",
        #             element="poly",
        #             ax=ax2,
        # ).set_title(r"$\chi1$ Angle Distribution | $\chi2$ Angle Distribution", **csfont)
        # ax2.set_xlim(xmin=-np.pi, xmax=np.pi)
        # ax2.set_yticks([])
        # ax2.set_ylabel('')
        # ax2.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        # ax2.set_xticklabels(['', '','', '', ''])

        # chi2_data = np.concatenate([chi1_chi2[:, 1], chi1_chi2_true[:, 1]])
        # df = pd.DataFrame(
        #     {
        #         "Angle [rad]": chi2_data,
        #         "": np.array(
        #             ["Backmapped"]*len(chi1_chi2) + ["Original"]*len(chi1_chi2_true)
        #             ),
        #     }
        # )
        # sns.histplot(data=df,
        #             x="Angle [rad]",
        #             bins=bins,
        #             hue="",
        #             stat = "probability",
        #             element="poly",
        #             ax=ax3,
        # )
        # ax3.set_xlim(xmin=-np.pi, xmax=np.pi)
        # ax3.set_yticks([])
        # ax3.set_ylabel('')
        # ax3.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        # ax3.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])
        plt.show()

In [None]:
TAG_TARGET = "backmapped_fixed"
TAG_REF = "true_fixed"
PLOT_DISTRIBUTION = True

dih_list = []
for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    rama, janin = compute_dihedrals(trg_filename)
    dih_list.append((rama, janin))

dih_list_true = []
for src_filename in glob.glob(os.path.join(ROOT_REF, f"{TAG_REF}*.pdb")):
    rama_true, janin_true = compute_dihedrals(src_filename)
    dih_list_true.append((rama_true, janin_true))

In [None]:
if PLOT_DISTRIBUTION:
    plot_distribution(dih_list, dih_list_true, dataset="HEqBM - CG-A2A", figsize=(80, 10), bins=90)
else:
    plot_rama_janin(dih_list, dih_list_true, title="Back-mapped", ref_back=True)