# CryoDRGN-ET visualization and analysis

This jupyter notebook provides a template for analyzing cryoDRGN-ET results, including:
* latent space visualization with PCA/UMAP
* clustering
* interactive visualization of the latent space, imaging, and pose parameters
* interactive selection of particle images from the latent space
* interactive generation of volumes from the latent space

Note that this is a simple template for data analysis, and not a polished UI. Experience with Python/Pandas is recommended.

This notebook assumes that the latent variable dimension is > 1 (e.g. multidimensional plotting).

In [None]:
import pandas as pd
import numpy as np
import pickle
import subprocess
import os, sys

from cryodrgn import analysis
from cryodrgn import utils
from cryodrgn import dataset
from cryodrgn import ctf
from cryodrgn import source
import cryodrgn.config
from cryodrgn.dataset import TiltSeriesData

                
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

In [None]:
# Enable interactive widgets
!jupyter nbextension enable --py widgetsnbextension

### Load results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '..' 
EPOCH = None # change me if necessary!

In [None]:
print(os.path.abspath(WORKDIR))

In [None]:
# Load z
with open(f'{WORKDIR}/z.{EPOCH}.pkl','rb') as f:
    z = pickle.load(f)
    z_logvar = pickle.load(f)

In [None]:
# Load UMAP
umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')
# or run UMAP
# umap = analysis.run_umap(z)

In [None]:
# Load kmeans
KMEANS = None
kmeans_labels = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/kmeans{KMEANS}/labels.pkl')
kmeans_centers = np.loadtxt(f'{WORKDIR}/analyze.{EPOCH}/kmeans{KMEANS}/centers.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers, centers_ind = analysis.get_nearest_point(z, kmeans_centers)

### Load dataset

In [None]:
# Load configuration file
config = cryodrgn.config.load(f'{WORKDIR}/config.yaml')
print(config)

In [None]:
# Tilt series data -- Get a representative tilt image for each particle
assert config['model_args']['encode_mode'] == 'tilt'
p_to_t, t_to_p = TiltSeriesData.parse_particle_tilt(config['dataset_args']['particles'])
print(f"Parsing {config['dataset_args']['particles']}")
print(f"Detected {len(t_to_p)} tilt images for {len(p_to_t)} particles")

# Selection for the first tilt of each particle
first_tilt_ind = np.array([pp[0] for pp in p_to_t])

In [None]:
# Load index filter
ind_orig = config['dataset_args']['ind']
if ind_orig is not None:
    print(f"Loading particle selection from {ind_orig}")
    ind_orig = utils.load_pkl(ind_orig)
    ind_orig_tilt = TiltSeriesData.particles_to_tilts(p_to_t, ind_orig)
    print(f"Filtering particles from {len(first_tilt_ind)} to {len(ind_orig)}")
    print(f"Filtering tilt images from {len(t_to_p)} to {len(ind_orig_tilt)}")
    N_orig_particles = len(p_to_t)
    first_tilt_ind = first_tilt_ind[ind_orig]
else:
    ind_orig_tilt = list(t_to_p)

In [None]:
# Load poses
if config['dataset_args']['do_pose_sgd']:
    pose_pkl = f'{WORKDIR}/pose.{EPOCH}.pkl'
    with open(pose_pkl,'rb') as f:
        rot, trans = pickle.load(f)
else:
    pose_pkl = config['dataset_args']['poses']
    rot, trans = utils.load_pkl(pose_pkl)
    
# Convert rotation matrices to euler angles
euler = RR.from_matrix(rot).as_euler('zyz', degrees=True)

# Filter poses to one representative tilt for each particle
rot0 = rot[first_tilt_ind]
trans0 = trans[first_tilt_ind]
euler0 = euler[first_tilt_ind]

In [None]:
# Load input particles -- only load one tilt image per particle ("ind0")
particles = source.ImageSource.from_file(config['dataset_args']['particles'],
                                         lazy=True,
                                         datadir=config['dataset_args']['datadir'],
                                         indices=first_tilt_ind)

In [None]:
# Load CTF 
ctf_params = utils.load_pkl(config['dataset_args']['ctf'])

# Filter CTF to one representative tilt for each particle
ctf_params = ctf_params[first_tilt_ind]

# Print parameters
ctf.print_ctf_params(ctf_params[0])

### Learning curve

In [None]:
loss = analysis.parse_loss(f'{WORKDIR}/run.log')
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.axvline(x=EPOCH, linestyle="--", color="black", label=f"Epoch {EPOCH}")
plt.legend()

### PCA

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1)
g.set_axis_labels('PC1', 'PC2')

In [None]:
try:
    g = sns.jointplot(x=pc[:,0], y=pc[:,1], kind='hex')
except ZeroDivisionError:
    print("Data too small to produce hexbins!")
g.set_axis_labels('PC1', 'PC2')

In [None]:
plt.bar(np.arange(z.shape[1])+1,pca.explained_variance_ratio_)
plt.xticks(np.arange(z.shape[1])+1)
plt.xlabel('PC')
plt.ylabel('explained variance')

### View pose distribution

In [None]:
# rotations
analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])

### View UMAP

In [None]:
g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1)
g.set_axis_labels('UMAP1', 'UMAP2')

In [None]:
try:
    g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex')
except ZeroDivisionError:
    print("Data too small to produce hexbins!")
g.set_axis_labels('UMAP1', 'UMAP2')

### View K-means clusters

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform to view with PCs
analysis.plot_by_cluster(pc[:,0], pc[:,1], K, 
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], K, 
                         kmeans_labels, 
                         centers_ind=centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels)

# Interactive visualization

Interactive visualization of the latent encodings for the trained model. Each point represents a particle image of the dataset. The hover text includes the index of the image in the particle stack. 

### Load into pandas dataframe

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler0, 
                             trans=trans0, 
                             labels=kmeans_labels, 
                             umap=umap,
                             df1=ctf_params[:,2],
                             df2=ctf_params[:,3],
                             dfang=ctf_params[:,4],
                             phase=ctf_params[:,8])
df.head()

In [None]:
pwd

In [None]:
# Annotated points correspond to kmeans cluster centers
widget, fig = analysis.ipy_plot_interactive_annotate(df,centers_ind)
VBox((widget,fig))

# Interactive selection

The next two cells contain helper code to select particles using an interactive lasso tool. 

1. In the first cell, select points with the lasso tool. The table widget is dynamically updated with the most recent selection's indices. 
2. Then once you've finalized your selection, use the next cell to save the particle indices for downstream analysis/viz.

(Double click to clear selection)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0] # save table values
ind_selected = np.array(ind_selected)
ind_selected_not = np.array(sorted(set(np.arange(len(df))) - set(ind_selected)))

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

### Visualize selected subset

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
# Subset of dataframe
df_sub = df.loc[ind_selected]
df_sub_not = df.loc[ind_selected_not]

In [None]:
# View pose distribution
if df_sub.shape[0] >= 10:
    analysis.plot_euler(df_sub.theta, df_sub.phi, df_sub.psi)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df_sub)
VBox((widget,fig,ind_table))

### Save the index selection

The indices for the selected particles may be saved for use in downstream processing in cryoDRGN or with other tools.
Within cryoDRGN, selections are saved a numpy array in `.pkl` file format. Then, the selected indices can be
provided to cryoDRGN with the `--ind` argument to train a new model on a subset of the images. 

Tools are provided in the `utils` subdirectory of the cryoDRGN repo to help convert the index selection to
`.star` file format.

**NOTE:** If there are multiple rounds of index filtering performed on the same particle stack (i.e. your results
come from a training run that already uses an --ind subselection), the index selection must be converted into the
correct indices into the original dataset.

In [None]:
indices_file = os.path.join(
    WORKDIR, f"analyze.{EPOCH}", "tmp_ind_selected.pkl") # RENAME ME

### IMPORTANT: convert index selection to original particles indices if current results
# have already been filtered

if ind_orig is not None:
    ind_selected_orig = analysis.convert_original_indices(
        ind_selected, N_orig_particles, ind_orig)
    utils.save_pkl(ind_selected_orig, indices_file)
else:
    utils.save_pkl(ind_selected, indices_file)

print(f"Saved chosen indices to {os.path.abspath(indices_file)}")

# View particles

View images at selected points in latent space

In [None]:
particle_ind = ind_selected # or set to custom selection

In [None]:
# choose 25 particles to view at random
if len(particle_ind) > 25:
    ind_subset25 = np.random.choice(particle_ind, 25, replace=False)
else: 
    ind_subset25 = particle_ind
print(ind_subset25)

In [None]:
p = [particles[int(ii)][0] for ii in ind_subset25]
analysis.plot_projections(p)
widget, fig = analysis.ipy_plot_interactive_annotate(df, ind_subset25, opacity=.1)
VBox((widget,fig))

# Generate volumes

Generate volumes at selected points in latent space

In [None]:
vol_ind = [] # ADD INDICES HERE
print(vol_ind)

In [None]:
widget, fig = analysis.ipy_plot_interactive_annotate(df, vol_ind, opacity=.1)
VBox((widget,fig))

In [None]:
def get_outdir():
    '''Helper function to get a clean directory to save volumes'''
    for i in range(100000):
        outdir = f'reconstruct_{i:06d}'
        if os.path.exists(outdir): continue
        else: break
    return outdir

def generate_volumes(zvalues, outdir, **kwargs):
    '''Helper function to call cryodrgn eval_vol and generate new volumes'''
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    np.savetxt(f'{outdir}/zfile.txt', zvalues)
    analysis.gen_volumes(f'{WORKDIR}/weights.{EPOCH}.pkl',
                         f'{WORKDIR}/config.yaml',
                         f'{outdir}/zfile.txt',
                         f'{outdir}', **kwargs)
    return FileLinks(f'{outdir}/')

In [None]:
# Get a unique output directory, or define your own
outdir = get_outdir()
print(os.path.abspath(outdir))

In [None]:
# Modify any defaults for volume generation -- see `cryodrgn eval_vol -h` for details 
Apix = 1 # Set to volume pixel size
flip = False # Hand flip?
invert = False # Invert contrast?
downsample = None # Set to smaller box size if desired

generate_volumes(z[vol_ind], outdir, Apix=Apix, flip=flip, downsample=downsample, invert=invert)