# UMAP feature embedding

In this notebook we will use UMAP to visualize the embeddings of the features of the different datasets and feature extracted from different layers of the model.
We will visualize the three future classes distributions in the embedding space - background, foreground and border.

# Imports

In [None]:
import umap
import umap.plot

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import pickle
import pathlib as pl
import os

# Utils

In [None]:
def load_pckl(file_name, path=None):
    if path is not None:
        file_name = os.path.join(path, file_name)

    with open(file_name, 'rb') as f:
        data = pickle.load(f)
    return data


def save_pckl(d, file_name, pr=None, path=None):
    if path is not None:
        file_name = os.path.join(path, file_name)

    with open(file_name, 'wb') as f:
        pickle.dump(d, f, protocol=pr if pr is not None else pickle.DEFAULT_PROTOCOL)

In [None]:
data_root = pl.Path(r'D:\Projects_MV\conv_paint\data')

# Load data

In [None]:
datasets = [ 'actin', 'fret', 'nuclei', 'spindle', 'worms', 'worms2']
ds = datasets[0]

We will define sizes of the blocks of the features, as well as the cumulative sum of the blocks. 

In [None]:
blocks = [1, 64, 64, 256, 512, 512, 64, 64, 256, 512, 512, 1, 64, 64, 256, 512, 512, 64, 64, 256, 512, 512]

In [None]:
sum(blocks)

In [None]:
csb = np.cumsum([0]+blocks)
csb = list(csb)

This code block does everything in one go. 
1. Load the features and targets, concatenate them
2. Randomly sample n_sample samples. Here try two options - 200k and 50k samples. 
3. For each block of features as well as for all features together, perform 2D UMAP embedding
4. visualize the embedding:
    a. points visualization with color coding of the classes by UMAP plot
    b. local dim visualization by UMAP plot
    c. normal scatter plot of the embedding with color coding of the classes
5. Save the plots
6. Save the embedded dataset `d_emb` as a pickle file which includes
    a. subsampled features
    b. subsampled labels
    c. blocks sizes
    d. embeddings for each block and for all features together

In [None]:
n_sample_opt = [200000, 50000]
n_neighbors_opt = [300, 300]

for n_sample, n_neighbors in zip(n_sample_opt, n_neighbors_opt):
    plot_scale = 7
    
    cmap = plt.get_cmap('jet', lut=3)
    
    for ds in datasets[3:]:
        p = data_root/'features'/ds
        all_features_all_samples = []
    
        all_targets_all_samples = []
    
        for pi in p.glob('features_*.pckl'):
            # print(pi)
            #try:
            if '_sample_' in str(pi):
                continue
            d = load_pckl(pi)
            #except Exception as e:
            #    continue
            # print(d['features'][0].shape)
            features_all_samples, targets_all_samples, feature_info = d['features'][0], d['targets'][0], d['feature_info']
            all_features_all_samples.append(np.asarray(features_all_samples))
            all_targets_all_samples.append(np.asarray(targets_all_samples))
    
    
        all_features_all_samples = np.concatenate(all_features_all_samples, axis=0)
        all_targets_all_samples = np.concatenate(all_targets_all_samples, axis=0)
        print(f'all_features_all_samples.shape={all_features_all_samples.shape}; all_targets_all_samples.shape={all_targets_all_samples.shape}')
        
        n = len(all_features_all_samples)
        n_ss = min(n_sample, n)
        idx_all = np.random.permutation(n)
    
        idx = idx_all[:n_ss]
    
    
        f = all_features_all_samples[idx]
        l = all_targets_all_samples[idx]
    
        d_emb = {
            'f':f,
            'l': l,
            'blocks': blocks
        }
        for b, e in zip(csb[0:-1]+[csb[0]], csb[1:]+[csb[-1]]):  # last = all channels together
            subset = f[:, b:e]
            print(ds, b, e, subset.shape)
            mapper_1 = umap.UMAP(n_neighbors=n_neighbors, verbose=True, low_memory=True).fit(subset)
            emb = mapper_1.embedding_.copy()
            d_emb[f'emb_{b}_{e}'] = emb
            
            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(3*plot_scale, plot_scale))
            try:
                #umap.plot.show(umap.plot.points(mapper_1, labels=l, theme='fire'))
                umap.plot.points(mapper_1, labels=l, theme='fire', ax=ax[0])
            except Exception as ex:
                print(ex)
                
            try:
                #umap.plot.show(umap.plot.diagnostic(mapper_1, diagnostic_type='local_dim'))
                umap.plot.diagnostic(mapper_1, diagnostic_type='local_dim', ax=ax[1])
            except Exception as ex:
                print(ex)
    
            sc = ax[2].scatter(*emb.T, c=(l-1), cmap=cmap, s=3)  # Adjust s to control dot size
            cbar = plt.colorbar(sc, ax=ax[2])
            cbar.set_label("Class")
            
            ax[0].set_title("Embedding Points")
            ax[1].set_title("Local dim")
            ax[2].set_title("Embedding Scatter")

            fib = feature_info[b]
            fie = feature_info[e-1]
            # assert fib[4] == fie[4]
            ttl = fib[1].replace(f'_{fib[0]}', '')
            plt.suptitle(ttl)
            plt.tight_layout(h_pad=0, w_pad=1)
            
            plt.savefig(p/f"plots_emb_{n_sample//1000}k_{n_neighbors}nn_b{b:05d}_e{e:05d}.png", dpi=300, bbox_inches='tight')
            plt.savefig(p/f"plots_emb_{n_sample//1000}k_{n_neighbors}nn_b{b:05d}_e{e:05d}.pdf", dpi=300, bbox_inches='tight')
            plt.show()
            plt.close()

            del mapper_1
        
        save_pckl(d_emb, f'embedding_sample_{n_sample//1000}k_{n_neighbors}nn.pckl', path=p.as_posix())
        del d_emb