# CryoDRGN visualization and figures

This jupyter notebook provides a template for regenerating and customizing cryoDRGN visualizations and figures

In [None]:
from cryodrgn import analysis
from cryodrgn import utils

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

### Load results

In [None]:
# Specify the workdir and the epoch number (0-based index) to analyze
WORKDIR = '..' 
EPOCH = None # change me if necessary!

In [None]:
# Load z
z = utils.load_pkl(f'{WORKDIR}/z.{EPOCH}.pkl')
umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')

# Plot learning curve

In [None]:
loss = analysis.parse_loss(f'{WORKDIR}/run.log')
plt.figure(figsize=(4, 4))
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.axvline(x=EPOCH, linestyle="--", color="black", label=f"Epoch {EPOCH}")
plt.legend()
plt.tight_layout()
#plt.savefig(f"{WORKDIR}/analyze.{EPOCH}/learning_curve_epoch{EPOCH}.png")

# Plot PCA

Visualize the latent space by principal component analysis (PCA).

In [None]:
pc, pca = analysis.run_pca(z)

In [None]:
# Style 1 -- Scatter

plt.figure(figsize=(4,4))
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style1.pdf')

In [None]:
# Style 2 -- Scatter with marginals

g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)
g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style2.pdf')

In [None]:
# Style 3 -- Hexbin/heatmap

try:
    g = sns.jointplot(x=pc[:,0], y=pc[:,1], height=4, kind='hex')
except ZeroDivisionError:
    print("Data too small to produce hexbins!")
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_style3.pdf')

# Plot UMAP

Visualize the latent space by Uniform Manifold Approximation and Projection (UMAP). 

In [None]:
# Style 1 -- Scatter

plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1,rasterized=True)
plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_style1.pdf')

In [None]:
# Style 2 -- Scatter with marginal distributions

g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1,rasterized=True, height=4)
g.ax_joint.set_xlabel('UMAP1')
g.ax_joint.set_ylabel('UMAP2')
#plt.savefig('umap_style2.pdf')

In [None]:
# Style 3 -- Hexbin / heatmap

try:
    g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex',height=4)
except ZeroDivisionError:
    print("Data too small to produce hexbins!")
g.ax_joint.set_xlabel('UMAP1')
g.ax_joint.set_ylabel('UMAP2')
#plt.savefig('umap_style3.pdf')

# Plot kmeans samples

In [None]:
# Load points
KMEANS = None
kmeans_ind = np.loadtxt(
    f'{WORKDIR}/analyze.{EPOCH}/kmeans{KMEANS}/centers_ind.txt', dtype=int
)


In [None]:
# Default chimerax color map
colors = analysis._get_chimerax_colors(KMEANS)

In [None]:
# Plot kmeans on PCA

f, ax = plt.subplots(figsize=(4,4))
plt.scatter(pc[:,0], pc[:,1], alpha=.05, s=1,rasterized=True)
plt.scatter(pc[kmeans_ind,0], pc[kmeans_ind,1], c=colors,edgecolor='black')
labels = np.arange(len(kmeans_ind))
centers = pc[kmeans_ind]
for i in labels:
    ax.annotate(str(i), centers[i, 0:2] + np.array([0.1, 0.1]))
plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_w_kmeans.pdf')

In [None]:
# Plot kmeans on UMAP

f, ax = plt.subplots(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.scatter(umap[kmeans_ind,0], umap[kmeans_ind,1], c=colors,edgecolor='black')
labels = np.arange(len(kmeans_ind))
centers = umap[kmeans_ind]
for i in labels:
    ax.annotate(str(i), centers[i, 0:2] + np.array([0.1, 0.1]))
plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_w_kmeans.pdf')

### Plot PC traversals

Visualize the PC axes traversals. By default, plot the first two PCs.

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)

# 10 points, from 5th to 95th percentile of PC1 values
t = np.linspace(np.percentile(pc[:,0],5),np.percentile(pc[:,0],95), 10, endpoint=True)
plt.scatter(t,np.zeros(10),c='cornflowerblue',edgecolor='white')

plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)

# 10 points, from 5th to 95th percentile of PC2 values
t = np.linspace(np.percentile(pc[:,1],5),np.percentile(pc[:,1],95),10,endpoint=True)
plt.scatter(np.zeros(10),t,c='cornflowerblue',edgecolor='white')

plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)

t = np.linspace(np.percentile(pc[:,0],5),np.percentile(pc[:,0],95),10,endpoint=True)
g.ax_joint.scatter(x=t,y=np.zeros(10),c='cornflowerblue',edgecolor='white')

g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_pc1_traversal.pdf')

In [None]:
g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)
t = np.linspace(np.percentile(pc[:,1],5),np.percentile(pc[:,1],95),10,endpoint=True)
g.ax_joint.scatter(x=np.zeros(10),y=t,c='cornflowerblue',edgecolor='white')
g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))
g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))
#plt.savefig('pca_pc2_traversal.pdf')

### Plot UMAP 

Plot the PC axes traversal paths in the UMAP visualization of the latent space.

In [None]:
z_pc1 = np.loadtxt('pc1/z_values.txt')

In [None]:
z_pc1_on_data, pc1_ind = analysis.get_nearest_point(z, z_pc1)
((z_pc1_on_data - z_pc1)**2).sum(axis=1)**.5

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.scatter(umap[pc1_ind,0], umap[pc1_ind,1], c='cornflowerblue',edgecolor='black')

plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_pc1_traversal.pdf')

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.plot(umap[pc1_ind,0], umap[pc1_ind,1], '--',c='k')
plt.scatter(umap[pc1_ind,0], umap[pc1_ind,1], c='cornflowerblue',edgecolor='black')

plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_pc1_traversal_v2.pdf')

In [None]:
z_pc2 = np.loadtxt('pc2/z_values.txt')

In [None]:
z_pc2_on_data, pc2_ind = analysis.get_nearest_point(z, z_pc2)
((z_pc2_on_data - z_pc2)**2).sum(axis=1)**.5

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.scatter(umap[pc2_ind,0], umap[pc2_ind,1], c='cornflowerblue',edgecolor='black')

plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_pc2_traversal.pdf')

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)
plt.plot(umap[pc2_ind,0], umap[pc2_ind,1], '--',c='k')
plt.scatter(umap[pc2_ind,0], umap[pc2_ind,1], c='cornflowerblue',edgecolor='black')

plt.xticks([])
plt.yticks([])
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
#plt.savefig('umap_pc2_traversal_v2.pdf')