In [None]:
import time
import sys
import os
import numpy as np

%load_ext autoreload
%autoreload 2

import xfmkit.utils as utils
import xfmkit.argops as argops
import xfmkit.clustering as clustering
import xfmkit.visualisations as vis
import xfmkit.processops as processops
import xfmkit.structures as structures
import xfmkit.entry_processed as entry_processed

#add the parent folder to sys path so we can import from the notebook subfolder
sys.path.insert(0,'..')

#args = ' -d /home/lachlan/CODEBASE/ReadoutXFM/data/processed_maps/nf_demo_short -n 2'
args = ["-d", "/home/lachlan/CODEBASE/ReadoutXFM/data/processed_maps/carlos_full", "-n", "2",]

pxs, embedding, categories, classavg, palette = entry_processed.read_processed(args)




In [None]:
import seaborn as sns
import colorcet as cc
from sklearn.preprocessing import normalize

def cluster_colourmap(embedding, categories):
    """
    create a colourmap clustered onto an embedding
    """
    GREY=( 0.5, 0.5, 0.5 )

    cat_min=np.min(categories)
    cat_max=np.max(categories)
    num_cats=cat_max-cat_min+1
    num_colours = num_cats*3

    palette = sns.color_palette(cc.glasbey_light,num_colours)
    colours = np.array(palette, dtype=np.float32)

    # produce 2D embedding for visualisation
    ___, colour_embedding = clustering.reduce(colours, "UMAP", target_components=2) 
    
    colour_embedding__ = np.copy(colour_embedding)
    colour_embedding__ = colour_embedding__-np.min(colour_embedding__)    

    return palette, colour_embedding

newpalette, colour_embedding = cluster_colourmap(embedding, categories) 


In [None]:
def norm_onto_2d(colour_embedding, embedding):
    """
    normalise one 2D array of values onto the other, by axis
    """

    if not ( len(colour_embedding.shape) == 2 and len(embedding.shape) == 2):
        raise ValueError("both arrays must be 2D")

    embedding__ = np.copy(embedding)
    colour_embedding__ = np.copy(colour_embedding)

    for i in range(embedding.shape[1]):
        embedding__[:,i] = embedding__[:,i]-np.min(embedding[:,i])

        colour_embedding__[:,i] = np.copy(colour_embedding[:,i])
        colour_embedding__[:,i] = colour_embedding__[:,i]-np.min(colour_embedding__[:,i])   

        colour_embedding__[:,i] = colour_embedding__[:,i]/np.max(colour_embedding__[:,i])
        colour_embedding__[:,i] = colour_embedding__[:,i]*np.max(embedding__[:,i])
        colour_embedding__[:,i] = colour_embedding__[:,i]+np.min(embedding[:,i])

    return colour_embedding__


normed = norm_onto_2d(colour_embedding, embedding)

palcat = np.arange(0,normed.shape[0])



In [None]:
def new_embedplot(embedding, categories, palette):
    x=embedding.T[0]
    y=embedding.T[1]

    ### scatter plot with marginal axes
    sns.set_style('white')

    embed_plot = sns.jointplot(x=x, y=y,
                hue=categories, palette=palette,
                lw=0,
                joint_kws = dict(alpha=1.0),
                height=12, ratio=6
                )

    embed_plot.set_axis_labels('x', 'y', fontsize=16)

    embed_plot.ax_joint.legend_.remove()

    sns.despine(ax=None, left=True, bottom=True)
    fig = embed_plot.fig

new_embedplot(normed, palcat, newpalette)

fig = vis.seaborn_embedplot(embedding, categories, palette=palette)

In [None]:
def get_closest_points(normed, centroids):
    closest = np.zeros(centroids.shape[0], dtype=np.int32)

    for i in range(centroids.shape[0]):
        print(i)
        matrix = normed-centroids[i,:]
        dist = matrix[:,0]**2 + matrix[:,1]**2

        #result = int(np.where(dist==np.min(dist))[0][0])
        j=0
        partn = np.partition(dist, j)[j]
        result = int(np.where(dist==partn)[0][0])

        while result in closest:
            print(f"{result} in closest")
            partn = np.partition(dist, j)[j]
            result = int(np.where(dist==partn)[0][0])
            j+=1

        print(f"using {result}")
        closest[i] = result

    return closest

centroids = utils.compile_centroids(embedding, categories)

closest = get_closest_points(normed, centroids)

closest.shape, centroids.shape[0], np.min(categories), np.max(categories), np.max(categories)-np.min(categories)






In [None]:
import copy


def extract_by_index(newpalette, closest):

    GREY=( 0.5, 0.5, 0.5 )

    final_palette = copy.deepcopy(newpalette)

    print(type(final_palette))

    del final_palette[0:]

    print(type(final_palette))

    for i in closest:        
        final_palette.append(newpalette[i])


    del final_palette[0]

    final_palette.insert( 0, GREY )    

    return final_palette

final_palette = extract_by_index(newpalette, closest)

In [None]:
final_palette

In [None]:
new_embedplot(normed, palcat, newpalette)

fig = vis.seaborn_embedplot(embedding, categories, palette=final_palette)

In [None]:
fig_cat_map = vis.category_map_direct(categories, pxs.dimensions, palette=final_palette)

