# CryoDRGN visualization and analysis

This jupyter notebook provides a template for analyzing cryoDRGN results, including:
* latent space visualization with PCA/UMAP
* clustering of the latent space (k-means or Gaussian mixture model)
* interactive visualization of the latent space, imaging, and pose parameters
* outlier detection (Z-score)
* interactive generation of volumes from the latent space

Note that this is a simple template for data analysis, and not a polished UI. Experience with Python/Pandas is recommended.

This notebook assumes that the latent variable dimension is > 1 (e.g. multidimensional plotting).

For each filtering method, the selected particles are tracked in the variable, `ind_selected`.

Once the selection has been finalized, the selected particles are saved as a `index.pkl` file at the end of
this notebook. The selected indices can be provided to cryoDRGN with the `--ind` argument to train a new
model on a subset of the images or converted to `.star` file format.

For more information, see the tutorial on [Notion page].

In [None]:
import numpy as np
import pickle
import os

from cryodrgn import analysis
from cryodrgn import utils
from cryodrgn import dataset
from cryodrgn import ctf
import cryodrgn.config

import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.transform import Rotation as RR
from ipywidgets import VBox

### Load results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '..' 
EPOCH = None

In [None]:
print(os.path.abspath(WORKDIR))

In [None]:
# Load z
with open(os.path.join(WORKDIR, f"z.{EPOCH}.pkl"), 'rb') as f:
    z = pickle.load(f)

In [None]:
# Load UMAP
umap = utils.load_pkl(os.path.join(WORKDIR, f"analyze.{EPOCH}", "umap.pkl"))
# or run UMAP
# umap = analysis.run_umap(z)

In [None]:
# Load kmeans
KMEANS = None
kmeans_labels = utils.load_pkl(os.path.join(WORKDIR, f"analyze.{EPOCH}",
                                            f"kmeans{KMEANS}", "labels.pkl"))
kmeans_centers = np.loadtxt(os.path.join(WORKDIR, f"analyze.{EPOCH}",
                                         f"kmeans{KMEANS}", "centers.txt"))
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers, centers_ind = analysis.get_nearest_point(z, kmeans_centers)

### Define helper functions

In [None]:
def invert_selection(ind_selected):
    return np.array(sorted(set(np.arange(len(z))) - set(ind_selected)))

In [None]:
def combine_selection(ind_sel1, ind_sel2, kind='union'):
    assert kind in ('union','intersection')
    return analysis.combine_ind(len(z), ind_sel1, ind_sel2, kind)

In [None]:
def select_clusters(labels, cluster_ids):
    return analysis.get_ind_for_cluster(labels, cluster_ids)

In [None]:
def save_selection(path, index):
    if ind_orig is not None:
        print('Converting to original .mrcs indices')
        print(f"{index} -- {N_orig} -- {ind_orig}")
        index = analysis.convert_original_indices(index, N_orig, ind_orig)
    utils.save_pkl(index, path)
    print(f'Wrote {os.path.abspath(path)}')

### Load dataset

In [None]:
# load configuration file
config = cryodrgn.config.load(os.path.join(WORKDIR, "config.yaml"))
print(config)

In [None]:
# load poses
if config["model_args"]["pose_estimation"] in {"fixed", "refine"}:
    pose_pkl = config["dataset_args"]["poses"]
    rot, trans = utils.load_pkl(pose_pkl)

else:
    pose_pkl = os.path.join(WORKDIR, f"pose.{EPOCH}.pkl")
    with open(pose_pkl,'rb') as f:
        rot, trans = pickle.load(f)

In [None]:
# convert rotation matrices to euler angles
euler = RR.from_matrix(rot).as_euler('zyz', degrees=True)

In [None]:
# load index filter
ind_orig = config["dataset_args"]["ind"]
if ind_orig is not None:
    if ind_orig.endswith('.pkl'):
        ind_orig = utils.load_pkl(ind_orig)
    else:
        ind_orig = np.arange(int(ind_orig))

    if len(rot) > len(ind_orig):
        print(f'Filtering poses from {len(rot)} to {len(ind_orig)}')
        rot = rot[ind_orig]
        trans = trans[ind_orig]
        euler = euler[ind_orig]

In [None]:
# load input particles; we can look at the 
particles = dataset.ImageDataset(
    config['dataset_args']['particles'], lazy=True, ind=ind_orig,
    datadir=config['dataset_args']['datadir']
)
N_orig = particles.src.orig_n

# particles object can be filtered manually as well
# (e.g. to retrieve individual particles)

# if ind_orig is not None:
#    print(f'Filtering particles from {len(particles)} to {len(ind_orig)}')
#    particles = [particles[int(i)][0, ...] for i in ind_orig]

In [None]:
# load CTF
ctf_params = utils.load_pkl(config["dataset_args"]["ctf"])
if ind_orig is not None:
    print(f'Filtering ctf parameters from {len(ctf_params)} to {len(ind_orig)}')
    ctf_params = ctf_params[ind_orig]

ctf.print_ctf_params(ctf_params[0])

### View pose distribution

In [None]:
# rotations
try:
    analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])
except ZeroDivisionError:
    print("Data too small to produce plot of rotations!")

In [None]:
# translations
try:
    sns.jointplot(x=trans[:,0], y=trans[:,1], kind='hex').set_axis_labels(
        'tx (fraction)','ty (fraction)')
except ZeroDivisionError:
    print("Data too small to produce plot of translations!")


### Learning curve

In [None]:
loss = analysis.parse_loss(os.path.join(WORKDIR, 'run.log'))
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.axvline(x=EPOCH, linestyle="--", color="black", label=f"Epoch {EPOCH}")
plt.legend()

### PCA

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1)
g.set_axis_labels('PC1', 'PC2')

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], kind='hex')
g.set_axis_labels('PC1', 'PC2')

In [None]:
plt.bar(np.arange(z.shape[1]) + 1, pca.explained_variance_ratio_)
plt.xticks(np.arange(z.shape[1])+1)
plt.xlabel('PC')
plt.ylabel('explained variance')

### UMAP

In [None]:
g = sns.jointplot(x=umap[:, 0], y=umap[:, 1], alpha=.1, s=1)
g.set_axis_labels('UMAP1', 'UMAP2')

In [None]:
g = sns.jointplot(x=umap[:, 0], y=umap[:, 1], kind='hex')
g.set_axis_labels('UMAP1', 'UMAP2')

# Filter by clustering

Select particles based on k-means cluster labels or GMM cluster labels

### View K-means clusters

In [None]:
# Optionally, re-run kmeans with the desired number of classes
#K = 20
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, K)

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform to view with PCs
analysis.plot_by_cluster(pc[:,0], pc[:,1], KMEANS,
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], KMEANS,
                                           kmeans_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], KMEANS,
                         kmeans_labels, 
                         centers_ind=centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], KMEANS,
                                           kmeans_labels)

**Select particles based on k-means clustering**

In [None]:
cluster_ids = [0,2] # set clusters to select, 0 and 2 in this example
ind_selected = select_clusters(kmeans_labels, cluster_ids)
ind_selected_not = invert_selection(ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:, 0], pc[:, 1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:, 0], umap[:, 1], alpha=.1, s=1)
plt.scatter(umap[ind_selected, 0], umap[ind_selected, 1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

### GMM-clustering

In [None]:
G = 3 # or change to desired cluster number
random_state = np.random.randint(100000) # sample random integer
print(f'Random state: {random_state}')
gmm_labels, gmm_centers = analysis.cluster_gmm(z, G, random_state=random_state)
gmm_centers, gmm_centers_ind = analysis.get_nearest_point(z, gmm_centers)

In [None]:
analysis.plot_by_cluster(pc[:, 0], pc[:, 1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:, 0], pc[:, 1], G, gmm_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], G, 
                         gmm_labels, 
                         centers_ind=gmm_centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], G, gmm_labels)

**Select particles based on GMM-means clustering**

In [None]:
cluster_ids = [0,2] # set clusters to select, 0 and 2 in this example
ind_selected = select_clusters(gmm_labels, cluster_ids)
ind_selected_not = invert_selection(ind_selected)
print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# Filter outlier particles

In [None]:
# Compute magnitude of latent encodings
znorm = np.sum(z**2,axis=1)**.5

In [None]:
# Distribution of ||z||
sns.distplot(znorm, kde=False)
plt.xlabel('||z||')

In [None]:
# By default, identify particles with ||z|| 2 std deviations above mean
zscore = 2
thresh = znorm.mean()+zscore*znorm.std()
print(f'Mean: {znorm.mean()}, Std: {znorm.std()}, Selected threshold: {thresh}')

In [None]:
ind_outliers = np.where(znorm >= thresh)[0]
ind_outliers_not = invert_selection(ind_outliers)

print('Selected indices:')
print(ind_outliers)
print('Number of selected points:')
print(len(ind_outliers))
print('Number of unselected points:')
print(len(ind_outliers_not))

In [None]:
g = sns.distplot(znorm, kde=False)
plt.axvline(x=thresh)
plt.xlabel('||z||')
plt.title('Magnitude of particle latent encodings')

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_outliers,0], pc[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_outliers,0], umap[ind_outliers,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title(f'Particles with ||z|| > {thresh}')

In [None]:
# Assign variables for viz/saving cells at the end of the notebook
ind_selected = ind_outliers
ind_selected_not = ind_outliers_not

# Interactive visualization

Interactive visualization of the latent encodings for the trained model. Each
point represents a particle image of the dataset. The hover text includes the
index of the image in the particle stack. 

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler, 
                             trans=trans, 
                             labels=kmeans_labels, 
                             umap=umap,
                             df1=ctf_params[:,2],
                             df2=ctf_params[:,3],
                             dfang=ctf_params[:,4],
                             phase=ctf_params[:,8],
                             znorm=znorm)
df.head()

### Interactive selection

The next two cells contain helper code to select particles using an interactive
lasso tool. 

1. In the first cell, select points with the lasso tool. The table widget is
dynamically updated with the most recent selection's indices. 
2. Then once you've finalized your selection, **run the next cell** to save the
particle indices for downstream analysis/viz.

(Double click to clear selection)

You can also use our interactive command line tool `cryodrgn filter $WORKDIR`
for selecting particles.

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0] # save table values
ind_selected = np.array(ind_selected)
ind_selected_not = invert_selection(ind_selected)

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# UMAP/PC selection

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler, 
                             trans=trans, 
                             labels=kmeans_labels, 
                             umap=umap,
                             df1=ctf_params[:,2],
                             df2=ctf_params[:,3],
                             dfang=ctf_params[:,4],
                             phase=ctf_params[:,8])
df.head()

### Selection by UMAP/PC values

In the next cell, you can select different indexes using UMAP or PC values.
Change the values in the selection, and add more selections if necessary. The
default is UMAP1 and UMAP2, you can change that by changing the 'UMAP1' by your
desired field (e.g. PC1).

In [None]:
# 1 selection
ind_selected1 = df.index[(df['UMAP1'] >= -5) & (df['UMAP1'] <= 5) & (df['UMAP2'] >= -5) & (df['UMAP2'] <= 5)]
ind_selected1 = np.array(ind_selected1)
ind_selected = ind_selected1
# 2 selections
#ind_selected2 = df.index[(df['UMAP1'] >= -5) & (df['UMAP1'] <= 5) & (df['UMAP2'] >= -5) & (df['UMAP2'] <= 5)]
#ind_selected2 = np.array(ind_selected2)
#ind_selected = np.append(ind_selected1, ind_selected2)
#ind_selected = np.unique(ind_selected)
# 3 selections
#ind_selected3 = df.index[(df['UMAP1'] >= -5) & (df['UMAP1'] <= 5) & (df['UMAP2'] >= -5) & (df['UMAP2'] <= 5)]
#ind_selected3 = np.array(ind_selected3)
#ind_selected = np.append(ind_selected, ind_selected3)
#ind_selected = np.unique(ind_selected)

ind_selected_not = invert_selection(ind_selected)

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# View particles

View images from selected particles

In [None]:
# or set to custom selection
n = len(particles)
particle_ind = list(ind_selected) or np.random.choice(range(n), min(n, 9), replace=False)


In [None]:
# Choose 9 particles to view at random
if len(particle_ind) > 9:
    ind_subset9 = np.random.choice(particle_ind, 9, replace=False)
else: 
    ind_subset9 = particle_ind

print(ind_subset9)

In [None]:
p = [particles[int(ii)]["y"][0, ...] for ii in ind_subset9]
_ = analysis.plot_projections(p, ind_subset9)

plt.figure()
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_subset9,0], umap[ind_subset9,1], color='k')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

# Save selection

Save the particle indices for the selected (`ind_selected`) and unselected particles (`ind_selected_not`) as a .pkl file for downstream processing in cryoDRGN or with other tools.

Rename the paths as desired. Note that the indices will be automatically converted if the current cryoDRGN training run has already been filtered (`ind_orig` loaded in an earlier cell).

In [None]:
# Set selection as either the kept or bad particles (for file naming purposes)
ind_keep = ind_selected # or ind_selected_not
ind_bad = ind_selected_not # or ind_selected

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_keep,0], pc[ind_keep,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View UMAP
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_keep,0], umap[ind_keep,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
print('Kept particle indices:')
print(ind_keep)
print('Number of kept particles:')
print(len(ind_keep))
print('Number of bad particles:')
print(len(ind_bad))

In [None]:
if len(ind_keep):
    # Path to save index .pkl for selected particles
    SAVE_PATH = f'{WORKDIR}/ind_keep.{len(ind_keep)}_particles.pkl'
    save_selection(SAVE_PATH, ind_keep)

In [None]:
if len(ind_bad):
    # Path to save index .pkl for non-selected particles
    SAVE_PATH = f'{WORKDIR}/ind_bad.{len(ind_bad)}_particles.pkl'
    save_selection(SAVE_PATH, ind_bad)