# plotting functions of supplmentary figure 3 in the manuscript
This figure shows the 3D classification of the covid spike monomer steered MD simulation and c3c3b steered MD simulation.

his figure shows the following:

- Panel showing distribution of particles in each 3D class over the frames of the steered MD simulation


In [None]:
# imports
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import mrcfile

# roodmus
from roodmus.analysis.utils import load_data

## panel A
plotting the unnormalised correlation matrices of the 3D classification of the spike glycoprotein

In [None]:
project_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    print(f"loading data for {classification_job}")
    correlation_matrix_filename = os.path.join(project_dir, "aligned_ensembles", f"correlation_matrix_{classification_job}.npy")
    if os.path.exists(correlation_matrix_filename):
        correlation_matrix = np.load(correlation_matrix_filename)
    else:
        print(f"correlation matrix for {classification_job} does not exist yet. Compute it first.")
    n_classes = correlation_matrix.shape[1]
    n_frames = correlation_matrix.shape[0]
    print(f"number of classes: {n_classes}")
    print(f"number of frames: {n_frames}")

    fig, ax = plt.subplots(figsize=(3.5, 3.5))
    # show the heatmap as a square (~50x3)
    ax.imshow(correlation_matrix, cmap="coolwarm", aspect="auto", origin="lower")
    ticklabels = [f"class {i+1}" for i in range(n_classes)]
    ax.set_xticks(range(n_classes))
    ax.set_xticklabels(ticklabels, fontsize=12, rotation=45)
    ax.set_yticks([0, n_frames//2-2, n_frames//2+2, n_frames-1])
    ax.set_yticklabels(["0 \u03BCs", "10 \u03BCs", "0 \u03BCs", "10 \u03BCs"], fontsize=12)
    ax.axhline(y=n_frames//2, color="black", linestyle="solid", linewidth=2)
    # add colorbar
    cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical")
    ax.text(
        -0.65,
        3*n_frames//4,
        "Closed state",
        ha='right',
        va='bottom',
        fontsize=12,
        fontweight='bold',
    )
    ax.text(
        -0.65,
        n_frames//4,
        "Open state",
        ha='right',
        va='bottom',
        fontsize=12,
        fontweight='bold',
    )

    # save the figure
    outfilename = os.path.join(project_dir, "figures", f"correlation_matrix_{n_classes}classes_unnormalised.pdf")
    # fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")

## panel B
plotting the distribution of particles in each 3D class

In [None]:
### data loading (6xm5 steered MD)
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
config_dir = os.path.join(project_dir, "mrc")
figures_dir = os.path.join(project_dir, "figures")
meta_files = [
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_0.cs"),
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_1.cs"),
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_2.cs"),
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_3.cs"),
]

jobtypes = {
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_0.cs"): "ab initio class 0",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_1.cs"): "ab initio class 1",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_2.cs"): "ab initio class 2",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_3.cs"): "ab initio class 3",
}

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True # prints out progress statements
ignore_missing_files = True # if .mrc files are missing, the analysis will still be performed
enable_tqdm = True # enables tqdm progress bars

for i, meta_file in enumerate(meta_files):
    if i == 0:
        analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
    else:
        analysis.add_data(meta_file, config_dir, verbose=verbose) # updates the class with the next metadata file
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
# df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
    meta_files,
    df_picked,
    df_truth,
    verbose=False,
    enable_tqdm=True,
)
df_precision = analysis.compute_1to1_match_precision(
    p_match,
    p_unmatched,
    t_unmatched,
    df_truth,
    verbose=False,
)
print(f"mean precision: {df_precision['precision'].mean()}")
print(f"mean recall: {df_precision['recall'].mean()}")
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
df_picked["closest_truth_index"] = closest_truth_index
df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)
df_picked["jobtype"] = df_picked["metadata_filename"].map(jobtypes)
print(f"total number of particles loaded: {len(df_picked)}")
print(f"number of particles in class 0: {len(df_picked[df_picked['jobtype'] == 'ab initio class 0'])}")
df_picked.head()

In [None]:
# add a column to the df_picked data frame that indicates if the particles originates from the open or closed state of the spike protein
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
# df_picked["jobtype"] = df_picked["metadata_filename"].apply(lambda x: jobtypes[x])
df_picked_grouped = df_picked.groupby("TP").get_group(True)
print(f"number of TP particles : {len(df_picked_grouped)}")
num_classes = len(df_picked_grouped["jobtype"].unique())

dt = 0.01 # ps timeinterval between frames

fig, ax = plt.subplots(figsize=(7, 3.5))
# make kde plot out of the data
kde_picked = sns.kdeplot(
    data=df_picked_grouped,
    x="closest_pdb_index",
    ax=ax,
    hue="jobtype",
    fill=False,
    label="jobtype",
    linewidth=10,
    alpha=1,
    palette="RdYlBu",
    legend=True,
)
# add an extra black line over the second class
kde_extra = sns.kdeplot(
    data=df_picked_grouped,
    x="closest_pdb_index",
    ax=ax,
    hue="jobtype",
    palette={f"{r}": "black" for r in df_picked_grouped["jobtype"].unique()},
    label="picked_particles",
    linewidth=3,
    alpha=1,
    fill=False,
    legend=False,
)
# add kdeplot of the True particles
ax_truth = ax.twinx()
kde_truth = sns.kdeplot(
    data=df_truth,
    x="pdb_index",
    ax=ax_truth,
    color="black",
    linestyle="--",
    label="GT",
    linewidth=3,
    fill=False,
    alpha=1,
    legend=True,
)

# get the legend handles from kde_picked
handles, labels = ax.get_legend_handles_labels()
h = [handles[r] for r in range(len(handles)) if labels[r] == "jobtype"]
h = h[::-1]
h.extend([handles[r] for r in range(len(handles)) if labels[r] == "GT"])
l = [f"class {r+1}" for r in range(num_classes)] + ["GT"]
# add the legend
ax.legend(
    handles=h,
    labels=l,
    loc='upper right',
    bbox_to_anchor=(1.50, 1.0),
    ncol=1,
    fontsize=14,
    frameon=True,
)
ax_truth.set_ylabel("GT", fontsize=16, rotation=270, labelpad=20)
if num_classes == 10:
    ylim = ax.get_ylim()
    ax.set_ylim(ylim[0], ylim[1]*1.1)
else:
    ylim = ax_truth.get_ylim()
    ax.set_ylim(ylim[0], ylim[1])

# change the xticks to the time in us
ax.set_xticks(np.linspace(0, df_truth["pdb_index"].max(), 5))
ax.set_xticklabels(np.linspace(0, df_truth["pdb_index"].max()*dt, 5).round(2), fontsize=14)
ax.set_xlabel("Time (ps)", fontsize=16)
ax.set_ylabel("Density", fontsize=16)
ax.tick_params(axis='both', which='major', labelsize=14)
ax_truth.tick_params(axis='both', which='major', labelsize=14)

# save the figure
# fig.savefig(os.path.join(project_dir, "figures", f"frame_distribution.pdf"), bbox_inches="tight")
print(f"saved figure to: {os.path.join(project_dir, 'figures', 'frame_distribution.pdf')}")


## panel C
plotting the precision per 3D class and the fraction of particles in each 3D class

In [None]:
# plot the preicison per 3D class
df_precision["jobtype"] = df_precision["metadata_filename"].apply(lambda x: jobtypes[x].replace("ab initio", ""))
# change the numbering in the classes from 0-n to 1-n+1
df_precision["jobtype"] = df_precision["jobtype"].apply(lambda x: f"class {int(x.split()[-1])+1}")
num_classes = len(df_precision["jobtype"].unique())
colors = sns.color_palette("RdYlBu", n_colors=num_classes)

fig, ax = plt.subplots(figsize=(3.5, 3.5))
sns.barplot(
    data=df_precision,
    x="jobtype",
    y="precision",
    hue="jobtype",
    palette=colors,
    edgecolor='black',
    ax=ax,
    dodge=0,
    linewidth=1,
    errorbar=None,
)
ax.set_ylim((0, 1))
ax.set_ylabel("Precision", fontsize=16)
ax.set_xlabel("")
ax.tick_params(axis='both', which='major', labelsize=14)
ax.legend().remove()

# fig.savefig(os.path.join(project_dir, "figures", f"precision_per_class.pdf"), bbox_inches="tight")
print(f"saved figure to {os.path.join(project_dir, 'figures', 'precision_per_class.pdf')}")


In [None]:
results = {
    "class": [],
    "picked_fraction": [],
}
for jobtype in df_picked["jobtype"].unique():
    results["class"].append(jobtype.replace("ab initio", ""))
    results["picked_fraction"].append(len(df_picked[df_picked["jobtype"] == jobtype]) / len(df_picked))
df_results = pd.DataFrame(results)
df_results["class"] = df_results["class"].apply(lambda x: f"class {int(x.split()[-1])+1}")

fig, ax = plt.subplots(figsize=(3.5, 3.5))
sns.barplot(
    data=df_results,
    x="class",
    y="picked_fraction",
    hue="class",
    palette=colors,
    edgecolor='black',
    ax=ax,
    dodge=0,
    linewidth=1,
)
ax.set_ylim((0, 1))
ax.set_ylabel("Picked fraction", fontsize=16)
ax.set_xlabel("")
ax.tick_params(axis='both', which='major', labelsize=14)
ax.legend().remove()

# fig.savefig(os.path.join(project_dir, "figures", f"particle_fraction_per_class.pdf"), bbox_inches="tight")
print(f"saved figure to {os.path.join(project_dir, 'figures', 'particle_fraction_per_class.pdf')}")


## panel D
For each class, the density map is plotted before refinement

In [None]:
# flipping volume
volumes_to_flip = [
    # "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2/cryosparc_P51_J519/cryosparc_P51_J519_class_00_final_volume.mrc",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2/cryosparc_P51_J520/cryosparc_P51_J520_003_volume_map_sharp.mrc",
]

for volume_to_flip in volumes_to_flip:
    with mrcfile.open(volume_to_flip, mode='r', permissive=True) as mrc:
        volume = mrc.data
        vsize = mrc.voxel_size
    with mrcfile.new(volume_to_flip.replace(".mrc", "_flipped.mrc"), overwrite=True) as mrc:
        mrc.set_data(np.flip(volume, axis=0))
        mrc.voxel_size = vsize
        


In [None]:
# print the colors for the 4 classes as hex values
colors = sns.color_palette("RdYlBu", n_colors=num_classes)
for color in colors:
    # convert the rgb values to hex
    color_hex = '#%02x%02x%02x' % tuple([int(r*255) for r in color])
    print(color_hex)
    



## in text
print some stats about the distribution of the particles in the 3D classes

In [None]:
# print the total number of TP and FP particles in the dataset
print(f"total number of TP particles: {len(df_picked_grouped)}")
print(f"total number of FP particles: {len(df_picked[df_picked['TP'] == False])}")
print(f"total number of particles: {len(df_picked)}")

# print the precision and recall per class
df_precision = df_precision.groupby("jobtype").agg({"precision": "mean", "recall": "mean"})
df_precision.head()