# CryoDRGN visualization and analysis

This jupyter notebook provides a template for analyzing cryoDRGN results, including:
* latent space visualization with PCA/UMAP
* clustering
* interactive visualization of the latent space, imaging, and pose parameters
* interactive selection of particle images from the latent space
* interactive generation of volumes from the latent space

Note that this is a simple template for data analysis, and not a polished UI. Experience with Python/Pandas is recommended.

This notebook assumes that the latent variable dimension is > 1 (e.g. multidimensional plotting).

In [None]:
import pandas as pd
import numpy as np
import pickle
import subprocess
import os, sys
sys.path.insert(0,f'{os.environ["CDRGN_SRC"]}/lib-python')
import analysis
import utils
import dataset
import ctf
                
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

### Load results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '.' 
EPOCH = 49 # CHANGE ME

In [None]:
print(os.path.abspath(WORKDIR))

In [None]:
# Load z
with open(f'{WORKDIR}/z.{EPOCH}.pkl','rb') as f:
    z = pickle.load(f)
    z_logvar = pickle.load(f)

In [None]:
# Load UMAP
umap = utils.load_pkl(f'{WORKDIR}/umap.{EPOCH}.pkl')
# or run UMAP
# umap = analysis.run_umap(z)

In [None]:
# Load kmeans
kmeans_labels = utils.load_pkl(f'{WORKDIR}/kmeans.labels.{EPOCH}.pkl')
kmeans_centers = np.loadtxt(f'{WORKDIR}/kmeans.centers.{EPOCH}.txt')
# Or re-run kmeans with the desired number of classes
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)

# Get index for on-data cluster center
kmeans_centers, centers_ind = analysis.get_nearest_point(z, kmeans_centers)

### Load dataset

In [None]:
# Load configuration file
config = utils.load_pkl(f'{WORKDIR}/config.pkl')
print(config)

In [None]:
# Load poses
if config['dataset_args']['do_pose_sgd']:
    pose_pkl = f'{WORKDIR}/pose.{EPOCH}.pkl'
    with open(pose_pkl,'rb') as f:
        rot, trans = pickle.load(f)
else:
    pose_pkl = config['dataset_args']['poses']
    rot, trans = utils.load_pkl(pose_pkl)

In [None]:
# Convert rotation matrices to euler angles
euler = RR.from_dcm(rot).as_euler('zyz', degrees=True)

In [None]:
# Load index filter
ind = config['dataset_args']['ind']
if ind is not None:
    ind = utils.load_pkl(ind)
    if len(rot) > len(ind):
        print(f'Filtering poses from {len(rot)} to {len(ind)}')
        rot = rot[ind]
        trans = trans[ind]
        euler = euler[ind]

In [None]:
# Load input particles
particles = dataset.load_particles(config['dataset_args']['particles'],
                            lazy=True,
                            datadir=config['dataset_args']['datadir'])
if ind is not None:
    print(f'Filtering particles from {len(particles)} to {len(ind)}')
    particles = [particles[i] for i in ind]

In [None]:
# Load CTF
ctf_params = utils.load_pkl(config['dataset_args']['ctf'])
if ind is not None:
    print(f'Filtering ctf parameters from {len(ctf_params)} to {len(ind)}')
    ctf_params = ctf_params[ind]
ctf.print_ctf_params(ctf_params[0])

### PCA

In [None]:
pc, pca = analysis.run_pca(z)
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

### View pose distribution

In [None]:
# rotations
analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])

In [None]:
# translations
sns.jointplot(trans[:,0],trans[:,1],
              kind='hex').set_axis_labels('tx','ty')

### View UMAP

In [None]:
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

### View K-means clusters

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform to view with PCs
analysis.plot_by_cluster(pc[:,0], pc[:,1], K, 
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels)

In [None]:
analysis.plot_by_cluster(umap[:,0], umap[:,1], K, 
                         kmeans_labels, 
                         centers_ind=centers_ind,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels)

# Interactive visualization

Interactive visualization of the latent encodings for the trained model. Each point represents a particle image of the dataset. The hover text includes the index of the image in the particle stack. 

### Load into pandas dataframe

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler, 
                             trans=trans, 
                             labels=kmeans_labels, 
                             umap=umap,
                             df1=ctf_params[:,2],
                             df2=ctf_params[:,3],
                             dfang=ctf_params[:,4],
                             phase=ctf_params[:,8])
df.head()

In [None]:
# Annotated points correspond to kmeans cluster centers
widget, fig = analysis.ipy_plot_interactive_annotate(df,centers_ind)
VBox((widget,fig))

# Interactive selection

The next two cells contain helper code to select particles using an interactive lasso tool. 

1. In the first cell, select points with the lasso tool. The table widget is dynamically updated with the most recent selection's indices. 
2. Then once you've finalized your selection, use the next cell to save the particle indices for downstream analysis/viz.

(Double click to clear selection)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0] # save table values
ind_selected = np.array(ind_selected)
ind_selected_not = np.array(sorted(set(np.arange(len(df))) - set(ind_selected)))

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

In [None]:
# Optionally, save the selected indices to a .pkl file 
#SAVE_PATH = f'{WORKDIR}/ind_selected.pkl' # CHANGE ME
#with open(SAVE_PATH, 'wb') as f:
#    pickle.dump(ind_selected, f)

### Visualize selected subset

In [None]:
# View PCA
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.scatter(pc[ind_selected,0], pc[ind_selected,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
# View umap
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.scatter(umap[ind_selected,0], umap[ind_selected,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
# Subset of dataframe
df_sub = df.loc[ind_selected]
df_sub_not = df.loc[ind_selected_not]

In [None]:
# View pose distribution
analysis.plot_euler(df_sub.theta, df_sub.phi, df_sub.psi)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df_sub)
VBox((widget,fig,ind_table))

# View particles

View images at selected points in latent space

In [None]:
particle_ind = ind_selected # or set to custom selection

In [None]:
# choose 9 particles to view at random
if len(particle_ind) > 9:
    ind_subset9 = np.random.choice(particle_ind, 9, replace=False)
else: 
    ind_subset9 = particle_ind
print(ind_subset9)

In [None]:
p = [particles[ii].get() for ii in ind_subset9]
analysis.plot_projections(p, ind_subset9)
widget, fig = analysis.ipy_plot_interactive_annotate(df,ind_subset9,opacity=.1)
VBox((widget,fig))

# Generate volumes

Generate volumes at selected points in latent space

In [None]:
vol_ind = [] # ADD INDICES HERE
print(vol_ind)

In [None]:
widget, fig = analysis.ipy_plot_interactive_annotate(df,vol_ind, opacity=.1)
VBox((widget,fig))

In [None]:
def get_outdir():
    '''Helper function to get a clean directory to save volumes'''
    for i in range(100000):
        outdir = f'{WORKDIR}/reconstruct.{EPOCH}_{i:06d}'
        if os.path.exists(outdir): continue
        else: break
    return outdir

def generate_volumes(zvalues, outdir):
    '''Helper function to call eval_decoder.py and generate new volumes'''
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    np.savetxt(f'{outdir}/zfile.txt', zvalues)
    analysis.gen_volumes(f'{WORKDIR}/weights.{EPOCH}.pkl',
                         f'{WORKDIR}/config.pkl',
                         f'{outdir}/zfile.txt',
                         f'{outdir}')
    return FileLinks(f'{outdir}/')

In [None]:
outdir = get_outdir()
print(os.path.abspath(outdir))

In [None]:
generate_volumes(z[vol_ind], outdir)