# plotting functions of figure 5 in the manuscript
This figure shows the results of applying cryoDRGN to the covid spike trimer dataset (DESRES-Trajectory_sarscov2-11021571)



In [None]:
# imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# roodmus
from roodmus.analysis.utils import load_data
from roodmus.heterogeneity.hetRec import HetRec
from roodmus.heterogeneity.plot_heterogeneous_reconstruction import plot_latent_space_scatter

In [None]:
# data loading for DE-Shaw covid spike partially open set
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/20231116_DESRES-Trajectory_sarscov2-11021571-all-glueCA"
config_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021571-all-glueCA/Micrographs"
figures_dir = os.path.join(project_dir, "figures")
meta_file = os.path.join(project_dir, "cryoDRGN", "run_data.star")
jobtypes = {
     os.path.join(project_dir, "cryoDRGN", "run_data.star"): "cryoDRGN",
}
latent_file =  os.path.join(project_dir, "cryoDRGN", "train_320", "z.19.pkl")

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True
ignore_missing_files = True
enable_tqdm = True

analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
# df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
    meta_file,
    df_picked,
    df_truth,
    verbose=False,
    enable_tqdm=True,
)
df_precision = analysis.compute_1to1_match_precision(
    p_match,
    p_unmatched,
    t_unmatched,
    df_truth,
    verbose=False,
)

df_picked, ndim = HetRec.add_latent_space_coordinates(
    latent_file=latent_file,
    df_picked=df_picked,
)
df_picked, pca = HetRec.compute_PCA(
    df_picked=df_picked,
    ndim=ndim,
)
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
df_picked["closest_truth_index"] = closest_truth_index
df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)
df_picked.tail()

## panel A
plotting latent space of epoch 20 of cryoDRGN with the FP particle picks plotted in red

In [None]:
# latent space scatter plot
dim1=0
dim2=1

grid = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    # color="#4eb3d3",
    pca=True
)
df_FP = df_picked[df_picked["TP"]==0]
print(f"number of FP: {len(df_FP)}")
print(f"number of TP: {len(df_picked)-len(df_FP)}")
ax = grid.figure.get_axes()[0]
sns.scatterplot(
    data=df_FP,
    x=f"PCA_{dim1}",
    y=f"PCA_{dim2}",
    color="red",
    s=10,
    ax=ax,
    label="FP",
)
grid.set_axis_labels(f"PCA{dim1}", f"PCA{dim2}", fontsize=14)
grid.figure.get_axes()[0].tick_params(labelsize=12)
grid.figure.get_axes()[0].set_xlim((-12, 12))
grid.figure.get_axes()[0].set_ylim((-12, 12))

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_pca_{dim1}_{dim2}_FP")
# grid.savefig(outfilename+".pdf", bbox_inches="tight")
grid.savefig(outfilename+".png", bbox_inches="tight", dpi=600)
print(f"saved figure to: {outfilename}")

## panel B
plot of the latent space with each point coloured by its corresponding frame from the MD trajecory

In [None]:
# latent space scatter plot, coloured by ground truth frames
dim1=0
dim2=1
dt = 1.2e-3# time between frames in microseconds

fig, ax = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    color_by="closest_pdb_index",
    palette="RdYlBu",
    pca=True,
)
# remove legend and add colorbar for the closest_pdb_index
ax.legend_.remove()
S_m = plt.cm.ScalarMappable(cmap="RdYlBu")
S_m.set_array(df_picked["closest_pdb_index"])
cbar = plt.colorbar(S_m)
cbar.set_label("Time (\u03BCs)", rotation=270, labelpad=15, fontsize=16) # time in ps
# change the tick labels on the colorbar to go from 0 to 10 us
cbar.set_ticks(np.linspace(1, df_picked["closest_pdb_index"].max(), 10))
xticklabels = [np.round(r, 1) for r in np.linspace(1, df_picked["closest_pdb_index"].max(), 10)*dt]
cbar.set_ticklabels(xticklabels, fontsize=12)
ax.set_xlabel(f"PCA{dim1}", fontsize=14)
ax.set_ylabel(f"PCA{dim2}", fontsize=14)
ax.tick_params(labelsize=12)
ax.set_xlim((-12, 12))
ax.set_ylim((-12, 12))

# add trajectory to the plot
N_volumes = 50
pdb_indices = np.unique(df_picked["closest_pdb_index"])
d_pdbs = len(pdb_indices) // N_volumes

trajectory = np.zeros((N_volumes, ndim))
trajectory_pca = np.zeros((N_volumes, ndim))
for i in range(N_volumes):
    pdb_group = pdb_indices[i*d_pdbs:(i+1)*d_pdbs]
    mean_latent = df_picked[df_picked["closest_pdb_index"].isin(pdb_group)].agg(
        {f"latent_{i}": "mean" for i in range(ndim)}
    )
    trajectory[i] = mean_latent.values
    mean_pca = df_picked[df_picked["closest_pdb_index"].isin(pdb_group)].agg(
        {f"PCA_{i}": "mean" for i in range(ndim)}
    )
    trajectory_pca[i] = mean_pca.values

ax.scatter(trajectory_pca[:, 0], trajectory_pca[:, 1], s=5, c="black", zorder=10)
ax.plot(trajectory_pca[:, 0], trajectory_pca[:, 1], c="black", zorder=10, linewidth=0.5)
ax.set_aspect("equal")

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_latent_space_scatter_colored_by_closest_pdb_index_pca")
fig.savefig(outfilename+".png", dpi=600, bbox_inches="tight")
fig.savefig(outfilename+".pdf", bbox_inches="tight")
print(f"saved figure to: {outfilename}")


In [None]:
# for the figure, repeat previous plot with only the Nth point in the trajectory marked
N = 31
fig, ax = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    color_by="closest_pdb_index",
    palette="RdYlBu",
    pca=True,
)
# remove legend and add colorbar for the closest_pdb_index
ax.legend_.remove()
S_m = plt.cm.ScalarMappable(cmap="RdYlBu")
S_m.set_array(df_picked["closest_pdb_index"])
cbar = plt.colorbar(S_m)
cbar.set_label("time (\u03BCs)", rotation=270, labelpad=15, fontsize=16) # time in ps
# change the tick labels on the colorbar to go from 0 to 10 us
cbar.set_ticks(np.linspace(1, df_picked["closest_pdb_index"].max(), 10))
xticklabels = [np.round(r, 1) for r in np.linspace(1, df_picked["closest_pdb_index"].max(), 10)*dt]
cbar.set_ticklabels(xticklabels, fontsize=12)
ax.set_xlabel(f"PCA{dim1}", fontsize=14)
ax.set_ylabel(f"PCA{dim2}", fontsize=14)
ax.tick_params(labelsize=12)
ax.set_xlim((-12, 12))
ax.set_ylim((-12, 12))

# add trajectory to the plot
ax.scatter(trajectory_pca[N, 0], trajectory_pca[N, 1], s=5, c="black", zorder=10)
ax.set_aspect("equal")


## panel C
correlation matrix between the MD trajectory and the sampled volumes from the latent space

In [None]:
# plot the correlation matrix
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/20231116_DESRES-Trajectory_sarscov2-11021571-all-glueCA"
figures_dir = os.path.join(project_dir, "figures")
correlation_matrix_file = os.path.join(project_dir, "cryoDRGN", "analyze_320", "correlation_matrix.npy")
correlation_matrix = np.load(correlation_matrix_file)

frames = correlation_matrix.shape[0]

fig, ax = plt.subplots(figsize=(3.5, 3.5))
ax.imshow(correlation_matrix, cmap="coolwarm")
yticks = np.linspace(0, frames, 10)
yticklabels = np.round(np.linspace(0, 10, 10, dtype=float), 1)
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels, fontsize=12)
ax.set_ylabel("Time (\u03BCs)", fontsize=16)
ax.set_xlabel("Sampled volume", fontsize=16)
cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical", pad=0.01, shrink=0.9)
cbar.ax.tick_params(labelsize=14)
cbar.ax.set_ylabel("Correlation", fontsize=16, rotation=270, labelpad=20)
ax.tick_params(axis="both", which="major", labelsize=14)

outfilename = os.path.join(figures_dir, f"{os.path.basename(latent_file)}_correlation_matrix.pdf")
fig.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


In [None]:
# print the best and worst fir for every sampled volume to the MD trahjectory
pdb_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/mnt/parakeet_storage4/ConformationSampling/DESRES-Trajectory_sarscov2-11021571-all-glueCA/even_sampling_8334"
pdb_files = [r for r in os.listdir(pdb_dir) if r.endswith(".pdb")]
pdb_files.sort()
pdb_files = pdb_files[::int(len(pdb_files)/(frames+1))]
results = {
    "sampled_vol": [],
    "best_fit": [],
    "best_fit_index": [],
    "worst_fit": [],
    "worst_fit_index": [],
}
for i in range(frames):
    results["sampled_vol"].append(i)
    results["best_fit_index"].append(correlation_matrix[i].argmax())
    results["worst_fit_index"].append(correlation_matrix[i].argmin())
    results["best_fit"].append(pdb_files[correlation_matrix[i].argmax()])
    results["worst_fit"].append(pdb_files[correlation_matrix[i].argmin()])

df_results = pd.DataFrame(results)
df_results.head(n=len(df_results))



## panel D
For this panel I want to write the real-space croscc-correlation between each volume and the best matching frame from the MD trajectory

In [None]:
# print for each sampled volume (column) the correlation with the best matching frame of the trajectory (row), and also the conformation index

pdb_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/mnt/parakeet_storage4/ConformationSampling/DESRES-Trajectory_sarscov2-11021571-all-glueCA/even_sampling_8334"
pdb_files = [os.path.join(pdb_dir, f) for f in os.listdir(pdb_dir) if f.endswith(".pdb")]
pdb_files.sort()
N = 50
pdb_files = pdb_files[::len(pdb_files)//N]

best_matching_frame = np.argmax(correlation_matrix, axis=0)
best_matching_frame_correlation = np.max(correlation_matrix, axis=0)
best_matching_conformation = [os.path.basename(pdb_files[r]) for r in best_matching_frame]

for i in range(N):
    print(f"{i} {best_matching_frame[i]} {best_matching_frame_correlation[i]} {best_matching_conformation[i]}")

In [None]:
# for the first sampled volume, print the correlation with all frames of the trajectory
correlations = correlation_matrix[:, 0]
for i in range(frames):
    print(f"{i} {correlations[i]} {os.path.basename(pdb_files[i])}")

## panel E
zoom in of a selected atomic model and map from D. For this A attribute file needs to be made with the Q-scores from ChimeraX

In [None]:
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/20231116_DESRES-Trajectory_sarscov2-11021571-all-glueCA"
conformation = "007968"
# conformation = "008150"
conformation = "000163"
qscores_file = os.path.join(project_dir, f"chimeraX_qscores_table_conformation_{conformation}.txt")
df_qscores = pd.read_csv(qscores_file, sep="\t" , header=0, skiprows=[1])
df_qscores.tail()


In [None]:
# create an attribute file for ChimeraX
attribute_file = os.path.join(project_dir, f"chimeraX_attributes_conformation_{conformation}.defattr")

recipient = "residues"
matchmode = "1-to-1"
modelnr = "4"

# write comments to the file explaining the attributes
with open(attribute_file, "w") as f:
    f.write(
        "# ChimeraX attributes file\n"
    )
    f.write(
        "# This file contains attributes for the residues of the SARS-CoV-2 spike protein\n"
    )
    f.write(
        "# The attributes are the q-scores of the residues\n"
    )

# write the recipient and match modelines
with open(attribute_file, "a") as f:
    f.write(
        f"recipient: {recipient}\n"
    )
    f.write(
        f"match mode: {matchmode}\n"
    )
    f.write(
        f"attribute: qscore\n"
    )

# add a line for each residue with the qscore as attribute
with open(attribute_file, "a") as f:
    for i, row in df_qscores.iterrows():
        atom_spec = f"#{modelnr}/{row['Chain']}:{row['Number']}"
        attribute = row["Qbb"]
        f.write(
            f"\t{atom_spec}\t{attribute}\n"
        )
        
