# Figure 2

DATASETS

for our methods overcluster with 15 clusters:

- Gaussians with 6 clusters in 8D, 16D, 32D, 64D
- Gaussians with 6 clusters with vaying frequency in 8D, 16D, 32D, 64D
- Worms with 6 clusters in 8D, 16D, 32D, 64D
- funky shapes with 6 clusters in 8D, 16D, 32D, 64D
- transcriptomic dataset (one mentioned in PAGA paper)

after decide on ~4-5 datasets for main figure, rest in Appendix

METHODS

- K-Means
- GMM
- TMM
- Leiden
- Agglomerative + dendrogram
- HDBSCAN
- Leiden + PAGA
- GMM + NEB
- GMM + dip-statistic
- TMM + NEB
- TMM + dip-statistic

In [None]:
from corc.graph_metrics import paga, gwg, gwgmara
from corc import generation, complex_datasets

import densired

from openTSNE import TSNE
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import numpy as np
from scipy.sparse import csr_matrix
import scanpy as sc
import anndata as ad

class Leiden():
    def __init__(self, resolution=1.0, seed=42):
        self.resolution = resolution
        self.seed = seed

    def fit(self, data):
        self.data = data

        counts = csr_matrix(self.data, dtype=np.float32)
        adata = ad.AnnData(counts)

        self.adata = adata

        sc.pp.neighbors(self.adata)
        sc.tl.leiden(self.adata, flavor="igraph", n_iterations=2, resolution=self.resolution, random_state=self.seed)

        self.labels_ = self.adata.obs['leiden']

In [None]:
import time
import warnings
from itertools import cycle, islice

import matplotlib.pyplot as plt
import numpy as np
import studenttmixture

from sklearn import cluster, datasets, mixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler

# ============
# Generate datasets. We choose the size big enough to see the scalability
# of the algorithms, but not too big to avoid too long running times
# ============
n_samples = 1000
seed = 30

dims = [8,16,32,64]
std = 0.075
# Gaussians with 6 clusters in 8D, 16D, 32D, 64D
blobs1_0 = complex_datasets.make_gaussians(dim=dims[0], std=std*np.sqrt(dims[0]), n_samples=n_samples)
blobs1_1 = complex_datasets.make_gaussians(dim=dims[1], std=std*np.sqrt(dims[1]), n_samples=n_samples)
blobs1_2 = complex_datasets.make_gaussians(dim=dims[2], std=std*np.sqrt(dims[2]), n_samples=n_samples)
blobs1_3 = complex_datasets.make_gaussians(dim=dims[3], std=std*np.sqrt(dims[3]), n_samples=n_samples)

# Gaussians with 6 clusters with varying frequency in 8D, 16D, 32D, 64D
blobs2_0 = complex_datasets.make_gaussians(dim=dims[0], std=std*np.sqrt(dims[0]), n_samples=n_samples, equal_sized_clusters=False)
blobs2_1 = complex_datasets.make_gaussians(dim=dims[1], std=std*np.sqrt(dims[1]), n_samples=n_samples, equal_sized_clusters=False)
blobs2_2 = complex_datasets.make_gaussians(dim=dims[2], std=std*np.sqrt(dims[2]), n_samples=n_samples, equal_sized_clusters=False)
blobs2_3 = complex_datasets.make_gaussians(dim=dims[3], std=std*np.sqrt(dims[3]), n_samples=n_samples, equal_sized_clusters=False)

# Worms with 6 clusters in 8D, 16D, 32D, 64D

# funky shapes with 6 clusters in 8D, 16D, 32D, 64D
densired0 = complex_datasets.make_densired(dim=dims[0], n_samples=n_samples, std=std*np.sqrt(dims[0]))
densired1 = complex_datasets.make_densired(dim=dims[1], n_samples=n_samples, std=std*np.sqrt(dims[1]))
densired2 = complex_datasets.make_densired(dim=dims[2], n_samples=n_samples, std=std*np.sqrt(dims[2]))
densired3 = complex_datasets.make_densired(dim=dims[3], n_samples=n_samples, std=std*np.sqrt(dims[3]))

# MNIST-Nd
mnist0 = complex_datasets.make_mnist_nd(dim=dims[0])
mnist1 = complex_datasets.make_mnist_nd(dim=dims[1])
mnist2 = complex_datasets.make_mnist_nd(dim=dims[2])
mnist3 = complex_datasets.make_mnist_nd(dim=dims[3])

# transcriptomic dataset (one mentioned in PAGA paper)
# paul15 = make_Paul15()


# ============
# Set up cluster parameters
# ============
plt.figure(figsize=(9 * 2 + 3, 13 * 2))
plt.subplots_adjust(
    left=0.02, right=0.98, bottom=0.001, top=0.95, wspace=0.05, hspace=0.01
)

plot_num = 1

default_base = {
    "dim": 2, 
    "quantile": 0.3,
    "eps": 0.3,
    "damping": 0.9,
    "preference": -200,
    "n_neighbors": 3,
    "n_clusters": 6,
    "n_components": 15,
    "min_samples": 7,
    "xi": 0.05,
    "min_cluster_size": 0.1,
    "allow_single_cluster": True,
    "hdbscan_min_cluster_size": 15,
    "hdbscan_min_samples": 3,
    "random_state": 42,
    "resolution":1.0,
    "resolution_leiden":1.0,
}

datasets = [
    (blobs1_0, {
        "dim":dims[0], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs1_1, {
        "dim":dims[1], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs1_2, {
        "dim":dims[2], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs1_3, {
        "dim":dims[3], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs2_0, {
        "dim":dims[0], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs2_1, {
        "dim":dims[1], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs2_2, {
        "dim":dims[2], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (blobs2_3, {
        "dim":dims[3], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (densired0, {
        "dim":dims[0], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (densired1, {
        "dim":dims[1], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (densired2, {
        "dim":dims[2], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (densired3, {
        "dim":dims[3], 
        "n_clusters": 6,
        "resolution":1.0,
        "resolution_leiden":1.0,
    }),
    (mnist0, {
        "dim":dims[0], 
        "n_clusters": 15
    }),
    (mnist1, {
        "dim":dims[1], 
        "n_clusters": 15
    }),
    (mnist2, {
        "dim":dims[2], 
        "n_clusters": 15
    }),
    (mnist3, {
        "dim":dims[3], 
        "n_clusters": 15
    }),

    # (paul15, {"dim":1000, "n_clusters": 12, "n_components":20}),


]

for i_dataset, (dataset, algo_params) in enumerate(datasets):
    # update parameters with dataset-specific values
    params = default_base.copy()
    params.update(algo_params)

    X, y = dataset
    y = np.array(y, dtype='int')

    # normalize dataset for easier parameter selection
    X = StandardScaler().fit_transform(X)

    # dimensionality reduction for plotting results in 2D
    perplexity = 100 if dataset in ["Paul15"] else 30
    tsne = TSNE(
            perplexity=perplexity,
            metric='euclidean',
            n_jobs=8,
            random_state=42,
            verbose=False,
        )
    X2D = tsne.fit(X)

    # eps = 2.0
    # plotrange_x = np.min(X2D, axis=0)[0] - eps, np.max(X2D, axis=0)[0] + eps
    # plotrange_y = np.min(X2D, axis=0)[1] - eps, np.max(X2D, axis=0)[1] + eps

    # estimate bandwidth for mean shift
    bandwidth = cluster.estimate_bandwidth(X, quantile=params["quantile"])

    # connectivity matrix for structured Ward
    connectivity = kneighbors_graph(
        X, n_neighbors=params["n_neighbors"], include_self=False
    )
    # make connectivity symmetric
    connectivity = 0.5 * (connectivity + connectivity.T)

    # ============
    # Create cluster objects
    # ============
    ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
    two_means = cluster.MiniBatchKMeans(
        n_clusters=params["n_clusters"],
        random_state=params["random_state"],
    )
    ward = cluster.AgglomerativeClustering(
        n_clusters=params["n_clusters"], linkage="ward", connectivity=connectivity
    )
    spectral = cluster.SpectralClustering(
        n_clusters=params["n_clusters"],
        eigen_solver="arpack",
        affinity="nearest_neighbors",
        random_state=params["random_state"],
    )
    dbscan = cluster.DBSCAN(eps=params["eps"])
    hdbscan = cluster.HDBSCAN(
        min_samples=params["hdbscan_min_samples"],
        min_cluster_size=params["hdbscan_min_cluster_size"],
        allow_single_cluster=params["allow_single_cluster"],
    )
    optics = cluster.OPTICS(
        min_samples=params["min_samples"],
        xi=params["xi"],
        min_cluster_size=params["min_cluster_size"],
    )
    affinity_propagation = cluster.AffinityPropagation(
        damping=params["damping"],
        preference=params["preference"],
        random_state=params["random_state"],
    )
    average_linkage = cluster.AgglomerativeClustering(
        linkage="average",
        metric="cityblock",
        n_clusters=params["n_clusters"],
        connectivity=connectivity,
    )
    birch = cluster.Birch(n_clusters=params["n_clusters"])
    gmm = mixture.GaussianMixture(
        n_components=params["n_clusters"],
        covariance_type="full",
        random_state=params["random_state"],
    )
    tmm = studenttmixture.EMStudentMixture(
        n_components=params["n_clusters"],
        n_init=5,
        fixed_df=False,#True,
        # df=1.0,
        init_type="k++",
        random_state=params["random_state"]
        )
    leiden = Leiden(
        resolution=params["resolution_leiden"], 
        seed=params["random_state"]
        )
    mpaga = paga.PAGA(
        latent_dim=params["dim"], 
        resolution=params["resolution"], 
        seed=params["random_state"]
        )
    mgwgmara = gwgmara.GWGMara(
        latent_dim=params["dim"], 
        n_components=params["n_components"], 
        n_neighbors = params["n_neighbors"],
        seed=params["random_state"]
        )
    mgwg = gwg.GWG(
        latent_dim=params["dim"], 
        n_components=params["n_components"],#params["n_clusters"],
        n_neighbors = params["n_neighbors"],
        seed=params["random_state"]
        )

    clustering_algorithms = (
        ("MiniBatch\nKMeans", two_means),
        ("Agglomerative\nClustering", average_linkage),
        ("HDBSCAN", hdbscan),
        ("Gaussian\nMixture", gmm),
        ("t-Student\nMixture", tmm),
        ("Leiden", leiden),
        ("PAGA", mpaga),
        ("GWG-dip", mgwgmara),
        ("GWG-pvalue", mgwg),
        # ("Affinity\nPropagation", affinity_propagation),
        # ("MeanShift", ms),
        # ("Spectral\nClustering", spectral),
        # ("Ward", ward),
        # ("DBSCAN", dbscan),
        # ("OPTICS", optics),
        # ("BIRCH", birch),
    )

    for name, algorithm in clustering_algorithms:
        t0 = time.time()

        # catch warnings related to kneighbors_graph
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore",
                message="the number of connected components of the "
                + "connectivity matrix is [0-9]{1,2}"
                + " > 1. Completing it to avoid stopping the tree early.",
                category=UserWarning,
            )
            warnings.filterwarnings(
                "ignore",
                message="Graph is not fully connected, spectral embedding"
                + " may not work as expected.",
                category=UserWarning,
            )
            algorithm.fit(X)

        t1 = time.time()
        if hasattr(algorithm, "labels_"):
            y_pred = algorithm.labels_.astype(int)
        else:
            try:
                y_pred = algorithm.predict(X)
            except ValueError as e:
                y_pred = np.array([0]*len(y))

        plt.subplot(len(datasets), len(clustering_algorithms)+1, plot_num)
        if i_dataset == 0:
            plt.title(name, size=18)

        colors = np.array(
            list(
                islice(
                    cycle(
                        [
                            "#377eb8",
                            "#ff7f00",
                            "#4daf4a",
                            "#f781bf",
                            "#a65628",
                            "#984ea3",
                            "#999999",
                            "#e41a1c",
                            "#dede00",
                            "#add8e6",
                            "#006400",
                        ]
                    ),
                    int(max(max(y_pred), max(y)) + 1),
                )
            )
        )
        # add black color for outliers (if any)
        colors = np.append(colors, ["#000000"])
        plt.scatter(X2D[:, 0], X2D[:, 1], s=10, color=colors[y_pred])

        if name in ["GWG-dip", "GWG-pvalue", "PAGA"]:
            algorithm.plot_graph(X2D)

        # plt.xlim(plotrange_x[0], plotrange_x[1])
        # plt.ylim(plotrange_y[0], plotrange_y[1])
        plt.xticks(())
        plt.yticks(())
        # plt.text(
        #     0.99,
        #     0.01,
        #     ("%.2fs" % (t1 - t0)).lstrip("0"),
        #     transform=plt.gca().transAxes,
        #     size=15,
        #     horizontalalignment="right",
        # )
        plot_num += 1

    # plot ground truth
    plt.subplot(len(datasets), len(clustering_algorithms)+1, plot_num)
    if i_dataset == 0:
        plt.title('Ground truth', size=18)
    plt.scatter(X2D[:, 0], X2D[:, 1], s=10, color=colors[y])
    plt.xticks(())
    plt.yticks(())
    plot_num += 1

plt.show()