In [None]:
import os
import glob
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis.rms import RMSD
from MDAnalysis.analysis.rdf import InterRDF
from MDAnalysis.transformations.boxdimensions import set_dimensions
from os.path import basename
from MDAnalysis.analysis.dihedrals import Ramachandran, Janin

from scipy.stats import norm
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

In [None]:
ROOT_TARGET = "comparison/cg2at/" # "backmapped/CG-all/protein-whole/"
ROOT_REF = "backmapped/A2A/atomistic/protein/minimised/"

# ROOT_TARGET = "/home/angiod@usi.ch/GenZProt/scripts/result_model_seed_12345_PED00218_CA_trace"
# ROOT_REF = "/home/angiod@usi.ch/GenZProt/scripts/result_model_seed_12345_PED00218_CA_trace"

# Compute Steric Clashes #

In [None]:
TAG_TARGET = "all-whole"

for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    u = mda.Universe(trg_filename)
    protein = u.select_atoms(f'not (type H)')
    clash_ratios = []
    for ts in u.trajectory:
        pos = protein.positions
        dist = np.linalg.norm(pos[None, ...] - pos[:, None, ...], axis=-1)
        dist = dist[np.triu_indices(len(dist), k=1)]
        clash_ratio = dist[dist < 1.2].shape[0] / dist[dist < 5.].shape[0]
        clash_ratios.append(clash_ratio)
    clash_ratios = np.array(clash_ratios)
    print(clash_ratios.mean(), clash_ratios.var())

# Compute RMSD distribution #

In [None]:
TAG_TARGET = "backmapped_fixed_"
TAG_REF = "true_fixed_"

SELECT = "protein and backbone and not (resname ACE NME or name OXT or type H)"
GROUP_SELECTIONS = [
    "protein and (not backbone) and not (resname ACE NME or name OXT or type H)",
]

# SELECT='not type H'
# GROUP_SELECTIONS=[]

def get_rmsd(ref_frame=0):
    rmsd = []
    for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
        src_filename = os.path.join(ROOT_REF, basename(trg_filename).replace(TAG_TARGET, TAG_REF))
        u = mda.Universe(trg_filename)
        ref = mda.Universe(src_filename)

        R = RMSD(u, ref,
                select=SELECT,
                groupselections=GROUP_SELECTIONS,
                ref_frame=ref_frame,
        )
        R.run()
        frame_rmsd = R.rmsd[:, 2:]
        rmsd.append(frame_rmsd)
    return np.concatenate(rmsd, axis=0).T

def plot_rmsd_distribution(rmsd):
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    ax.hist(rmsd[0], density=False, bins=20, histtype='step', linewidth=1, edgecolor='orange',fill=False, label="Backbone")
    # ax.hist(rmsd[1], density=False, bins=20, histtype='step', linewidth=1, edgecolor='c',fill=False, label="Side-Chains")

    # BB, SC = rmsd[:, np.argmax(rmsd[0])]
    # ax.vlines(BB, 0, 8, colors='k', linestyles='--', label='')
    # ax.vlines(SC, 0, 8, colors='c', linestyles='--', label='')

    mu, std = norm.fit(rmsd[0])
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    p_BB = norm.pdf(x, mu, std)
    ax.plot(x, p_BB, 'k', linewidth=.5)

    # mu, std = norm.fit(rmsd[1])
    # xmin, xmax = plt.xlim()
    # x = np.linspace(xmin, xmax, 100)
    # p_SC = norm.pdf(x, mu, std)
    # ax.plot(x, p_SC, 'c', linewidth=.5)

    ax.legend(loc="best", prop={'size': 24})
    ax.set_xlabel(r"RMSD ($\AA$)", size=24)
    ax.set_ylabel("Count", size=24)
    # plt.title(f"RMSD Distribution for {os.path.join(ROOT_TARGET, TAG_TARGET)}")
    # plt.title(f"RMSD Distribution for PED model on PED00055 dataset")
    plt.show()
    return fig

# for i in range(20):
#     print(get_rmsd(ref_frame=i).min(axis=1))

rmsd = get_rmsd(ref_frame=0)
fig = plot_rmsd_distribution(rmsd)

# fig.savefig("RMSD_dist_PED00055.svg")

print(f"Mean RMSD: {rmsd.mean(axis=1)}")
print(f"STD: {rmsd.std(axis=1)}")
print(f"ID of frame with Max RMSD: {np.argmax(rmsd[0])}")

# Plot Dihedrals #

In [None]:
def plot_rama_janin(dih_list, dih_list_true, ref_back: bool = True):
    fig_rama, ax_rama = plt.subplots(figsize=(12, 12), facecolor='white')
    fig_janin, ax_janin = plt.subplots(figsize=(12, 12), facecolor='white')
    
    for (rama, janin) in dih_list:
        rama.plot(ax=ax_rama, ref=ref_back, marker='d', color='orange', edgecolors='black', s=25, linewidth=.5)
        janin.plot(ax=ax_janin, ref=ref_back, marker='d', color='orange', edgecolors='black', s=25, linewidth=.5)
        
    for (rama_true, janin_true) in dih_list_true:
        rama_true_scatter = rama_true.results.angles.reshape(np.prod(rama_true.results.angles.shape[:2]), 2)
        ax_rama.scatter(rama_true_scatter[:, 0], rama_true_scatter[:, 1], marker='x', s=20, alpha=0.2, color='black', edgecolors='red', linewidth=2)
        janin_true_scatter = janin_true.results.angles.reshape(np.prod(janin_true.results.angles.shape[:2]), 2)
        ax_janin.scatter(janin_true_scatter[:, 0], janin_true_scatter[:, 1], marker='x', s=20, alpha=0.2, color='black', edgecolors='red', linewidth=2)
    
    ax_rama.set_xlabel(r'')
    ax_rama.set_ylabel(r'')
    ax_janin.set_xlabel(r'')
    ax_janin.set_ylabel(r'')
    
    ax_rama.set_xticks([-180, -90, 0, 90, 180])
    ax_rama.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi$/2', r'$\pi$'])
    ax_rama.set_yticks([-180, -90, 0, 90, 180])
    ax_rama.set_yticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])
    
    ax_janin.set_xticks([0, 90, 180, 270, 360])
    ax_janin.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi$/2', r'$\pi$'])
    ax_janin.set_yticks([0, 90, 180, 270, 360])
    ax_janin.set_yticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])
    
    plt.show()
    return fig_rama, fig_janin

def compute_dihedrals(filename: str):
    u = mda.Universe(filename)
    protein = u.select_atoms(f'protein')
    rama = Ramachandran(protein).run()
    protein_chi = u.select_atoms(f'protein and resname ARG ASN ASP GLN GLU HIE HID HIS ILE LEU LYS MET TRP TYR')
    janin = Janin(protein_chi).run()

    return rama, janin

def normalize(x):
    x = x / 180 * np.pi
    if np.any(x > np.pi):
        x -= np.pi
    return x

def plot_distribution(dih_list, dih_list_true, title:str=None, thresh=0.05, ref_thresh=0.05, bins=60, show_chi=True):
        csfont = {'fontname':'Comic Sans MS'}

        fig = plt.figure(figsize=(12, 6), facecolor='white')
        ax1 = plt.subplot(1,2,1)
        if show_chi:
            ax2 = plt.subplot(2,2,2)
            ax3 = plt.subplot(2,2,4)

        plt.subplots_adjust(
            left=0.1,
            bottom=0.1,
            right=0.9,
            top=0.9,
            wspace=0.02,
            hspace=0.02,
        )

        phi_psi = normalize(np.concatenate([dih[0].results.angles.reshape(np.prod(dih[0].results.angles.shape[:2]), 2) for dih in dih_list], axis=0))
        
        if title is not None:
            ax1.set_title(title, **csfont)
        sns.kdeplot(
            x=phi_psi[:, 0],
            y=phi_psi[:, 1],
            cmap=sns.color_palette(f"blend:#EEE,{sns.color_palette().as_hex()[0]}", as_cmap=True),
            fill=True, thresh=thresh, ax=ax1, levels=10, bw=0.18
        )
        sns.kdeplot(
            x=phi_psi[:, 0],
            y=phi_psi[:, 1],
            color=sns.color_palette()[0],
            fill=False, thresh=thresh, ax=ax1, levels=10, linewidths=0.1, bw=0.18
        )

        if dih_list_true is not None and len(dih_list_true) > 0:
            phi_psi_true = normalize(np.concatenate([dih[0].results.angles.reshape(np.prod(dih[0].results.angles.shape[:2]), 2) for dih in dih_list_true], axis=0))
            sns.kdeplot(
                x=phi_psi_true[:, 0],
                y=phi_psi_true[:, 1],
                color=sns.color_palette()[1],
                fill=False, thresh=ref_thresh, ax=ax1, levels=10, linewidths=0.5, bw=0.18
            )
        ax1.set_xlim(xmin=-np.pi, xmax=np.pi)
        ax1.set_ylim(ymin=-np.pi, ymax=np.pi)
        # ax1.set_xlabel('Phi [rad]')
        # ax1.set_ylabel('Psi [rad]')
        ax1.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        ax1.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi$/2', r'$\pi$'])
        ax1.set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
        ax1.set_yticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])

        if show_chi:
            chi1_chi2 = normalize(np.concatenate([dih[1].results.angles.reshape(np.prod(dih[1].results.angles.shape[:2]), 2) for dih in dih_list], axis=0))
            if dih_list_true is not None and len(dih_list_true) > 0:
                chi1_chi2_true = normalize(np.concatenate([dih[1].results.angles.reshape(np.prod(dih[1].results.angles.shape[:2]), 2) for dih in dih_list_true], axis=0))
            else:
                chi1_chi2_true=np.array([[],[]])
        
            chi1_data = np.concatenate([chi1_chi2[:, 0], chi1_chi2_true[:, 0]])
            df = pd.DataFrame(
                {
                    "Angle [rad]": chi1_data,
                    "": np.array(
                        ["Backmapped"]*len(chi1_chi2) + ["Original"]*len(chi1_chi2_true)
                        ),
                }
            )
            sns.histplot(data=df,
                        x="Angle [rad]",
                        bins=bins,
                        hue="",
                        stat = "probability",
                        element="poly",
                        ax=ax2,
            ).set(xlabel=None) # .set_title(r"$\chi1$ Angle Distribution | $\chi2$ Angle Distribution", **csfont)
            ax2.set_xlim(xmin=-np.pi, xmax=np.pi)
            ax2.set_yticks([])
            ax2.set_ylabel('')
            ax2.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
            ax2.set_xticklabels(['', '','', '', ''])

            chi2_data = np.concatenate([chi1_chi2[:, 1], chi1_chi2_true[:, 1]])
            df = pd.DataFrame(
                {
                    "Torsion [rad]": chi2_data,
                    "": np.array(
                        ["Backmapped"]*len(chi1_chi2) + ["Original"]*len(chi1_chi2_true)
                        ),
                }
            )
            sns.histplot(data=df,
                        x="Torsion [rad]",
                        bins=bins,
                        hue="",
                        stat = "probability",
                        element="poly",
                        ax=ax3,
            ).set(xlabel=None)
            ax3.set_xlim(xmin=-np.pi, xmax=np.pi)
            ax3.set_yticks([])
            ax3.set_ylabel('')
            ax3.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
            ax3.set_xticklabels([r'-$\pi$', r'-$\pi/2$','0', r'$\pi/2$', r'$\pi$'])
        plt.show()
        return fig

In [None]:
TAG_TARGET = "4_"
TAG_REF = "xxx"
PLOT_DISTRIBUTION = False

dih_list = []
for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    print(trg_filename)
    rama, janin = compute_dihedrals(trg_filename)
    dih_list.append((rama, janin))

dih_list_true = []
for src_filename in glob.glob(os.path.join(ROOT_REF, f"{TAG_REF}*.pdb")):
    rama_true, janin_true = compute_dihedrals(src_filename)
    dih_list_true.append((rama_true, janin_true))

In [None]:
if PLOT_DISTRIBUTION:
    fig = plot_distribution(dih_list, dih_list_true, bins=90, thresh=0.005, ref_thresh=0.005, show_chi = False)
else:
    fig_rama, fig_janin = plot_rama_janin(dih_list, dih_list_true, ref_back=True)
fig_rama.savefig("cg2at-minwithcg2at-rama.svg")
fig_janin.savefig("cg2at-minwithcg2at-janin.svg")

# Plot RDF #

In [None]:
def compute_rdf(filename: str, sel1, sel2, range=None):
    u = mda.Universe(filename)
    d = np.max(u.atoms.positions) - np.min(u.atoms.positions)
    if range is None:
        range = (0,d)
    transform = set_dimensions([2*d, 2*d, 2*d, 90, 90, 90])
    u.trajectory.add_transformations(transform)
    s1 = u.select_atoms(sel1)
    s2 = u.select_atoms(sel2)
    rdf = InterRDF(s1, s2, range=range, nbins=100, norm="rdf", exclude_same="residue")
    rdf.run()
    density = u.atoms.n_atoms / mda.lib.mdamath.box_volume(u.dimensions)
    return rdf, density

TAG_TARGET = "backmapped_fixed_100"
TAG_REF = "true_fixed_100"
RANGE = None # (0, 20)
RADIAL_S1 = 'resname PA'
RADIAL_S2 = 'resname PA'

rdf_list = []
for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    _rdf, _ = compute_rdf(trg_filename, RADIAL_S1, RADIAL_S2, RANGE)
    rdf_list.append(_rdf)

rdf_list_true = []
for src_filename in glob.glob(os.path.join(ROOT_REF, f"{TAG_REF}*.pdb")):
    _rdf_true, density = compute_rdf(src_filename, RADIAL_S1, RADIAL_S2, RANGE)
    rdf_list_true.append(_rdf_true)

bins = rdf_list[0].bins
rdf = rdf_list[0].rdf
for rdf_ in rdf_list[1:]:
    rdf += rdf_.rdf

bins_true = rdf_list_true[0].bins
rdf_true = rdf_list_true[0].rdf
for rdf_ in rdf_list_true[1:]:
    rdf_true += rdf_.rdf

In [None]:
plt.plot(bins_true[:], density * rdf_true[:], label="Atomistic", c='k')
plt.plot(bins[:], density * rdf[:], label="HEqBM", c='r', linestyle='--')
plt.legend()
plt.savefig("PA.svg")

# RMSD per residue #

In [None]:
TAG_TARGET = "backmapped_fixed_"
TAG_REF = "true_fixed_"

rmsd_list = []
for trg_filename in glob.glob(os.path.join(ROOT_TARGET, f"{TAG_TARGET}*.pdb")):
    src_filename = os.path.join(ROOT_REF, basename(trg_filename).replace(TAG_TARGET, TAG_REF))
    u = mda.Universe(trg_filename)
    ref = mda.Universe(src_filename)

    for resid in range(len(u.residues)):
        selection = f"resindex {resid} and not (type H)"
        R = RMSD(u, ref,
                select=selection,
                groupselections=[],
                ref_frame=0,
        )
        R.run()
        rmsd = R.rmsd[:, 2]
        selection_h = f"resindex {resid}"
        rmsd_list.append(np.repeat(rmsd, len(u.select_atoms(selection_h).atoms)))

    rmsd_array = np.concatenate(rmsd_list).flatten()
    u.add_TopologyAttr('occupancy', rmsd_array)
    u.add_TopologyAttr('tempfactors', rmsd_array)
    u.select_atoms('all').write("asd.pdb")
    break