# CryoDRGN landscape analysis

This jupyter notebook contains additional functions to visualize the results of `cryodrgn analyze_landscape` and `cryodrgn analyze_landscape_full`.

In [None]:
import numpy as np
import pickle
import subprocess
import os, sys

from cryodrgn.mrcfile import parse_mrc
from cryodrgn import analysis
from cryodrgn import utils
from cryodrgn import dataset
from cryodrgn import ctf
 
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import cdist

### Set parameters

In [None]:
EPOCH = None  # change me if necessary!
WORKDIR = None  # Directory with cryoDRGN outputs

K = None  # Number of sketched volumes
M = None  # Number of clusters
linkage = None  # Linkage method used for clustering

In [None]:
# landscape analysis directory
landscape_dir = f'{WORKDIR}/landscape.{EPOCH}'

# subdirectories with clustering analysis and volume mapping
clustering_dir = f'{landscape_dir}/clustering_L2_{linkage}_{M}'
landscape_full_dir = f'{landscape_dir}/landscape_full'

### Load data

In [None]:
z = utils.load_pkl(f'{WORKDIR}/z.{EPOCH}.pkl')
z_pc = PCA(z.shape[1]).fit_transform(z)
umap = utils.load_pkl(f'{landscape_dir}/umap.pkl')
centers_ind = np.loadtxt(f'{landscape_dir}/kmeans{K}/centers_ind.txt').astype(int)

In [None]:
mask = parse_mrc(f'{landscape_dir}/mask.mrc')
mask = mask[0].astype(bool)
print(f'{mask.sum()} out of {np.prod(mask.shape)} voxels included in mask')

In [None]:
vol_pc = utils.load_pkl(f'{landscape_dir}/vol_pca_{K}.pkl')
vol_pc_all = utils.load_pkl(f'{landscape_full_dir}/vol_pca_all.pkl')

In [None]:
kmeans_labels = utils.load_pkl(f'{landscape_dir}/kmeans{K}/labels.pkl')
labels = utils.load_pkl(f'{clustering_dir}/state_labels.pkl')

### Optionally reanalyze volume data




In [None]:
# Load volumes
'''
volm, _ = parse_mrc(f'kmeans{K}/vol_mean.mrc')
vols = np.array([parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])
vols.shape
vols[vols<0]=0
'''

In [None]:
# Rerun volume PCA
'''
pca = PCA(50)
pca.fit(vols)
vol_pc = pca.transform(vols)
'''

In [None]:
# Save out volume PCs
'''
mkdir volume_pcs
'''

In [None]:
'''
!for i in {1..5}; do mkdir volume_pcs/pc$i; done
'''

In [None]:
# Save first 5 volume PCs
'''
for i in range(5):
    min_, max_ = pc[:,i].min(), pc[:,i].max()
    print(min_, max_)
    for j, a in enumerate(np.linspace(min_,max_,10,endpoint=True)):
        v = volm.copy()
        v[mask] += pca.components_[i]*a
        mrc.write(f'volume_pcs/pc{i+1}/{j}.mrc', v)
'''

In [None]:
# Rerun clustering
'''
cluster = AgglomerativeClustering(n_clusters=10, affinity='euclidean', linkage='average')
labels = cluster.fit_predict(vols)
'''

# Plotting

In [1]:
save_pdf = False

### Plot landscape

In [None]:
i, j = 0,1
sns.jointplot(x=vol_pc[:,i], y=vol_pc[:,j])
if save_pdf:
    plt.savefig('volpca_sketch.pdf')

In [None]:
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='hex', height=8)
plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)  # shrink fig so cbar is visible
# make new ax object for the cbar
cbar_ax = g.fig.add_axes([.85, .25, .03, .4])  # x, y, width, height
plt.colorbar(cax=cbar_ax)
if save_pdf:
    plt.savefig('volpca_landscape.pdf')

In [None]:
# Plot landscape -- energy scale
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='hex', height=8,
                  cmap='jet_r', bins='log', mincnt=1)

plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)  # shrink fig so cbar is visible
# make new ax object for the cbar
cbar_ax = g.fig.add_axes([.85, .25, .03, .4])  # x, y, width, height
plt.colorbar(cax=cbar_ax)
if save_pdf:
    plt.savefig('volpca_landscape_energy.pdf')

### Plot clusters

In [None]:
i, j = 0,1
g = sns.jointplot(x=vol_pc[:,i], y=vol_pc[:,j], hue=labels, palette={i:f'C{i}' for i in range(M)})
if save_pdf:
    plt.savefig('volpca_clusters.pdf')

In [None]:
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='scatter', color='lightgrey', s=1, alpha=.1, rasterized=True)
g.ax_joint.scatter(x=vol_pc[:,i], y=vol_pc[:,j], c=labels, cmap='tab10', s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('volpca_clusters_all.pdf')

### Plot latent space

In [None]:
i, j = 0,1
cmap = 'tab10' # change if M != 10
g = sns.jointplot(x=z_pc[:,i], y=z_pc[:,j], kind='scatter', color='lightgrey', s=1, alpha=.2, rasterized=True)
g.ax_joint.scatter(x=z_pc[centers_ind,i], y=z_pc[centers_ind,j], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('zpca_clusters.pdf')

In [None]:
i, j = 0,1
cmap = 'tab10' # change if M != 10
g = sns.jointplot(x=umap[:,i], y=umap[:,j], kind='scatter', color='lightgrey', s=1, alpha=.2, rasterized=True)
g.ax_joint.scatter(x=umap[centers_ind,i], y=umap[centers_ind,j], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('umap_clusters.pdf')

### Cluster counts

In [None]:
from collections import Counter
counts = Counter(labels)
kmeans_counts = Counter(kmeans_labels)
M = len(counts)

In [None]:
particle_counts = [np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]]) for i in range(M)]

In [None]:
plt.subplots(figsize=(6,5))
x = np.arange(M)
y = [counts[i] for i in range(M)]
g = sns.barplot(x,y)
for i in range(len(counts)):
    g.text(i-.2, y[i]+3, y[i])
plt.xlabel('State')
plt.ylabel('Volume count')
plt.tight_layout()
if save_pdf:
    plt.savefig('volume_counts.pdf')

In [None]:
plt.subplots(figsize=(6,5))
x = np.arange(M)
y = particle_counts
g = sns.barplot(x,y)
for i in range(len(counts)):
    g.text(i-.45, y[i]+1000, y[i])
plt.xlabel('State')
plt.ylabel('Particle count')
plt.tight_layout()
if save_pdf:
    plt.savefig('particle_counts.pdf')