# CryoDRGN interactive filtering

This jupyter notebook provides multiple methods for filtering particles based on the latent space, including:
* clustering of the latent space (k-means or Gaussian mixture model)
* outlier detection (Z-score)
* interactive selection with a lasso tool

For each method, the selected particles are tracked in the variable, `ind_selected`.

Once the selection has been finalized, the selected particles are saved as a `index.pkl` file at the end of this notebook. The selected indices can be provided to cryoDRGN with the `--ind` argument to train a new model on a subset of the images or converted to `.star` file format.

For more information, see the tutorial on [Notion page].

In [None]:
import pandas as pd
import numpy as np
import pickle
import subprocess
import os, sys

from cryodrgn import analysis
from cryodrgn import utils
from cryodrgn import dataset
from cryodrgn import ctf
import cryodrgn.config
                
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

In [None]:
# Enable interactive widgets
!jupyter nbextension enable --py widgetsnbextension

### Load results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '..' 
EPOCH = 49 # CHANGE ME

In [None]:
print(os.path.abspath(WORKDIR))

In [None]:
# Load z
with open(f'{WORKDIR}/z.{EPOCH}.pkl','rb') as f:
    z = pickle.load(f)
    z_logvar = pickle.load(f)

In [None]:
# Load UMAP
umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')
# or run UMAP
# umap = analysis.run_umap(z)

In [None]:
# Load kmeans
K = 20
kmeans_labels = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/kmeans{K}/labels.pkl')
kmeans_centers = np.loadtxt(f'{WORKDIR}/analyze.{EPOCH}/kmeans{K}/centers.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers, centers_ind = analysis.get_nearest_point(z, kmeans_centers)

### Define helper functions

In [None]:
def invert_selection(ind_selected):
    return np.array(sorted(set(np.arange(len(z))) - set(ind_selected)))

In [None]:
def combine_selection(ind_sel1, ind_sel2, kind='union'):
    assert kind in ('union','intersection')
    return analysis.combine_ind(len(z), ind_sel1, ind_sel2, kind)

In [None]:
def select_clusters(labels, cluster_ids):
    return analysis.get_ind_for_cluster(labels, cluster_ids)

In [None]:
def save_selection(path, index):
    if ind_orig is not None:
        print('Converting to original .mrcs indicies')
        index = analysis.convert_original_indices(index, N_orig, ind_orig)
    utils.save_pkl(index, path)
    print(f'Wrote {os.path.abspath(path)}')

### Load dataset

In [None]:
# Load configuration file
config = cryodrgn.config.load(f'{WORKDIR}/config.yaml')
print(config)

In [None]:
# Load poses
if config['dataset_args']['do_pose_sgd']:
    pose_pkl = f'{WORKDIR}/pose.{EPOCH}.pkl'
    with open(pose_pkl,'rb') as f:
        rot, trans = pickle.load(f)
else:
    pose_pkl = config['dataset_args']['poses']
    rot, trans = utils.load_pkl(pose_pkl)

In [None]:
# Convert rotation matrices to euler angles
euler = RR.from_matrix(rot).as_euler('zyz', degrees=True)

In [None]:
# Load index filter
ind_orig = config['dataset_args']['ind']
if ind_orig is not None:
    ind_orig = utils.load_pkl(ind_orig)
    if len(rot) > len(ind_orig):
        print(f'Filtering poses from {len(rot)} to {len(ind_orig)}')
        rot = rot[ind_orig]
        trans = trans[ind_orig]
        euler = euler[ind_orig]

In [None]:
# Load input particles
particles = dataset.load_particles(config['dataset_args']['particles'],
                            lazy=True,
                            datadir=config['dataset_args']['datadir'])
N_orig = len(particles)
if ind_orig is not None:
    print(f'Filtering particles from {len(particles)} to {len(ind_orig)}')
    particles = [particles[i] for i in ind_orig]

In [None]:
# Load CTF
ctf_params = utils.load_pkl(config['dataset_args']['ctf'])
if ind_orig is not None:
    print(f'Filtering ctf parameters from {len(ctf_params)} to {len(ind_orig)}')
    ctf_params = ctf_params[ind_orig]
ctf.print_ctf_params(ctf_params[0])

### View pose distribution

In [None]:
# rotations
analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])

In [None]:
# translations
sns.jointplot(x=trans[:,0],
              y=trans[:,1],
              kind='hex').set_axis_labels('tx (fraction)','ty (fraction)')

### Learning curve

In [None]:
loss = analysis.parse_loss(f'{WORKDIR}/run.log')
plt.plot(loss)
plt.xlabel('epoch')
plt.ylabel('loss')

### PCA

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1)
g.set_axis_labels('PC1', 'PC2')

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], kind='hex')
g.set_axis_labels('PC1', 'PC2')

In [None]:
plt.bar(np.arange(z.shape[1])+1,pca.explained_variance_ratio_)
plt.xticks(np.arange(z.shape[1])+1)
plt.xlabel('PC')
plt.ylabel('explained variance')

### UMAP

In [None]:
g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1)
g.set_axis_labels('UMAP1', 'UMAP2')

In [None]:
g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex')
g.set_axis_labels('UMAP1', 'UMAP2')

# Filter by clustering

Select particles based on k-means cluster labels or GMM cluster labels

### View K-means clusters

In [None]:
# Optionally, re-run kmeans with the desired number of classes
#K = 20
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, K)

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform to view with PCs
analysis.plot_by_cluster(pc[:,0], pc[:,1], K, 
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], K, 
                         kmeans_labels, 
                         centers_ind=centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels)

**Select particles based on k-means clustering**

In [None]:
cluster_ids = [0,2] # set clusters to select, 0 and 2 in this example
ind_selected = select_clusters(kmeans_labels, cluster_ids)
ind_selected_not = invert_selection(ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

### GMM-clustering

In [None]:
G = 3 # or change to desired cluster number
random_state = np.random.randint(100000) # sample random integer
print(f'Random state: {random_state}')
gmm_labels, gmm_centers = analysis.cluster_gmm(z, G, random_state=random_state)
gmm_centers, gmm_centers_ind = analysis.get_nearest_point(z, gmm_centers)

In [None]:
analysis.plot_by_cluster(pc[:,0], pc[:,1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], G, gmm_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], G, gmm_labels)

**Select particles based on GMM-means clustering**

In [None]:
cluster_ids = [0,2] # set clusters to select, 0 and 2 in this example
ind_selected = select_clusters(gmm_labels, cluster_ids)
ind_selected_not = invert_selection(ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# Filter outlier particles

In [None]:
# Compute magnitude of latent encodings
znorm = np.sum(z**2,axis=1)**.5

In [None]:
# Distribution of ||z||
sns.distplot(znorm, kde=False)
plt.xlabel('||z||')

In [None]:
# By default, identify particles with ||z|| 2 std deviations above mean
zscore = 2
thresh = znorm.mean()+zscore*znorm.std()
print(f'Mean: {znorm.mean()}, Std: {znorm.std()}, Selected threshold: {thresh}')

In [None]:
ind_outliers = np.where(znorm >= thresh)[0]
ind_outliers_not = invert_selection(ind_outliers)

print('Selected indices:')
print(ind_outliers)
print('Number of selected points:')
print(len(ind_outliers))
print('Number of unselected points:')
print(len(ind_outliers_not))

In [None]:
g = sns.distplot(znorm, kde=False)
plt.axvline(x=thresh)
plt.xlabel('||z||')
plt.title('Magnitude of particle latent encodings')

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_outliers,0], pc[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_outliers,0], umap[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# Assign variables for viz/saving cells at the end of the notebook
ind_selected = ind_outliers
ind_selected_not = ind_outliers_not

# Interactive visualization

Interactive visualization of the latent encodings for the trained model. Each point represents a particle image of the dataset. The hover text includes the index of the image in the particle stack. 

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler, 
                             trans=trans, 
                             labels=kmeans_labels, 
                             umap=umap,
                             df1=ctf_params[:,2],
                             df2=ctf_params[:,3],
                             dfang=ctf_params[:,4],
                             phase=ctf_params[:,8],
                             znorm=znorm)
df.head()

### Interactive selection

The next two cells contain helper code to select particles using an interactive lasso tool. 

1. In the first cell, select points with the lasso tool. The table widget is dynamically updated with the most recent selection's indices. 
2. Then once you've finalized your selection, **run the next cell** to save the particle indices for downstream analysis/viz.

(Double click to clear selection)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0] # save table values
ind_selected = np.array(ind_selected)
ind_selected_not = invert_selection(ind_selected)

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# View particles

View images from selected particles

In [None]:
particle_ind = ind_selected # or set to custom selection

In [None]:
# Choose 9 particles to view at random
if len(particle_ind) > 9:
    ind_subset9 = np.random.choice(particle_ind, 9, replace=False)
else: 
    ind_subset9 = particle_ind
print(ind_subset9)

In [None]:
p = [particles[ii].get() for ii in ind_subset9]
_ = analysis.plot_projections(p, ind_subset9)

plt.figure()
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_subset9,0], umap[ind_subset9,1], color='k')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# Save selection

Save the particle indices for the selected (`ind_selected`) and unselected particles (`ind_selected_not`) as a .pkl file for downstream processing in cryoDRGN or with other tools.

Rename the paths as desired. Note that the indices will be automatically converted if the current cryoDRGN training run has already been filtered (`ind_orig` loaded in an earlier cell).

In [None]:
# Set selection as either the kept or bad particles (for file naming purposes)
ind_keep = ind_selected # or ind_selected_not
ind_bad = ind_selected_not # or ind_selected

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_keep,0], pc[ind_keep,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_keep,0], umap[ind_keep,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
print('Kept particle indices:')
print(ind_keep)
print('Number of kept particles:')
print(len(ind_keep))
print('Number of bad particles:')
print(len(ind_bad))

In [None]:
# Path to save index .pkl for selected particles
SAVE_PATH = f'{WORKDIR}/ind_keep.{len(ind_keep)}_particles.pkl'
save_selection(SAVE_PATH, ind_keep)

In [None]:
# Path to save index .pkl for non-selected particles
SAVE_PATH = f'{WORKDIR}/ind_bad.{len(ind_bad)}_particles.pkl'
save_selection(SAVE_PATH, ind_bad)