In [20]:
import numpy as np
import pickle
import matplotlib.pylab as plt
import seaborn as sns
import colorcet as cc
from pathlib import Path
from openTSNE import TSNE
from tqdm import tqdm
import sys

sys.path.append('../')

from visualization_utils import remap_cluster_ids, color_dict

In [21]:
data_path = '../data/synthetic_data/'
prediction_path = '../data/ari_synthetic/'
gt_path = '../data/ari_neuronal/'
figure_path = 'figures/synthetic/'

In [22]:
n = 32_571
variances = np.array([0.05, 0.1, 0.3, 0.5, 1.0])
n_clusters = [10, 20, 40]

#### Plot t-SNE embeddings of synthetic datasets

In [None]:
for n_c in tqdm(n_clusters):
    palette = sns.color_palette(cc.glasbey, n_colors=n_c)
    color_palette = color_dict if n_c == 20 else palette

    for var in tqdm(variances):
        # Load synthetic data
        sample = np.load(Path(data_path, f'gm_c{n_c}_var{var}_samples.npy'))
        predictions = np.load(Path(prediction_path, f'best_preds_nc{n_c}_var{var}.npy'))
        means = np.load(Path(prediction_path, f'best_means_nc{n_c}_var{var}.npy'))

        # Remap cluster IDs such that colors in plots align.
        gt_means = np.load(Path(gt_path, f'best_means_nc{n_c}.npy'))
        lsa_dict, _ = remap_cluster_ids(gt_means, means)
        predictions_remapped = np.array([lsa_dict[p] for p in predictions])

        # Run t-SNE.
        tsne = TSNE(
            perplexity=300,
            metric='cosine',
            n_jobs=8,
            random_state=42,
            verbose=False,
        )
        z = tsne.fit(sample)

        # Plot t-SNE embeddings colored by GMM prediction.
        fig, ax = plt.subplots(1, 1)
        for i in range(n_c):
            ax.scatter(
                *z[predictions_remapped == i].T,
                s=3,
                color=color_palette[i],
                alpha=0.4,
                rasterized=True,
            )
        ax.axis('off')
        fig.savefig(
            Path(figure_path, f'gm_c{n_c}_var{var}_preds.png'),
            dpi=300,
            transparent=True,
            bbox_inches='tight',
        )
        plt.close(fig)