# Single-Cell Latent Space Clustering

This notebook clusters and visualizes latent representations of single cells using various clustering methods and t-SNE projection.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict

from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE

from framework import clustering

np.random.seed(1013)

## Setup

In [None]:
# Shortcut for repetitive filepaths
dataset = "GBM"
model_dir = "results/{}/Train1000g{}VAE/{}VAE_Final".format(
    dataset.lower(), dataset, dataset)
latent_space_file = model_dir + "/latent_representations.txt"

# Number of clusters to evaluate
n_clusters = list(range(2, 12, 1))

# Methods to use for clustering
cluster_methods = ["gmm", "km", "hc"]

# Column to use as labels (None if no labels)
label_col = "tumor_ids"
# Whether the labels should be used to calculate 
# clustering metrics
calculate_metrics = False

# What column marks the end of the features and 
# the beginning of labels/other information
features_end_col_idx = -1

cluster_output_file = model_dir + "/clusters.txt"

## Load the Latent Space

In [None]:
# Load the latent space
def load_latent_space():
    df = pd.read_csv(latent_space_file, sep="\t", header=0, index_col=0)
    latent_space = df.iloc[:, 0:features_end_col_idx].values.astype(
        dtype=np.float64)
    
    return df, latent_space
    
df, latent_space = load_latent_space()
if label_col:
    labels = df.loc[:, label_col].values
    label_encoder = LabelEncoder()
    label_encoder.fit(labels)
    int_labels = label_encoder.transform(labels)
    
# Fit a t-SNE model to the features for visualization purposes
def fit_tsne(features):
    tsne_model = TSNE(n_components=2, init='pca', random_state=0, perplexity=30)
    return tsne_model.fit_transform(features)

tsne_output = fit_tsne(latent_space)

In [None]:
# Plot t-SNE projection
if label_col:
    plt.scatter(tsne_output[:, 0], tsne_output[:, 1], 
                c=int_labels, cmap=plt.cm.get_cmap('rainbow', len(np.unique(int_labels))))
    plt.colorbar()
    plt.clim(-0.5, np.max(int_labels) + 0.5)
else:
    plt.scatter(tsne_output[:, 0], tsne_output[:, 1])
plt.show()

## Clustering

In [None]:
# Clustering
clustering_results = OrderedDict([])
for method in cluster_methods:
    method_results = eval("clustering.cluster_" + method)(latent_space, n_clusters)
    clustering_results[method] = method_results

# Calculate clustering metrics (if specified)
if label_col and calculate_metrics:
    print('name\thomo\tcompl\tv-meas\tARI\tAMI\tsilhouette')
    for name, clusters in all_results:
        print('%s\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f'
              % (name, metrics.homogeneity_score(int_labels, clusters),
                 metrics.completeness_score(int_labels, clusters),
                 metrics.v_measure_score(int_labels, clusters),
                 metrics.adjusted_rand_score(int_labels, clusters),
                 metrics.adjusted_mutual_info_score(int_labels,  clusters),
                 metrics.silhouette_score(latent_space, clusters,
                                          metric='euclidean')))
else:
    for name, results in clustering_results.items():
        print("{} - Avg Silhouette Score: {}".format(
            name, metrics.silhouette_score(
                latent_space, results["clusters"], metric='euclidean')))

In [None]:
# Show plots of clusters
for idx, (name, results) in enumerate(
        clustering_results.items()):
    plt.figure(idx + 1)
    plt.scatter(tsne_output[:, 0], tsne_output[:, 1], 
                c=results["clusters"], cmap=plt.cm.get_cmap(
                    'Paired', len(np.unique(results["clusters"]))))
    plt.colorbar(ticks=np.unique(results["clusters"]))
    plt.clim(np.min(results["clusters"]) - 0.5, np.max(results["clusters"]) + 0.5)
    plt.title(name.upper() + " Clustering")

plt.show()

In [None]:
# Save clusters
# clustering_cols = [(name, results["clusters"]) for name, results in clustering_results.items()]
# clustering_cols.insert(0, ("cell_id", df.index.values))
# output_df = pd.DataFrame(OrderedDict(clustering_cols))
# output_df.to_csv(cluster_output_file, index=False, sep="\t")