# plotting functions of figure 6 in the manuscript
This figure shows the results of training cryoDRGN on the RTC model. For now it is the same layout as figure 5, subject to change

In [None]:
# imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.decomposition import PCA

# roodmus
from roodmus.analysis.utils import load_data
from roodmus.heterogeneity.hetRec import HetRec
from roodmus.heterogeneity.plot_heterogeneous_reconstruction import plot_latent_space_scatter

In [None]:
### data loading covid RTC DE-Shaw data set
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_RTC/20240124_DESRES-Trajectory_sarscov2-13795965-no-water-movies"
config_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_RTC/DESRES-Trajectory_sarscov2-13795965-no-water-movies/Movies"
# config_dir = os.path.join(project_dir, "Movies")
figures_dir = os.path.join(project_dir, "figures")
meta_file = os.path.join(project_dir, "cryoDRGN", "run_data.star")
jobtypes = {
    os.path.join(project_dir, "cryoDRGN", "run_data.star"): "cryoDRGN",
}
latent_file = os.path.join(project_dir, "cryoDRGN", "train_320", "z.24.pkl")

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True
ignore_missing_files = True
enable_tqdm = True

analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
    meta_file,
    df_picked,
    df_truth,
    verbose=False,
    enable_tqdm=True,
)
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
df_picked["closest_truth_index"] = closest_truth_index
df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)


# latent_space, ndim = IO.get_latents_cs(latent_file)
df_picked, ndim = HetRec.add_latent_space_coordinates(
    latent_file=latent_file,
    df_picked=df_picked,
)
df_picked, pca = HetRec.compute_PCA(
    df_picked=df_picked,
    ndim=ndim,
)
print(f"latent space dimensionality: {ndim}")
df_picked.tail()


## panel B
plotting the latent space of the RTC model


In [None]:
# latent space scatter plot
dim1=0
dim2=1

grid = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    pca=True
)
df_FP = df_picked[df_picked["TP"]==0]
print(f"number of FP: {len(df_FP)}")
print(f"number of TP: {len(df_picked)-len(df_FP)}")
ax = grid.figure.get_axes()[0]
sns.scatterplot(
    data=df_FP,
    x=f"PCA_{dim1}",
    y=f"PCA_{dim2}",
    color="red",
    s=10,
    ax=ax,
    label="FP",
)
grid.set_axis_labels(f"PCA{dim1}", f"PCA{dim2}", fontsize=16)
grid.figure.get_axes()[0].tick_params(labelsize=14)
grid.figure.get_axes()[0].set_xlim((-12, 12))
grid.figure.get_axes()[0].set_ylim((-12, 12))

grid.savefig(os.path.join(figures_dir, f"{os.path.basename(latent_file)}_pca_{dim1}_{dim2}_FP.pdf"), bbox_inches="tight")
grid.savefig(os.path.join(figures_dir, f"{os.path.basename(latent_file)}_pca_{dim1}_{dim2}_FP.png"), bbox_inches="tight", dpi=600)

## panel C
plot of the latent space with each point coloured by its corresponding frame from the MD trajecory

In [None]:
# latent space scatter plot, coloured by ground truth frames
dim1=0
dim2=1
dt = 48e-3# time between frames in microseconds
save_trajectory = True

fig, ax = plot_latent_space_scatter(
    df_picked,
    dim_1=dim1,
    dim_2=dim2,
    color_by="closest_pdb_index",
    palette="RdYlBu",
    pca=True,
)
# remove legend and add colorbar for the closest_pdb_index
ax.legend_.remove()
S_m = plt.cm.ScalarMappable(cmap="RdYlBu")
S_m.set_array(df_picked["closest_pdb_index"])
cbar = plt.colorbar(S_m)
cbar.set_label("Time [\u03BCs]", rotation=270, labelpad=15, fontsize=16) # time in ps
# change the tick labels on the colorbar to go from 0 to 10 us
cbar.set_ticks(np.linspace(1, df_picked["closest_pdb_index"].max(), 10))
xticklabels = [int(r) for r in np.linspace(1, df_picked["closest_pdb_index"].max(), 10)*dt]
cbar.set_ticklabels(xticklabels, fontsize=14)
ax.set_xlabel(f"PCA{dim1}", fontsize=16)
ax.set_ylabel(f"PCA{dim2}", fontsize=16)
ax.tick_params(labelsize=14)
ax.set_xlim((-12, 12))
ax.set_ylim((-12, 12))

# add trajectory to the plot
N_volumes = 50
pdb_indices = np.unique(df_picked["closest_pdb_index"])
d_pdbs = len(pdb_indices) // N_volumes

trajectory = np.zeros((N_volumes, ndim))
trajectory_pca = np.zeros((N_volumes, ndim))
for i in range(N_volumes):
    pdb_group = pdb_indices[i*d_pdbs:(i+1)*d_pdbs]
    mean_latent = df_picked[df_picked["closest_pdb_index"].isin(pdb_group)].agg(
        {f"latent_{i}": "mean" for i in range(ndim)}
    )
    trajectory[i] = mean_latent.values
    mean_pca = df_picked[df_picked["closest_pdb_index"].isin(pdb_group)].agg(
        {f"PCA_{i}": "mean" for i in range(ndim)}
    )
    trajectory_pca[i] = mean_pca.values

if save_trajectory:
    # save trajectory to a .txt file in the cryoDRGN directory
    np.savetxt(
        os.path.join(
            os.path.dirname(os.path.dirname(latent_file)),
            "trajectory.txt"
        ),
        trajectory
    )

ax.scatter(trajectory_pca[:, 0], trajectory_pca[:, 1], s=5, c="black", zorder=10)
ax.plot(trajectory_pca[:, 0], trajectory_pca[:, 1], c="black", zorder=10, linewidth=0.5)
ax.set_aspect("equal")


fig.savefig(os.path.join(figures_dir, f"{os.path.basename(latent_file)}_latent_space_scatter_colored_by_closest_pdb_index_pca.png"), dpi=600, bbox_inches="tight")
fig.savefig(os.path.join(figures_dir, f"{os.path.basename(latent_file)}_latent_space_scatter_colored_by_closest_pdb_index_pca.pdf"), bbox_inches="tight")


## panel D
plotting the correlation matrix between the sampled volumes and frames from the MD trajectory

In [None]:
# plot the correlation matrix
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_RTC/20240124_DESRES-Trajectory_sarscov2-13795965-no-water-movies"
figures_dir = os.path.join(project_dir, "figures")
correlation_matrix_file = os.path.join(project_dir, "cryoDRGN", "analyze_320", "correlation_matrix.npy")
correlation_matrix = np.load(correlation_matrix_file)

frames = correlation_matrix.shape[0]

fig, ax = plt.subplots(figsize=(3.5, 3.5))
ax.imshow(correlation_matrix, cmap="coolwarm")
yticks = np.linspace(0, frames, 10)
yticklabels = np.linspace(0, 500, 10, dtype=int)
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_ylabel("Time (\u03BCs)", fontsize=16)
ax.set_xlabel("Sampled volume", fontsize=16)
cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical", pad=0.01, shrink=0.9)
cbar.ax.tick_params(labelsize=14)
cbar.ax.set_ylabel("Correlation", fontsize=16, rotation=270, labelpad=20)
ax.tick_params(axis="both", which="major", labelsize=14)

fig.savefig(os.path.join(figures_dir, f"{os.path.basename(latent_file)}_correlation_matrix.pdf"), bbox_inches="tight")
print(f"saved figure to {os.path.join(figures_dir, f'{os.path.basename(latent_file)}_correlation_matrix.pdf')}")


## panel E
plotting a zoomed-in region of the map with the atomic model, coloured by Q-scores


In [None]:
project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_RTC/20240124_DESRES-Trajectory_sarscov2-13795965-no-water-movies"
conformation = "007600" # best fit
conformation = "001000" # worst fit
qscores_file = os.path.join(project_dir, f"chimeraX_qscores_table_conformation_{conformation}.txt")
df_qscores = pd.read_csv(qscores_file, sep="\t" , header=0, skiprows=[1])
df_qscores.tail()


In [None]:
# create an attribute file for ChimeraX
attribute_file = os.path.join(project_dir, f"chimeraX_attributes_conformation_{conformation}.defattr")

recipient = "residues"
matchmode = "1-to-1"
modelnr = "2"

# write comments to the file explaining the attributes
with open(attribute_file, "w") as f:
    f.write(
        "# ChimeraX attributes file\n"
    )
    f.write(
        "# This file contains attributes for the residues of the SARS-CoV-2 spike protein\n"
    )
    f.write(
        "# The attributes are the q-scores of the residues\n"
    )

# write the recipient and match modelines
with open(attribute_file, "a") as f:
    f.write(
        f"recipient: {recipient}\n"
    )
    f.write(
        f"match mode: {matchmode}\n"
    )
    f.write(
        f"attribute: qscore\n"
    )

# add a line for each residue with the qscore as attribute
with open(attribute_file, "a") as f:
    for i, row in df_qscores.iterrows():
        atom_spec = f"#{modelnr}/{row['Chain']}:{row['Number']}"
        attribute = row["Qbb"]
        f.write(
            f"\t{atom_spec}\t{attribute}\n"
        )
        
