# plotting functions of supplmentary figure 1 in the manuscript
This figure shows the following:

- Panel showing CTF estimation for the covid spike trimer
- Panel showing ground-truth and estimated poses for covid spike trimer

- Panel showing consensus reconstruction of covid spike monomer
- Panel showing particle picking for covid spike monomer


In [None]:
# imports
import os
import pandas as pd
import numpy as np

# roodmus
from roodmus.analysis.utils import load_data
from roodmus.analysis.plot_ctf import plot_defocus_scatter
from roodmus.analysis.plot_alignment import true_pose_distribution_plot, picked_pose_distribution_plot
from roodmus.analysis.plot_picking import plot_recall, plot_precision

In [None]:
# loading covid spike trimer data
project_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/20231116_DESRES-Trajectory_sarscov2-11021571-all-glueCA"
config_dir = "/tudelft/mjoosten1/staff-umbrella/ajlab/MJ/projects/Roodmus/data/DE-Shaw_covid_spike_protein/DESRES-Trajectory_sarscov2-11021571-all-glueCA/Micrographs"
# config_dir = os.path.join(project_dir, "Micrographs")
figures_dir = os.path.join(project_dir, "figures")
meta_files = [
    os.path.join(project_dir, "Refine3D", "job008", "run_it011_data.star"),
    os.path.join(project_dir, "Refine3D", "job039", "run_it014_data.star"),
]

jobtypes = {
    os.path.join(project_dir, "Refine3D", "job008", "run_it011_data.star"): "Refine3D_8000",
    os.path.join(project_dir, "Refine3D", "job039", "run_it014_data.star"): "Refine3D_236079",
}

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True # prints out progress statements
ignore_missing_files = True # if .mrc files are missing, the analysis will still be performed
enable_tqdm = True # enables tqdm progress bars

for i, meta_file in enumerate(meta_files):
    if i == 0:
        analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
    else:
        analysis.add_data(meta_file, config_dir, verbose=verbose) # updates the class with the next metadata file
    
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
# df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
    meta_files,
    df_picked,
    df_truth,
    verbose=False,
    enable_tqdm=True,
)
df_precision = analysis.compute_1to1_match_precision(
    p_match,
    p_unmatched,
    t_unmatched,
    df_truth,
    verbose=False,
)
print(f"mean precision: {df_precision['precision'].mean()}")
print(f"mean recall: {df_precision['recall'].mean()}")
df_precision["jobtype"] = df_precision["metadata_filename"].map(jobtypes)
df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
df_picked["closest_truth_index"] = closest_truth_index
df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)
df_picked.head(n=10)


## panel A
plotting the CTF estimation for the covid spike trimer

In [None]:
from importlib import reload
import roodmus.analysis.plot_ctf
reload(roodmus.analysis.plot_ctf)
from roodmus.analysis.plot_ctf import plot_defocus_scatter

fig, (ax_l, ax_r) = plot_defocus_scatter(
    df_picked=df_picked,
    metadata_filename=meta_files[0],
    df_truth=df_truth,
    palette="BuGn",
)
fig.set_size_inches(4.5*2, 4.5)
# change the xticklabels from Angstrom to um
ax_l.set_xticklabels([f"{-x/1000:.1f}" for x in ax_l.get_xticks()])
ax_r.set_xticklabels([f"{-x/1000:.1f}" for x in ax_r.get_xticks()])
ax_l.set_yticklabels([f"{-x/1000:.1f}" for x in ax_l.get_yticks()])
ax_r.set_aspect("equal")
ax_l.set_aspect("equal")
ax_l.set_xlabel("Ground-truth")
ax_l.set_ylabel("Estimated")
ax_r.set_xlabel("Ground-Truth")
fig.tight_layout()

fig.savefig(os.path.join(figures_dir, "defocus_scatter.pdf"), bbox_inches="tight")
print(f"saved figure to: {os.path.join(figures_dir, 'defocus_scatter.pdf')}")

# print the correlation coefficient between the true and estimated defocus values
print(f"correlation coefficient: {df_picked['defocusU'].corr(df_truth['defocus'])}")

In [None]:
# compute the correlation between the estimated and ground truth defocus values
df_picked_defocus = df_picked.groupby("ugraph_filename").agg({"defocusU": "mean"}).reset_index()
print(len(df_picked_defocus))
df_truth_defocus = df_truth.groupby("ugraph_filename").agg({"defocus": "mean"}).reset_index()
print(len(df_truth_defocus))

# print the correlation coefficient between the true and estimated defocus values
# print(f"correlation coefficient: {np.corrcoef(defocus_truth, defocus_picked)[0, 1]}")
print(f"correlation coefficient: {df_picked_defocus['defocusU'].corr(df_truth_defocus['defocus'])}")

## panel B
plotting distribution of estimated poses for covid spike trimer

In [None]:
from importlib import reload
import roodmus.analysis.plot_alignment
reload(roodmus.analysis.plot_alignment)
from roodmus.analysis.plot_alignment import picked_pose_distribution_plot


grid, vmin, vmax = picked_pose_distribution_plot(
    df_picked=df_picked,
    metadata_filename=meta_files[0],
)
outfilename = os.path.join(figures_dir, f"picked_pose_distribution_{os.path.basename(meta_files[0])}.pdf")
grid.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")

## panel C
plotting ground-truth poses for covid spike trimer

In [None]:
from importlib import reload
import roodmus.analysis.plot_alignment
reload(roodmus.analysis.plot_alignment)
from roodmus.analysis.plot_alignment import true_pose_distribution_plot

grid, _, _ = true_pose_distribution_plot(
    df_truth=df_truth,
    vmin=vmin,
    vmax=vmax,
)
outfilename = os.path.join(figures_dir, "true_pose_distribution.pdf")
grid.savefig(outfilename, bbox_inches="tight")
print(f"saved figure to: {outfilename}")


## panel E
plotting particle picking for covid spike monomer

In [None]:
# data loading for spike monomer
### data loading (6xm5 steered MD)
project_dir = "/home/mjoosten1/projects/roodmus/data/6xm5_steered_Roodmus_2"
config_dir = os.path.join(project_dir, "mrc")
figures_dir = os.path.join(project_dir, "figures")
meta_files = [
    os.path.join(project_dir, "cryoSPARC", "J508_picked_particles.cs"), # blob picker
    os.path.join(project_dir, "cryoSPARC", "J509_extracted_particles.cs"), # filtering
    os.path.join(project_dir, "cryoSPARC", "J513_passthrough_particles_selected.cs"), # 2D class selection
    os.path.join(project_dir, "cryoSPARC", "J515_topaz_picked_particles.cs"), # topaz picker
    os.path.join(project_dir, "cryoSPARC", "J518_050_particles.cs"), # 2D classification 2
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_0.cs"), # ab initio class 0
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_1.cs"), # ab initio class 1
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_2.cs"), # ab initio class 2
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_3.cs"), # ab initio class 3
    os.path.join(project_dir, "cryoSPARC", "J523_passthrough_particles.cs"), # homogneous refinement
]

jobtypes = {
    os.path.join(project_dir, "cryoSPARC", "J508_picked_particles.cs"): "blob_picker",
    os.path.join(project_dir, "cryoSPARC", "J509_extracted_particles.cs"): "filtering",
    os.path.join(project_dir, "cryoSPARC", "J513_passthrough_particles_selected.cs"): "2D_class_selection",
    os.path.join(project_dir, "cryoSPARC", "J515_topaz_picked_particles.cs"): "topaz_picker",
    os.path.join(project_dir, "cryoSPARC", "J518_050_particles.cs"): "2D_classification_2",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_0.cs"): "ab_initio_class_0",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_1.cs"): "ab_initio_class_1",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_2.cs"): "ab_initio_class_2",
    os.path.join(project_dir, "cryoSPARC", "J519_passthrough_particles_class_3.cs"): "ab_initio_class_3",
    os.path.join(project_dir, "cryoSPARC", "J523_passthrough_particles.cs"): "homogeneous_refinement",
}

particle_diameter = 100 # approximate particle diameter in Angstroms
ugraph_shape = (4000, 4000) # shape of the micrograph in pixels. Only needs to be given if the metadata file is a .star file
verbose = True # prints out progress statements
ignore_missing_files = True # if .mrc files are missing, the analysis will still be performed
enable_tqdm = True # enables tqdm progress bars

for i, meta_file in enumerate(meta_files):
    if i == 0:
        analysis = load_data(meta_file, config_dir, particle_diameter, ugraph_shape=ugraph_shape, verbose=verbose, enable_tqdm=enable_tqdm, ignore_missing_files=ignore_missing_files) # creates the class
    else:
        analysis.add_data(meta_file, config_dir, verbose=verbose) # updates the class with the next metadata file
df_picked = pd.DataFrame(analysis.results_picking)
df_truth = pd.DataFrame(analysis.results_truth)
df_precision, df_picked = analysis.compute_precision(df_picked, df_truth, verbose=verbose)
# p_match, _, p_unmatched, t_unmatched, closest_truth_index = analysis._match_particles(
#     meta_files,
#     df_picked,
#     df_truth,
#     verbose=False,
#     enable_tqdm=True,
# )
# df_precision = analysis.compute_1to1_match_precision(
#     p_match,
#     p_unmatched,
#     t_unmatched,
#     df_truth,
#     verbose=False,
# )
# df_truth["pdb_index"] = df_truth["pdb_filename"].apply(lambda x: int(x.strip(".pdb").split("_")[-1]))
# df_picked["closest_truth_index"] = closest_truth_index
# df_picked["TP"] = df_picked["closest_truth_index"].apply(lambda x: np.isnan(x) == False)
# df_picked["closest_pdb_index"] = df_picked["closest_truth_index"].apply(lambda x: df_truth.loc[x, "pdb_index"] if np.isnan(x) == False else np.nan)
print(f"mean precision: {df_precision['precision'].mean()}")
print(f"mean recall: {df_precision['recall'].mean()}")
df_picked.head(n=10)


In [None]:
### plot boxplot for precision and recall
from importlib import reload
import roodmus.analysis.plot_picking
reload(roodmus.analysis.plot_picking)
from roodmus.analysis.plot_picking import plot_recall, plot_precision

order = []
for r in meta_files:
    if type(r) == str:
        order.append(r)
    else:
        order.append(r[0])  
fig, ax = plot_precision(df_precision, jobtypes, order)
xticklabels = ax.get_xticklabels()
ax.set_xticklabels(xticklabels, fontsize=18)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
ax.set_title("")
ax.set_ylabel("Precision", fontsize=21)
fig.set_size_inches(7, 7)

fig.savefig(os.path.join(figures_dir, "precision.pdf"), bbox_inches="tight")
print(f"saved figure to: {os.path.join(figures_dir, 'precision.pdf')}")


fig, ax = plot_recall(df_precision, jobtypes, order)
xticklabels = ax.get_xticklabels()
ax.set_xticklabels(xticklabels, fontsize=16)
ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels(
    [f"{x:.1f}" for x in ax.get_yticks()],
    fontsize=16,
)
ax.set_title("")
ax.set_ylabel("Recall", fontsize=21)
fig.set_size_inches(7, 7)

fig.savefig(os.path.join(figures_dir, "recall.pdf"), bbox_inches="tight")
print(f"saved figure to: {os.path.join(figures_dir, 'recall.pdf')}")
