# Processing molecular dynamics trajectories #

In this experiment we showcase how the tools from Riemannian geometry help us understand protein data sets. In particular, we want to show how 
* having geodesics under a suitable Riemannian metric can preserve important features of the data
* having a suitable notion of distance is useful in computing a natural mean of the data
* having a logarithmic mappings is useful in dimension reduction and having an exponential mapping is useful for visualizing the subspace

In [22]:
import mdtraj as md
import numpy as np
import torch
import os
import matplotlib.pyplot as plt
import sys

plt.rcParams.update({'font.size': 18})
plt.rcParams.update({'figure.autolayout': True})
%cd ..    
from src.manifolds.pointcloud import PointCloud
%cd experiments

/home/sc.uni-leipzig.de/nq194gori/github/Riemannian-geometry-for-efficient-analysis-of-protein-dynamics-data
/home/sc.uni-leipzig.de/nq194gori/github/Riemannian-geometry-for-efficient-analysis-of-protein-dynamics-data/experiments


In [23]:
results_folder = os.path.join(os.getcwd(), "results")

## Load data ##

In [24]:
struct = 2

wd = ".."
data_folder = os.path.join("data", "molecular_dynamics")

if struct == 1:
    trajectory_path = os.path.join(wd, data_folder, "4ake", "dims0001_fit-core.dcd")
    topology_path = os.path.join(wd, data_folder, "4ake", "adk4ake.psf")
    results_folder = os.path.join(results_folder, "4ake")
    fig_prefix = "4ake"
    bbox = 20
elif struct == 2:
    trajectory_path = os.path.join(wd, data_folder, "covid_spike", "MDtraj_sarscov_2.dcd")
    topology_path = os.path.join(wd, data_folder, "covid_spike", "DESRES-Trajectory_sarscov2-12212688-5-2-no-water.pdb")
    results_folder = os.path.join(results_folder, "covid_spike")
    fig_prefix = "covid_spike"
    bbox = 40

t = md.load(trajectory_path, top = topology_path)

# get Calpha positions
indices = []

for m in t.topology.atoms_by_name('CA'):
    indices.append(m.index)
    
ca_pos = 10 * torch.tensor(t.xyz[:,indices,:])
# if struct == 2:
#     ca_pos = ca_pos[0:-1:2]

In [25]:
# construct manifold
num_proteins = ca_pos.shape[0] 
print(num_proteins)
protein_len = ca_pos.shape[1]
manifold = PointCloud(3, protein_len, base=ca_pos[0], alpha=1.)
# constuct rotation matrix
rot_xz = torch.zeros(3,3)
rot_xz[2,0] = 1.
rot_xz[1,1] = 1.
rot_xz[0,2] = -1.
manifold.base_point = torch.einsum("ba,ia->ib", rot_xz, manifold.base_point)
rot_xy = torch.zeros(3,3)
theta = torch.tensor([- torch.pi * 1/3])
rot_xy[0,0] = torch.cos(theta)
rot_xy[0,1] = - torch.sin(theta)
rot_xy[1,0] = torch.sin(theta)
rot_xy[1,1] = torch.cos(theta)
rot_xy[2,2] = 1.
manifold.base_point = torch.einsum("ba,ia->ib", rot_xy, manifold.base_point)

200


In [26]:
# align all proteins with base
proteins = manifold.align_mpoint(ca_pos[None], base=manifold.base_point).squeeze()

In [None]:
proteins.shape

## Separation-geodesic interpolating ##

In [None]:
t_steps = 21
p0 = proteins[0]
p1 = proteins[-1]
T = torch.linspace(0,1,t_steps) # torch.tensor([1/4,1/2,3/4]) # torch.tensor([1/2])
pt = torch.zeros(t_steps, protein_len, 3)
mdt = torch.zeros(t_steps, protein_len, 3)
for i,t in enumerate(T):
    print(f"computing geodesic {i+1}")
    md_ind = int(i/(t_steps-1) * (num_proteins-1))
    pt[i] = manifold.s_geodesic(p0[None,None], p1[None,None], torch.tensor([t]), debug=True).squeeze()
    mdt[i] = proteins[md_ind]
    
%timeit manifold.s_geodesic(p0[None,None], p1[None,None], torch.tensor([T[int(t_steps/2)]]))

In [27]:
import torch.multiprocessing as mp

def compute_geodesic(i, t, p0, p1, pt, mdt, manifold, proteins, num_proteins, t_steps):
    print(f"computing geodesic {i+1}")
    md_ind = int(i/(t_steps-1) * (num_proteins-1))
    pt[i] = manifold.s_geodesic(p0[None,None], p1[None,None], torch.tensor([t]), debug=True).squeeze()
    mdt[i] = proteins[md_ind]

t_steps = 21
p0 = proteins[0]
p1 = proteins[-1]
T = torch.linspace(0,1,t_steps)

pt = mp.Array('f', t_steps * protein_len * 3)
pt = torch.from_numpy(np.frombuffer(pt.get_obj(), dtype=np.float32).reshape(t_steps, protein_len, 3))

mdt = mp.Array('f', t_steps * protein_len * 3)
mdt = torch.from_numpy(np.frombuffer(mdt.get_obj(), dtype=np.float32).reshape(t_steps, protein_len, 3))

processes = []
for i, t in enumerate(T):
    p = mp.Process(target=compute_geodesic, args=(i, t, p0, p1, pt, mdt, manifold, proteins, num_proteins, t_steps))
    p.start()
    processes.append(p)

# Wait for all processes to finish
for p in processes:
    p.join()

%timeit manifold.s_geodesic(p0[None,None], p1[None,None], torch.tensor([T[int(t_steps/2)]]))

computing geodesic 1
computing geodesic 2
computing geodesic 3
computing geodesic 4
computing geodesic 5
computing geodesic 6
computing geodesic 7
computing geodesic 8
computing geodesic 9
computing geodesic 10
computing geodesic 11
computing geodesic 12
computing geodesic 13
computing geodesic 14
computing geodesic 15
computing geodesic 16
computing geodesic 17
computing geodesic 18
computing geodesic 19
computing geodesic 21
computing geodesic 20


In [None]:
# Plot 
formt = '.png'
num_figs = 5
fig_size = 31 * (protein_len/214) ** (1/3) # cm

for i in range(num_figs):
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    fig = plt.figure(figsize=(fig_size/5, fig_size/5))
    ax = plt.axes(projection='3d')
    ax.plot3D(pt[g_ind,:,0],pt[g_ind,:,1],pt[g_ind,:,2], 'black')
    ax.scatter(pt[g_ind,:,0],pt[g_ind,:,1],pt[g_ind,:,2], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    ax.axes.set_xlim3d(left=-bbox, right=bbox) 
    ax.axes.set_ylim3d(bottom=-bbox, top=bbox) 
    ax.axes.set_zlim3d(bottom=-bbox, top=bbox) 
    ax.set_axis_off() 
    extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_{g_ind}f{t_steps-1}_geo'+formt), bbox_inches=extent.expanded(0.8, 0.8))
plt.show()


for i in range(num_figs):
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    md_ind = int(i/(num_figs-1) * (num_proteins-1))
    fig = plt.figure(figsize=(fig_size/5, fig_size/5))
    ax = plt.axes(projection='3d')
    ax.plot3D(mdt[g_ind,:,0],mdt[g_ind,:,1],mdt[g_ind,:,2], 'black')
    ax.scatter(mdt[g_ind,:,0],mdt[g_ind,:,1],mdt[g_ind,:,2], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    ax.axes.set_xlim3d(left=-bbox, right=bbox) 
    ax.axes.set_ylim3d(bottom=-bbox, top=bbox) 
    ax.axes.set_zlim3d(bottom=-bbox, top=bbox) 
    ax.set_axis_off() 
    extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_{g_ind}f{t_steps-1}_md_{md_ind}'+formt), bbox_inches=extent.expanded(0.8, 0.8))
plt.show()

In [None]:
# compute RMSD in Anstrom for geodesics vs trajectory
rmsd_T = torch.sqrt(torch.sum((pt - mdt) ** 2,[1,2]) / protein_len)

# plot 
fig_size = 21 # cm
plt.figure(figsize=(fig_size/4, fig_size/4))
plt.plot(T, rmsd_T, 'tab:red')
plt.xlim([0,1])
plt.ylim([0,rmsd_T.max()+1])
plt.xlabel(r'$t$')
plt.ylabel(r'RMSD from MD simulation ($\AA$)')
# save figure
plt.savefig(os.path.join(results_folder,fig_prefix +f'_RMSD_progression_md'+formt))
plt.show


In [None]:
d_T = torch.sqrt(torch.sum((pt - mdt) ** 2, -1))
d_total = torch.sqrt(torch.sum((proteins[0] - proteins[-1]) ** 2, -1))

max_error = int(torch.max(d_T)) + 1
max_displacement = int(torch.max(d_total)) + 1

num_figs = 3
fig_size = 21 # cm
for i in range(num_figs):
    g_ind = int((i+1)/(num_figs+1) * (t_steps-1))
    ax = plt.figure(figsize=(fig_size/4, fig_size/4))
    plt.hist(d_T[g_ind][None], bins=50)
    plt.xlim([0, max_error])
    plt.ylabel('Frequency')
    plt.xlabel(r'Deviation from MD simulation ($\AA$)')
plt.show()

# fig = plt.figure(figsize=(fig_size, fig_size/num_figs))
for i in range(num_figs):
    g_ind = int((i+1)/(num_figs+1) * (t_steps-1))
    md_ind = int((i+1)/(num_figs+1) * (num_proteins-1))
    plt.figure(figsize=(fig_size/4, fig_size/4))
    plt.scatter(d_total,d_T[g_ind], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    plt.ylim([0, max_error])
    plt.xlim([0, max_displacement])
    plt.xlabel(r'Total displacement ($\AA$)')
    plt.ylabel(r'Deviation from MD simulation ($\AA$)')
    # save figure
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_displacement_vs_error_{g_ind}f{t_steps-1}_md_{md_ind}'+formt))
plt.show()

In [None]:
pt_pairwise_distances = manifold.pairwise_distances(pt[None]).squeeze()
pt_adj_pairwise_distances = torch.diagonal(pt_pairwise_distances,offset=1, dim1=1, dim2=2)

num_figs = 3
fig_size = 21 # cm

for i in range(num_figs):
    plt.figure(figsize=(fig_size, fig_size/4))
    g_ind = int((i+1)/(num_figs+1) * (t_steps-1))
    plt.plot(range(protein_len-1), torch.sqrt(pt_adj_pairwise_distances[0]), label=r'$t=0$')
    plt.plot(range(protein_len-1), torch.sqrt(pt_adj_pairwise_distances[g_ind]), label=r'$t=$' +f'{T[g_ind]}')
    plt.plot(range(protein_len-1), torch.sqrt(pt_adj_pairwise_distances[-1]), label=r'$t=1$')
    plt.xlim([0, protein_len-2])
    plt.ylim([2.8,4.8])
    plt.ylabel(r'$\|\|\mathbf{x}_i - \mathbf{x}_{i+1}\|\|_2$ ($\AA$)')
    plt.xlabel(r'$i$')
    plt.legend()
    # save figure
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_adjacent_residue_{g_ind}f{t_steps-1}'+formt)) 
plt.show()

## Separation-barycentre of the trajectory ##

In [None]:
p_barycentre = manifold.s_mean(proteins[None], x0=proteins[int(num_proteins/2)][None,None], debug=True).squeeze()
%timeit manifold.s_mean(proteins[None], x0=proteins[int(num_proteins/2)][None,None]).squeeze()

In [None]:
p_barycentre_pairwise_distances = manifold.pairwise_distances(p_barycentre[None,None]).squeeze()
p_barycentre_adj_pairwise_distances = torch.diagonal(p_barycentre_pairwise_distances,1)

In [None]:
fig_size = 31 * (protein_len/214) ** (1/3) # cm

fig = plt.figure(figsize=(fig_size/5, fig_size/5))
ax = plt.axes(projection='3d')
ax.plot3D(p_barycentre[:,0], p_barycentre[:,1], p_barycentre[:,2], 'black')
ax.scatter(p_barycentre[:,0], p_barycentre[:,1], p_barycentre[:,2], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
ax.axes.set_xlim3d(left=-bbox, right=bbox) 
ax.axes.set_ylim3d(bottom=-bbox, top=bbox) 
ax.axes.set_zlim3d(bottom=-bbox, top=bbox) 
ax.set_axis_off() 
extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plt.savefig(os.path.join(results_folder,f'4ake_barycentre'+formt), bbox_inches=extent.expanded(0.8, 0.8))
plt.show()

fig_size = 21 # cm
plt.figure(figsize=(fig_size, fig_size/4))
plt.plot(range(protein_len-1), torch.sqrt(p_barycentre_adj_pairwise_distances), color='tab:orange')
plt.xlim([0, protein_len-2])
plt.ylim([2.8,4.8])
plt.ylabel(r'$\|\|\mathbf{x}_i - \mathbf{x}_{i+1}\|\|_2$ ($\AA$)')
plt.xlabel(r'$i$')
# save figure
# extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plt.savefig(os.path.join(results_folder,fig_prefix +f'_adjacent_residue_barycentre'+formt)) 
plt.show()

## Separation-logarithmic map for low rank approximation ##

In [None]:
# compute logs to all points from pt
log_p_barycentre = manifold.s_log(p_barycentre[None,None], proteins[None]).squeeze()
%timeit manifold.s_log(p_barycentre[None,None], proteins[None]).squeeze()

In [None]:
# compute gramm matrix
# Gramm_mat = manifold.inner(p_barycentre[None,None], log_p_barycentre[None,None],log_p_barycentre[None,None]).squeeze()
Gramm_mat = torch.einsum("Nia,Mia->NM",log_p_barycentre,log_p_barycentre)
L, U = torch.linalg.eigh(Gramm_mat)
R_p_barycentre = torch.einsum("NM,Nia->Mia", U, log_p_barycentre)

# print eigenvalues
print(L)
th = 0.85
rank = int(torch.linspace(1,len(L),len(L))[L.flip(0).cumsum(0)/L.sum() > th].min())
print(rank)

plt.figure()
plt.plot(torch.linspace(1,len(L),len(L)), L.flip(0).cumsum(0)/L.sum())
plt.xlim([1, len(L)])
plt.ylim([0, 1])
plt.show()

In [None]:
s = 0.2
fig_size = 31 * (protein_len/214) ** (1/3) # cm
num_figs = rank +2
for i in range(num_figs):
    plt.figure(figsize=(fig_size/5, fig_size/5))
    ax = plt.axes(projection='3d')
    ax.plot3D(p_barycentre[:,0], p_barycentre[:,1], p_barycentre[:,2], 'black')
    ax.scatter(p_barycentre[:,0], p_barycentre[:,1], p_barycentre[:,2], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    ax.quiver(p_barycentre[:, 0], p_barycentre[:, 1], p_barycentre[:, 2], s * R_p_barycentre[-(i+1),:, 0], s * R_p_barycentre[-(i+1),:, 1], s * R_p_barycentre[-(i+1),:, 2],color='tab:orange')
    ax.axes.set_xlim3d(left=-bbox, right=bbox) 
    ax.axes.set_ylim3d(bottom=-bbox, top=bbox) 
    ax.axes.set_zlim3d(bottom=-bbox, top=bbox) 
    ax.set_axis_off() 
    extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_barycentre_tvector_{i}'+formt), bbox_inches=extent.expanded(0.8, 0.8))
plt.show()

In [None]:
log_p_barycentre_r = torch.einsum("Mia,NM->Nia", R_p_barycentre[num_proteins-rank:num_proteins],U[:,num_proteins-rank:num_proteins])

log_p_barycentre_t = torch.zeros(t_steps, protein_len, 3)
log_p_barycentre_t_r = torch.zeros(t_steps, protein_len, 3)
mdt_r = torch.zeros(t_steps, protein_len, 3)
for i in range(t_steps):
    md_ind = int(i/(t_steps-1) * (num_proteins-1))
    print(f"computing approximation of protein {md_ind}")
    log_p_barycentre_t[i] = log_p_barycentre[md_ind]
    log_p_barycentre_t_r[i] = log_p_barycentre_r[md_ind]
    mdt_r[i] = manifold.s_exp(p_barycentre[None,None], log_p_barycentre_r[md_ind][None,None], c=1/4, step_size=1., debug=True).squeeze()
%timeit manifold.s_exp(p_barycentre[None,None], log_p_barycentre_r[md_ind][None,None], c=1/4, step_size=1.).squeeze()

In [None]:
fig_size = 31 * (protein_len/214) ** (1/3) # cm
num_figs=5
for i in range(num_figs):
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    md_ind = int(i/(num_figs-1) * (num_proteins-1))
    fig = plt.figure(figsize=(fig_size/5, fig_size/5))
    ax = plt.axes(projection='3d')
    ax.plot3D(mdt_r[g_ind,:,0],mdt_r[g_ind,:,1],mdt_r[g_ind,:,2], 'black')
    ax.scatter(mdt_r[g_ind,:,0],mdt_r[g_ind,:,1],mdt_r[g_ind,:,2], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    ax.axes.set_xlim3d(left=-bbox, right=bbox) 
    ax.axes.set_ylim3d(bottom=-bbox, top=bbox) 
    ax.axes.set_zlim3d(bottom=-bbox, top=bbox) 
    ax.set_axis_off() 
    extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_rank_{rank}_{g_ind}f{t_steps-1}_md_{md_ind}'+formt), bbox_inches=extent.expanded(0.8, 0.8))
plt.show()

# compute RMSD in Anstrom for geodesics vs trajectory
rmsd_T_r = torch.sqrt(torch.sum((mdt_r - mdt) ** 2,[1,2]) / protein_len)

# scatter plot 

fig_size = 21 # cm
plt.figure(figsize=(fig_size/4, fig_size/4))
plt.plot(T, rmsd_T_r, 'tab:red')
plt.xlim([0,1])
plt.ylim([0,rmsd_T.max()+1])
plt.xlabel(r'$t$')
plt.ylabel(r'RMSD from MD simulation ($\AA$)')
# save figure
# extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plt.savefig(os.path.join(results_folder,fig_prefix +f'_RMSD_progression_rank_{rank}'+formt))
plt.show

In [None]:
d_T_r = torch.sqrt(torch.sum((mdt_r - mdt) ** 2, -1))

# max_error = int(torch.max(d_T_r)) + 1
max_displacement = int(torch.max(d_total)) + 1

num_figs = 5
fig_size = 21 # cm
for i in range(num_figs):
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    ax = plt.figure(figsize=(fig_size/4, fig_size/4))
    plt.hist(d_T_r[g_ind][None], bins=50)
    plt.xlim([0, max_error])
    plt.ylabel('Frequency')
    plt.xlabel(r'Deviation from MD simulation ($\AA$)')
plt.show()

for i in range(num_figs):
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    md_ind = int(i/(num_figs-1) * (num_proteins-1))
    plt.figure(figsize=(fig_size/4, fig_size/4))
    plt.scatter(d_total,d_T_r[g_ind], c = torch.linspace(0,1,protein_len), cmap = 'rainbow')
    plt.ylim([0, max_error])
    plt.xlim([0, max_displacement])
    plt.xlabel(r'Total displacement ($\AA$)')
    plt.ylabel(r'Deviation from MD simulation ($\AA$)')
    # save figure
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_displacement_vs_error_rank_{rank}_md_{md_ind}'+formt))
plt.show()

In [None]:
mdt_r_pairwise_distances = manifold.pairwise_distances(mdt_r[None]).squeeze()
mdt_r_adj_pairwise_distances = torch.diagonal(mdt_r_pairwise_distances,offset=1, dim1=1, dim2=2)

num_figs = 5
fig_size = 21 # cm

for i in range(num_figs):
    plt.figure(figsize=(fig_size, fig_size/4))
    g_ind = int(i/(num_figs-1) * (t_steps-1))
    plt.plot(range(protein_len-1), torch.sqrt(mdt_r_adj_pairwise_distances[g_ind]), color='tab:orange')
    plt.xlim([0, protein_len-2])
    plt.ylim([2.8,4.8])
    plt.ylabel(r'$\|\|\mathbf{x}_i - \mathbf{x}_{i+1}\|\|_2$ ($\AA$)')
    plt.xlabel(r'$i$')
    # save figure
    plt.savefig(os.path.join(results_folder,fig_prefix +f'_adjacent_residue_{g_ind}f{t_steps-1}_mdt_r_rank_{rank}'+formt)) 
plt.show()

## Stability with respect to curvature ##

In [None]:
# so first we want to do geodesic variation
# start from different starting points in the data and go to same end
v_index = 3
pt_v0 = torch.zeros(t_steps, protein_len, 3)
pt_v1 = torch.zeros(t_steps, protein_len, 3)
p0_v = proteins[v_index]
p1_v = proteins[-1 - v_index]
for i,t in enumerate(T):
    print(f"computing geodesic {i+1} of variation {v_index}")
    pt_v0[i] = manifold.s_geodesic(p0_v[None,None], p1[None,None], torch.tensor([t]), debug=True).squeeze()
    pt_v1[i] = manifold.s_geodesic(p0[None,None], p1_v[None,None], torch.tensor([t]), debug=True).squeeze()

In [None]:
# compute RMSD in Anstrom for geodesics vs trajectory
rmsd_T_v0 = torch.sqrt(torch.sum((pt - pt_v0) ** 2,[1,2]) / protein_len)
rmsd_T_v1 = torch.sqrt(torch.sum((pt - pt_v1) ** 2,[1,2]) / protein_len)

# plot 
fig_size = 21 # cm
plt.figure(figsize=(fig_size/3, fig_size/3))
plt.plot(T[1:-1], (rmsd_T_v0 - (1 - T)*rmsd_T_v0[0])[1:-1]/rmsd_T_v0[0])
plt.xlim([0,1])
plt.ylim([-0.12,0.12])
plt.xlabel(r'$t$')
plt.ylabel(r'Relative discrepancy from zero-curvature baseline')
# save figure
plt.savefig(os.path.join(results_folder,fig_prefix +f'_RMSD_progression_v0'+formt))
plt.show

# plot 
fig_size = 21 # cm
plt.figure(figsize=(fig_size/3, fig_size/3))
plt.plot(T[1:-1], (rmsd_T_v1 - T*rmsd_T_v1[-1])[1:-1]/rmsd_T_v1[-1])
plt.xlim([0,1])
plt.ylim([-0.12,0.12])
plt.xlabel(r'$t$')
plt.ylabel(r'Relative discrepancy from zero-curvature baseline')
# save figure
plt.savefig(os.path.join(results_folder,fig_prefix +f'_RMSD_progression_v1'+formt))
plt.show


In [None]:
trmsd_T_r = torch.sqrt(torch.sum((log_p_barycentre_t_r - log_p_barycentre_t) ** 2,[1,2]) / protein_len)
trmsd_T_r_0 = torch.sqrt(torch.sum(log_p_barycentre_t ** 2,[1,2]) / protein_len)

fig_size = 21 # cm
plt.figure(figsize=(fig_size/3, fig_size/3))
plt.plot(T, (rmsd_T_r - trmsd_T_r)/trmsd_T_r_0)
plt.xlim([0,1])
plt.ylim([-0.12,0.12])
plt.xlabel(r'$t$')
plt.ylabel(r'Relative discrepancy from zero-curvature baseline')
# save figure
plt.savefig(os.path.join(results_folder,fig_prefix +f'_RMSD_and_TRMSD_progression_rank_{rank}'+formt))
plt.show