# plotting functions of figure 4 in the manuscript


In [None]:
# imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mrcfile

from emmer.pdb.convert.convert_pdb_to_map import convert_pdb_to_map
from emmer.ndimage.filter.low_pass_filter import low_pass_filter
from emmer.ndimage.filter.smoothen_mask import smoothen_mask

# roodmus
from roodmus.analysis.utils import load_data


In [None]:
# functions
def get_precision_per_class(df_truth, df_picked, classification_job):
    result = {
        "class": [],
        "precision": [],
        "number_of_particles": [],
        "picked_fraction": [],
        "fraction_of_true_positives": [],
    }


    num_gt_particles = len(df_truth)
    df_picked_grouped = df_picked.groupby("jobtype").get_group(classification_job).groupby("class2D")
    num_gt_particles_in_picked = len(df_picked.groupby("jobtype").get_group(classification_job).query("TP == True"))
    print(f"number of particles in ground truth: {num_gt_particles}")
    print(f"number of particles in picked: {num_gt_particles_in_picked}")
    num_picked_particles = len(df_picked.groupby("jobtype").get_group(classification_job))
    for class2d in df_picked_grouped.groups:
        particles_in_class = len(df_picked_grouped.get_group(class2d))
        print(f"number of picked particles in class {class2d}: {particles_in_class}. Fraction of picked particles: {np.round(particles_in_class / num_picked_particles, 2)}")
        TP_in_class = len(df_picked_grouped.get_group(class2d).query("TP == True"))
        precision_in_class = TP_in_class / particles_in_class
        print(f"precision in class {class2d}: {np.round(precision_in_class, 2)}")
        print(f"TP fraction in class: {np.round(TP_in_class / num_gt_particles_in_picked, 2)}")
        print(f"test: {np.round((TP_in_class/num_gt_particles_in_picked) / (particles_in_class/num_picked_particles), 2)}")

        result["class"].append(int(class2d))
        result["precision"].append(precision_in_class)
        result["number_of_particles"].append(particles_in_class)
        result["picked_fraction"].append(particles_in_class / num_picked_particles)
        result["fraction_of_true_positives"].append(TP_in_class / num_gt_particles_in_picked)

    df_result = pd.DataFrame(result)
    return df_result


In [None]:
### data loading
# project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
project_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
config_dir = os.path.join(project_dir, "Micrographs")
meta_files = [
    # os.path.join(project_dir, "Extract", "job004", "particles.star"),
    # os.path.join(project_dir, "Extract", "job009", "particles.star"),
    # os.path.join(project_dir, "Select", "job012", "particles.star"),
    os.path.join(project_dir, "Class3D", "job014", "run_it025_data.star"),
    os.path.join(project_dir, "Class3D", "job037", "run_it025_data.star"),
    os.path.join(project_dir, "Class3D", "job038", "run_it025_data.star"),
]

jobtypes = {
    os.path.join(project_dir, "Extract", "job004", "particles.star"): "LoG",
    os.path.join(project_dir, "Extract", "job009", "particles.star"): "topaz",
    os.path.join(project_dir, "Select", "job012", "particles.star"): "class selection",
    os.path.join(project_dir, "Class3D", "job014", "run_it025_data.star"): "3_classes",
    os.path.join(project_dir, "Class3D", "job037", "run_it025_data.star"): "2_classes",
    os.path.join(project_dir, "Class3D", "job038", "run_it025_data.star"): "10_classes",
}

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True # prints out progress statements
ignore_missing_files = True # if .mrc files are missing, the analysis will still be performed
enable_tqdm = True # enables tqdm progress bars

for i, meta_file in enumerate(meta_files):
    if i == 0:
        analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
    else:
        analysis.add_data(meta_file, config_dir, verbose=verbose) # updates the class with the next metadata file

df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
# df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
df_picked["jobtype"] = df_picked["metadata_filename"].map(jobtypes)
df_picked_grouped = df_picked.groupby("jobtype")
for group in df_picked_grouped.groups:
    print(f"jobtype: {group}, number of particles: {len(df_picked_grouped.get_group(group))}")

p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
    meta_files,
    df_picked,
    df_truth,
    verbose=False,
    enable_tqdm=True,
)
df_precision = analysis.compute_1to1_match_precision(
    p_match,
    p_unmatched,
    t_unmatched,
    df_truth,
    verbose=False,
)

df_precision["jobtype"] = df_precision["metadata_filename"].map(jobtypes)
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
df_picked["closest_truth_index"] = closest_truth_index
df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)
df_precision.head()

## panel A
distributions of the particles in each class over the MD trajectories of the open and closed states


In [None]:
# add a column to the df_picked data frame that indicates if the particles originates from the open or closed state of the spike protein
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
    df_picked_grouped = df_picked.query("TP == True").groupby("jobtype").get_group(classification_job)

    # if the closest_pdb_index < 8334, the particle originates from the closed state, otherwise it originates from the open state
    df_picked_grouped["state"] = df_picked_grouped["closest_pdb_index"] <= 8334
    df_picked_grouped["state"] = df_picked_grouped["state"].map({True: "closed", False: "open"})

    # df_picked_grouped["class2D"] = df_picked_grouped["class2D"].astype(int)
    df_picked_grouped.sort_values(by="class2D", inplace=True)
    num_classes = len(df_picked_grouped.groupby("class2D").groups)
    print(f"found {num_classes} classes")
    colors = sns.color_palette("RdYlBu", n_colors=num_classes)

    fig, ax = plt.subplots(figsize=(7, 3.5))
    # make kde plot out of the data
    kde_picked = sns.kdeplot(
        data=df_picked_grouped,
        x="closest_pdb_index",
        ax=ax,
        hue="class2D",
        fill=False,
        label="class2D",
        linewidth=5,
        alpha=1,
        palette=colors,
        legend=True,
    )
    # add an extra black line over the second class
    kde_extra = sns.kdeplot(
        data=df_picked_grouped,
        x="closest_pdb_index",
        ax=ax,
        hue="class2D",
        palette={r+1:"black" for r in range(num_classes)},
        label="picked_particles",
        linewidth=1,
        alpha=1,
        fill=False,
        legend=False,
    )
    # add kdeplot of the True particles
    ax_truth = ax.twinx()
    kde_truth = sns.kdeplot(
        data=df_truth,
        x="pdb_index",
        ax=ax_truth,
        color="black",
        linestyle="--",
        label="GT",
        linewidth=2,
        fill=False,
        alpha=1,
        legend=True,
    )

    # get the legend handles from kde_picked
    handles, labels = ax.get_legend_handles_labels()
    h = [handles[r] for r in range(len(handles)) if labels[r] == "class2D"]
    h = h[::-1]
    h.extend([handles[r] for r in range(len(handles)) if labels[r] == "GT"])
    l = [f"class {r+1}" for r in range(num_classes)] + ["GT"]
    # add the legend
    ax.legend(
        handles=h,
        labels=l,
        loc='upper right',
        bbox_to_anchor=(1.40, 1.0),
        ncol=1,
        fontsize=14,
        frameon=True,
    )
    ax_truth.set_ylabel("GT", fontsize=16, rotation=270, labelpad=20)
    if num_classes == 10:
        ylim = ax.get_ylim()
        ax.set_ylim(ylim[0], ylim[1]*1.1)
    else:
        ylim = ax_truth.get_ylim()
        ax.set_ylim(ylim[0], ylim[1])

    # change the xticks to the time in us
    ax.set_xticks([0, 8334-1000, 8334+1000, 16668])
    ax.set_xticklabels(["0 \u03BCs", "10 \u03BCs", "0 \u03BCs", "10 \u03BCs"], fontsize=14)
    ax.axvline(x=8334, color="red", linestyle="-.", linewidth=2)
    ax.set_xlabel("")
    ax.set_ylabel("Density", fontsize=16)
    # label the right side of the plot with 'open state' and the left side with 'closed state' undeneath the x-axis
    ax.text(0.25, -0.2, "Closed state", ha='center', va='bottom', fontsize=16, fontweight='bold', transform=ax.transAxes)
    ax.text(0.75, -0.2, "Open state", ha='center', va='bottom', fontsize=16, fontweight='bold', transform=ax.transAxes)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax_truth.tick_params(axis='both', which='major', labelsize=14)

    outfilename = os.path.join(project_dir, "figures", f"frame_distribution_{classification_job}.pdf")
    # fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")

    # break


In [None]:
# for the 10-class case plot all distributions separately
df_picked_grouped = df_picked.query("TP == True").groupby("jobtype").get_group("10_classes")
df_picked_grouped["state"] = df_picked_grouped["closest_pdb_index"] <= 8334
df_picked_grouped["state"] = df_picked_grouped["state"].map({True: "closed", False: "open"})
df_picked_grouped.sort_values(by="class2D", inplace=True)
num_classes = len(df_picked_grouped.groupby("class2D").groups)
print(f"found {num_classes} classes")
colors = sns.color_palette("RdYlBu", n_colors=num_classes)

for class2d in df_picked_grouped.groupby("class2D").groups:
    fig, ax = plt.subplots(figsize=(7, 3.5))
    # make kde plot out of the data
    kde_picked = sns.kdeplot(
        data=df_picked_grouped.groupby("class2D").get_group(class2d),
        x="closest_pdb_index",
        ax=ax,
        fill=False,
        label="class2D",
        linewidth=5,
        alpha=1,
        legend=True,
    )
    # get the legend handles from kde_picked
    handles, labels = ax.get_legend_handles_labels()
    h = [handles[r] for r in range(len(handles)) if labels[r] == "state"]
    h = h[::-1]
    h.extend([handles[r] for r in range(len(handles)) if labels[r] == "GT"])
    l = ["closed", "open", "GT"]
    # add the legend
    ax.legend(
        handles=h,
        labels=l,
        loc='upper right',
        bbox_to_anchor=(1.40, 1.0),
        ncol=1,
        fontsize=14,
        frameon=True,
    )
    ax_truth.set_ylabel("GT", fontsize=16, rotation=270, labelpad=20)
    if num_classes == 10:
        ylim = ax.get_ylim()
        ax.set_ylim(ylim[0], ylim[1]*1.1)
    else:
        ylim = ax_truth.get_ylim()

    # change the xticks to the time in us
    ax.set_xticks([0, 8334-1000, 8334+1000, 16668])
    ax.set_xticklabels(["0 \u03BCs", "10 \u03BCs", "0 \u03BCs", "10 \u03BCs"], fontsize=14)
    ax.axvline(x=8334, color="red", linestyle="-.", linewidth=2)
    ax.set_xlabel("")
    ax.set_ylabel("Density", fontsize=16)
    # label the right side of the plot with 'open state' and the left side with 'closed state' undeneath the x-axis
    ax.text(0.25, -0.2, "closed state", ha='center', va='bottom', fontsize=16, fontweight='bold', transform=ax.transAxes)
    ax.text(0.75, -0.2, "open state", ha='center', va='bottom', fontsize=16, fontweight='bold', transform=ax.transAxes)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax_truth.tick_params(axis='both', which='major', labelsize=14)
    ax.set_title(f"class {class2d}")

    outfilename = os.path.join(project_dir, "figures", f"frame_distribution_{classification_job}_class{class2d}.png")
    # fig.savefig(outfilename, dpi=300, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")

    # close the plot
    plt.close(fig)


In [None]:
# for each class print if there are more particles from the open or closed state
df_picked["state"] = df_picked["closest_pdb_index"] <= 8334
df_picked["state"] = df_picked["state"].map({True: "closed", False: "open"})
df_picked["class2D"] = df_picked["class2D"].astype(int)

for classification_job in ["2_classes", "3_classes", "10_classes"]:
    print(f"jobtype: {classification_job}")
    df_picked_grouped = df_picked.groupby(["jobtype", "TP"]).get_group((classification_job,True))
    # df_picked_grouped["state"] = df_picked_grouped["closest_pdb_index"] <= 8334
    # df_picked_grouped["state"] = df_picked_grouped["state"].map({True: "closed", False: "open"})
    # df_picked_grouped["class2D"] = df_picked_grouped["class2D"].astype(int)
    # df_picked_grouped.sort_values(by="class2D", inplace=True)
    num_classes = len(df_picked_grouped.groupby("class2D").groups)

    for class2D in range(1, num_classes+1):
        df_class = df_picked_grouped.groupby("class2D").get_group(class2D)
        num_open = len(df_class.query("state == 'open'"))
        num_closed = len(df_class.query("state == 'closed'"))
        percentage_open = num_open / (num_open + num_closed)*100
        print(f"class {class2D}: {num_open} ({percentage_open:.1f}%) open, {num_closed} ({100-percentage_open:.1f}%) closed")

## panel B
plotting the precision of each 3D class

In [None]:
# plot the preicison per 3D class
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    df_result = get_precision_per_class(df_truth, df_picked, classification_job)
    # setup colors
    num_classes = len(df_result["class"].unique())
    colors = sns.color_palette("RdYlBu", n_colors=num_classes)

    fig, ax = plt.subplots(figsize=(3.5, 3.5))
    sns.barplot(
        data=df_result,
        x="class",
        y="precision",
        hue="class",
        palette=colors,
        edgecolor='black',
        ax=ax,
        dodge=0,
        linewidth=1,
    )
    # remove legend
    ax.legend().remove()
    ax.set_ylim((0, 1))
    ax.set_ylabel("Precision", fontsize=16)
    ax.set_xlabel("Class", fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=14)

    outfilename = os.path.join(project_dir, "figures", f"precision_{classification_job}.pdf")
    # fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")


## panel C
plotting the fraction of particles in each class

In [None]:
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    df_result = get_precision_per_class(df_truth, df_picked, classification_job)
    # setup colors
    num_classes = len(df_result["class"].unique())
    colors = sns.color_palette("RdYlBu", n_colors=num_classes)

    fig, ax = plt.subplots(figsize=(3.5, 3.5))
    sns.barplot(
        data=df_result,
        x="class",
        y="picked_fraction",
        hue="class",
        palette=colors,
        edgecolor='black',
        ax=ax,
        dodge=0,
        linewidth=1,
    )
    # remove legend
    ax.legend().remove()
    ax.set_ylabel("particle fraction")
    ax.set_ylim((0, 1))
    ax.set_ylabel("Picked fraction", fontsize=16)
    ax.set_xlabel("Class", fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=14)

    outfilename = os.path.join(project_dir, "figures", f"picked_fraction_{classification_job}.pdf")
    # fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")


## panel D
plotting the correlation matrix between each class and a sampling of the frames in the MD trajectory

In [None]:
project_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    print(f"loading data for {classification_job}")
    correlation_matrix_filename = os.path.join(project_dir, "aligned_ensembles", f"correlation_matrix_{classification_job}.npy")
    if os.path.exists(correlation_matrix_filename):
        correlation_matrix = np.load(correlation_matrix_filename)
    else:
        print(f"correlation matrix for {classification_job} does not exist yet. Compute it first.")
    correlation_matrix_normalised = (correlation_matrix - correlation_matrix.min(axis=0)) / (correlation_matrix.max(axis=0) - correlation_matrix.min(axis=0))
    n_classes = correlation_matrix.shape[1]
    n_frames = correlation_matrix.shape[0]
    print(f"number of classes: {n_classes}")
    print(f"number of frames: {n_frames}")

    fig, ax = plt.subplots(figsize=(3.5, 3.5))
    # show the heatmap as a square (~50x3)
    ax.imshow(correlation_matrix_normalised, cmap="coolwarm", aspect="auto", origin="lower")
    ticklabels = [f"class {i+1}" for i in range(n_classes)]
    ax.set_xticks(range(n_classes))
    ax.set_xticklabels(ticklabels, fontsize=12, rotation=45)
    ax.set_yticks([0, n_frames//2-2, n_frames//2+2, n_frames-1])
    ax.set_yticklabels(["0 \u03BCs", "10 \u03BCs", "0 \u03BCs", "10 \u03BCs"], fontsize=12)
    ax.axhline(y=n_frames//2, color="black", linestyle="solid", linewidth=2)
    # add colorbar
    cbar = ax.figure.colorbar(ax.get_images()[0], ax=ax, orientation="vertical")
    ax.text(
        -0.65,
        3*n_frames//4,
        "Closed state",
        ha='right',
        va='bottom',
        fontsize=12,
        fontweight='bold',
    )
    ax.text(
        -0.65,
        n_frames//4,
        "Open state",
        ha='right',
        va='bottom',
        fontsize=12,
        fontweight='bold',
    )

    # save the figure
    outfilename = os.path.join(project_dir, "figures", f"correlation_matrix_{n_classes}classes.pdf")
    # fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")

In [None]:
# print the average correlation with the open and closed states for each class
for classification_job in ["2_classes", "3_classes", "10_classes"]:
    print(f"loading data for {classification_job}")
    correlation_matrix_filename = os.path.join(project_dir, "aligned_ensembles", f"correlation_matrix_{classification_job}.npy")
    if os.path.exists(correlation_matrix_filename):
        correlation_matrix = np.load(correlation_matrix_filename)
    else:
        print(f"correlation matrix for {classification_job} does not exist yet. Compute it first.")
    n_classes = correlation_matrix.shape[1]
    n_frames = correlation_matrix.shape[0]
    print(f"number of classes: {n_classes}")
    print(f"number of frames: {n_frames}")

    # compute the average correlation with the open and closed states for each class
    avg_correlation_open = np.zeros(n_classes)
    avg_correlation_closed = np.zeros(n_classes)
    for i in range(n_classes):
        avg_correlation_open[i] = np.mean(correlation_matrix[:n_frames//2, i])
        avg_correlation_closed[i] = np.mean(correlation_matrix[n_frames//2:, i])

    # plot the average correlation with the open and closed states for each class
    fig, ax = plt.subplots(figsize=(3.5, 3.5))
    ax.plot(range(1, n_classes+1), avg_correlation_open, label="open state", color="blue")
    ax.plot(range(1, n_classes+1), avg_correlation_closed, label="closed state", color="red")
    ax.set_xlabel("class", fontsize=16)
    ax.set_ylabel("average correlation", fontsize=16)
    ax.legend()
    ax.tick_params(axis='both', which='major', labelsize=14)

    outfilename = os.path.join(project_dir, "figures", f"average_correlation_{classification_job}.pdf")
    fig.savefig(outfilename, bbox_inches="tight")
    print(f"saved figure to: {outfilename}")

## in text
to report the resolution in text of the refined maps for the 3class dataset I need to load the generated confidence masks and filter them

In [None]:
# project_dir = "/home/mjoosten1/projects/roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
# project_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021566-11021571-mixed"
# data = {
#     0:{
#         "class": 1,
#         "ConfidenceMap": "job103",
#     },
#     1: {
#         "class": 2,
#         "ConfidenceMap": "job104",
#     },
#     2: {
#         "class": 3,
#         "ConfidenceMap": "job105",
#     },
# }

# for key, item in data.items():
#     confidence_map_file = os.path.join(project_dir, "ConfidenceMap", item["ConfidenceMap"], "run_class001_confidenceMap.mrc")
#     confidence_map = mrcfile.open(confidence_map_file)
#     confidence_mask = confidence_map.data > 0.99

#     confidence_mask_smooth = smoothen_mask(
#         confidence_mask,
#         cosine_falloff_length=5,
#     )

#     with mrcfile.new(os.path.join(project_dir, "ConfidenceMap", item["ConfidenceMap"], "run_class001_confidenceMap_filtered.mrc"), overwrite=True) as mrc:
#         mrc.set_data(confidence_mask_smooth.astype(np.float32))
#         mrc.voxel_size = confidence_map.voxel_size
#         mrc.header.origin.x = confidence_map.header.origin.x
#         mrc.header.origin.y = confidence_map.header.origin.y
#         mrc.header.origin.z = confidence_map.header.origin.z