# plotting functions for supplementary figure 6
This figure shows the results of 2 3DFlex models trained on the covid spike monomer steered MD dataset. The panels are largely the same as figure 5 in the manunscript


In [None]:
# imports
import os
import gemmi
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# roodmus
from roodmus.analysis.utils import load_data
from roodmus.analysis.utils import IO
from roodmus.heterogeneity.plot_heterogeneous_reconstruction import plot_latent_space_scatter

# 3DFlex model with zdim = 1
This model's latent space must be visualised differently then the usual cases where zdim >= 2.


In [None]:
# data loading
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
config_dir = os.path.join(project_dir, "mrc")
figures_dir = os.path.join(project_dir, "figures")
meta_file = os.path.join(project_dir, "cryoSPARC", "J526_passthrough_particles.cs")
jobtypes = {
    os.path.join(project_dir, "cryoSPARC", "J526_passthrough_particles.cs"): "3DFlex zdim=1"
}
latent_file = os.path.join(project_dir, "cryoSPARC", "J526_latents_019446.cs")
dt = 0.01 # ps interval between frames

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True
ignore_missing_files = True
enable_tqdm = True

analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
print(f"mean precision: {df_precision['precision'].mean()}")
print(f"mean recall: {df_precision['recall'].mean()}")

latent_space, ndim = IO.get_latents_cs(latent_file)
print(f"latent space dimensionality: {ndim}")
print(latent_space.shape)
for i in range(ndim):
    df_picked["latent_{}".format(i)] = latent_space[:, i]
df_picked.tail()

## panel A
plotting the 1-dimensional latent space as a histogram

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 3.5))
sns.histplot(
    data=df_picked,
    x="latent_0",
    color="black",
    kde=True,
    bins=20,
    ax=ax,
)
ax.set_xlabel("Z0", fontsize=16)
ax.set_ylabel("Count", fontsize=16)
ax.tick_params(axis="both", which="major", labelsize=14)

outfilename = os.path.join(figures_dir, f"histogram_z0_{os.path.basename(latent_file)}.pdf")
# fig.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


## panel B
kdeplot of the 1-dimensional latent space versus the time in the MD trajectory of each particle

In [None]:
df_picked["time"] = df_picked["closest_pdb_index"].apply(lambda x: x * dt)
print(df_picked["time"].max())
print(df_picked["latent_0"].max())

fig, ax = plt.subplots(figsize=(3.5, 3.5))
sns.histplot(
    data=df_picked,
    x="time",
    y="latent_0",
    cbar=True,
    cbar_kws={"label": ""},
    bins=50,
    cmap="Oranges",
    ax=ax,
)
ax.set_xlabel("Time (ps)", fontsize=16)
ax.set_ylabel("Z0", fontsize=16)
ax.tick_params(axis="both", which="major", labelsize=14)
# change the tick labels of the colourbar to fontsize 14
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
outfilename = os.path.join(figures_dir, f"z0_vs_time_{os.path.basename(latent_file)}.pdf")
# fig.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


# 3DFlex model with zdim = 2

In [None]:
# data loading
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
config_dir = os.path.join(project_dir, "mrc")
figures_dir = os.path.join(project_dir, "figures")
meta_file = os.path.join(project_dir, "cryoSPARC", "J577_passthrough_particles.cs")
jobtypes = {
    os.path.join(project_dir, "cryoSPARC", "J577_passthrough_particles.cs"): "3DFlex zdim=2"
}
latent_file = os.path.join(project_dir, "cryoSPARC", "J577_latents_022224.cs")
dt = 0.01 # ps interval between frames

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True
ignore_missing_files = True
enable_tqdm = True

analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
# p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
#     meta_file,
#     df_picked,
#     df_truth,
#     verbose=False,
#     enable_tqdm=True,
# )
# df_precision = analysis.compute_1to1_match_precision(
#     p_match,
#     p_unmatched,
#     t_unmatched,
#     df_truth,
#     verbose=False,
# )
# df_picked["TP"] = ~np.isnan(closest_truth_index)
print(f"mean precision: {df_precision['precision'].mean()}")
print(f"mean recall: {df_precision['recall'].mean()}")

latent_space, ndim = IO.get_latents_cs(latent_file)
print(f"latent space dimensionality: {ndim}")
print(latent_space.shape)
for i in range(ndim):
    df_picked["latent_{}".format(i)] = latent_space[:, i]
df_picked.tail()

## panel C
scatter plot of the 2-dimensional latent space

In [None]:
# latent space scatter plot
dim1=0
dim2=1

grid = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    pca=False
)
df_FP = df_picked[df_picked["TP"]==0]
print(f"number of FP: {len(df_FP)}")
print(f"number of TP: {len(df_picked)-len(df_FP)}")
ax = grid.figure.get_axes()[0]
# sns.scatterplot(
#     data=df_FP,
#     x=f"latent_{dim1}",
#     y=f"latent_{dim2}",
#     color="red",
#     s=10,
#     ax=ax,
#     label="FP",
# )
grid.set_axis_labels(f"Z{dim1}", f"Z{dim2}", fontsize=16)
grid.figure.get_axes()[0].tick_params(axis="both", which="major", labelsize=14)

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_{dim1}_{dim2}.png")
# grid.figure.savefig(outfilename, dpi=600, bbox_inches="tight")
print(f"saved figure to: {outfilename}")

## panel D
scatterplot of the 2-dimensional latent space with the time in the MD trajectory of each particle as the color

In [None]:
# latent space scatter plot, coloured by ground truth frames
dim1=0
dim2=1

fig, ax = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    color_by="closest_pdb_index",
    palette="RdYlBu",
    pca=False,
)
# remove legend and add colorbar for the closest_pdb_index
ax.legend_.remove()
S_m = plt.cm.ScalarMappable(cmap="RdYlBu")
S_m.set_array(df_picked["closest_pdb_index"])
cbar = plt.colorbar(S_m)
cbar.set_label("Time (ps)", rotation=270, labelpad=15, fontsize=16) # time in ps
# change the tick labels on the colorbar to go from 0 to 10 us
cbar.set_ticks(np.linspace(1, df_picked["closest_pdb_index"].max(), 10))
xticklabels = [np.round(r, 1) for r in np.linspace(1, df_picked["closest_pdb_index"].max(), 10)*dt]
cbar.set_ticklabels(xticklabels, fontsize=14)
ax.set_xlabel(f"Z{dim1}", fontsize=16)
ax.set_ylabel(f"Z{dim2}", fontsize=16)
ax.tick_params(labelsize=14)

# add trajectory to the plot
trajectory_file = os.path.join(project_dir, "cryoSPARC", "J577_custom_path_2.csv")
trajectory = np.loadtxt(trajectory_file, delimiter=",", skiprows=1)

ax.scatter(trajectory[:, 0], trajectory[:, 1], s=5, c="black", zorder=10)
ax.plot(trajectory[:, 0], trajectory[:, 1], c="black", zorder=10, linewidth=0.5)
ax.set_aspect("equal")

# add trajectory to the plot
# N_volumes = 50
# pdb_indices = np.unique(df_picked["closest_pdb_index"])
# d_pdbs = len(pdb_indices) // N_volumes

# trajectory = np.zeros((N_volumes, ndim))
# for i in range(N_volumes):
#     pdb_group = pdb_indices[i*d_pdbs:(i+1)*d_pdbs]
#     mean_latent = df_picked[df_picked["closest_pdb_index"].isin(pdb_group)].agg(
#         {f"latent_{i}": "mean" for i in range(ndim)}
#     )
#     trajectory[i] = mean_latent.values

# ax.scatter(trajectory[:, 0], trajectory[:, 1], s=5, c="black", zorder=10)
# ax.plot(trajectory[:, 0], trajectory[:, 1], c="black", zorder=10, linewidth=0.5)
# ax.set_aspect("equal")

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_latent_space_scatter_colored_by_closest_pdb_index_pca.png")
# fig.savefig(outfilename, bbox_inches="tight", dpi=600)
print(f"saved figure to: {outfilename}")


## panel E
plotting the correlation between the MD trajectory and the 1-dimensional latent space


In [None]:
# plot the correlation matrix
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
figures_dir = os.path.join(project_dir, "figures")
correlation_matrix_file = os.path.join(project_dir, "cryosparc_P51_J527_series_000", "correlation_matrix.npy")
correlation_matrix = np.load(correlation_matrix_file)
latent_file = os.path.join(project_dir, "cryoSPARC", "J526_latents_019446.cs")

frames = correlation_matrix.shape[0]

fig, ax = plt.subplots(figsize=(3.5, 3.5))
ax.imshow(correlation_matrix, cmap="coolwarm")
yticks = np.linspace(0, frames, 10)
yticklabels = np.round(np.linspace(0, 10, 10, dtype=float), 1)
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_ylabel("Time (ps)", fontsize=16)
ax.set_xlabel("Sampled volume", fontsize=16)
cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical", pad=0.01, shrink=0.9)
cbar.ax.tick_params(labelsize=14)
cbar.ax.set_ylabel("Correlation", fontsize=16, rotation=270, labelpad=20)
ax.tick_params(axis="both", which="major", labelsize=14)

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_correlation_matrix.pdf")
fig.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


## panel F
plotting the correlation between the MD trajectory and the 2-dimensional latent space

In [None]:
# plot the correlation matrix
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
figures_dir = os.path.join(project_dir, "figures")
correlation_matrix_file = os.path.join(project_dir, "cryosparc_P51_J641_series_000", "correlation_matrix.npy")
correlation_matrix = np.load(correlation_matrix_file)
latent_file = os.path.join(project_dir, "cryoSPARC", "J577_latents_022224.cs")

frames = correlation_matrix.shape[0]

fig, ax = plt.subplots(figsize=(3.5, 3.5))
ax.imshow(correlation_matrix, cmap="coolwarm")
yticks = np.linspace(0, frames, 10)
yticklabels = np.round(np.linspace(0, 10, 10, dtype=float), 1)
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_ylabel("Time (ps)", fontsize=16)
ax.set_xlabel("Sampled volume", fontsize=16)
cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical", pad=0.01, shrink=0.9)
cbar.ax.tick_params(labelsize=14)
cbar.ax.set_ylabel("Correlation", fontsize=16, rotation=270, labelpad=20)
ax.tick_params(axis="both", which="major", labelsize=14)

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_correlation_matrix.pdf")
fig.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


## in-text
compute and print the total mass of the first conformation


In [None]:
pdb_file = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/pdb/conformation_000000.pdb"
pdb_gemmi = gemmi.read_structure(pdb_file)
mass = 0
for chn in pdb_gemmi[0]:
    for res in chn:
        for atm in res:
            mass += atm.element.weight

print(f"total mass: {mass/1000} kDa")
print(f"mass of entire protein: {mass/1000*3} kDa")
