In [1]:
%matplotlib inline

# path hack for relative import in jupyter notebook
import os
import sys
import time

# LIBRARY GLOBAL MODS
CELLTYPES = os.path.dirname(os.path.abspath(''))
sys.path.append(CELLTYPES)

In [2]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import pandas as pd
import umap
import time
%matplotlib inline

sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})

In [3]:
from matplotlib.colors import ListedColormap, BoundaryNorm, NoNorm, Normalize
from matplotlib.colorbar import ColorbarBase
from matplotlib.cm import get_cmap, ScalarMappable

In [4]:
# need to set cmaps to be used before importing multicell_replot 
#   as it uses proplot which modifies matplotlib

CMAP_SPECTRALR = get_cmap('Spectral_r')
GLOBAL_DPI = 450

In [5]:
from utils.file_io import RUNS_FOLDER

Appended to sys path I:\Development\Repositories\biomodels\celltypes


In [6]:
NOTEBOOK_OUTDIR = RUNS_FOLDER + os.sep + 'explore' + os.sep + 'nb_alignUMAP'
os.makedirs(NOTEBOOK_OUTDIR, exist_ok=True)

# TODO: Import reloading while retaining variables in memory

In [7]:
#import importlib
#importlib.reload(multicell.multicell_replot)

#import inspect
#print(inspect.getsource(multicell.multicell_replot.replot_scatter_dots))

# Functions: Non-Plotting

In [8]:
def build_X_multi(manyruns_paths, subsample_ndim=None, subsample_ensemble=None, smod_last=True, use_01=True):
    # any asserts and simple variable settings
    if smod_last:
        smod = '_last' # oldstyle
    else:
        smod = ''      # newstyle
    
    # get data dimensionality and set system size variables
    ref_fpath_state = manyruns_paths[0] + os.sep + 'aggregate' + os.sep + 'X_aggregate%s.npz' % smod
    ref_X = np.load(ref_fpath_state)['arr_0'].T  # umap wants transpose
    npts, ndim = ref_X.shape
    print("npts, ndim", npts, ndim)
    
    if subsample_ensemble is not None:
        assert 0 < subsample_ensemble <= npts
        nn = subsample_ensemble
    else: 
        nn = npts
        
    if subsample_ndim is not None:
        assert 0 < subsample_ndim <= ndim
        kk = subsample_ndim
    else:
        kk = ndim
    
    # Step 1: generate X_multi
    X_multi = np.zeros((len(manyruns_paths), nn, kk), dtype=int)
    for j, manyruns_path in enumerate(manyruns_paths):

        #gamma_val = gamma_list[j]

        agg_dir = manyruns_path + os.sep + 'aggregate'
        fpath_state = agg_dir + os.sep + 'X_aggregate%s.npz' % smod
        fpath_energy = agg_dir + os.sep + 'X_energy%s.npz' % smod
        fpath_pickle = manyruns_path + os.sep + 'multicell_template.pkl'

        X = np.load(fpath_state)['arr_0'].T  # umap wants transpose
        X_energies = np.load(fpath_energy)['arr_0'].T  # umap wants transpose (?)
        with open(fpath_pickle, 'rb') as pickle_file:
            multicell_template = pickle.load(pickle_file)  # unpickling multicell object

        #print('accessing', j, manyruns_path)
        X_multi[j, :, :] = X[0:nn, 0:kk]
    
    if use_01:
        X_multi = (1 + X_multi) / 2.0
        X_multi = X_multi.astype(int)
        
    num_cells = multicell_template.num_cells
    num_genes = multicell_template.num_genes
    return X_multi, num_cells, num_genes


def gather_data_then_alignedUMAP(
    manyruns_paths, manyruns_dirnames, umap_kwargs, 
    skip_alignment=False, subsample_ndim=None, subsample_ensemble=None, 
    smod_last=True, use_01=True):
    """
    All of these arguments are prepared in a dictionary 
        settings_aligned_umap_fig5_v1
        settings_aligned_umap_fig5_v2
        etc.
    and chosen later in the notebook settings_alignment_chosen
    """
    # TODO once this fn done, need to propogate data_subdict object through nb
    #  - it should replace aligned_mapper
    #  - should also integrate it with alignment settings dict somehow
    # TODO should have class or validation function for data_subdict init
    # TODO integrate side data like num unqieu states (analagous to energy)
    
    if smod_last:
        smod = '_last'
    else:
        smod = ''

    def build_data_subdict_from_alignedUMAP(data_subdict, X, embedding):
        """
        data_subdict: partially complete data_subdict (path and label keys)
        X:            raw data 2d arr (num_runs x num_dimensions)
        embedding:    embedded data (num_runs x num_lowdim)
        """
        # Reload data_subdict source files
        #  note: we transpose raw data, energies for umap/other functions
        manyruns_path = data_subdict['path']
        agg_dir = manyruns_path + os.sep + 'aggregate'
        fpath_state = agg_dir + os.sep + 'X_aggregate%s.npz' % smod
        fpath_energy = agg_dir + os.sep + 'X_energy%s.npz' % smod
        fpath_pickle = manyruns_path + os.sep + 'multicell_template.pkl'
        X = np.load(fpath_state)['arr_0'].T          
        X_energies = np.load(fpath_energy)['arr_0'].T
        with open(fpath_pickle, 'rb') as pickle_file:
            # unpickling multicell object
            multicell_template = pickle.load(pickle_file)  
        
        # Fill out the rest of data_subdict keys        
        num_runs, total_spins = X.shape
        data_subdict['data'] = X
        data_subdict['index'] = list(range(num_runs))
        data_subdict['energies'] = X_energies
        data_subdict['num_runs'] = num_runs
        data_subdict['total_spins'] = total_spins
        data_subdict['multicell_template'] = multicell_template
        data_subdict['algos'] = {
            'umap': 
                {'reducer': None,
                 'embedding': embedding}
        }

        # Add extra analysis data used within the notebook
        X_nunique = get_unique_states_over_ensemble(
            X, 
            multicell_template.num_cells, 
            multicell_template.num_genes
        )
        data_subdict['nunique'] = X_nunique.T  # note transpose
        assert (
            data_subdict['energies'].shape[0] == 
            data_subdict['nunique'].shape[0])
        
        return data_subdict
    
    # Step 0: any asserts and simple variable settings
    assert 'random_state' in umap_kwargs.keys()
    
    # Step 1: construct X_multi (tensor of tissue state ensembles)
    X_multi, _, _ = build_X_multi(
        manyruns_paths, 
        subsample_ndim=subsample_ndim, 
        subsample_ensemble=subsample_ensemble, 
        smod_last=smod_last, 
        use_01=use_01
    )
    nn = X_multi.shape[1]
    
    # Step 2: perform alignedUMAP
    aligned_mapper = umap.AlignedUMAP(**umap_kwargs)
    
    if skip_alignment:
        print('Warning - skipping alignment step...')
        mappers = [
            umap.UMAP(**umap_kwargs).fit(X_multi[n, :, :])
            for n in range(len(manyruns_paths))
        ]
        embeddings = [mappers[n].embedding_ for n in range(len(manyruns_paths))]
        aligned_mapper.embeddings_ = embeddings
    else:
        # UMAP aligned needs a 'relation dict' for the 'time varying' dataset
        #   our relation is that each traj maps to itself (in time) -- constant relation
        constant_dict = {i: i for i in range(nn)}
        constant_relations = [constant_dict for i in range(len(manyruns_paths)-1)]
        print('Beginning AlignedUMAP...')
        aligned_mapper.fit(X_multi, relations=constant_relations)
        print('...done')
        
    # Step 3: Convert to data_subdict format
    # TODO - implement saving
    embedded_datasets = {i: {'label': manyruns_dirnames[i],
                             'path': manyruns_paths[i]}
                         for i in range(len(manyruns_dirnames))}
    
    for idx in range(len(manyruns_dirnames)):
        embedded_datasets[idx] = build_data_subdict_from_alignedUMAP(
            embedded_datasets[idx], 
            X_multi[idx,:,:], 
            aligned_mapper.embeddings_[idx]
        )
        #fpath = manyruns_paths[idx] + os.sep + 'dimreduce' \
        #        + os.sep + 'dimreduce%s.z' % fmod
        #save_dimreduce_object(datasets[idx], fpath)  # save to file (joblib)
    
    return embedded_datasets


def tissue_to_num_unique_cells(X, num_cells, num_genes):
    if X.shape != (num_genes, num_cells):
        # perform the reshape
        assert len(X) == num_genes * num_cells
        X = X.reshape((num_genes, num_cells), order='F')
        
    tissue_unique_cols = np.unique(X, axis=1)
    num_unique = tissue_unique_cols.shape[1]
    return num_unique, tissue_unique_cols


def get_unique_states_over_ensemble(ensemble_array, num_cells, num_genes):
    # TODO track with X_energies somehow...
    num_ensemble = ensemble_array.shape[0]
    carray = np.zeros(num_ensemble, dtype=int)
    for idx in range(num_ensemble):
        tissue_arr = ensemble_array[idx, :]
        num_unique, unique_states = tissue_to_num_unique_cells(tissue_arr, num_cells, num_genes)
        carray[idx] = num_unique
    return carray

In [9]:
def add_PCA_to_embedded_datasets(embedded_datasets):
    for key in embedded_datasets.keys():
        data_subdict = embedded_datasets[key]
        assert 'pca' not in data_subdict['algos'].keys()
        data_subdict['algos']['pca'] = {}
        # get full PCA embedding
        pca_full = PCA()
        embedding = pca_full.fit_transform(data_subdict['data'])
        data_subdict['algos']['pca']['embedding'] = embedding[:,:2]
        # also store cum. sum. pca variance
        exp_var_cumul = np.cumsum(pca_full.explained_variance_ratio_)
        embedding = data_subdict['algos']['pca']['cumsum_x'] = \
            np.arange(1, exp_var_cumul.shape[0] + 1)
        embedding = data_subdict['algos']['pca']['cumsum_y'] = \
            exp_var_cumul
    return embedded_datasets


# Functions: Plotting

In [10]:
def axis_bounds(embedding):
    left, right = embedding.T[0].min(), embedding.T[0].max()
    bottom, top = embedding.T[1].min(), embedding.T[1].max()
    adj_h, adj_v = (right - left) * 0.1, (top - bottom) * 0.1
    return [left - adj_h, right + adj_h, bottom - adj_v, top + adj_v]

In [11]:
def plot_fpdata_as_grid(
    embedded_datasets, settings_alignment, fpdata_type, fpath, 
    gamma_selection=None, nrow=2, ncol=4, ext='.jpg', logy=False): 
    """
    Fixed point (FP) data forms: 
    - 'energies': energy array (1 dimensional -- 1 dim per fp)
    - 'nunique': number of unique single cell states (1 dimensional -- 1 dim per fp)
    Args:
    - embedded_datasets: dict of data_subdicts objects
    - settings_alignment: dict of manyruns info from jupyter notebook globals
    - fpdata_type: 'energy' or 'nunique'
    - gamma_selection: list of integers to choose a subset of data from settings_alignment['manyruns_paths']
    """
    
    gamma_values = settings_alignment['gamma_list']
    manyruns_paths = settings_alignment['manyruns_paths']
    #nn = settings_alignment['alignment_wrapper_kwargs']['subsample_ensemble']
    if gamma_selection is None:
        gamma_selection = settings_alignment['subplot_selection']
    
    assert fpdata_type in ['energies', 'nunique']
    # TODO smarter bin settings: set after load data, observe statistics?
    # TODO bins_settings = np.arange(0.5,40,1)  
    data_arrays = [0] * len(gamma_selection)
    if fpdata_type == 'energies':
        bins_settings = 40
        color = 'coral'
        xlabel = 'Energy'
        for idx, idx_selected in enumerate(gamma_selection):
            data_arrays[idx] = embedded_datasets[idx_selected][fpdata_type][:, 0]
    else:
        bins_settings = np.arange(0.5,40,1)  # TODO identify after loading data and observing statistics? 
        color = 'deepskyblue'
        xlabel = 'Number of unique cell states'
        for idx, idx_selected in enumerate(gamma_selection):
            data_arrays[idx] = embedded_datasets[idx_selected][fpdata_type]
    
    # plot settings
    if logy:
        fpath += '_logy'
        
    # prints
    print('Picks for plot_fpdata_as_grid():')
    print([(i, gamma_values[i]) for i in gamma_selection])

    # Step 2: plot as grid of histograms
    fig, axs = plt.subplots(nrow, ncol, figsize=(18*0.68, 12*0.68))
    for idx, ax in enumerate(axs.flatten()):
        gamma_index = gamma_selection[idx]
        gamma_value = gamma_values[gamma_index]
        ptitle = r'%d: idx=%d, $\gamma=%.3f$' % (
            idx, gamma_index, gamma_value)
        ax.set_title(ptitle, fontsize=12)
        ax.hist(data_arrays[idx], bins=bins_settings, color=color)
        if logy:
            ax.set_yscale('log')
    
    # Step 3: global axis labels
    offset = 0.005
    fig.text(0.5, offset, xlabel, ha='center')
    fig.text(offset, 0.5, 'Number of fixed points', 
             va='center', rotation='vertical')

    #ax.axis(ax_bound)
    #ax.set(xticks=[], yticks=[])
            
    plt.tight_layout(pad=1.2, w_pad=0.5, h_pad=1.0)
    if isinstance(ext, list):
        for ext_str in ext:
            assert ext_str[0] == '.'
            plt.savefig(fpath + ext_str, dpi=300)
    else: 
        plt.savefig(fpath + ext, dpi=300)

In [12]:
def annotate_a_point(ax, embedding, agg_idx, ss=10, zorder=4):
    x, y = embedding[agg_idx, :]
    ax.scatter(x, y, s=ss, facecolors='none', edgecolors='k',
              path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
              zorder=zorder)
    ax.annotate('%d' % agg_idx, (x, y), 
                xytext=(3, 3),
                textcoords='offset points',
                path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
                zorder=zorder)

In [13]:
import matplotlib.patheffects as PathEffects

def plot_aligned_panels_as_grid(
    embedded_datasets, settings_alignment, fpath, 
    algo='umap',
    gamma_selection=None, ext='.jpg', 
    ss=6, annotate_indices=[], rasterized=True,
    c='energies', skip_cbar=False,
    alpha=1.0, edges=False, black_edges=True):
    """
    Args: 
    - embedded_datasets: dict of data_subdicts (stores metadata, embedding)
    - settings_alignment: dict of settings for the notebook,
    - fpath: output path for plots (minus extension)
    - algo: (str) key for data_subdict['algos'][algo] (e.g. 'umap', 'pca')
    - gamma_selection: can override same key in settings_alignment
    - ss: change marker size for scatter
    - annotate_indices: list of indices to annotate on the embedding
    - rasterized: If True, rasterize the scatter points at given dpi
    - edges: if True, plot edges
        - markers and edges separate calls to scatter
        - alpha affects marker alpha only
        - edge color default: face color (try switch to black)
    Settings to consider changing:
    - fontsize for title
    Good defaults:
    - ss=6, edges=True (with linewidth 0.1)
    Notes:
    - nunique cmap is based on 
        sns.color_palette("pastel") or
        sns.color_palette()
        see https://seaborn.pydata.org/generated/seaborn.color_palette.html
    
    """
    # this line maintains cmap colors (matplotlib and proplot fighting)
    sns.set(style='white', context='notebook')
    dpi = GLOBAL_DPI
    
    # parse info from dict: settings_alignment
    gamma_list = settings_alignment['gamma_list']
    manyruns_paths = settings_alignment['manyruns_paths']
    nn = settings_alignment['alignment_wrapper_kwargs']['subsample_ensemble']
    assert nn is None
    if gamma_selection is None:
        gamma_selection = settings_alignment['subplot_selection']
    
    # parse info from dict: embedded_datasets (contains data_subdicts)
    # TODO handler for embedding dimension > 2?
    assert algo in embedded_datasets[0]['algos'].keys()
    embeddings = [
        embedded_datasets[idx]['algos'][algo]['embedding'][:, 0:2]
        for idx in range(len(settings_alignment_chosen['gamma_list']))
    ]

    # Part A: Figure initialization
    # see https://matplotlib.org/stable/tutorials/intermediate/gridspec.html
    num_rows = 2
    num_cols = 4
    if skip_cbar:
        #fig = plt.figure(figsize=(5,5), dpi=dpi)
        fig, axs = plt.subplots(
            num_rows, num_cols, figsize=(12.24, 8.16), dpi=dpi)
    else:
        cbar_horizontal = False
        if cbar_horizontal:
            fig, axs = plt.subplots(
                (num_rows+1), num_cols, figsize=(12.24, 8.16), dpi=dpi,
                gridspec_kw={'height_ratios': [1,1,0.1]}
            )
            orientation = 'horizontal'
            gs = axs[-1,-1].get_gridspec()
            cbar_ax = fig.add_subplot(gs[-1, :])
            for ax in axs[-1, :]:
                ax.remove()
        else:
            fig, axs = plt.subplots(
                num_rows, (num_cols+1), figsize=(12.24, 8.16), dpi=dpi, 
                gridspec_kw={'width_ratios': [1,1,1,1,0.1]}
            )
            orientation = 'vertical'
            gs = axs[0, -1].get_gridspec()
            cbar_ax = fig.add_subplot(gs[:, -1])
            for ax in axs[:, -1]:
                ax.remove()
        cbar_ax.set(xticks=[], yticks=[])
        
    # Part B: Plot coloring rules (static, nunique, or energy)
    if c is None:
        fixed_color = 'gainsboro'  # '#5076b7', 'slategray', 'gainsboro'
        cmap = None
    else:
        assert c in ['energies', 'nunique']
        if c == 'energies':
            #cmap_str = 'Spectral_r'
            #cmap = get_cmap(cmap_str)
            cmap = CMAP_SPECTRALR
            
            if not skip_cbar:
                artificial_norm = Normalize(vmin=-1, vmax=1)
                mappable = ScalarMappable(artificial_norm, cmap=cmap)
                cbar = plt.colorbar(
                    mappable,
                    cax=cbar_ax,
                    orientation=orientation,
                )
                cbar.set_label('Energy')
                cbar.ax.tick_params(size=0)
                if cbar_horizontal:
                    cbar.ax.set_xticklabels([])
                else:
                    cbar.ax.set_yticklabels([])
                #cbar_ax.set_xticks([])
                #cbar_ax.set(xticks=[], yticks=[])
            
        else:
            # create bounds for custom discrete cmap
            #nunique = data_subdict['nunique']
            bounds_upper = 22 #  np.max(nunique), or static if proportional cbar
            bounds = [1, 2, 3, 4, 8, 15, bounds_upper]
            bounds_midpts = [0.5 * (bounds[i] + bounds[i+1]) for i in range(len(bounds) - 1)]
            ncolors = len(bounds) - 1
            # refining bounds labels
            bounds_labels = [''] * ncolors
            for i in range(len(bounds) - 2):
                lower = bounds[i]
                upper = bounds[i+1] - 1
                if lower == upper:
                    bounds_labels[i] = '%d' % lower
                else:
                    bounds_labels[i] = '%d to %d' % (lower, upper)
            #bounds_labels = ['%d to %d' % (bounds[i], bounds[i+1] - 1) for i in range(len(bounds) - 1)]
            bounds_labels[-1] = r'$\geq%d$' % bounds[-2]

            spectral_auto = False
            if spectral_auto:
                #cmap_primitive = get_cmap('Spectral_r')
                cmap_primitive = CMAP_SPECTRALR
                norm_manual = (np.array(bounds_midpts) - bounds_midpts[0]) / (bounds_midpts[-1] - bounds_midpts[0])
                print(norm_manual)
                color_palette = [cmap_primitive(a) for a in np.linspace(0, 1, ncolors)]
            else:
                sns_pastel_indices = [0,9,2,3,4,8]
                
                #color_palette = [sns.color_palette("pastel")[i] 
                #                 for i in sns_pastel_indices] 
                color_palette = [sns.color_palette()[i] 
                                 for i in sns_pastel_indices] 
                assert len(color_palette) >= ncolors
                color_palette = color_palette[0:ncolors]
            cmap = ListedColormap(color_palette)
            
            # create scatter array of colors
            norm_nunique = BoundaryNorm(bounds, cmap.N, clip=True)
            #color = cmap(norm(nunique))
            
            if not skip_cbar:
                # this is alternative to gridspec:
                #fig.subplots_adjust(right=0.8)
                #cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
                cbar = ColorbarBase(
                    cbar_ax,
                    cmap=cmap,
                    norm=norm_nunique,
                    #boundaries=[0] + bounds + [13],
                    ticks=bounds_midpts,
                    spacing='proportional',  # 'proportional', 'uniform'
                    orientation=orientation)
                if cbar_horizontal:
                    cbar.ax.set_xticklabels(bounds_labels)
                else:
                    cbar.ax.set_yticklabels(bounds_labels)
                cbar.set_label('Number of unique single cell states')
    
    # Part C: Edges (on scatterplot points) or not? 
    sc_kwargs = { 
        'cmap': cmap,
        's': ss,
        'rasterized': rasterized,
        'alpha': alpha}
    sc_kwargs['edgecolors'] = 'None'
    if edges:
        #sc_kwargs['edgecolors'] = 'None'
        sc_kwargs_edge = sc_kwargs.copy()
        sc_kwargs_edge.pop('edgecolors')
        sc_kwargs_edge['alpha'] = 1.0
        sc_kwargs_edge['facecolor'] = 'None'
        sc_kwargs_edge['linewidths'] = 0.4  # 0.1  # ss / 20.0
        
    # Part D: Fill in each grid box (default - 2x4 panels)
    ax_bound = axis_bounds(np.vstack(embeddings))

    #for j, ax in enumerate(axs.flatten()):
        #if j < len(gamma_selection):
    
    for j in range(num_rows * num_cols):
        row_idx = j // num_cols
        col_idx = j % num_cols
        ax = axs[row_idx, col_idx] 
        i = gamma_selection[j]

        data_subdict = embedded_datasets[i]
        if c == 'energies':
            color_arr = data_subdict['energies'][:, 0]
            cnorm = Normalize(
                vmin=np.min(color_arr), vmax=np.max(color_arr))
            colors = cmap(cnorm(color_arr))
        elif c == 'nunique':
            color_arr = data_subdict['nunique']
            #cnorm = Normalize(
            #    vmin=np.min(color_arr), vmax=np.max(color_arr))
            colors = cmap(norm_nunique(color_arr))
        else:
            colors = fixed_color

        ax.set_title(gamma_list[i], fontsize=20, y=-0.12)

        sc = ax.scatter(*embeddings[i].T, 
                        **sc_kwargs, 
                        c=colors,
                        zorder=3)
        # Use two scatter calls to get edges with different alpha
        if edges:
            if black_edges:
                edgecolors = 'black'
            else:
                edgecolors = colors
            sc2 = ax.scatter(*embeddings[i].T, 
                             **sc_kwargs_edge, 
                             edgecolors=edgecolors,
                             zorder=2)

        # annotate certain elements with empty circles
        if isinstance(annotate_indices, dict):
            # dict of lists for each gamma index
            annotate_indices_specific = annotate_indices[i]
            for agg_idx in annotate_indices_specific:
                annotate_a_point(ax, embeddings[i], agg_idx, zorder=4)
        else:
            # assume its the same list for each gamma value
            for agg_idx in annotate_indices:
                annotate_a_point(ax, embeddings[i], agg_idx, zorder=4)

        #ax.axis(ax_bound)
        ax.set(xticks=[], yticks=[])

    # Part E: Plot Save
    plt.tight_layout()
    if isinstance(ext, list):
        for ext_str in ext:
            assert ext_str[0] == '.'
            plt.savefig(fpath + ext_str, dpi=dpi)
    else: 
        plt.savefig(fpath + ext, dpi=dpi)

In [14]:
def plot_single_umap_panel_gridspec(
    data_subdict, fpath, 
    c=None, ext='.jpg', ss=10, annotate_indices=[], 
    rasterized=True, skip_cbar=False, 
    alpha=1.0, edges=False):
    # this line maintains cmap colors (matplotlib and proplot fighting)
    sns.set(style='white', context='notebook')
    dpi = GLOBAL_DPI
    
    num_runs = data_subdict['num_runs']
    label = data_subdict['label']
    embedding = data_subdict['algos']['umap']['embedding']
    assert embedding.shape[1] == 2
    
    if skip_cbar:
        fig = plt.figure(figsize=(5,5), dpi=dpi)
    else:
        cbar_horizontal = True
        if cbar_horizontal:
            fig, ax = plt.subplots(
                2, 1, figsize=(5.5, 5), dpi=dpi,
                gridspec_kw={'height_ratios': [12, 1]}
            )
            orientation = 'horizontal'
        else:
            fig, ax = plt.subplots(
                1, 2, figsize=(5, 5.5), dpi=dpi, 
                gridspec_kw={'width_ratios': [12, 1]}
            )
            orientation = 'vertical'

    if c is None:
        color = 'gainsboro'  # '#5076b7', 'slategray', 'gainsboro'
        cmap = None
    else:
        assert c in ['energies', 'nunique']
        if c == 'energies':
            color = data_subdict['energies'][:, 0]  # range(num_runs)
            cmap = 'Spectral_r'
        else:
            nunique = data_subdict['nunique']
            
            bounds_upper = 22 #  np.max(nunique), or static if proportional cbar
            bounds = [1, 2, 3, 4, 8, 15, bounds_upper]
            bounds_midpts = [0.5 * (bounds[i] + bounds[i+1]) for i in range(len(bounds) - 1)]
            ncolors = len(bounds) - 1
            # refining bounds labels
            bounds_labels = [''] * ncolors
            for i in range(len(bounds) - 2):
                lower = bounds[i]
                upper = bounds[i+1] - 1
                if lower == upper:
                    bounds_labels[i] = '%d' % lower
                else:
                    bounds_labels[i] = '%d to %d' % (lower, upper)
            #bounds_labels = ['%d to %d' % (bounds[i], bounds[i+1] - 1) for i in range(len(bounds) - 1)]
            bounds_labels[-1] = r'$\geq%d$' % bounds[-2]

            spectral_auto = False
            if spectral_auto:
                #cmap_primitive = get_cmap('Spectral_r')
                cmap_primitive = CMAP_SPECTRALR
                norm_manual = (np.array(bounds_midpts) - bounds_midpts[0]) / (bounds_midpts[-1] - bounds_midpts[0])
                print(norm_manual)
                color_palette = [cmap_primitive(a) for a in np.linspace(0, 1, ncolors)]
            else:
                #colors_bank = [cmap_primitive(a) for a in 
                #               [0.0, 0.1, 0.15, 0.25, 0.35, 0.55, 0.65, 0.75, 0.85, 1.0]]
                
                # 12 max for sns "Paired", manually rearrange the color list first
                #sns_paired_indices = [0,1,2,3,9,8]
                #sns_paired_indices = [0,1,2,3,7,6]
                #color_palette = [sns.color_palette("Paired")[i] 
                #                 for i in sns_paired_indices] 
                
                #sns_set2_indices = [0,1,2,3,4,7]
                #color_palette = [sns.color_palette("Set2")[i] 
                #                 for i in sns_set2_indices] 
                
                #sns_pastel_indices = [0,1,2,3,4,6]
                #sns_pastel_indices = [0,1,2,3,4,9]
                #sns_pastel_indices = [0,7,2,3,4,9]
                #sns_pastel_indices = [0,9,2,3,4,7]
                sns_pastel_indices = [0,9,2,3,4,8]
                #color_palette = [sns.color_palette("pastel")[i] 
                #                 for i in sns_pastel_indices] 
                color_palette = [sns.color_palette()[i] 
                                 for i in sns_pastel_indices]
                
                assert len(color_palette) >= ncolors
                color_palette = color_palette[0:ncolors]
                
                
            cmap = ListedColormap(color_palette)
            
            # create scatter array of colors
            norm = BoundaryNorm(bounds, cmap.N, clip=True)
            color = cmap(norm(nunique))
            
            # use grid spec to make two axes for this block
            if not skip_cbar:
                cbar = ColorbarBase(
                    ax[1], 
                    cmap=cmap,
                    norm=norm,
                    #boundaries=[0] + bounds + [13],
                    ticks=bounds_midpts,
                    spacing='proportional',  # 'proportional', 'uniform'
                    orientation=orientation)
                if cbar_horizontal:
                    cbar.ax.set_xticklabels(bounds_labels)
                else:
                    cbar.ax.set_yticklabels(bounds_labels)
                cbar.set_label('Number of unique single cell states')
            
    sc_kwargs = {
        'c': color, 
        'cmap': cmap,
        's': ss,
        'rasterized': rasterized,
        'alpha': alpha        
    }
    sc_kwargs['edgecolors'] = 'None'
    if edges:
        sc_kwargs_edge = sc_kwargs.copy()
        sc_kwargs_edge['alpha'] = 1.0
        sc_kwargs_edge['facecolor'] = 'None'
        sc_kwargs_edge['linewidths'] = 0.25
        sc_kwargs_edge['edgecolors'] = 'black'
            
    if skip_cbar:
        ax_sc = plt.gca()
    else: 
        ax_sc = ax[0]
    sc = ax_sc.scatter(
        embedding[:, 0], embedding[:, 1], **sc_kwargs, zorder=3)
    if edges:
        sc2 = ax_sc.scatter(
            embedding[:, 0], embedding[:, 1], **sc_kwargs_edge, zorder=2)
    #ax_sc.set_aspect('equal', 'datalim')
    ax_sc.set(xticks=[], yticks=[])
    
    # annotate certain elements with empty circles
    annotate_ss = 15
    if isinstance(annotate_indices, dict):
        # dict of lists for each gamma index
        annotate_indices_specific = annotate_indices[i]
        for agg_idx in annotate_indices_specific:
            annotate_a_point(ax_sc, embedding, agg_idx, ss=annotate_ss,
                             zorder=4)
    else:
        # assume its the same list for each gamma value
        for agg_idx in annotate_indices:
            annotate_a_point(ax_sc, embedding, agg_idx, ss=annotate_ss, 
                             zorder=4)
    
    if c == 'energies' and not skip_cbar:
        cbar = plt.colorbar(sc, cax=ax[1])
        cbar.ax.tick_params(size=0)
        #plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

    # plt.title('UMAP projection of the %s dataset' % label, fontsize=24)
    
    plt.tight_layout()
    if isinstance(ext, list):
        for ext_str in ext:
            assert ext_str[0] == '.'
            plt.savefig(fpath + ext_str, dpi=dpi)
    else: 
        plt.savefig(fpath + ext, dpi=dpi)

# Specify settings and gen/load alignedUMAP

In [15]:
RANDOM_STATE = 6  #0 #100 #40 #20 #0  Note RANDOM_STATE = 6 for July fig. generation

UMAP_KWARGS_DEFAULT = {
    'random_state': RANDOM_STATE,  
    'transform_seed': RANDOM_STATE,
    'n_components': 2,
    'metric': 'euclidean',
    'init': 'spectral',
    'unique': False,
    'n_neighbors': 15,  # Default: 15
    'min_dist': 0.1,    # Default: 0.1
    'spread': 1.0,      # Default: 1.0
    'n_epochs': 200,    # Default: 200
}

UMAP_KWARGS_ALT = {
    'random_state': RANDOM_STATE,  
    'transform_seed': RANDOM_STATE,
    'n_components': 2,
    'metric': 'euclidean',
    'init': 'spectral',
    'unique': False,
    'n_neighbors': 15,  # Default: 15
    'min_dist': 0.1,    # Default: 0.1
    'spread': 1.0,      # Default: 1.0
    'n_epochs': 200     # Default: 200

}

In [16]:
# TODO - consider also including a subset parameter for relationsdict (subset to 900 was the preJUly Fig5)

# Settings used to generate Fig 4, 5, associated SI Figures
gamma_list_v1 = [0.0, 0.05, 0.06, 0.07, 0.08, 
                 0.09, 0.10, 0.15, 0.20, 0.4, 
                 0.6, 0.8, 0.9, 1.0, 20.0]
manyruns_dirnames_v1 = [
    'Wrandom0_gamma%.2f_10k_periodic_fixedorderV3_p3_M100' % a 
    for a in gamma_list_v1]
manyruns_paths_v1 = [RUNS_FOLDER + os.sep + 'multicell_manyruns' + os.sep + dirname for dirname in manyruns_dirnames_v1]
settings_aligned_umap_fig5_v1 = {
    'gamma_list': gamma_list_v1,
    'manyruns_paths': manyruns_paths_v1,
    'manyruns_dirnames': manyruns_dirnames_v1,
    'umap_kwargs': UMAP_KWARGS_DEFAULT,
    'alignment_wrapper_kwargs': 
        {'subsample_ndim': None, 
         'subsample_ensemble': None,  # 4000 originally
         'smod_last': True,
         'use_01': False,
         'skip_alignment': False},
    'subplot_selection': [0, 3, 4, 5, 7, 8, 9, 13]
}

gamma_list_v2 = [0.0, 0.07, 0.08, 0.09, 0.15, 0.20, 0.4, 1.0]
manyruns_dirnames_v2 = [
    'Wrandom0_gamma%.2f_10k_periodic_fixedorderV3_p3_M100' % a 
    for a in gamma_list_v2]
manyruns_paths_v2 = [RUNS_FOLDER + os.sep + 'multicell_manyruns' + os.sep + dirname for dirname in manyruns_dirnames_v2]
settings_aligned_umap_fig5_v2 = {
    'gamma_list': gamma_list_v2,
    'manyruns_paths': manyruns_paths_v2,
    'manyruns_dirnames': manyruns_dirnames_v2,
    'umap_kwargs': UMAP_KWARGS_ALT,
    'alignment_wrapper_kwargs': 
        {'subsample_ndim': None, 
         'subsample_ensemble': None,
         'smod_last': True,
         'use_01': False,
         'skip_alignment': False},
    'subplot_selection': np.arange(8)    
}


gamma_list_v3 = gamma_list_v1
manyruns_dirnames_v3 = [
    'Wrandom0_gamma%.2f_10k_periodic_R1_p3_M100' % a 
    for a in gamma_list_v3]
manyruns_paths_v3 = [RUNS_FOLDER + os.sep + 'multicell_manyruns' + os.sep + dirname for dirname in manyruns_dirnames_v3]
settings_aligned_umap_fig5_v3 = {
    'gamma_list': gamma_list_v3,
    'manyruns_paths': manyruns_paths_v3,
    'manyruns_dirnames': manyruns_dirnames_v3,
    'umap_kwargs': UMAP_KWARGS_DEFAULT,
    'alignment_wrapper_kwargs': 
        {'subsample_ndim': None, 
         'subsample_ensemble': None,
         'smod_last': True,
         'use_01': False,
         'skip_alignment': False},
    'subplot_selection': [0, 3, 4, 5, 7, 8, 9, 13]    
}

**Choose settings option and embed/align UMAP for each manyruns**

In [None]:
flag_generate = True  # If False, load data instead

if flag_generate:
    settings_alignment_chosen = settings_aligned_umap_fig5_v3
    
    # these asserts are needed for replotting the high dim states as tissues 
    #  using raw data within settings_alignment_chosen
    assert not settings_alignment_chosen['alignment_wrapper_kwargs']['use_01']
    assert settings_alignment_chosen['alignment_wrapper_kwargs']['subsample_ndim'] is None
    assert settings_alignment_chosen['alignment_wrapper_kwargs']['subsample_ensemble'] is None

    start = time.time()
    embedded_datasets = gather_data_then_alignedUMAP(
        settings_alignment_chosen['manyruns_paths'], 
        settings_alignment_chosen['manyruns_dirnames'],
        settings_alignment_chosen['umap_kwargs'],
        **settings_alignment_chosen['alignment_wrapper_kwargs']
    )
    end = time.time()
    print('Total time (s):', end - start)
    
    # save embeddings and settings
    # TODO - reimplement with new object embedded_datasets
    #  - maybe best to save and load each individually? 
    """pickled_alignUMAP = NOTEBOOK_OUTDIR + os.sep + 'aligned_mapper.pkl'
    with open(pickled_alignUMAP, 'wb') as f:
        # convert from class 'numba.typed.typedlist.List' to numpy array
        raw_python_list = [i for i in aligned_mapper.embeddings_]
        pickle.dump(raw_python_list, f)
    """
    
    # save settings 
    pickled_align_misc = NOTEBOOK_OUTDIR + os.sep + 'aligned_misc.pkl'
    pickle.dump(settings_alignment_chosen, open(pickled_align_misc, 'wb'))
    
else:
    # load settings
    input_pickled_align_misc = NOTEBOOK_OUTDIR + os.sep + 'aligned_misc.pkl'
    with open(input_pickled_alignUMAP, 'rb') as pickle_file:
        settings_alignment_chosen = pickle.load(pickle_file)

    # load main object (alignUMAP)
    # TODO reimplement as above
    """
    input_pickled_alignUMAP = NOTEBOOK_OUTDIR + os.sep + 'aligned_mapper.pkl'
    with open(input_pickled_alignUMAP, 'rb') as pickle_file:
        aligned_mapper_embeddings = pickle.load(pickle_file)
    aligned_mapper = umap.AlignedUMAP(**settings_alignment_chosen['umap_kwargs'])
    aligned_mapper.embeddings_ = aligned_mapper_embeddings    
    """
    # embedded_datasets = ...LOAD each data_subdict object...


npts, ndim 10000 900




DYNAMICS_FIXED_UPDATE_ORDER: True
Beginning AlignedUMAP...


In [None]:
embedded_datasets = add_PCA_to_embedded_datasets(embedded_datasets)

In [None]:
#for key in embedded_datasets.keys():
#    embedded_datasets[key]['algos'].pop('pca')

### Part A: Generate 8 panel figure analagous to Fig. 5

In [None]:
gamma_selection = settings_alignment_chosen['subplot_selection']  # [0, 3, 4, 5, 7, 8, 9, 13]
fpath = NOTEBOOK_OUTDIR + os.sep + 'picks_nobound_2x4_gammas%d' \
    % (len(settings_alignment_chosen['gamma_list']))

plot_args = (
    embedded_datasets, 
    settings_alignment_chosen, 
)
plot_kwargs = {
    'annotate_indices': [],
    'ext': ['.jpg', '.pdf', '.svg'],
    'rasterized': True,
    'ss': 6.0,
    'edges': False,
    'alpha': 1.0,
    'black_edges': True,
    'skip_cbar': False,
}

c = 'energies'
fpath_spec = fpath + '_energies'
plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

c = 'nunique'
fpath_spec = fpath + '_nunique'
plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

#c = None
#fpath_spec = fpath + '_fixedcolor'
#plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

### Part B: Plot distributions of fixed point data (energy, num. unique single cell states) on similar subplot grid

In [None]:
X_multi_kwargs = settings_alignment_chosen['alignment_wrapper_kwargs'].copy()
X_multi_kwargs.pop('skip_alignment', None)
X_multi, num_cells, num_genes = build_X_multi(settings_alignment_chosen['manyruns_paths'], **X_multi_kwargs)

In [None]:
fpath = NOTEBOOK_OUTDIR + os.sep + 'picks_nobound_2x4_gammas%d' \
    % (len(settings_alignment_chosen['gamma_list']))
        
plot_kwargs = {
    'nrow': 2, 
    'ncol': 4, 
    'ext': ['.jpg', '.pdf'],
    'gamma_selection': None
}

for fpdata_type in ['energies', 'nunique']:
    fpath_figure = fpath + '_%sHist' % fpdata_type
    plot_args = (
        embedded_datasets,
        settings_alignment_chosen, 
        fpdata_type, 
        fpath_figure
    )
    plot_fpdata_as_grid(*plot_args, **plot_kwargs, logy=False)
    plot_fpdata_as_grid(*plot_args, **plot_kwargs, logy=True)

### Part C: Plotly style interactive plot of particular subpanel

In [None]:
from multicell.unsupervised_helper import \
    plotly_express_embedding, generate_control_data, plot_given_multicell, make_dimreduce_object

In [None]:
gamma_selection = settings_alignment_chosen['subplot_selection']
gamma_list = settings_alignment_chosen['gamma_list']

for idx in gamma_selection:
    data_subdict_selected = embedded_datasets[idx]
    outdir = NOTEBOOK_OUTDIR + os.sep + 'plotly_idx%d_gamma%.3f' % (idx, gamma_list[idx])
    os.makedirs(outdir, exist_ok=True)
        
    plotly_express_embedding(
        data_subdict_selected, 
        color_by_index=False, 
        as_landscape=False, 
        fmod='jupyter', 
        show=False, 
        dirpath=outdir, 
        surf=False, 
        step=None
    )

### Part D: Plot high-demensional state visualization of certain embedded points 

Notes from powerpoint figure 5 (From bottom left):
- BL: 1395, 3717, 2325, 
- TL: 2357, 728,
- Homogenous + Maze: 636, 3473, 
- Heterogeneous: 3964, 1135, 1467, 2602

In [None]:
# constants
gamma_list = settings_alignment_chosen['gamma_list']
assert gamma_list[13] == 1.0  # assumption for the manually created dicts below
subset_orig = [1395, 3717, 2325, 2357, 728, 636, 3473, 3964, 1135, 1467, 2602]

# dictionary of points to plot for each data_subdict (i.e. each gamma value UMAP)
agg_points_to_plot_v1 = {
    i: subset_orig for i in range(len(gamma_list))
}

agg_points_to_plot_v2 = {
    0: [0, 1, 2, 3],
    3: [8011, 5709, 8249, 6626, 9345, 9017, 8573, 9229],
    4: [9521, 9904, 9691, 5542, 5402, 8322, 7098, 9852, 9844, 8186],
    5: [9521, 9904, 9691, 5542, 5402, 8322, 7098, 9852, 9844, 8186],
    7: [3886, 1551, 3649, 5492, 3809, 2990, 83, 1787, 8007, 7342, 7882, 4307],
    8: [3886, 1551, 3649, 5492, 3809, 2990, 83, 1787, 8007, 7342, 7882, 4307, 9168, 775, 1512, 37],
    9: [3886, 1551, 3649, 5492, 3809, 2990, 83, 1787, 8007, 7342, 7882, 4307, 9168, 775, 1512, 37],
    13: subset_orig,  # Panel 8: original points
}

agg_points_to_plot_v3_many = {
    13: [2269, 7142,
         7350, 7694,
         6809, 2, 2947,
         8552,         
         9713, 369, 52, 2970, 8237, 6789, 
         6681, 9018, 2722, 2503, 503, 3895, 7831, 
         9385, 1347, 2458, 5189, 7057,
         9742, 5113,
         8891, 1908, 2717, 3515, 5112, 5592, 7105, 9447, 9660,
         1842, 8609, 4302, 6247, 5649, 224, 5776, 2504, 2978, 5727,
         7657,
         8134, 3657, 7849,
         646, 5607, 3132, 4865, 7076,
         5552, 8601, 1828, 2258, 3321,
         6467, 2220, 3611, 7414, 5522
        ]
}


agg_points_to_plot_v4_selected = {
    13: [2269, 7694, 2,           # maze 'A'
         6681, 5189, 5113,        # triple/homog 'B'
         2717,                    # "C"
         2978,                    # "D"
         7657,                    # B'
         6467, 5522               # A'
        ]
}

fmod = ''
smod = settings_alignment_chosen['alignment_wrapper_kwargs']['smod_last']

agg_points_to_plot = agg_points_to_plot_v4_selected

# TODO implement this block -- diff points to plot for each gamma value

In [None]:
from multicell.multicell_replot import plot_tissue_given_agg_idx

for gamma_idx in agg_points_to_plot.keys():
    gamma_val = gamma_list[gamma_idx]
    outdir = NOTEBOOK_OUTDIR + os.sep + 'tissue_g%.3f' % gamma_val
    os.makedirs(outdir, exist_ok=True)
    data_subdict = embedded_datasets[gamma_idx]
    
    print('Plotting points for subdict #%d' % gamma_idx)
    agg_indices = agg_points_to_plot[gamma_idx]
    for agg_idx in agg_indices:
        title = 'Point: %d' % agg_idx  # None
        plot_tissue_given_agg_idx(
            data_subdict, agg_idx, fmod, outdir, 
            state_int=False, smod_last=smod, title=title)


### Part E: Plot 'annotated' version of alignedUMAP where some elements are circled

In [None]:
subset_orig = [1395, 3717, 2325, 2357, 728, 636, 3473, 3964, 1135, 1467, 2602]
fpath_annotate = fpath + '_annotate'
ss = 6

#agg_points_to_annotate = subset_orig
agg_points_to_annotate = agg_points_to_plot_v3_many[13]

In [None]:
plot_args = (
    embedded_datasets, 
    settings_alignment_chosen, 
    fpath_annotate
)
plot_kwargs = {
    'ss': ss,
    'annotate_indices': agg_points_to_annotate
}

plot_aligned_panels_as_grid(*plot_args, **plot_kwargs, ext=['.pdf', '.jpg'], rasterized=True)


### Part F: Plot single panel umap

In [None]:
gamma_idx = 13
gamma_val = settings_alignment_chosen['gamma_list'][gamma_idx]
assert gamma_val == 1.0
fpath = NOTEBOOK_OUTDIR + os.sep + 'umap_singlepanel_g%.3f' % gamma_val

annotate_indices = agg_points_to_plot_v4_selected[13] #[]
color_choice = 'nunique' # None, 'energies', 'nunique'

In [None]:
plot_single_umap_panel_gridspec(
    embedded_datasets[gamma_idx], fpath, 
    ext=['.jpg', '.pdf', '.svg'], 
    ss=10, alpha=1.0, edges=True,
    annotate_indices=annotate_indices, 
    rasterized=True,
    c=color_choice,
    skip_cbar=False)

In [None]:
plot_single_umap_panel_gridspec(
    embedded_datasets[gamma_idx], fpath + '_altA', 
    ext=['.jpg', '.pdf', '.svg'], 
    ss=10, alpha=1.0, edges=False,
    annotate_indices=annotate_indices, 
    rasterized=True,
    c=color_choice,
    skip_cbar=False)

### Part G: Plot PCA grid and cumsum

In [None]:
gamma_selection = settings_alignment_chosen['subplot_selection']
gamma_list = settings_alignment_chosen['gamma_list']
fpath_pca_grid = NOTEBOOK_OUTDIR + os.sep + 'pca_picks_2x4'
fpath_pca_cumsum = NOTEBOOK_OUTDIR + os.sep + 'pca_picks_cumsum'

*PCA grid in the style of aligned UMAP grid*

In [None]:
plot_args = (
    embedded_datasets, 
    settings_alignment_chosen, 
)
plot_kwargs = {
    'algo': 'pca',
    'annotate_indices': [],
    'ext': ['.jpg', '.pdf', '.svg'],
    'rasterized': True,
    'ss': 6.0,
    'edges': False,
    'alpha': 1.0,
    'black_edges': True,
    'skip_cbar': False,
}

c = 'energies'
fpath_spec = fpath_pca_grid + '_energies'
plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

c = 'nunique'
fpath_spec = fpath_pca_grid + '_nunique'
plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

#c = None
#fpath_spec = fpath + '_fixedcolor'
#plot_aligned_panels_as_grid(*plot_args, fpath_spec, **plot_kwargs, c=c)

*PCA cumsum*

In [None]:
# full: (5, 3.5)
# inset with xlim(0,25): (2.5, 2.5)
fig = plt.figure(figsize=(5, 3.5), dpi=450)
for idx in gamma_selection:
    gamma_val = gamma_list[idx]
    data_subdict_selected = embedded_datasets[idx]
    pca_dict = data_subdict_selected['algos']['pca']
    plt.plot(
        pca_dict['cumsum_x'], 
        pca_dict['cumsum_y'], 
        '--',
        label=r'$\gamma=%.2f$' % gamma_val)

plt.axvline(300, linestyle='-.', c='k', zorder=1)    

# axis modifications
#plt.gca().set_xscale('log')
#plt.gca().set_yscale('log')
#plt.xlim(0,25)

#plt.legend()
plt.xlabel('Num. components')
plt.ylabel('Cumulative sum of eigenvalues')

#plt.gcf().subplots_adjust(bottom=0.25, left=0.25)
plt.gcf().subplots_adjust(bottom=0.15, left=0.25)
plt.savefig(fpath_pca_cumsum + '.svg')
plt.savefig(fpath_pca_cumsum + '.pdf')
plt.show()

### Part H: Correlation between various features

*Correlation: self-energy and interaction energy*
- use violin plot? 
- energies array has 5 columns:
  - 0: H_quadratic_form (should match index 1)
  - 1: H_multi (which is sum of index (2,3,4)
  - 2: H_self
  - 3: H_app
  - 4: H_pairwise_scaled
- interesting: compare term 2 and 4 as $\gamma$ changes
  - can use dual violin plot https://seaborn.pydata.org/generated/seaborn.violinplot.html

In [None]:
gamma_idx = 2
xx = embedded_datasets[gamma_idx]['nunique']
yy = embedded_datasets[gamma_idx]['energies'][:,2]
print(xx.shape)
print(yy.shape)
plt.scatter(xx,yy)
plt.show()

In [None]:
gamma_idx = 2
xx = embedded_datasets[gamma_idx]['nunique']
yy = embedded_datasets[gamma_idx]['energies'][:,4]
print(xx.shape)
print(yy.shape)
plt.scatter(xx,yy)
plt.show()

*Correlation: energy and num unique single cell states*

In [None]:
gamma_idx = 0
xx = embedded_datasets[gamma_idx]['nunique']
yy = embedded_datasets[gamma_idx]['energies'][:,0]

df = pd.DataFrame(
    {'nunique': xx.astype('str'),
     'energies_0': yy}
)

ax = sns.violinplot(x="nunique", y="energies_0",
                    data=df, palette="muted", split=False)

In [None]:
#sns.set_theme(style="whitegrid")
tips = sns.load_dataset("tips")
ax = sns.violinplot(x="day", y="total_bill", data=tips)

In [None]:
print(tips)

*Correlation: init cond to final state*

# Bug investigation: nunique>8 for $\gamma=0$

**Notes so far**
- look in following dir for the state 7000 with nunique=10
I:\Development\Repositories\biomodels\celltypes\runs\multicell_manyruns\Wrandom0_gamma0.00_10k_periodic_fixedorderV3_p3_M100\s7000\states
- there are files X_0.npz, ... X_4.npz, X_last.npz, X_secondlast.npz
- issue: X_3.npz, X_4.npz are the same and are valid fixed points. why is X_last.npz slightly different? is it a bit flip corruption or a bug? 
- explanation A: s7000 directory does not correspond to agg_index 7000 of the aggregate file. can reaggregate and reindex (this will make the manual plot need to be redone unless mapping is found)

In [None]:
# reaggregate the data
from multicell.multicell_manyruns import aggregate_manyruns

"""
dirname = 'Wrandom0_gamma0.00_10k_periodic_fixedorderV3_p3_M100'
basedir = RUNS_FOLDER + os.sep + 'multicell_manyruns' + os.sep + dirname
aggregate_manyruns(
    basedir, 
    agg_subdir='aggregate',
    agg_states=True,
    agg_energy=True,
    agg_plot=False,
    only_last=True)"""

In [None]:
def build_data_subdict_local(data_subdict, X, embedding):
    """
    data_subdict: partially complete data_subdict (path and label keys)
    X:            raw data 2d arr (num_runs x num_dimensions)
    embedding:    embedded data (num_runs x num_lowdim)
    """
    # Reload data_subdict source files
    #  note: we transpose raw data, energies for umap/other functions
    manyruns_path = data_subdict['path']
    agg_dir = manyruns_path + os.sep + 'aggregate'
    fpath_state = agg_dir + os.sep + 'X_aggregate_rerun_last.npz'
    fpath_energy = agg_dir + os.sep + 'X_energy_last.npz'
    fpath_pickle = manyruns_path + os.sep + 'multicell_template.pkl'
    X = np.load(fpath_state)['arr_0'].T          
    X_energies = np.load(fpath_energy)['arr_0'].T
    with open(fpath_pickle, 'rb') as pickle_file:
        # unpickling multicell object
        multicell_template = pickle.load(pickle_file)  

    # Fill out the rest of data_subdict keys        
    num_runs, total_spins = X.shape
    data_subdict['data'] = X
    data_subdict['index'] = list(range(num_runs))
    data_subdict['energies'] = X_energies
    data_subdict['num_runs'] = num_runs
    data_subdict['total_spins'] = total_spins
    data_subdict['multicell_template'] = multicell_template
    data_subdict['algos'] = {
        'umap': 
            {'reducer': None,
             'embedding': embedding}
    }

    # Add extra analysis data used within the notebook
    X_nunique = get_unique_states_over_ensemble(
        X, 
        multicell_template.num_cells, 
        multicell_template.num_genes
    )
    data_subdict['nunique'] = X_nunique.T  # note transpose
    assert (
        data_subdict['energies'].shape[0] == 
        data_subdict['nunique'].shape[0])

    return data_subdict

dirname = 'Wrandom0_gamma0.00_10k_periodic_fixedorderV3_p3_M100'
data_subdict_rerun0 = {
    'label': dirname,
    'path': RUNS_FOLDER + os.sep + 'multicell_manyruns' + os.sep + dirname
}
data_subdict_rerun0 = build_data_subdict_local(
    data_subdict_rerun0, None, None)

In [None]:
gamma_idx = 0
#data_subdict_gamma0 = embedded_datasets[gamma_idx]
data_subdict_gamma0 = data_subdict_rerun0
xx = data_subdict_gamma0['nunique']
yy = data_subdict_gamma0['energies'][:,0]

df = pd.DataFrame(
    {'nunique': xx.astype('str'),
     'energies_0': yy}
)

print(df[df['nunique'] == '10'])
print(df[df['nunique'] == '9'].head())
print(df[df['nunique'] == '8'].head())

#agg_idx_n9 = [428, 465, 508]
#agg_idx_n10 = [7000]

agg_idx_n9 = [590, 593, 696]
agg_idx_n10 = [7299]

*Plot potential bug indices*

In [None]:
outdir = NOTEBOOK_OUTDIR + os.sep + 'tissue_bugged_rerun'
os.makedirs(outdir, exist_ok=True)

plot_tissue_given_agg_idx(
    data_subdict_gamma0, agg_idx_n10[0], '_bug10', outdir, 
    state_int=True, smod_last=True, title='bug? %d' % agg_idx_n10[0]
)

for agg in agg_idx_n9:
    plot_tissue_given_agg_idx(
        data_subdict_gamma0, agg, '_bug9', outdir, 
        state_int=True, smod_last=True, title='bug? %d' % agg
    )

In [None]:
agg = 0
sidelength = 10
outpath = outdir + os.sep + 'replot_manual_%d' % agg
multicell = data_subdict_gamma0['multicell_template']
manyruns_path = data_subdict_gamma0['path']

print(manyruns_path)

#dir(multicell)
#print(multicell.graph_state_arr)

agg_dir = manyruns_path + os.sep + 'aggregate'
fpath_state = agg_dir + os.sep + 'X_aggregate_rerun_last.npz'

X = np.load(fpath_state)['arr_0'].T
X_state = X[agg, :]
X_state = X_state.reshape(multicell.num_cells, multicell.num_genes)

*Now check to see if these states are indeed fixed points/minima*

In [None]:
agg = 0
fname = 'X_last.npz'
sdir = 's%d' % agg
specific_states_dir = manyruns_path + os.sep + sdir + os.sep + 'states'
state_path = specific_states_dir + os.sep + fname

X_specific = np.load(state_path)['arr_0'].T
print(X_specific.shape)
#print(X_specific)
replot_scatter_dots_TEST(
    X_specific.T, sidelength, outpath,
    fmod='_bugfmodSpecific_%s' % fname[:-4], 
    state_int=True, title='bug? %s %d' % (fname, agg))

In [None]:
from multicell.multicell_replot import FIXED_COLOURMAP 
import matplotlib as mpl

def state_to_label(state):
    # Idea: assign integer label (0 to 2^N - 1) to the state
    # state acts like binary representation of integers
    # "0" corresponds to all -1
    # 2^N - 1 corresponds to all +1
    label = 0
    bitlist = ((1+np.sign(np.array(state, dtype=int)))/2).astype(int)  # was np.array, now np.sign to handle xi corruption experiments
    for bit in bitlist:
        label = (label << 1) | bit
    return label


def lattice_square_loc_to_int(loc, sidelength):
    # maps a two-tuple, for the location of a cell on square grid, to a unique integer
    # sidelength is sqrt(num_cells), the edge length of the lattice
    assert 0 <= loc[0] < sidelength
    assert 0 <= loc[1] < sidelength
    x, y = loc[0], loc[1]
    return x * sidelength + y


def lattice_square_int_to_loc(node_idx, sidelength):
    # maps node_idx, the unique int rep of a cell location on the grid, to corresponding two-tuple
    # sidelength is sqrt(num_cells), the edge length of the lattice
    y = node_idx % sidelength              # remainder from the division mod n
    x = int((node_idx - y) / sidelength)   # solve for x
    return x, y


def replot_scatter_dots_TEST(lattice_state, sidelength, outpath,
                        fmod='', state_int=False, cmap=None, title=None,
                        ext=['.jpg', '.svg'], rasterized=True):
    """
    Full info morphology plot with grid of 9 genes as dots
    """

    if cmap is None:
        cmap = FIXED_COLOURMAP

    def state_to_colour_and_morphology(state):
        """
        # assign a unique colour to each state based on a colourmap
        cellstate_01 = ((cellstate + 1) / 2).astype(int)
        cellstate_brief = str(cellstate_01[2]) + str(cellstate_01[5]) + str(cellstate_01[8])

        # eight handpicked colours based on combinations of encoded celltypes
        color_dict_brief = {
            '000': soft_grey_norm,        # grey (all off)
            '100': color_A_pos,           # type A - blue
            '010': color_B_pos,           # type B - red
            '001': color_C_pos,           # type C - yellow
            '101': color_AC,              # type A+C - green
            '011': color_BC,              # type B+C - orange
            '110': color_AB,              # type A+B - purple
            '111': color_anchor_white,    # white (all on)
        }

        unique_colour = color_dict_brief[cellstate_brief]
        """
        genes = [0,1,2, 3,4,5, 6,7,8]
        cellstate_brief = state[genes]
        #cellstate_brief = [1,1,1, -1, -1, -1, -1, -1, -1]
        #cellstate_brief = [-1,-1,-1, 1, 1, 1, -1, -1, -1]
        #cellstate_brief = [-1,-1,-1, -1, -1, -1, 1, 1, 1]

        label = state_to_label(cellstate_brief)
        unique_colour = cmap[label]
        return unique_colour[0:3]

    n = sidelength
    x = np.zeros(n ** 2)
    y = np.zeros(n ** 2)
    colors = np.zeros((n**2, 3))
    for i in range(n):
        for j in range(n):
            grid_loc_to_idx = lattice_square_loc_to_int((i,j), sidelength)
            cellstate = lattice_state[:, grid_loc_to_idx]
            colors[grid_loc_to_idx, :] = state_to_colour_and_morphology(cellstate)
            x[grid_loc_to_idx] = j
            y[grid_loc_to_idx] = n - i

    # plot
    #fig = plt.figure(figsize=(12, 12))
    #fig, ax = plt.subplots(figsize=(12, 12), dpi=100)
    #ax = fig.add_axes([0, 0, 1, 1])  # position: left, bottom, width, height
    #ax.set_axis_off()
    fig, ax = plt.subplots(figsize=(12, 12), dpi=GLOBAL_DPI)
    ax.set_axis_off()

    def get_cell_mask(gene_idx):
        mask = lattice_state[gene_idx, :] == 1
        return mask

    # plot - detailed settings
    assert n in [2, 10, 20]
    if n == 10:
        box_lw = 2*1.5
        eps = 0.25
        lw = 2*2
        boxsize = 4*1800  # 600, 750, 850, at 990 it forms grey grid
        trisize = 4*50  # 225
        lw_eps = 0.05
        fontsize = 24

    elif n == 2:
        box_lw = 4*1.5
        eps = 0.3
        lw = 4*2
        boxsize = 95*1800  # 600, 750, 850, at 990 it forms grey grid
        trisize = 30*50  # 225
        lw_eps = 0.05
        fontsize = 24

    else:
        assert n == 20
        box_lw = 1.5
        eps = 0.25
        lw = 2
        boxsize = 1800  # 600, 750, 850, at 990 it forms grey grid
        trisize = 50  # 225
        lw_eps = 0.05
        fontsize = 24

    # create gene markers (the three celltype block functionality is no longer used, now 9 dots)
    appendage_style = 'o'  # 1
    appendage_z = 2
    t_series = [0] * 9
    for idx in range(9):
        t_mod = mpl.markers.MarkerStyle(marker=appendage_style)
        t_series[idx] = t_mod

    # outer square with alpha (orig 0.4 alpha)
    plt.scatter(x, y, marker='s', c=colors, alpha=1.0, s=boxsize,
                ec='k', zorder=1, lw=box_lw, rasterized=rasterized)
    # gene 0, 1 mask for celltype A: originally up/down appendage
    mask0 = get_cell_mask(0)
    mask1 = get_cell_mask(1)
    mask2 = get_cell_mask(2)
    x0, x1, x2 = -eps, 0, +eps
    y0, y1, y2 = +eps, +eps, +eps
    plt.scatter(x[mask0]+x0, y[mask0]+y0, marker=t_series[0], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask1]+x1, y[mask1]+y1, marker=t_series[1], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask2]+x2, y[mask2]+y2, marker=t_series[2], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    # gene 3, 4 mask for celltype B: originally left/right appendage
    mask3 = get_cell_mask(3)
    mask4 = get_cell_mask(4)
    mask5 = get_cell_mask(5)
    x3, x4, x5 = x0, x1, x2
    y3, y4, y5 = 0, 0, 0
    plt.scatter(x[mask3]+x3, y[mask3]+y3, marker=t_series[3], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask4]+x4, y[mask4]+y4, marker=t_series[4], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask5]+x5, y[mask5]+y5, marker=t_series[5], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    # gene 6, 7 mask for celltype C: originally membrane/circle interior
    mask6 = get_cell_mask(6)
    mask7 = get_cell_mask(7)
    mask8 = get_cell_mask(8)
    x6, x7, x8 = x0, x1, x2
    y6, y7, y8 = -eps, -eps, -eps
    plt.scatter(x[mask6]+x6, y[mask6]+y6, marker=t_series[6], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask7]+x7, y[mask7]+y7, marker=t_series[7], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)
    plt.scatter(x[mask8]+x8, y[mask8]+y8, marker=t_series[8], c='white', alpha=1.0,
                s=trisize, ec='k', zorder=appendage_z, linewidths=lw, rasterized=rasterized)

    if state_int:
        num_cells = lattice_state.shape[1]
        for k in range(num_cells):
            cellstate = lattice_state[:, k]
            label = state_to_label(cellstate)
            i, j = lattice_square_int_to_loc(k, n)
            #plt.gca().text(j, i, label, color='black', ha='center', va='center')
            #plt.gca().text(j, i, '(%d, %d)' % (j, i), color='black', ha='center', va='center')            
            xc = j
            yc = n - i
            plt.gca().text(xc, yc, '%s\n(%d, %d)' % (label, xc, yc), color='black', ha='center', va='center')
            #print(k, '(%d, %d)' % (i, j), ' - ', label, 'coord (%d, %d)' % (j, i))
            #print(cellstate)

    if title is not None:
        plt.title(title, fontsize=fontsize)
    # draw gridlines
    ax = plt.gca()
    plt.axis('off')  # no grid can look nice
    #ax.grid(which='major', axis='both', linestyle='-', color='k', linewidth=2)

    #ax.set_xticks([], [])
    #ax.set_yticks([], [])
    xticks = np.arange(-.5, n, 1)
    yticks = np.arange(-.5, n, 1)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.xaxis.set_ticklabels(['' for _ in xticks])
    ax.yaxis.set_ticklabels(['' for _ in yticks])

    # this crops border
    plt.xlim(-0.5 - lw_eps, n - 0.5 + lw_eps)
    plt.ylim( 0.5 - lw_eps, n + 0.5 + lw_eps)
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)  # unsure
    plt.tight_layout(pad=1.2, w_pad=0.5, h_pad=1.0)                            # maybe remove line
    # save figure
    if title is None:
        bbox_inches = None
    else:
        bbox_inches = 'tight'
    if isinstance(ext, list):
        for ext_str in ext:
            assert ext_str[0] == '.'
            plt.savefig(outpath + fmod + ext_str, bbox_inches=bbox_inches, dpi=GLOBAL_DPI)
    else:
        plt.savefig(outpath + fmod + ext, bbox_inches=bbox_inches, dpi=GLOBAL_DPI)
    plt.close()
    return


*Now check to see if these states are indeed fixed points/minima*

In [None]:
sc_state_init = np.array([-1, -1, 1, 1, 1, 1, 1, 1, 1])
hh = np.dot(multicell.matrix_J, sc_state_init)

print(np.sign(hh))

In [None]:
sc_state_init = np.array([1, 1, -1, 1, 1, 1, 1, 1, 1])
hh = np.dot(multicell.matrix_J, sc_state_init)

print(np.sign(hh))

# Testing code for plotting

**Plotting num unique cell states -- special discrete cmap**

Consider also using different markerstyles (triangle, square, etc.) for different nunique

In [None]:
sns.color_palette("hls", 8)

In [None]:
sns.color_palette('pastel')

In [None]:
sns.color_palette()