In [1]:
import os

import numpy as np
import pandas as pd

import math
from natsort import natsorted

from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap

from utils.utility_functions import categorical_cmap

from IPython.display import display, clear_output
from ipywidgets import interactive, widgets

In [2]:
def interactive_plot(Clustering, elev, azim):
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    if Clustering == 'VAE9':
        
        print (f'{Clustering} clustering')
        
        ax.scatter(
            data_dict[Clustering][0]['VAE9_ROT_VIG18_emb_3d_1'],
            data_dict[Clustering][0]['VAE9_ROT_VIG18_emb_3d_2'],
            data_dict[Clustering][0]['VAE9_ROT_VIG18_emb_3d_3'],
            cmap=data_dict[Clustering][1],
            c=data_dict[Clustering][2],
            s=150000 / len(data_dict[Clustering][0]),
            ec=['k' for i in data_dict[Clustering][0]['cluster']],
            linewidth=0.0
        )
        
        legend_elements = []
        for e, i in enumerate(natsorted(data_dict[Clustering][0]['cluster'].unique())):
    
            legend_elements.append(
                Line2D([0], [0], marker='o',
                       color='none',
                       label=f'{Clustering} cluster {i}',
                       markerfacecolor=data_dict[Clustering][1].colors[e],
                       markeredgecolor='none',
                       lw=0.001, markersize=8)
            )
        
        cluster_lgd = ax.legend(
            handles=legend_elements, prop={'size': 10}, bbox_to_anchor=[1.3, 0.95]
        )
    elif Clustering == 'VAE20':
        
        print (f'{Clustering} clustering')
        
        ax.scatter(
            data_dict[Clustering][0]['VAE20_emb_3d_1'],
            data_dict[Clustering][0]['VAE20_emb_3d_2'],
            data_dict[Clustering][0]['VAE20_emb_3d_3'],
            cmap=data_dict[Clustering][1],
            c=data_dict[Clustering][2],
            s=150000 / len(data_dict[Clustering][0]),
            ec=['k' for i in data_dict[Clustering][0]['cluster']],
            linewidth=0.0
        )
        
        legend_elements = []
        for e, i in enumerate(natsorted(data_dict[Clustering][0]['cluster'].unique())):
    
            legend_elements.append(
                Line2D([0], [0], marker='o',
                       color='none',
                       label=f'{Clustering} cluster {i}',
                       markerfacecolor=data_dict[Clustering][1].colors[e],
                       markeredgecolor='none',
                       lw=0.001, markersize=8)
            )
        
        cluster_lgd = ax.legend(
            handles=legend_elements, prop={'size': 10}, bbox_to_anchor=[1.3, 0.95]
        )
    
    elif Clustering == 'Leiden':
        
        print (f'{Clustering} clustering')
        
        ax.scatter(
            data_dict[Clustering][0]['emb1'],
            data_dict[Clustering][0]['emb2'],
            data_dict[Clustering][0]['emb3'],
            cmap=data_dict[Clustering][1],
            c=data_dict[Clustering][2],
            s=150000 / len(data_dict[Clustering][0]),
            ec=['k' for i in data_dict[Clustering][0]['cluster']],
            linewidth=0.0
        )
        
        legend_elements = []
        for e, i in enumerate(natsorted(data_dict[Clustering][0]['cluster'].unique())):
    
            legend_elements.append(
                Line2D([0], [0], marker='o',
                       color='none',
                       label=f'{Clustering} cluster {i}',
                       markerfacecolor=data_dict[Clustering][1].colors[e],
                       markeredgecolor='none',
                       lw=0.001, markersize=8)
            )
        
        cluster_lgd = ax.legend(
            handles=legend_elements, prop={'size': 10}, bbox_to_anchor=[1.3, 0.95]
        )

    elif Clustering == 'HDBSCAN':
        
        print (f'{Clustering} clustering')
        
        ax.scatter(
            data_dict[Clustering][0]['emb1'],
            data_dict[Clustering][0]['emb2'],
            data_dict[Clustering][0]['emb3'],
            cmap=data_dict[Clustering][1],
            c=data_dict[Clustering][2],
            s=150000 / len(data_dict[Clustering][0]),
            ec=['k' for i in data_dict[Clustering][0]['cluster']],
            linewidth=0.0
        )
    
        legend_elements = []
        for e, i in enumerate(natsorted(data_dict[Clustering][0]['cluster'].unique())):
    
            legend_elements.append(
                Line2D([0], [0], marker='o',
                       color='none',
                       label=f'{Clustering} cluster {i}',
                       markerfacecolor=data_dict[Clustering][1].colors[e],
                       markeredgecolor='none',
                       lw=0.001, markersize=8)
            )
    
        cluster_lgd = ax.legend(
            handles=legend_elements, prop={'size': 10}, bbox_to_anchor=[1.3, 0.95]
        )
    
    ax.axis('auto')
    ax.tick_params(labelsize=10)
    ax.grid(True)
    ax.view_init(elev=elev, azim=azim)
    
    plt.show()

In [3]:
def rotation_movie(clustering, df_labeled, cmap, c):
    
    fig = plt.figure()
    ax = plt.axes(projection='3d', computed_zorder=False)

    def init():
        
        if clustering == 'VAE9':
            ax.scatter(
                df_labeled['VAE9_ROT_VIG18_emb_3d_1'],
                df_labeled['VAE9_ROT_VIG18_emb_3d_2'],
                df_labeled['VAE9_ROT_VIG18_emb_3d_3'],
                c=c,
                cmap=cmap,
                s=150000 / len(df_labeled),
                ec=['k' for i in df_labeled['cluster']],
                linewidth=0.0
            )
        elif clustering == 'VAE20':
            ax.scatter(
                df_labeled['VAE20_emb_3d_1'],
                df_labeled['VAE20_emb_3d_2'],
                df_labeled['VAE20_emb_3d_3'],
                c=c,
                cmap=cmap,
                s=150000 / len(df_labeled),
                ec=['k' for i in df_labeled['cluster']],
                linewidth=0.0
            )
        else:
            ax.scatter(
                df_labeled['emb1'],
                df_labeled['emb1'],
                df_labeled['emb1'],
                c=c,
                cmap=cmap,
                s=150000 / len(df_labeled),
                ec=['k' for i in df_labeled['cluster']],
                linewidth=0.0
            )
    
        ax.axis('auto')
        ax.tick_params(labelsize=10)
        ax.grid(True)
    
        legend_elements = []
        for e, i in enumerate(natsorted(df_labeled['cluster'].unique())):
    
            legend_elements.append(
                Line2D([0], [0], marker='o',
                       color='none',
                       label=f'{clustering} cluster {i}',
                       markerfacecolor=cmap.colors[e],
                       markeredgecolor='none',
                       lw=0.001, markersize=8)
            )
    
        cluster_lgd = ax.legend(
            handles=legend_elements, prop={'size': 10}, bbox_to_anchor=[1.2, 0.95]
        )
        plt.tight_layout()
        return fig,
    
    def animate(i):
        ax.view_init(elev=10., azim=i)
        return fig,

    anim = animation.FuncAnimation(
        fig, animate, init_func=init,
        frames=360, interval=20, blit=True)

    anim.save(
        os.path.join(out, f'{clustering}_400dpi.mp4'),
        dpi=400, fps=30, extra_args=['-vcodec', 'libx264'])
    

In [4]:
# Paths and inputs
root = '/Volumes/T7 Shield/cylinter_input/clean_quant/output_3d_v2/'
df = pd.read_parquet(os.path.join(os.getcwd(), 'input/CRC-097_clean_cylinter_pca.parquet'))

main = pd.read_csv(os.path.join(os.getcwd(), 'input/main.csv'))
vae9_clustering = 'VAE9_ROT_VIG18'
vae20_clustering = 'VAE20'

out = os.path.join(os.getcwd(), 'output/3d_umap_animation')
if not os.path.exists(out):
    os.makedirs(out)

In [5]:
# Reproduce index at point of embedding, then merge saved embedding with dataframe
groups = df.groupby('Sample')
sample_weights = pd.DataFrame({'weights': 1 / (groups.size() * len(groups))})
weights = pd.merge(df[['Sample']], sample_weights,left_on='Sample', right_index=True)
df = df.sample(frac=1.0, replace=False, weights=weights['weights'], random_state=5, axis=0)
embedding = np.load(os.path.join(root, 'clustering/final/embedding.npy'))
df['emb1'] = embedding[:, 0]
df['emb2'] = embedding[:, 1]
df['emb3'] = embedding[:, 2]

In [6]:
# GATING CODE
# read soft gates
# gate = pd.read_csv('/Volumes/T7 Shield/cylinter_input/' +
#                    'clean_quant/gating.csv')

# work on a copy of needed columns
# gate = gate[['CellID', 'Label']].copy()

# isolate + and - populations
# yes_gate = df[df['CellID'].isin(gate['CellID'])]
# no_gate = df[~df['CellID'].isin(gate['CellID'])]

# sample the negative population
# no_gate = no_gate.sample(n=len(gate), random_state=1)

# append to positive population
# df = no_gate.append(yes_gate)

# append CyLinter dataframe with gate calls, row-wise
# df = df.merge(gate, how='left', on='CellID')

# fill NANs
# df['Label'].fillna(value='Pop0', inplace=True)

# sort by gate label column
# df['Label'] = pd.Categorical(
#     df['Label'], ordered=True,
#     categories=natsorted(df['Label'].unique())
#     )

# df = df.sort_values('Label')

In [7]:
# Read Leiden cluster labels
leiden_labels = pd.read_csv(os.path.join(os.getcwd(), 'input/CRC-097_leiden.csv'))

# Read HDBSCAN cluster labels
hdbscan_labels = pd.read_csv(os.path.join(os.getcwd(), 'input/CRC-097_hdbscan.csv'))

# Read VAE9_ROT_VIG18 cluster labels
vae9_labels = pd.read_csv(
    '/Users/greg/projects/vae-paper/src/input/VAE9_ROT_VIG18/6_latent_space_LD184/clustering_full/'
    'UMAP/neighbors30_repulsion5_random1/Leiden_cluster_labels_50PC_32C.csv'
)

# Read VAE20 cluster labels
vae20_labels = pd.read_csv(
    '/Users/greg/projects/vae-paper/src/input/VAE20/6_latent_space_LD850/clustering_full/UMAP/'
    'neighbors30_repulsion3_random3/Leiden_cluster_labels_50PC_36C.csv'
)

In [8]:
# generate plot data for Leiden and HDBSCAN clusterings
data_dict = {}
for clustering, labels in zip(['VAE9', 'VAE20', 'Leiden', 'HDBSCAN'], [vae9_labels, vae20_labels, leiden_labels, hdbscan_labels]):

    labels.columns = ['CellID', 'cluster']
    
    # Ensure CELL ID column is int dtype
    labels['CellID'] = labels['CellID'].astype(int)

    # Ensure cluster column is int dtype
    labels['cluster'] = labels['cluster'].astype(int)

    if 'VAE' in clustering:
        main['cluster'] = labels['cluster']
        df_labeled = main

        # Use a subset of cells for plotting; shuffle index for homogenous z-order
        df_labeled = df_labeled.sample(frac=1.0, random_state=1)
        df_labeled.reset_index(drop=True, inplace=True)

        palette_multiplier = math.ceil(
            len(df_labeled['cluster'].unique()) / len(plt.get_cmap('tab20').colors)
        )
        ccolors = np.array(plt.get_cmap('tab20').colors)
        ccolors = np.tile(ccolors, (palette_multiplier, 1))
        cmap = colors.ListedColormap(ccolors)
        
        # trim colors if necessary
        if len(cmap.colors) > len(df_labeled['cluster'].unique()):
            trim = len(cmap.colors) - len(df_labeled['cluster'].unique())
            cmap_colors = cmap.colors[:-trim]
            cmap = colors.ListedColormap(cmap_colors, name='from_list', N=None)

    else:
        # Append cluster labels to CyLinter dataframe
        df_labeled = df.merge(labels, how='inner', on='CellID')	
    
        # Use a subset of cells for plotting; shuffle index for homogenous z-order
        df_labeled = df_labeled.sample(n=479909, random_state=1)

        # Remove unclustered cells in the case of HDBSCAN
        df_labeled = df_labeled[df_labeled['cluster'] != -1]
        
        df_labeled.reset_index(drop=True, inplace=True)

        # build cmap
        cmap = categorical_cmap(
            numUniqueSamples=len(df_labeled['cluster'].unique()),
            numCatagories=20, cmap='tab20', continuous=False
        )

    if df_labeled['cluster'].unique().min() == -1:
        # make black the first color to specify
        # cluster outliers (i.e. cluster -1 cells)
        cmap = ListedColormap(
            np.insert(
                arr=cmap.colors, obj=0, values=[0.0, 0.0, 0.0], axis=0)
        )

        # Trim qualitative cmap to number of unique samples
        cmap = ListedColormap(cmap.colors[:-1])

    sample_dict = dict(
        zip(
            natsorted(df_labeled['cluster'].unique()),
            list(range(len(df_labeled['cluster'].unique()))))
    )

    c = [sample_dict[i] for i in df_labeled['cluster']]

    data_dict[clustering] = (df_labeled, cmap, c)

In [9]:
# display interactive plots
plt.rcParams['figure.figsize'] = (13, 9)

elev_slider = widgets.IntSlider(min=0, max=90, step=1, value=10, description='Elevation:')
azim_slider = widgets.IntSlider(min=0, max=360, step=1, value=230, description='Azimuth:')

plot = interactive(
        interactive_plot, elev=elev_slider, azim=azim_slider,
        Clustering=['VAE9', 'VAE20', 'Leiden', 'HDBSCAN'], 
    )

display(plot)

interactive(children=(Dropdown(description='Clustering', options=('VAE9', 'VAE20', 'Leiden', 'HDBSCAN'), value…

In [None]:
# generate and save movie
for clustering, data in data_dict.items():
    if not os.path.exists(os.path.join(out, f'{clustering}_400dpi.mp4')):
        print(f'Saving rotating plot for {clustering} clustering...')
        rotation_movie(clustering=clustering, df_labeled=data[0], cmap=data[1], c=data[2])
print('Complete!')

Saving rotating plot for HDBSCAN clustering...
