In [None]:
import pandas as pd
import numpy as np
import pickle
import subprocess
import os, sys
sys.path.insert(0,f'{os.environ["CDRGN_SRC"]}/lib-python')
import analysis
import utils
import dataset

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

### Load results

In [None]:
# Specify the workdir and the epoch number to analyze
WORKDIR = '.' 
EPOCH = 49 # CHANGE ME

In [None]:
# Load z
with open(f'{WORKDIR}/z.{EPOCH}.pkl','rb') as f:
    z = pickle.load(f)
    z_logvar = pickle.load(f)

In [None]:
# Load umap
umap = utils.load_pkl(f'{WORKDIR}/umap.{EPOCH}.pkl')

In [None]:
# Load kmeans
kmeans_labels = utils.load_pkl(f'{WORKDIR}/kmeans.{EPOCH}.labels.pkl')
kmeans_centers = np.loadtxt(f'{WORKDIR}/kmeans.{EPOCH}.centers.txt')
# Or run kmeans
#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)
#_, centers_i = analysis.get_nearest_point(z,kmeans_centers)


In [None]:
# Load poses and convert rotation matrices to euler angles
pose_pkl = f'{WORKDIR}/pose.{EPOCH}.pkl'
if os.path.exists(pose_pkl):
    with open(pose_pkl,'rb') as f:
        rot = pickle.load(f)
        trans = pickle.load(f)
else:
    config = utils.load_pkl(f'{WORKDIR}/config.pkl')
    pose_pkl = config['dataset_args']['poses']
    if len(pose_pkl) == 2:
        rot, trans = utils.load_pkl(pose_pkl[0]), utils.load_pkl(pose_pkl[1])
    else:
        (rot,trans) = utils.load_pkl(pose_pkl)
euler = np.array([RR.from_dcm(rr).as_euler('XYX') for rr in rot])*180/np.pi

In [None]:
# Load input particles
particles = dataset.load_particles(config['dataset_args']['particles'],
                            lazy=True,
                            datadir=config['dataset_args']['datadir'])

### PCA

In [None]:
pc, pca = analysis.run_pca(z)
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

### Pose distribution

In [None]:
# rotations
analysis.plot_euler(euler[:,0],euler[:,1], euler[:,2])

In [None]:
# translations
sns.jointplot(trans[:,0],trans[:,1],
              kind='hex').set_axis_labels('tx','ty')

### View UMAP

In [None]:
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

### View K-means clusters

In [None]:
K = len(set(kmeans_labels))
c = pca.transform(kmeans_centers) # transform with PC basis
analysis.plot_by_cluster(pc[:,0], pc[:,1], K, 
                         kmeans_labels, 
                         centers=c,
                         annotate=True)
plt.xlabel('PC1')
plt.ylabel('PC2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K, 
                            kmeans_labels)

In [None]:
# use nearest point as center so we don't have to rerun umap
analysis.plot_by_cluster(umap[:,0], umap[:,1], K, 
                         kmeans_labels, 
                         centers_i=centers_i,
                         annotate=True)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')

In [None]:
fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, 
                            kmeans_labels)

# Interactive visualization

In [None]:
# Load data into a pandas dataframe
df = analysis.load_dataframe(z=z, 
                             pc=pc, 
                             euler=euler, 
                             trans=trans, 
                             labels=kmeans_labels, 
                             umap=umap)
df.head()

In [None]:
widget, fig = analysis.ipy_plot_interactive_annotate(df,centers_i)
VBox((widget,fig))

## Interactive selection
1. Select points with the lasso tool. The table is dynamically updated with the most recent selection's indices. 
2. Save indices for downstream analysis/viz (see next cell).

(Double click to clear selection)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df)
VBox((widget,fig,ind_table))

In [None]:
ind_selected = ind_table.data[0].cells.values[0]
ind_selected = np.array(ind_selected)
ind_selected_not = np.array(sorted(set(np.arange(50000)) - set(ind_selected)))

print('Selected indices:')
print(ind_selected)
print('Number of selected points:')
print(len(ind_selected))
print('Number of unselected points:')
print(len(ind_selected_not))

### Visualize selected subset

In [None]:
# Subset of dataframe
df_sub = df.loc[ind_selected]
df_sub_not = df.loc[ind_selected_not]

In [None]:
# View pose distribution
analysis.plot_euler(df_sub.theta, df_sub.phi, df_sub.psi)

In [None]:
widget, fig, ind_table = analysis.ipy_plot_interactive(df_sub)
VBox((widget,fig,ind_table))

# View particles

In [None]:
particle_ind = ind_selected # or set to new selection

In [None]:
if len(particle_ind) > 9:
    ind_subset9 = np.random.choice(particle_ind, 9, replace=False)
else: 
    ind_subset9 = particle_ind
p = [particles[ii].get() for ii in ind_subset9]
analysis.plot_projections(p, ind_subset9)
widget, fig = analysis.ipy_plot_interactive_annotate(df,ind_subset9)
VBox((widget,fig))

# Generate volumes

In [None]:
# choose z indices
vol_ind = [] 

In [None]:
widget, fig = analysis.ipy_plot_interactive_annotate(df,vol_ind)
VBox((widget,fig))

In [None]:
def get_outdir():
    for i in range(100000):
        outdir = f'{WORKDIR}/reconstruct.{EPOCH}_{i:06d}'
        if os.path.exists: continue
        else: break
    return outdir

def generate_volumes(zvalues, outdir):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    np.savetxt(f'{outdir}/zfile.txt', zvalues)
    analysis.gen_volumes(f'{WORKDIR}/weights.{i}.pkl',
                         f'{WORKDIR}/config.pkl',
                         f'{outdir}/zfile.txt',
                         f'{outdir}')
    FileLinks(f'{outdir}/')

In [None]:
outdir = get_outdir()
print(outdir)

In [None]:
generate_volumes(z[vol_ind], outdir)