# tomoDRGN interactive visualization and filtering

### Filtering functionality
This jupyter notebook provides multiple methods for filtering particles based on the latent space, including:
* clustering of the latent space (k-means or Gaussian mixture model)
* outlier detection (Z-score)
* interactive selection with a lasso tool

For each method, the selected particles are tracked in the variable, `ind_selected`. Once the selection has been finalized, the selected particles are saved as a `ind_keep.pkl` file at the end of this notebook. The `ind_keep.pkl` can be used with `tomodrgn train_vae` to train a new model on a subset of the images, or used with `tomodrgn filter_star` to filter the original input `.star` file for further processing wtih external tools, both via the `--ind ind_keep.pkl` optional argument.

Note that indices in `ind_selected` and `ind_keep.pkl` are 0-indexed, are per-particle (not per tilt image), and are numbered sequentially by unique values present in the `_rlnGroupName` column of the original input `.star` file.

### Visualization functionality
Several static plots are generated in this notebook, including summaries of the input data (pose distribution) and of the training data (PCA, UMAP, and loss curve).
Additional interactive plots are also created, with the primary intention of aiding users in uncovering correlations and interesting particle subsets for further analyses and structural hypothesis generation:
* interactive 2D scatter plot per-particle
    * axes and colormaps selectable from all columns from input star file + all training outputs (latent z, UMAP, and filtering tools described above) + optional tomogram XYZ positions from a separate star file + user-defined additions to `df_merged`
* interactive 3D quiver plot per-particle in the source tomogram spatial context
    * axes defined by particle XYZ coordinates in each source tomogram
    * optional overlay of tomogram voxel data in voxel or z-slice view
    * particle colormaps and sub-selection tools from all columns from input star file + all training outputs (latent z, UMAP, and filtering tools described above) + optional tomogram XYZ positions from a separate star file + user-defined additions to `df_merged`

### Requirements
For full functionality, the following files are required. Additionally, locations marked by `USER INPUT` (below) require the user to specify certain properties about these files:
* tomoDRGN train_vae training input (`particles_imageseries.star`)
* tomoDRGN train_vae training outputs (`z.pkl`, `config.pkl`, `run.log`)
* tomoDRGN analyze outputs (`umap.pkl`, `labels.pkl`, `centers.txt`)
* reconstructed tomograms, preferably denoised or deconvolved (`*.mrc`)
* a star file from Warp or M's "export subtomograms as volumes" tomography task dialog (`particles_volumeseries.star`)

In [None]:
import pandas as pd
import numpy as np
import pickle
import os
import pprint

from tomodrgn import analysis
from tomodrgn import utils
from tomodrgn import dataset
from tomodrgn import ctf
from tomodrgn import starfile
                
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.offline as py
from ipywidgets import VBox
py.init_notebook_mode()

In [None]:
# Enable interactive widgets
!jupyter nbextension enable --py widgetsnbextension

### Load training results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '..' 
EPOCH = 49 # USER INPUT

In [None]:
print(os.path.abspath(WORKDIR))

In [None]:
# Load z
with open(f'{WORKDIR}/z.{EPOCH}.pkl','rb') as f:
    z = pickle.load(f)
    z_logvar = pickle.load(f)

In [None]:
# Load UMAP
umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')
# or run UMAP
# umap = analysis.run_umap(z)

In [None]:
# Load kmeans
K = 100  # or user defined if re-running kmeans
kmeans_labels = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/kmeans{K}/labels.pkl')
kmeans_centers = np.loadtxt(f'{WORKDIR}/analyze.{EPOCH}/kmeans{K}/centers.txt')

# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers, centers_ind = analysis.get_nearest_point(z, kmeans_centers)

### Load input dataset

In [None]:
# Load configuration file
config = utils.load_pkl(f'{WORKDIR}/config.pkl')
pprint.pprint(config)

In [None]:
# Load particles starfile
ptcls_star = starfile.TiltSeriesStarfile.load(config['dataset_args']['particles'])

In [None]:
# Define useful variables
ptcls_unique_list = ptcls_star.df['_rlnGroupName'].unique().astype(str)

df_grouped = ptcls_star.df.groupby('_rlnGroupName', sort=False)
ind_imgs = np.array([df_grouped.get_group(ptcl).index.to_numpy() for ptcl in df_grouped.groups], dtype=object)
ind_ptcls = np.arange(len(ptcls_unique_list))

n_ptcls = len(ptcls_unique_list)
n_imgs = len(ptcls_star.df)

# Create per-particle dataframe sampling only first image of each particle
ind_imgs_first = np.array([i[0] for i in ind_imgs])
df_ptcls = ptcls_star.df.iloc[ind_imgs_first, :].reset_index(drop=True)
if '_rlnMicrographName' in df_ptcls.columns: df_ptcls.drop('_rlnMicrographName', axis=1, inplace=True)

In [None]:
# Load ptcl index filter and apply to dataframes
ind_ptcls_pkl = config['dataset_args']['ind']
if ind_ptcls_pkl is not None:
    ind_ptcls = np.array(utils.load_pkl(ind_ptcls_pkl))
    print(f'Filtering particles from {n_ptcls} to {len(ind_ptcls)}')

    ind_imgs = ind_imgs[ind_ptcls]
    ptcls_unique_list = ptcls_unique_list[ind_ptcls]
    ptcls_star.df = ptcls_star.df.iloc[ind_imgs.flatten()]
    df_ptcls = df_ptcls.iloc[ind_ptcls]
    
    n_ptcls = len(ind_ptcls)
    n_imgs = len(ind_imgs)

else:
    print('No indices found in config.pkl; not filtering star file')

In [None]:
# Load poses (from pre-filtered dataframe)
rots_columns = ['_rlnAngleRot', '_rlnAngleTilt', '_rlnAnglePsi']
euler = ptcls_star.df[rots_columns].to_numpy(dtype=np.float32)

trans_columns = ['_rlnOriginX', '_rlnOriginY']
if np.all([trans_column in ptcls_star.headers for trans_column in trans_columns]):
    trans = ptcls_star.df[trans_columns].to_numpy(dtype=np.float32)
else: trans = np.zeros((n_imgs, 2))

In [None]:
# Load CTF (from pre-filtered dataframe)
ctf_columns = ['_rlnDetectorPixelSize','_rlnDefocusU', '_rlnDefocusV', '_rlnDefocusAngle', '_rlnVoltage', '_rlnSphericalAberration',
         '_rlnAmplitudeContrast', '_rlnPhaseShift']
if np.all([ctf_column in ptcls_star.headers for ctf_column in ctf_columns]):
    box_size = config['lattice_args']['D'] - 1
    ctf_params = ptcls_star.df[ctf_columns].to_numpy(dtype=np.float32)
    box_size = np.ones((ctf_params.shape[0],1)) * box_size
    ctf_params = np.concatenate((box_size, ctf_params), 1)
else:
    ctf_params = np.zeros((n_imgs, 9))

ctf.print_ctf_params(ctf_params[0])

In [None]:
# Optionally load a RELION3.0 volumeseries star file from Warp/M (to get particle positions within tomograms)
# Starfile must reference the same set of particles referenced by the starfile used for tomodrgn train_vae

### USER INPUT: metadata to find particle positions, merge with training df, and rescale to tomogram pixel size
path_to_volseries_star = ''        # absolute path to volume series star file
star_from_M = False                 # True if star file from M; False if star file from Warp subtomogram volumeseries
tomo_max_xyz_nm = (680, 680, 510)  # dimensions in nm of reconstructed tomograms for later interactive visualization
tomo_pixelsize = 10                # pixel size of reconstructed tomogram in A/px
starfile_pixelsize = 6             # pixel size in volumeseries star file, CoordinateX,Y columns, in A/px

# load and filter star file
volseries_star = starfile.GenericStarfile(path_to_volseries_star)
volseries_df = volseries_star.blocks['data_'].copy()
if ind_ptcls is not None:
    volseries_df = volseries_df[volseries_df.index.isin(ind_ptcls)]

# Assign unique, sequential ID numbers to each tomogram for later visualization
if star_from_M: tomo_id_col = '_wrpSourceHash'
else: tomo_id_col = '_rlnMicrographName'
ind_tomo = {tomo_name : tomo_index for tomo_index, tomo_name in enumerate(volseries_df[tomo_id_col].unique())}
volseries_df['ind_tomo'] = [ind_tomo[row] for row in volseries_df[tomo_id_col]]

# rescale df xyz coordinates to match reconstructed tomograms pixel size
volseries_df = analysis.rescale_df_coordinates(volseries_df,
                                               tomo_max_xyz_nm=tomo_max_xyz_nm, 
                                               tomo_pixelsize=tomo_pixelsize, 
                                               starfile_pixelsize=starfile_pixelsize)

# filter out irrelevant columns
cols_to_keep = ['_rlnCoordinateX', '_rlnCoordinateY', '_rlnCoordinateZ', '_rlnAngleRot', '_rlnAngleTilt', '_rlnAnglePsi', '_rlnMicrographName', '_wrpSourceHash', 'ind_tomo']
for col in volseries_df.columns:
    if col not in cols_to_keep:
        volseries_df.drop(col, axis=1, inplace=True)

# view the result
volseries_df

### Define helper functions

In [None]:
def list_paths_by_extension(basedir, extension):
    file_list = [os.path.join(basedir, file) for file in os.listdir(basedir) if file.endswith(extension)]
    return sorted(file_list)

In [None]:
def invert_selection(all_labels, ind_selected):
    return np.array(sorted(set(np.arange(len(all_labels))) - set(ind_selected)))

In [None]:
def combine_selection(ind_sel1, ind_sel2, kind='union'):
    assert kind in ('union','intersection')
    return analysis.combine_ind(len(z), ind_sel1, ind_sel2, kind)

In [None]:
def select_clusters(labels, cluster_ids):
    return analysis.get_ind_for_cluster(labels, cluster_ids)

# Plot dataset / model properties

### View pose distribution

In [None]:
# rotations
analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])

In [None]:
# translations
# set near-zero translations to 0.0 to allow sns.jointplot to work
trans[np.isclose(trans, 0.0, atol=1e-4)] = 0.0
sns.jointplot(x=trans[:,0],
              y=trans[:,1],
              kind='hex').set_axis_labels('tx (px)','ty (px)')

### Learning curve

In [None]:
loss = analysis.parse_loss(f'{WORKDIR}/run.log')
plt.plot(loss)
plt.xlabel('epoch')
plt.ylabel('loss')

### PCA

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1)
g.set_axis_labels('PC1', 'PC2')

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], kind='hex')
g.set_axis_labels('PC1', 'PC2')

In [None]:
plt.bar(np.arange(z.shape[1])+1,pca.explained_variance_ratio_)
plt.xticks(np.arange(z.shape[1])+1)
plt.xlabel('PC')
plt.ylabel('explained variance')

### UMAP

In [None]:
g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1)
g.set_axis_labels('UMAP1', 'UMAP2')

In [None]:
g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex')
g.set_axis_labels('UMAP1', 'UMAP2')

# Filter by clustering

Select particles based on k-means cluster labels or GMM cluster labels

### View K-means clusters

In [None]:
# Optionally, re-run kmeans with the desired number of classes
#K = 20  # USER INPUT: optionally change to desired cluster number
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, K)

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform to view with PCs
analysis.plot_by_cluster(pc[:,0], pc[:,1], K, 
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], K, 
                         kmeans_labels, 
                         centers_ind=centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels)

**Select particles based on k-means clustering**

In [None]:
cluster_ids = [0,2] # USER INPUT: integer I.D. of clusters to select, 0 and 2 in this example
ind_selected = select_clusters(kmeans_labels, cluster_ids)
ind_selected_not = invert_selection(kmeans_labels, ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

### GMM-clustering

In [None]:
G = 3 # USER INPUT: or change to desired cluster number
random_state = np.random.randint(100000) # sample random integer
print(f'Random state: {random_state}')
gmm_labels, gmm_centers = analysis.cluster_gmm(z, G, random_state=random_state)
gmm_centers, gmm_centers_ind = analysis.get_nearest_point(z, gmm_centers)

In [None]:
analysis.plot_by_cluster(pc[:,0], pc[:,1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], G, gmm_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], G, gmm_labels)

**Select particles based on GMM-means clustering**

In [None]:
cluster_ids = [0,2] # USER INPUT: integer I.D. of clusters to select, 0 and 2 in this example
ind_selected = select_clusters(gmm_labels, cluster_ids)
ind_selected_not = invert_selection(gmm_labels, ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# Filter by latent outliers

In [None]:
# Compute magnitude of latent encodings
znorm = np.sum(z**2,axis=1)**.5

In [None]:
# Distribution of ||z||
sns.histplot(znorm, kde=False)
plt.xlabel('||z||')

In [None]:
# By default, identify particles with ||z|| 2 std deviations above mean
zscore = 2
thresh = znorm.mean()+zscore*znorm.std()
print(f'Mean: {znorm.mean()}, Std: {znorm.std()}, Selected threshold: {thresh}')

In [None]:
ind_outliers = np.where(znorm >= thresh)[0]
ind_outliers_not = invert_selection(znorm, ind_outliers)

print('Selected indices:')
print(ind_outliers)
print('Number of selected points:')
print(len(ind_outliers))
print('Number of unselected points:')
print(len(ind_outliers_not))

In [None]:
g = sns.histplot(znorm, kde=False)
plt.axvline(x=thresh)
plt.xlabel('||z||')
plt.title('Magnitude of particle latent encodings')

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_outliers,0], pc[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_outliers,0], umap[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# Assign variables for viz/saving cells at the end of the notebook
ind_selected = ind_outliers
ind_selected_not = ind_outliers_not

# Interactive visualization

Interactive visualization of the latent encodings for the trained model. Each point represents a particle image of the dataset. The hover text includes the index of the image in the particle stack. 

In [None]:
# Load training output data into a pandas dataframe on a per-particle basis
df_train = analysis.load_dataframe(z=z,
                                   pc=pc,
                                   umap=umap,
                                   kmeans_labels=kmeans_labels,
                                   gmm_labels=gmm_labels,
                                   znorm=znorm)
df_train.drop('index', axis=1, inplace=True)
df_train['index_all_tomos_particles'] = ind_ptcls

In [None]:
# Merge dataset and training dataframes (and volseries dataframe, if present)
df_merged = pd.DataFrame(np.column_stack([df_train, df_ptcls]),
                         columns=df_train.columns.append(df_ptcls.columns))

if 'volseries_df' in locals():
    # remove duplicate cols from volseries_df and  dataset (image series) df
    for col in df_merged.columns:
        if col in volseries_df.columns:
            df_merged.drop(col, axis=1, inplace=True)

    df_merged = pd.DataFrame(np.column_stack([df_merged, volseries_df]),
                         columns=df_merged.columns.append(volseries_df.columns))

In [None]:
# Coerce particle dataframe from starfile to appropriate dtypes
df_merged = starfile.guess_dtypes(df_merged)

pd.set_option('display.max_columns', None)
df_merged

### Interactive selection

The next two cells contain helper code to select particles using an interactive lasso tool. 

1. In the first cell, select points with the lasso tool. The table widget is dynamically updated with the most recent selection's indices. 
2. Then once you've finalized your selection, **run the next cell** to save the particle indices for downstream analysis/viz.

(Double click to clear selection)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df_merged.select_dtypes(include=np.number), opacity=0.8)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0] # save table values
ind_selected = np.array(ind_selected)
ind_selected_not = invert_selection(range(len(df_train)), ind_selected)

df_merged.loc[:,'ind_selected'] = 0
df_merged.loc[ind_selected, 'ind_selected'] = 1

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# View particles

## View tilt images from selected particles

In [None]:
# lazily load particle images and filter by ind.pkl, if applicable
images = dataset.load_particles(config['dataset_args']['particles'],
                                lazy=True,
                                datadir=config['dataset_args']['datadir'])

In [None]:
# choose 9 particles to view from ind_selected particles
if len(ind_selected) > 9:
    ind_subset9 = np.random.choice(ind_selected, 9, replace=False)
else: 
    ind_subset9 = ind_selected
print(ind_subset9)

In [None]:
# plot the first tilt image of each ind_subset9 particle

ind_subset9_imgs = np.array([i[0] for i in ind_imgs[ind_subset9]])

p = [images[ii].get() for ii in ind_subset9_imgs]
_ = analysis.plot_projections(p, ind_subset9)

plt.figure()
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_subset9,0], umap[ind_subset9,1], color='k')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

## View particle distributions in tomogram context
Experimental feature to interactively explore particle distributions in 3D tomogram context. Benefits from optional superposition of tomogram data in z-slice or voxel view, and coloring or selecting particles by any numeric property in df_merged. Requires ipyvolume to be installed, and requires this notebook to be opened in Jupyter Notebook (not Jupyter Lab).

In [None]:
required_cols_for_tomogram_viz = ['_rlnCoordinateX',
                                  '_rlnCoordinateY',
                                  '_rlnCoordinateZ',
                                  '_rlnAngleRot',
                                  '_rlnAngleTilt',
                                  '_rlnAnglePsi',
                                  'index_all_tomos_particles']
assert np.all([col in df_merged.columns for col in required_cols_for_tomogram_viz])

In [None]:
# Define list of absolute paths to all reconstructed tomograms used in this analysis
# USER INPUT: absolute path to folder containing (preferably deconvolved or denoised) tomograms
path_to_tomograms = ''
tomogram_extension = ''

tomo_list = list_paths_by_extension(path_to_tomograms, tomogram_extension)
tomo_list

In [None]:
# Define dictionary mapping tomogram file name on disk (as in `tomo_list` above, typically $TOMOGRAM.mrc) 
#     to the name in the input starfile under _rlnMicrographName header (typically $TOMOGRAM.tomostar)
# USER INPUT: provide tomogram.mrc : tomogram.tomostar mappings
tomo_star_mappings = {f'{i:05d}_10.00Apx.mrc' : f'{i:05d}.tomostar' for i in range(254, 320)} 
tomo_star_mappings

In [None]:
analysis.interactive_tomo_ptcls(df_merged, tomo_list, tomo_star_mappings)

# Save selection indices

Save the particle indices for the selected (`ind_selected`) and unselected particles (`ind_selected_not`) as a .pkl file for downstream processing in tomoDRGN or with other tools via `filter_star.py`.

In [None]:
# Set selection as either the kept or bad particles (for file naming purposes)
ind_keep = ind_selected # or ind_selected_not
ind_bad = ind_selected_not # or ind_selected

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_keep,0], pc[ind_keep,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_keep,0], umap[ind_keep,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
# reindex particle indices to match (potential) prior `--ind` usage when training model
ind_keep = df_merged.iloc[ind_keep]['index_all_tomos_particles'].to_numpy(dtype=int)
ind_bad = df_merged.iloc[ind_bad]['index_all_tomos_particles'].to_numpy(dtype=int)

print('Kept particle indices:')
print(ind_keep)
print('Number of kept particles:')
print(len(ind_keep))
print('Number of bad particles:')
print(len(ind_bad))

In [None]:
# Path to save index .pkl for selected particles
SAVE_PATH = f'{WORKDIR}/ind_keep.{len(ind_keep)}_particles.pkl'
utils.save_pkl(ind_keep, SAVE_PATH)
print(f'Wrote {os.path.abspath(SAVE_PATH)}')

In [None]:
# Path to save index .pkl for non-selected particles
SAVE_PATH = f'{WORKDIR}/ind_bad.{len(ind_bad)}_particles.pkl'
utils.save_pkl(ind_keep, SAVE_PATH)
print(f'Wrote {os.path.abspath(SAVE_PATH)}')

# Save selection latent coordinates
Save latent coordinates for selected (ind_selected) particles as `z.selected.pkl`, optionally further separated by tomogram. Useful for volume generation of each `ind_selected` particle via `tomodrgn eval_vol`

In [None]:
# Set selection as ind_selected (subset) or all particles
ind_keep = ind_selected # or all particles: df_merged.index

In [None]:
# List all tomograms specified in star files
tomo_ids = (df_merged['_rlnGroupName'].str.split('_').str[0] + '_').unique()

print(f'Unique tomogram identifiers in _rlnGroupName column: {tomo_ids}')

In [None]:
# save selected particles' latent coordinates
# USER INPUT: True to write one z file per tomogram, False to write one z file for the whole dataset
separate_zfiles_by_tomogram = False  

z_cols = [f'z{i}' for i in range(z.shape[1])]
if separate_zfiles_by_tomogram:
    for tomo in tomo_ids:
        SAVE_PATH = f'{WORKDIR}/z_keep.{len(ind_keep)}.{tomo}.pkl'
        df_sub = df_merged[df_merged['_rlnGroupName'].str.contains(tomo)]
        ind_keep_sub = np.array([i for i in ind_keep if i in df_sub.index])
        z_out = df_sub[z_cols].loc[ind_keep_sub]
        utils.save_pkl(z_out, SAVE_PATH)
        print(f'Wrote {os.path.abspath(SAVE_PATH)}')

else:
    SAVE_PATH = f'{WORKDIR}/z_keep.{len(ind_keep)}.pkl'
    z_out = df_merged[z_cols].iloc[ind_keep]
    utils.save_pkl(z_out, SAVE_PATH)
    print(f'Wrote {os.path.abspath(SAVE_PATH)}')