## Notebook showcasing the various options in Roodmus for analysis and visualisations
In this notebook, the user can load metadata from one or several jobs from a processing pipeline done in RELION or cryoSPARC. This metadata, along with the ground-truth particle parameters are loaded into data frames, which allow for easy and conveniet plotting. We also provide several convenient functions to make plots.


In [None]:
### imports
# general
import os
import mrcfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# roodmus
from importlib import reload
import roodmus.analysis.utils
reload(roodmus.analysis.utils)

from roodmus.analysis.utils import load_data
from roodmus.analysis.plot_ctf import plot_CTF, plot_defocus_scatter
from roodmus.analysis.plot_picking import label_micrograph_picked, label_micrograph_truth, label_micrograph_truth_and_picked, plot_precision, plot_recall, plot_boundary_investigation, plot_overlap_investigation
# from roodmus.analysis.analyse_alignment import alignment_3D
# from roodmus.analysis.plot_alignment import 


In [None]:
### data loading
config_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/test/"
meta_files = [
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J293_picked_particles.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J296_020_particles.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J297_passthrough_particles_selected.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J298_picked_particles.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J428_040_particles.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J429_passthrough_particles_selected.cs",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J433_passthrough_particles.cs",
]

jobtypes = {
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J293_picked_particles.cs": "blob picking",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J296_020_particles.cs": "2D classification",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J297_passthrough_particles_selected.cs": "class selection",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J298_picked_particles.cs": "template picking",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J428_040_particles.cs": "2D classification 2",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J429_passthrough_particles_selected.cs": "class selection 2",
    "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_1/cryoSPARC/J433_passthrough_particles.cs": "3D classification",   
}

# meta_files = [
#     "data/6xm5_steered_Roodmus_1/RELION/job015_manual.star",
#     "data/6xm5_steered_Roodmus_1/RELION/job006_topaz.star",
#     "data/6xm5_steered_Roodmus_1/RELION/job008_subset_selection.star",
# ]

# jobtypes = {
#     "data/6xm5_steered_Roodmus_1/RELION/job015_manual.star": "manual picking",	
#     "data/6xm5_steered_Roodmus_1/RELION/job006_topaz.star": "topaz picking",
#     "data/6xm5_steered_Roodmus_1/RELION/job008_subset_selection.star": "2D classification",
# }

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True

for i, meta_file in enumerate(meta_files):
    if i == 0:
        analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose) # creates the class
    else:
        analysis.add_data(meta_file, config_dir, verbose=verbose) # updates the class with the next metadata file


In [None]:
### turn the loaded data into a pandas dataframe
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
df_truth.head()
df_picked.head()

In [None]:
### saving the dataframes
# it is recommended to save the dataframes after running the rest of the notebook, as they may be modified by downstream analysis

df_picked.to_csv("picked_particles.csv")
df_truth.to_csv("truth_particles.csv")

### CTF estimation


In [None]:
### scatter plot of the estimated vs. the true defocus values
meta_index = 0 # index of the metadata file to plot

palette = "BuGn"

fig, ax = plot_defocus_scatter(df_picked.groupby("metadata_filename").get_group(meta_files[meta_index]),
                                df_truth,
                                palette=palette)


In [None]:
### plot the CTF estimation for a single micrograph
meta_index = 0 # index of the metadata file to plot
ugraph_index = 3 # which micrograph to plot

fig, ax = plot_CTF(df_picked.groupby("metadata_filename").get_group(meta_files[meta_index]),
                    df_truth, config_dir, ugraph_index)

In [None]:
### plot the CTF for the particle with the largest defocus error (should take no more than a few seconds)
max_error_index = 0
max_error = 0
for i, groupname in enumerate(df_picked.groupby(["ugraph_filename"]).groups.keys()):
    defocus_estimated = df_picked.groupby(["ugraph_filename"]).get_group(groupname)["defocusU"].mean()
    defcous_true = np.abs(df_truth.groupby(["ugraph_filename"]).get_group(groupname)["defocus"].mean())
    error = np.abs(defocus_estimated - defcous_true)
    if error > max_error:
        max_error = error
        max_error_index = i

fig, ax = plot_CTF(df_picked, df_truth, config_dir, max_error_index)


### Particle picking

In [None]:
### plot the picked particles
ugraph_index = 0 # which micrograph to plot
metadata_index = 0 # which metadata file to plot

fig, ax = label_micrograph_picked(df_picked.groupby("metadata_filename").get_group(meta_files[metadata_index]), ugraph_index, config_dir, box_width=48, box_height=48, verbose=verbose)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
fig.set_size_inches(7, 7)


In [None]:
### plot the truth particles
ugraph_index = 3 # which micrograph to plot

fig, ax = label_micrograph_truth(df_truth, ugraph_index, config_dir, box_width=32, box_height=32, verbose=verbose)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
fig.set_size_inches(7, 7)


In [None]:
### plot the truth and picked particles
ugraph_index = 3 # which micrograph to plot
metadata_index = 0 # which metadata file to plot

fig, ax = label_micrograph_truth_and_picked(df_picked.groupby("metadata_filename").get_group(meta_files[metadata_index]),
                                             df_truth, ugraph_index, config_dir, box_width=48, box_height=48, verbose=verbose)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
fig.set_size_inches(7, 7)


In [None]:
### compute precision and recall (may take a few minutes)
df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
df_precision.head()


In [None]:
### plot boxplot for precision and recall
fig, ax = plot_precision(df_precision, jobtypes)
fig.set_size_inches([10,10])
fig, ax = plot_recall(df_precision, jobtypes)
fig.set_size_inches([10,10])


In [None]:
### alternatively, plot the precision and recall in the same plot
df = df_precision.melt(id_vars=["metadata_filename", "ugraph_filename", "defocus", "TP", "FP", "FN", "multiplicity", "num_particles_picked", "num_particles_truth", "class2D"])

plt.rcParams["font.size"] = 20
fig, ax = plt.subplots(figsize=(10,10))
sns.boxplot(x="metadata_filename", y="value", data=df, ax=ax, fliersize=0, palette="RdYlBu", hue="variable")
ax.set_ylabel("")
ax.set_xlabel("")
# change the xtix labels to the jobtypes
ax.set_xticklabels([jobtypes[meta_file] for meta_file in meta_files])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# add legend below axis
ax.legend().set_visible(False)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=1, bbox_to_anchor=(1.1, 0.85))
fig.tight_layout()



In [None]:
### plot the picked particles, now with the TP and FP marked in green and red
ugraph_index = 10 # which micrograph to plot
metadata_index = 0 # which metadata file to plot

fig, ax = label_micrograph_picked(df_picked.groupby("metadata_filename").get_group(meta_files[metadata_index]), ugraph_index, config_dir, box_width=48, box_height=48, verbose=verbose)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
fig.set_size_inches(7, 7)

In [None]:
### plot the distribution of the particles in the ugraphs in x, y, and z directions
metadata_index = 0 # which metadata file to plot
bin_width = [100, 100, 10] # bin width for x, y, z
axis = ["x", "y", "z"]

metadata_filename = meta_files[metadata_index]
for a, bnwdth in zip(axis, bin_width):
    fig, ax = plot_boundary_investigation(df_truth, df_picked, metadata_filename, bnwdth, axis=a)

In [None]:
df_overlap = analysis.compute_overlap(df_picked, df_truth, verbose=verbose)
df_overlap.head()

In [None]:
### plot the overlap between the picked and truth particles
metadata_index = 0 # which metadata file to plot. If None, all metadata files are plotted

metadata_filename = meta_files[metadata_index]
fig, ax = plot_overlap_investigation(df_overlap, metadata_filename, jobtypes=jobtypes)


In [None]:
### plot the distribution of trajectory frames in a metadata file
metadata_index = 0 # which metadata file to plot
df_picked["closest_pdb_index"] = df_picked["closest_pdb"].apply(lambda x: int(x.split("_")[-1].split(".")[0]))
# set the closest_pdb_index to np.nan if the particle is not closer to a truth particle thatn the particle diameter
df_picked.loc[df_picked["closest_dist"] > particle_diameter, "closest_pdb_index"] = np.nan
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.split("_")[-1].split(".")[0]))

plt.rcParams["font.size"] = 20
fig, ax = plt.subplots(figsize = (10, 10))
sns.histplot(df_picked.groupby("metadata_filename").get_group(meta_files[metadata_index])["closest_pdb_index"], ax=ax, bins=100, kde=True)
sns.histplot(df_truth["pdb_index"], ax=ax, bins=100, kde=True, color="red", alpha=0.2)
ax.set_xlabel("frame index")
ax.set_ylabel("count")
ax.set_title(jobtypes[meta_files[metadata_index]])
fig.tight_layout()
fig.legend(["picked", "truth"], loc='lower center', ncol=1, bbox_to_anchor=(1.1, 0.85))



In [None]:
### plot the precision per class
metadata_index = 4 # which metadata file to plot. Must have a class2D column

df_grouped = df_picked.groupby("metadata_filename").get_group(meta_files[metadata_index])
results = {
    "class2D": [],
    "precision": [],
    "average defocus": [],
}
for groupname in df_grouped.groupby("class2D").groups.keys():
    precision = df_grouped.groupby("class2D").get_group(groupname)["TP"].sum() / (df_grouped.groupby("class2D").get_group(groupname)["TP"].size)
    results["class2D"].append(int(groupname))
    results["precision"].append(precision)
    results["average defocus"].append(df_grouped.groupby("class2D").get_group(groupname)["defocusU"].mean())
df = pd.DataFrame(results)

plt.rcParams["font.size"] = 20
fig, ax = plt.subplots(figsize = (25, 10))
sns.barplot(x="class2D", y="precision", data=df, ax=ax, palette="YlGnBu")
ax.set_xlabel("class2D")
ax.set_ylabel("precision")
ax.set_title(jobtypes[meta_files[metadata_index]])
# remove every second xtick label
fig.tight_layout()



### 3D alignment

In [None]:
from importlib import reload
import roodmus.analysis.analyse_alignment
reload(roodmus.analysis.analyse_alignment)
from roodmus.analysis.analyse_alignment import alignment_3D
import roodmus.analysis.utils
reload(roodmus.analysis.utils)

import pandas as pd

In [None]:
analysis_alignment = alignment_3D(meta_file, config_dir, load_all_configs=True, verbose=verbose) # creates the class
df_alignment_estimated = pd.DataFrame(analysis_alignment.results_picking)
df_alignment_truth = pd.DataFrame(analysis_alignment.results_truth)
df_alignment_truth

In [None]:
### plot the alignment
grid = sns.jointplot(x="euler1", y="euler2", data=df_alignment_estimated, kind="hex", color="k", gridsize=50, bins="log", cmap="viridis")
grid.ax_joint.set_xlabel("Euler 1")
grid.ax_joint.set_ylabel("Euler 2")
grid.fig.set_size_inches(14, 7)
# adjust the x and y ticks to show multiples of pi
grid.ax_joint.set_xticks([-np.pi, -3/4*np.pi, -np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2, 3/4*np.pi, np.pi])
grid.ax_joint.set_xticklabels(["$-\pi$", "$-3/4\pi$", "$-\pi/2$", "$-\pi/4$", "$0$", "$\pi/4$", "$\pi/2$", "$3/4\pi$", "$\pi$"])
grid.ax_joint.set_yticks([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
grid.ax_joint.set_yticklabels(["$-\pi/2$", "$-\pi/4$", "$0$", "$\pi/4$", "$\pi/2$"])

# repeat for the true particles
grid = sns.jointplot(x="euler1", y="euler2", data=df_alignment_truth, kind="hex", color="k", gridsize=50, bins="log", cmap="viridis")
grid.ax_joint.set_xlabel("Euler 1")
grid.ax_joint.set_ylabel("Euler 2")
grid.fig.set_size_inches(14, 7)
# adjust the x and y ticks to show multiples of pi
grid.ax_joint.set_xticks([-np.pi, -3/4*np.pi, -np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2, 3/4*np.pi, np.pi])
grid.ax_joint.set_xticklabels(["$-\pi$", "$-3/4\pi$", "$-\pi/2$", "$-\pi/4$", "$0$", "$\pi/4$", "$\pi/2$", "$3/4\pi$", "$\pi$"])
grid.ax_joint.set_yticks([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
grid.ax_joint.set_yticklabels(["$-\pi/2$", "$-\pi/4$", "$0$", "$\pi/4$", "$\pi/2$"])



### Misc investigations