# CryoDRGN landscape analysis

This jupyter notebook contains additional functions to visualize the results of `cryodrgn analyze_landscape` and `cryodrgn analyze_landscape_full`.

In [None]:
import numpy as np
import pickle
import subprocess
import os, sys

from cryodrgn.mrcfile import parse_mrc, write_mrc
from cryodrgn import analysis
from cryodrgn import utils
from cryodrgn import dataset
from cryodrgn import ctf
 
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import interact, interactive, HBox, VBox
from scipy.spatial.transform import Rotation as RR
py.init_notebook_mode()
from IPython.display import FileLink, FileLinks

from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import cdist

import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec

### Set parameters

In [None]:
EPOCH = None  # change me if necessary!
WORKDIR = None  # Directory with cryoDRGN outputs

K = None  # Number of sketched volumes
M = None  # Number of clusters
linkage = None  # Linkage method used for clustering

In [None]:
# landscape analysis directory
landscape_dir = f'{WORKDIR}/landscape.{EPOCH}'

# subdirectories with clustering analysis and volume mapping
clustering_dir = os.path.join(landscape_dir, f"sketch_clustering_{linkage}_{M}")
landscape_full_dir = os.path.join(landscape_dir, "landscape_full")

### Load data

In [None]:
z = utils.load_pkl(f'{WORKDIR}/z.{EPOCH}.pkl')
z_pc = PCA(z.shape[1]).fit_transform(z)
umap = utils.load_pkl(f'{landscape_dir}/umap.pkl')
centers_ind = np.loadtxt(f'{landscape_dir}/kmeans{K}/centers_ind.txt').astype(int)

In [None]:
mask = parse_mrc(f'{landscape_dir}/mask.mrc')
mask = mask[0].astype(bool)
print(f'{mask.sum()} out of {np.prod(mask.shape)} voxels included in mask')

In [None]:
vol_pc = utils.load_pkl(f'{landscape_dir}/vol_pca_{K}.pkl')
vol_pca = utils.load_pkl(f'{landscape_dir}/vol_pca_obj.pkl')
vol_pc_all = utils.load_pkl(f'{landscape_full_dir}/vol_pca_all.pkl')

In [None]:
kmeans_labels = utils.load_pkl(f'{landscape_dir}/kmeans{K}/labels.pkl')
labels = utils.load_pkl(f'{clustering_dir}/state_labels.pkl')

### Optionally reanalyze volume data




In [None]:
# Load volumes
'''
volm, _ = parse_mrc(f'kmeans{K}/vol_mean.mrc')
vols = np.array([parse_mrc(f'kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K)])
vols.shape
vols[vols<0]=0
'''

In [None]:
# Rerun volume PCA
'''
pca = PCA(50)
pca.fit(vols)
vol_pc = pca.transform(vols)
'''

In [None]:
# Save out volume PCs
'''
mkdir volume_pcs
'''

In [None]:
'''
!for i in {1..5}; do mkdir volume_pcs/pc$i; done
'''

In [None]:
# Save first 5 volume PCs
'''
for i in range(5):
    min_, max_ = pc[:,i].min(), pc[:,i].max()
    print(min_, max_)
    for j, a in enumerate(np.linspace(min_,max_,10,endpoint=True)):
        v = volm.copy()
        v[mask] += pca.components_[i]*a
        write_mrc(f'volume_pcs/pc{i+1}/{j}.mrc', v)
'''

In [None]:
# Rerun clustering
'''
cluster = AgglomerativeClustering(n_clusters=10, affinity='euclidean', linkage='average')
labels = cluster.fit_predict(vols)
'''

# Plotting

In [1]:
save_pdf = False

### Scree plot for volume PCA

In [None]:
explained_variance_ratio = vol_pca.explained_variance_ratio_
cumulative_variance = np.cumsum(explained_variance_ratio)

# percent
explained_variance_ratio_percent = explained_variance_ratio * 100
cumulative_variance_percent = cumulative_variance * 100

# plot EV
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio_percent, marker='o', linestyle='--')
plt.xlabel('Principal Component')
plt.ylabel('Explained Variance Ratio (%)')
plt.title('Scree Plot for Volume PCA')

# x-axis ticks
plt.xticks(range(1, len(explained_variance_ratio) + 1))

# Plot cumulative EV
ax2 = plt.gca().twinx()
ax2.plot(range(1, len(cumulative_variance_percent) + 1), cumulative_variance_percent, marker='o', color='gray', linestyle='-')
ax2.set_ylabel('Cumulative Explained Variance (%)')

if save_pdf:
    plt.savefig('volpca_scree.pdf')

### Plot landscape

In [None]:
i, j = 0,1
sns.jointplot(x=vol_pc[:,i], y=vol_pc[:,j])
if save_pdf:
    plt.savefig('volpca_sketch.pdf')

In [None]:
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='hex', height=8)
plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)  # shrink fig so cbar is visible
# make new ax object for the cbar
cbar_ax = g.fig.add_axes([.85, .25, .03, .4])  # x, y, width, height
plt.colorbar(cax=cbar_ax)
if save_pdf:
    plt.savefig('volpca_landscape.pdf')

In [None]:
# Plot landscape -- energy scale
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='hex', height=8,
                  cmap='jet_r', bins='log', mincnt=1)

plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)  # shrink fig so cbar is visible
# make new ax object for the cbar
cbar_ax = g.fig.add_axes([.85, .25, .03, .4])  # x, y, width, height
plt.colorbar(cax=cbar_ax)
if save_pdf:
    plt.savefig('volpca_landscape_energy.pdf')

In [None]:
# Grid plot landscape -- energy scale

# Set up the triangular grid layout
n_pcs = 5 # CHANGE ME IF NEEDED
fig = plt.figure(figsize=(15, 15))
gs = gridspec.GridSpec(n_pcs-1, n_pcs-1, wspace=0, hspace=0)

# Define the color map and color bar axis
cmap = 'jet_r'
norm = plt.Normalize(vmin=0, vmax=5)
cbar_ax = fig.add_axes([0.92, 0.25, 0.02, 0.5])  # Adjust position as needed

# Loop over each subplot location in the triangular grid
for i in range(1, n_pcs):
    for j in range(i):            
            
        ax = fig.add_subplot(gs[i-1, j])

        # Plot hexbin with color map and bins as log scale
        hb = ax.hexbin(vol_pc_all[:, j], vol_pc_all[:, i], gridsize=50, cmap=cmap, bins='log', mincnt=1)
        
        # Only set labels for leftmost and bottom plots
        if j == 0:
            ax.set_ylabel(f'Volume PC{i+1} (EV: {vol_pca.explained_variance_ratio_[i]:.0%})',
                          fontsize=14, fontweight='bold')

        if i == n_pcs-1:
            ax.set_xlabel(f'Volume PC{j+1} (EV: {vol_pca.explained_variance_ratio_[j]:.0%})',
                          fontsize=14, fontweight='bold')
        
        # Exact values are not needed
        ax.set_yticks([])
        ax.set_xticks([])

plt.colorbar(hb, cax=cbar_ax, label='Log Density')

if save_pdf:
    plt.savefig(f'volpca_grid{n_pcs}_landscape_energy.pdf')

### Plot clusters

In [None]:
i, j = 0,1
g = sns.jointplot(
    x=vol_pc[:,i],
    y=vol_pc[:,j],
    hue=labels,
    palette={i:f'C{i}' for i in sorted(np.unique(labels))}
)
if save_pdf:
    plt.savefig('volpca_clusters.pdf')


In [None]:
i, j = 0,1
g = sns.jointplot(x=vol_pc_all[:,i], y=vol_pc_all[:,j], kind='scatter', color='lightgrey', s=1, alpha=.1, rasterized=True)
g.ax_joint.scatter(x=vol_pc[:,i], y=vol_pc[:,j], c=labels, cmap='tab10', s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('volpca_clusters_all.pdf')

In [None]:
# Set up the triangular grid layout
n_pcs = 5  # CHANGE ME if needed
fig = plt.figure(figsize=(15, 15))
gs = gridspec.GridSpec(n_pcs-1, n_pcs-1, wspace=0, hspace=0)

# Define the color map for cluster labels
cmap = 'tab20'

# Loop over each subplot location in the triangular grid
for i in range(1, n_pcs):
    for j in range(i):            
        ax = fig.add_subplot(gs[i-1, j])

        # Plot background scatter with light gray points
        ax.scatter(vol_pc_all[:, j], vol_pc_all[:, i], color='lightgrey', s=1, alpha=0.1, rasterized=True)
        
        # Overlay labeled scatter points with color coding
        sc = ax.scatter(vol_pc[:, j], vol_pc[:, i], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=0.25)

        # Only set labels for leftmost and bottom plots
        if j == 0:
            ax.set_ylabel(f'Volume PC{i+1} (EV: {vol_pca.explained_variance_ratio_[i]:.0%})',
                          fontsize=14, fontweight='bold')
        if i == n_pcs-1:
            ax.set_xlabel(f'Volume PC{j+1} (EV: {vol_pca.explained_variance_ratio_[j]:.0%})',
                          fontsize=14, fontweight='bold')

        # Remove ticks for cleaner look
        ax.xaxis.set_major_locator(ticker.NullLocator())
        ax.yaxis.set_major_locator(ticker.NullLocator())

# Create a legend outside the grid
unique_labels = np.unique(labels)
colors = [sc.cmap(sc.norm(label)) for label in unique_labels]
patches = [mpatches.Patch(color=colors[k], label=f'Cluster {unique_labels[k]}') for k in range(len(unique_labels))]
fig.legend(handles=patches, fontsize=20)

if save_pdf:
    plt.savefig(f'volpca_grid{n_pcs}_clusters_all.pdf')

### Plot latent space

In [None]:
i, j = 0,1
cmap = 'tab10' # change if M != 10
g = sns.jointplot(x=z_pc[:,i], y=z_pc[:,j], kind='scatter', color='lightgrey', s=1, alpha=.2, rasterized=True)
g.ax_joint.scatter(x=z_pc[centers_ind,i], y=z_pc[centers_ind,j], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('zpca_clusters.pdf')

In [None]:
i, j = 0,1
cmap = 'tab10' # change if M != 10
g = sns.jointplot(x=umap[:,i], y=umap[:,j], kind='scatter', color='lightgrey', s=1, alpha=.2, rasterized=True)
g.ax_joint.scatter(x=umap[centers_ind,i], y=umap[centers_ind,j], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=.25)
if save_pdf:
    plt.savefig('umap_clusters.pdf')

### Cluster counts

In [None]:
from collections import Counter
counts = Counter(labels)
kmeans_counts = Counter(kmeans_labels)
M = len(counts)

In [None]:
particle_counts = [np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]]) for i in range(1, M + 1)]

In [None]:
plt.subplots(figsize=(6,5))
x = np.arange(M)
y = [counts[i] for i in range(1, M + 1)]
g = sns.barplot(x=x,y=y)

for i in range(len(counts)):
    g.text(i-.2, y[i]+3, y[i])
plt.xlabel('State')
plt.ylabel('Volume count')
plt.tight_layout()

if save_pdf:
    plt.savefig('volume_counts.pdf')


In [None]:
plt.subplots(figsize=(6,5))
x = np.arange(M)
y = particle_counts
g = sns.barplot(x=x,y=y)

for i in range(len(counts)):
    g.text(i-.45, y[i]+1000, y[i])
plt.xlabel('State')
plt.ylabel('Particle count')
plt.tight_layout()

if save_pdf:
    plt.savefig('particle_counts.pdf')
