In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import re
import joblib
from sklearn.cluster import AgglomerativeClustering
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from collections import defaultdict
import numba as nb
import javabridge
import bioformats
import aicspylibczi as aplc
import skimage.restoration as skr


In [None]:
gc.collect()


Move to workdir

In [None]:
# Absolute path
cluster = '/fs/cbsuvlaminck2/'
project_workdir = cluster + '/workdir/bmg224/manuscripts/mgefish/code/harvard_plasmids_imaging/2023_09_12_sapp05_spades08'

os.chdir(project_workdir)
os.getcwd()  # Make sure you're in the right directory

Load config file

In [None]:
config_fn = 'config_matrix_classify.yaml' # relative path to config file from workdir

with open(config_fn, 'r') as f:
    config = yaml.safe_load(f)

Special imports

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append(cluster + config['pipeline_path'] + '/' + config['functions_path'])
import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi



## Preprocessing
### Pick a sample

In [None]:
input_table = pd.read_csv(config['input_table_fn'])
input_table

In [None]:
sn_i = 6
# sample_names = input_table['sample_name'].values
# sn = sample_names[sn_i]
sn, mt, af = input_table.iloc[sn_i, :].values
sn

Load registered image

In [None]:
m = 3

shift_fmt = config['output_dir'] + '/' + config['shift_dir']
af_str = re.sub('\.','_',str(af))
shift_dir = shift_fmt.format(sample_name=sn, maskthresh=mt, allfluor=af_str)
raw_fns = sorted(glob.glob(shift_dir + '/' + sn + '_M_' + str(m) + '_*'))
raws = [np.load(f) for f in raw_fns]
[r.shape for r in raws]


Plot RGB

In [None]:
clips = [(0.075,0.75),(0.075,0.5),(0.05,0.5)]

rgb = [np.max(r, axis=2) for r in raws]
rgb_adj_lst = []
for r, clip in zip(rgb, clips):
    mx = np.max(rgb)
    mn = np.min(rgb)
    r_norm = (r - mn) / (mx - mn)
    r_adj = np.clip(r_norm, clip[0], clip[1])
    r_adj = (r_adj - clip[0]) / (clip[1] - clip[0])
    rgb_adj_lst.append(r_adj)
rgb_adj = np.dstack(rgb_adj_lst)


In [None]:
for r in rgb_adj_lst:
    ip.plot_image(r, cmap='inferno', im_inches=10)
ip.plot_image(rgb_adj, im_inches=10)

### Smoothing

In [None]:
def get_smooth(m_raws_shift, sigma):
    raws_smooth = []
    for im in m_raws_shift:
        im_smooth = np.empty(im.shape)
        for i in range(im.shape[2]):
            im_smooth[:,:,i] = gaussian_filter(im[:,:,i], sigma=sigma)
        raws_smooth.append(im_smooth)
    return raws_smooth

raws_smooth = get_smooth(raws, config['sigma'])

Plot RGB

In [None]:
clips = [(0.075,0.75),(0.075,0.35),(0.05,0.35)]

rgb = [np.max(r, axis=2) for r in raws_smooth]
rgb_smooth_adj_lst = []
for r, clip in zip(rgb, clips):
    mx = np.max(rgb)
    mn = np.min(rgb)
    r_norm = (r - mn) / (mx - mn)
    r_adj = np.clip(r_norm, clip[0], clip[1])
    r_adj = (r_adj - clip[0]) / (clip[1] - clip[0])
    rgb_smooth_adj_lst.append(r_adj)
rgb_smooth_adj = np.dstack(rgb_smooth_adj_lst)


In [None]:
for r in rgb_smooth_adj_lst:
    ip.plot_image(r, cmap='inferno', im_inches=10)
ip.plot_image(rgb_smooth_adj, im_inches=10)

### masking

In [None]:
stack = np.dstack(raws)
raw_max = np.max(stack, axis=2)
print('max',np.max(raw_max),
      'min',np.min(raw_max),
      'mean',np.mean(raw_max),
      'std',np.std(raw_max),
      'med',np.median(raw_max)
      )

In [None]:
n_std = 1

thresh = np.mean(raw_max) + np.std(raw_max)*n_std
mask = raw_max > thresh 
ip.plot_image(mask, im_inches=10)
ip.plot_image(mask*raw_max, cmap='inferno',im_inches=10)

## Zoom in on some cells

In [None]:
c = [624,650]
w = [25,25]

rgb_zoom = rgb_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]
ip.plot_image(rgb_zoom)

In [None]:
dims=(10,5)

stack_smooth = np.dstack(raws_smooth)
stack_smooth_zoom = stack_smooth[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
spec_zoom = stack_smooth_zoom[np.ones((w[0],w[1]), dtype=np.bool)]
fig, ax = ip.general_plot(dims=dims)
fsi.plot_cell_spectra(ax, spec_zoom, {'lw':0.2,'alpha':0.1,'color':'r'})


### Classify

Get references

In [None]:
def sum_normalize_ref(ref_spec):
    ref_sum_norm = []
    for r in ref_spec:
        r_ = r - np.min(r, axis=1)[:,None]
        ref_sum_norm.append(r_ / np.sum(r_, axis=1)[:,None])
    return [np.mean(r, axis=0) for r in ref_sum_norm]

def get_reference_spectra(barcodes, bc_len, config):
    ref_dir = config['hipr_ref_dir']
    fmt = config['ref_files_fmt']
    if bc_len == 5:
        barcodes_str = [str(bc).zfill(5) for bc in barcodes]
        # barcodes_str = [str(bc).zfill(7) for bc in barcodes]
        barcodes_10bit = [bc[0] + '0' + bc[1] + '0000' + bc[2:] for bc in barcodes_str]
        # barcodes_10bit = [bc[0] + '0' + bc[1:4] + '00' + bc[4:] for bc in barcodes_str]
        barcodes_b10 = [int(str(bc),2) for bc in barcodes_10bit]
        st = config['ref_chan_start'] + config['chan_start']
        en = config['ref_chan_start'] + config['chan_end']
        ref_avgint_cols = [i for i in range(st,en)]

        ref_spec = []
        for bc in barcodes_b10:
            fn = cluster + '/' + ref_dir + '/'+ fmt.format(bc)
            ref = pd.read_csv(fn, header=None)
            ref = ref[ref_avgint_cols].values
            ref_spec.append(ref)
    return ref_spec

# Get reference spectra
probe_design_dir = config['probe_design_dir']
probe_design_fn = cluster + '/' + probe_design_dir + '/' + config['probe_design_filename']
probe_design = pd.read_csv(probe_design_fn)
barcodes = probe_design['code'].unique()
barcode_length = len(str(np.max(barcodes)))
ref_spec = get_reference_spectra(barcodes, barcode_length, config)
sci_names = [probe_design.loc[probe_design['code'] == bc,'sci_name'].unique()[0] 
            for bc in barcodes]
weights_sum_norm = sum_normalize_ref(ref_spec)

Run classif

In [None]:

def run_matrix_multiply(spec_pix_gem, weights_sum_norm):
    weights_t = np.array(weights_sum_norm).T
    return np.matmul(spec_pix_gem, weights_t)

def get_present_filter(probe_design, nlas):
    present_col = probe_design['laser_present'].values
    present_str = [str(lp).zfill(nlas) for lp in present_col]
    present_arr = np.array([[int(l) for l in lp] for lp in present_str])
    return present_arr.T

def remove_possibilities_laserpresent(raws_smooth, mask, classif_mat, present_filter):
    las_max = [np.max(im, axis=2) for im in raws_smooth]
    # Get arrays for pixels
    las_max_stack = np.dstack(las_max)
    las_max_pix = las_max_stack[mask]
    las_max_pix_norm = las_max_pix / np.max(las_max_pix, axis=1)[:,None]
    # Define filter
    las_frac_thresh = config['laser_absent_thresholds']  
    las_max_present = las_max_pix_norm > las_frac_thresh
    # present_filter = np.array([
    #         [1,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
    #         [0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
    #         [0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,0]
    #         ])
    # Run filtering
    classif_adj_mat = np.matmul(las_max_present, present_filter)
    classif_adj_mat_bool = classif_adj_mat == np.max(classif_adj_mat, axis=1)[:,None]
    return classif_mat * classif_adj_mat_bool

def remove_possibilities_laserpresent_v2(spec, present_filter, classif_mat):
    chans = config['las_ranges']
    maxes = []
    for i in range(len(chans) - 1):
        mx = np.max(spec[:,chans[i]:chans[i+1]], axis=1)
        maxes.append(mx[:,None])
    maxes = np.hstack(maxes)
    print(maxes.shape)
    las_max_pix_norm = maxes / np.max(maxes, axis=1)[:,None]
    las_frac_thresh = config['laser_absent_thresholds']  
    las_max_present = las_max_pix_norm > las_frac_thresh
    # present_filter = np.array([
    #         [1,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
    #         [0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
    #         [0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,0]
    #         ])
    # Run filtering
    classif_adj_mat = np.matmul(las_max_present, present_filter)
    classif_adj_mat_bool = classif_adj_mat == np.max(classif_adj_mat, axis=1)[:,None]
    return classif_mat * classif_adj_mat_bool    


def pick_maximum_weight(classif_mat_adj, sci_names):
    classifs_index = np.argmax(classif_mat_adj, axis=1)
    return np.array([sci_names[i] for i in classifs_index])

def filter_dim_spectra(classifs, spectra_adj, thresh):
    spec_zoom_max = np.max(spectra_adj, axis=1)
    spec_zoom_max_norm = spec_zoom_max / np.max(spec_zoom_max)
    bool_max = spec_zoom_max_norm < thresh
    classifs[bool_max] = 'None'
    return classifs





In [None]:
spec_pix_gem = spec_zoom
raw_smooth = [r[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:] for r in raws_smooth]
mask_ = np.ones((w[0],w[1]), dtype=bool)

# Classify pixels
classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
present_filter = get_present_filter(probe_design, len(config['lasers']))
classif_mat_adj = remove_possibilities_laserpresent(raw_smooth, mask_, 
        classif_mat, present_filter)
classifs = pick_maximum_weight(classif_mat_adj, sci_names)
classifs = filter_dim_spectra(classifs, spec_pix_gem, 
                                thresh=config['dim_spec_filt'])

print(np.unique(classifs))

plot classif

In [None]:
def classif_to_image(classifs, pix_ind, shape, plot_intensities, col_dict):
    im_clust = np.zeros(shape + (len(list(col_dict.values())[0]),))
    # im_clust = np.zeros(max.shape + (len(eval(barcode_color.color.values[0])),))
    for lab, x, y in zip(classifs, pix_ind[0], pix_ind[1]):
        col = col_dict[lab]
        # col = np.array(col_dict[lab]) * sum_norm[x,y]
        im_clust[x,y,:] = np.array(col) * plot_intensities[x,y]
    return im_clust

pix_inds = np.where((mask_ > 0))
plot_intensities = np.ones_like(mask_)
col_dict = dict(zip(np.unique(classifs), plt.get_cmap('tab20').colors))
im_classif = classif_to_image(classifs, pix_inds, (w[0],w[1]), plot_intensities, col_dict)

In [None]:
col_dict.keys()

In [None]:
ip.plot_image(im_classif)
ip.taxon_legend(taxon_names=col_dict.keys(),taxon_colors=col_dict.values())

## Another cell

In [None]:
c = [535,485]
w = [20,20]

rgb_zoom = rgb_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]
ip.plot_image(rgb_zoom)

In [None]:
dims=(10,5)

stack_smooth = np.dstack(raws_smooth)
stack_smooth_zoom = stack_smooth[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
spec_zoom = stack_smooth_zoom[np.ones((w[0],w[1]), dtype=np.bool)]
fig, ax = ip.general_plot(dims=dims)
spec_zoom_filt = spec_zoom[np.max(spec_zoom, axis=1) > 0]
fsi.plot_cell_spectra(ax, spec_zoom_filt, {'lw':0.2,'alpha':0.1,'color':'r'})




spec_pix_gem = spec_zoom
raw_smooth = [r[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:] for r in raws_smooth]
mask_ = np.ones((w[0],w[1]), dtype=bool)

# Classify pixels
classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
present_filter = get_present_filter(probe_design, len(config['lasers']))
classif_mat_adj = remove_possibilities_laserpresent(raw_smooth, mask_, 
        classif_mat, present_filter)
classifs = pick_maximum_weight(classif_mat_adj, sci_names)
classifs = filter_dim_spectra(classifs, spec_pix_gem, 
                                thresh=config['dim_spec_filt'])

print(np.unique(classifs))




pix_inds = np.where((mask_ > 0))
plot_intensities = np.ones_like(mask_)
col_dict = dict(zip(np.unique(classifs), plt.get_cmap('tab20').colors))
im_classif = classif_to_image(classifs, pix_inds, (w[0],w[1]), plot_intensities, col_dict)





ip.plot_image(im_classif)
ip.taxon_legend(taxon_names=col_dict.keys(),taxon_colors=col_dict.values())

In [None]:
ip.plot_image(np.max(stack[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:], axis=2), cmap='inferno')
ip.plot_image(np.sum(stack[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:], axis=2), cmap='inferno')
ip.plot_image(np.sum(stack_smooth_zoom, axis=2), cmap='inferno')

## Another cell

In [None]:
c = [10,955]
w = [10,10]

rgb_zoom = rgb_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]
ip.plot_image(rgb_zoom)

In [None]:
dims=(10,5)

stack_smooth = np.dstack(raws_smooth)
stack_smooth_zoom = stack_smooth[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
spec_zoom = stack_smooth_zoom[np.ones((w[0],w[1]), dtype=np.bool)]
fig, ax = ip.general_plot(dims=dims)
spec_zoom_filt = spec_zoom[np.max(spec_zoom, axis=1) > 0]
fsi.plot_cell_spectra(ax, spec_zoom_filt, {'lw':0.2,'alpha':0.1,'color':'r'})




spec_pix_gem = spec_zoom
raw_smooth = [r[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:] for r in raws_smooth]
mask_ = np.ones((w[0],w[1]), dtype=bool)

# Classify pixels
classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
present_filter = get_present_filter(probe_design, len(config['lasers']))
classif_mat_adj = remove_possibilities_laserpresent(raw_smooth, mask_, 
        classif_mat, present_filter)
classifs = pick_maximum_weight(classif_mat_adj, sci_names)
classifs = filter_dim_spectra(classifs, spec_pix_gem, 
                                thresh=config['dim_spec_filt'])

print(np.unique(classifs))




pix_inds = np.where((mask_ > 0))
plot_intensities = np.ones_like(mask_)
col_dict = dict(zip(np.unique(classifs), plt.get_cmap('tab20').colors))
im_classif = classif_to_image(classifs, pix_inds, (w[0],w[1]), plot_intensities, col_dict)





ip.plot_image(im_classif)
ip.taxon_legend(taxon_names=col_dict.keys(),taxon_colors=col_dict.values())

## Another Cell

In [None]:
c = [0,900]
w = [100,100]

rgb_zoom = rgb_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]
ip.plot_image(rgb_zoom)

In [None]:
stack_smooth_zoom = stack_smooth[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
sum_smooth_zoom = np.sum(stack_smooth_zoom, axis=2)
ip.plot_image(sum_smooth_zoom, cmap='inferno')
mask_zoom = sum_smooth_zoom > 25000
ip.plot_image(sum_smooth_zoom*mask_zoom, cmap='inferno')

Segment cells

In [None]:
stack_pre = []
for i in range(stack.shape[2]):
    im = stack[:,:,i]
    im_ = sf.pre_process(im, gauss=2)
    stack_pre.append(im_)
stack_pre_sum = np.sum(np.dstack(stack_pre), axis=2)
stack_pre_sum_zoom = stack_pre_sum[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]


In [None]:
ip.plot_image(np.sum(stack, axis=2)[c[0]:c[0]+w[0], c[1]:c[1]+w[1]],cmap='inferno')
ip.plot_image(stack_pre_sum_zoom, cmap='inferno')

In [None]:
mask_zoom = sf.get_background_mask(stack_pre_sum_zoom)
ip.plot_image(stack_pre_sum_zoom*mask_zoom, cmap='inferno')

In [None]:
seg_zoom = sf.segment(stack_pre_sum_zoom, background_mask=mask_zoom)
seg_zoom_rgb = ip.seg2rgb(seg_zoom)
ip.plot_image(seg_zoom_rgb)

Get spectra

In [None]:
stack_sum_zoom = np.sum(stack, axis=2)[c[0]:c[0]+w[0], c[1]:c[1]+w[1]]
seg_zoom_props = sf.measure_regionprops(seg_zoom, raw=stack_sum_zoom)

In [None]:
stack_zoom = stack[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
seg = seg_zoom

dict_lab_spec = {}
im_raw = stack_smooth_zoom
for i, row in seg_zoom_props.iterrows():
    b = row.bbox
    l = row.label
    b = eval(b) if isinstance(b, str) else b
    r_sub = im_raw[b[0]:b[2],b[1]:b[3],:]
    m_sub = seg[b[0]:b[2],b[1]:b[3]] == l
    dict_lab_spec[l] = r_sub[m_sub]


In [None]:
spec_pix_gem.shape

In [None]:
dict_lab_classif = {}

for k, v in dict_lab_spec.items():
    b = seg_zoom_props.loc[seg_zoom_props.label == k, 'bbox'].values[0]
    b = eval(b) if isinstance(b, str) else b
    r_sub = rgb_smooth_zoom[b[0]:b[2],b[1]:b[3],:]
    m_sub = seg[b[0]:b[2],b[1]:b[3]] == k
    ip.plot_image(r_sub*np.dstack([m_sub]*3))
    plt.show()
    plt.close()
    fig, ax = ip.general_plot(dims=dims)
    # fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    fsi.plot_cell_spectra(ax, v, {'lw':1,'alpha':0.2,'color':'r'})
    plt.show()
    plt.close()

    spec_pix_gem = np.mean(v, axis=0)[None,:]
    classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
    present_filter = get_present_filter(probe_design, len(config['lasers']))
    classif_mat_adj = remove_possibilities_laserpresent_v2(spec_pix_gem, 
                                                           present_filter,
                                                           classif_mat)
    classifs = pick_maximum_weight(classif_mat_adj, sci_names)

    print(classifs)
    dict_lab_classif[k] = classifs[0]

Project onto image

In [None]:
classifs_uniq = np.unique(list(dict_lab_classif.values()))
dict_classif_rgb = dict(zip(classifs_uniq, plt.get_cmap('tab10').colors))

im_classif_rgb = np.zeros((seg.shape[0],seg.shape[1], 3))
for i, row in seg_zoom_props.iterrows():
    b = row.bbox
    l = row.label
    b = eval(b) if isinstance(b, str) else b
    c_sub = im_classif_rgb[b[0]:b[2],b[1]:b[3]]
    m_sub = seg[b[0]:b[2],b[1]:b[3]] == l
    cl = dict_lab_classif[l]
    color = np.array(dict_classif_rgb[cl])
    # c_sub += np.matmul(m_sub[:,:,None], color[None,:])
    c_sub[m_sub] = color
    im_classif_rgb[b[0]:b[2],b[1]:b[3],:] = c_sub


In [None]:
rgb_smooth_zoom = rgb_smooth_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
ip.plot_image(rgb_smooth_zoom)
fig, ax, cbar = ip.plot_image(im_classif_rgb)
for i, row in seg_zoom_props.iterrows():
    ax.text(row.centroid[1],row.centroid[0], row.label, c='w')
ip.taxon_legend(dict_classif_rgb.keys(), dict_classif_rgb.values())

Look at groupings of segmented cells

In [None]:
group_labels = [10,13,14]

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = dict_lab_spec[l]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1


In [None]:
group_labels = [2,3,5]

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = dict_lab_spec[l]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1


In [None]:
group_labels = [6,17,21]

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = dict_lab_spec[l]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1


In [None]:
group_labels = [1,8,9]

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = dict_lab_spec[l]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1


In [None]:
group_labels = [19,15,7,11]

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = dict_lab_spec[l]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1

fig, ax = ip.general_plot(dims=dims)
i=0
for l in group_labels:
    spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
# spec_group = np.vstack(spec_group)
# fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    col = plt.get_cmap('tab10').colors[i]
    fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
    i+=1


Do distance based clustering 

In [None]:
# Distance matrix
dict_lab_mean = {}
labels = np.unique(list(dict_lab_spec.keys()))
for l in labels:
    spec = dict_lab_spec[l]
    spec_mean = np.mean(spec, axis=0)
    spec_norm = spec_mean / np.sum(spec_mean)
    dict_lab_mean[l] = spec_norm

dist_mat = np.empty((len(labels),len(labels)))
for i, l0 in enumerate(labels):
    s0 = dict_lab_mean[l0]
    for j, l1 in enumerate(labels):
        s1 = dict_lab_mean[l1]
        dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)
    

In [None]:
from scipy.cluster import hierarchy

linkage = hierarchy.linkage(dist_mat, method='average')
fig, ax = ip.general_plot(dims=(5,5))
dn = hierarchy.dendrogram(linkage, labels=labels)

In [None]:
clust = hierarchy.fcluster(linkage, t=0.06, criterion='distance')
dict_clust_lab = defaultdict(list)
for cl, l in zip(clust,labels):
    dict_clust_lab[cl].append(l)

dict_clust_lab

In [None]:
for k, v in dict_clust_lab.items():
    print(k)
    fig, ax = ip.general_plot(dims=dims)
    i=0
    for l in v:
        spec_group = dict_lab_spec[l]
    # spec_group = np.vstack(spec_group)
    # fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
        col = plt.get_cmap('tab20').colors[i]
        fsi.plot_cell_spectra(ax, spec_group, {'lw':0.5,'alpha':0.1,'color':col})
        i+=1
    plt.show()
    plt.close()

    fig, ax = ip.general_plot(dims=dims)
    i=0
    for l in v:
        spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
    # spec_group = np.vstack(spec_group)
    # fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
        col = plt.get_cmap('tab20').colors[i]
        fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
        i+=1
    plt.show()
    plt.close()

    fig, ax = ip.general_plot(dims=dims)
    i=0
    for l in v:
        spec_group = np.mean(dict_lab_spec[l], axis=0)[None,:]
        spec_group = spec_group / np.max(spec_group, axis=1)[:,None]
    # spec_group = np.vstack(spec_group)
    # fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
        col = plt.get_cmap('tab20').colors[i]
        fsi.plot_cell_spectra(ax, spec_group, {'lw':1,'alpha':1,'color':col})
        i+=1
    plt.show()
    plt.close()


In [None]:
dict_lab_clust = {l:cl for l,cl in zip(labels,clust)}
clust_uniq = np.unique(list(dict_lab_clust.values()))
dict_clust_rgb = dict(zip(clust_uniq, plt.get_cmap('tab10').colors))

im_clust_rgb = np.zeros((seg.shape[0],seg.shape[1], 3))
for i, row in seg_zoom_props.iterrows():
    b = row.bbox
    l = row.label
    b = eval(b) if isinstance(b, str) else b
    c_sub = im_classif_rgb[b[0]:b[2],b[1]:b[3]]
    m_sub = seg[b[0]:b[2],b[1]:b[3]] == l
    # cl = dict_lab_classif[l]
    cl = dict_lab_clust[l]
    color = np.array(dict_clust_rgb[cl])
    # c_sub += np.matmul(m_sub[:,:,None], color[None,:])
    c_sub[m_sub] = color
    im_clust_rgb[b[0]:b[2],b[1]:b[3],:] = c_sub

In [None]:
rgb_smooth_zoom = rgb_smooth_adj[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
fig, ax, cbar = ip.plot_image(rgb_smooth_zoom)
# fig, ax, cbar = ip.plot_image(im_classif_rgb)

for i, row in seg_zoom_props.iterrows():
    lab = dict_lab_clust[row.label]
    ax.text(row.centroid[1],row.centroid[0], lab, c='w')

ip.plot_image(im_clust_rgb)
# ip.taxon_legend(dict_classif_rgb.keys(), dict_classif_rgb.values())

In [None]:
fig, ax = ip.general_plot(dims=dims)
for cl in clust_uniq:
    labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_lab_mean[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]    
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})

In [None]:
dict_sciname_spec = dict(zip(sci_names, ref_spec))

for cl in clust_uniq:
    labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_lab_mean[l])
    spec_cl = np.vstack(spec_cl) 

    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]
    classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
    present_filter = get_present_filter(probe_design, len(config['lasers']))
    classif_mat_adj = remove_possibilities_laserpresent_v2(spec_pix_gem, 
                                                        present_filter,
                                                        classif_mat)
    classifs = pick_maximum_weight(classif_mat_adj, sci_names)

    print(cl)
    print(classifs[0])

    fig, ax = ip.general_plot(dims=dims)
    # fsi.plot_cell_spectra(ax, spec_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
    fsi.plot_cell_spectra(ax, spec_cl, {'lw':1,'alpha':1,'color':'k'})
    r_spec = dict_sciname_spec[classifs[0]]
    r_spec = np.mean(r_spec, axis=0)[None,:]
    r_spec = r_spec / np.sum(r_spec)
    fsi.plot_cell_spectra(ax, r_spec, {'lw':1,'alpha':1,'color':'r'})
    plt.show()
    plt.close()
    

Look at the green channel variance

In [None]:
bg_thresh = 13000

gr_raw = raws_smooth[0]
gr_sums = np.sum(gr_raw, axis=2)
fig, ax = ip.general_plot(dims=(10,5))
ax.plot(np.sort(gr_sums[np.ones_like(gr_sums, dtype=bool)]))
xlims = ax.get_xlim()
ax.plot(xlims, [bg_thresh]*2)

In [None]:
gr_spec_bg = gr_raw[gr_sums < bg_thresh]
ind_sub = np.random.choice(np.arange(gr_spec_bg.shape[0]), size=100, replace=False)
gr_spec_bg_sub = gr_spec_bg[ind_sub,:]
fig, ax = ip.general_plot(dims=dims)
fsi.plot_cell_spectra(ax, gr_spec_bg_sub, {'lw':0.5,'alpha':0.2,'color':'r'})


In [None]:
dims=(10,5)

stack_smooth = np.dstack(raws_smooth)
stack_smooth_zoom = stack_smooth[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:]
spec_zoom = stack_smooth_zoom[mask_zoom]
fig, ax = ip.general_plot(dims=dims)
spec_zoom_filt = spec_zoom[np.max(spec_zoom, axis=1) > 0]
fsi.plot_cell_spectra(ax, spec_zoom_filt, {'lw':0.2,'alpha':0.1,'color':'r'})




spec_pix_gem = spec_zoom
raw_smooth = [r[c[0]:c[0]+w[0], c[1]:c[1]+w[1],:] for r in raws_smooth]
mask_ = mask_zoom

# Classify pixels
classif_mat = run_matrix_multiply(spec_pix_gem, weights_sum_norm)
present_filter = get_present_filter(probe_design, len(config['lasers']))
classif_mat_adj = remove_possibilities_laserpresent(raw_smooth, mask_, 
        classif_mat, present_filter)
classifs = pick_maximum_weight(classif_mat_adj, sci_names)
classifs = filter_dim_spectra(classifs, spec_pix_gem, 
                                thresh=config['dim_spec_filt'])

print(np.unique(classifs))




pix_inds = np.where((mask_ > 0))
plot_intensities = np.ones_like(mask_)
col_dict = dict(zip(np.unique(classifs), plt.get_cmap('tab20').colors))
im_classif = classif_to_image(classifs, pix_inds, (w[0],w[1]), plot_intensities, col_dict)




ip.plot_image(rgb_zoom*np.dstack([mask_zoom]*3))
ip.plot_image(im_classif)
ip.taxon_legend(taxon_names=col_dict.keys(),taxon_colors=col_dict.values())

## Plot references

In [None]:


for s, name in zip(ref_spec, sci_names):
    fig, ax = ip.general_plot(dims=dims)
    fsi.plot_cell_spectra(ax, s, {'lw':0.2,'alpha':0.1,'color':'r'})
    ax.set_title(name)



In [None]:
print('hello')

# Run clustering on all z stack

In [None]:
M = 26

im_inches=10
gauss=2

dict_m_seg_spectra = {}
# iterate through z 
for m in range(M):
    # Load registered image
    shift_fmt = config['output_dir'] + '/' + config['shift_dir']
    af_str = re.sub('\.','_',str(af))
    shift_dir = shift_fmt.format(sample_name=sn, maskthresh=mt, allfluor=af_str)
    raw_fns = sorted(glob.glob(shift_dir + '/' + sn + '_M_' + str(m) + '_*'))
    raws = [np.load(f) for f in raw_fns]

    # pre process
    def get_smooth(m_raws_shift, sigma):
        raws_smooth = []
        for im in m_raws_shift:
            im_smooth = np.empty(im.shape)
            for i in range(im.shape[2]):
                im_smooth[:,:,i] = sf.pre_process(im[:,:,i], gauss=gauss)
                # im_smooth[:,:,i] = gaussian_filter(im[:,:,i], sigma=sigma)
            raws_smooth.append(im_smooth)
        return raws_smooth
    raws_smooth = get_smooth(raws, config['sigma'])

    # # pre process
    # raws_pre = []
    # for i in range(stack.shape[2]):
    #     im = stack[:,:,i]
    #     im_ = sf.pre_process(im, gauss=gauss)
    #     raws_pre.append(im_)
    stack_pre = np.dstack(raws_smooth)
    stack_pre_sum = np.sum(stack_pre, axis=2)
    stack_pre_sum_zoom = stack_pre_sum
    ip.plot_image(np.sum(stack, axis=2),cmap='inferno', im_inches=im_inches)
    plt.show()
    plt.close()
    ip.plot_image(stack_pre_sum_zoom, cmap='inferno', im_inches=im_inches)    
    plt.show()
    plt.close()

    # Plot RGB
    clips = [(0.075,0.75),(0.075,0.3),(0.05,0.4)]
    rgb = [np.max(r, axis=2) for r in raws_smooth]
    rgb_smooth_adj_lst = []
    for r, clip in zip(rgb, clips):
        mx = np.max(rgb)
        mn = np.min(rgb)
        r_norm = (r - mn) / (mx - mn)
        r_adj = np.clip(r_norm, clip[0], clip[1])
        r_adj = (r_adj - clip[0]) / (clip[1] - clip[0])
        rgb_smooth_adj_lst.append(r_adj)
    rgb_smooth_adj = np.dstack(rgb_smooth_adj_lst)
    for r in rgb_smooth_adj_lst:
        ip.plot_image(r, cmap='inferno', im_inches=im_inches)
    ip.plot_image(rgb_smooth_adj, im_inches=im_inches)
    plt.show()
    plt.close()

    # mask
    mask_zoom = sf.get_background_mask(stack_pre_sum_zoom)
    ip.plot_image(stack_pre_sum_zoom*mask_zoom, cmap='inferno', im_inches=im_inches) 
    plt.show()
    plt.close()

    # segment
    seg_zoom = sf.segment(stack_pre_sum_zoom, background_mask=mask_zoom)
    seg_zoom_rgb = ip.seg2rgb(seg_zoom)
    ip.plot_image(seg_zoom_rgb, im_inches=im_inches)
    plt.show()
    plt.close()

    # Get spectra
    stack_sum_zoom = np.sum(stack, axis=2)
    seg_zoom_props = sf.measure_regionprops(seg_zoom, raw=stack_sum_zoom)
    # stack_zoom = stack
    seg = seg_zoom
    dict_lab_spec = {}
    im_raw = stack_pre
    for i, row in seg_zoom_props.iterrows():
        b = row.bbox
        l = row.label
        b = eval(b) if isinstance(b, str) else b
        r_sub = im_raw[b[0]:b[2],b[1]:b[3],:]
        m_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        dict_lab_spec[l] = r_sub[m_sub]
    dict_m_seg_spectra[m] = dict_lab_spec

In [None]:
# distance matrix 
dict_index_specnorm = {}
dict_index_ml = {}
i = 0
for m, dict_lab_spec in tqdm(dict_m_seg_spectra.items()):
    # labels = np.unique(list(dict_lab_spec.keys()))
    for l, spec in dict_lab_spec.items():
        # spec = dict_lab_spec[l]
        spec_mean = np.mean(spec, axis=0)
        spec_norm = spec_mean / np.sum(spec_mean)
        dict_index_specnorm[i] = spec_norm
        dict_index_ml[i] = [m,l]  # Associate the distance matrix index with a z image and seg object 
        i += 1

    

# dist_mat = np.empty((len(dict_index_specnorm),len(dict_index_specnorm)))
# for i, s0 in tqdm(dict_index_specnorm.items()):
#     for j, s1 in dict_index_specnorm.items():
#         if j > i:
#             dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
dist_mat = np.empty((len(dict_index_specnorm),len(dict_index_specnorm)))
for i, s0 in tqdm(dict_index_specnorm.items()):
    for j, s1 in dict_index_specnorm.items():
        if j > i:
            dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
linkage = hierarchy.linkage(dist_mat, method='average')
# fig, ax = ip.general_plot(dims=(10,5))
# dn = hierarchy.dendrogram(linkage, labels=labels)

In [None]:
fig, ax = ip.general_plot(dims=(10,5))
dn = hierarchy.dendrogram(linkage)

In [None]:
t = 8

indices = list(dict_index_specnorm.keys())
clust = hierarchy.fcluster(linkage, t=t, criterion='distance')
dict_clust_lab = defaultdict(list)
for cl, l in zip(clust,indices):
    dict_clust_lab[cl].append(l)

len(dict_clust_lab)

In [None]:
# print(dict_clust_lab.keys())
# [len(v) for k,v in dict_clust_lab.items()]
# for i, (j,k) in enumerate(np.array([[1,2,3],[1,2,3],[1,2,3]])[:,:2]):
#     print(i,j,k)
# indices = dict_clust_lab[2]
# len(indices)
# linkage_02[:,2] < t

In [None]:
# cl_isolate = 2

# indices = dict_clust_lab[2]
# n = len(dict_index_specnorm)
# linkage_02 = []
# for i, (l0,l1) in enumerate(linkage[:,:2]):
#     if (l0 in indices) or (l1 in indices):
#         linkage_02.append(linkage[i,:])
#         indices.append(i + n)

# linkage_02 = np.vstack(linkage_02)
# linkage_02_trim = linkage_02[linkage_02[:,2] < t]
# print(linkage.shape)
# linkage_02_trim.shape

In [None]:
fig, ax = ip.general_plot(dims=(10,5))
# dn = hierarchy.dendrogram(linkage_02_trim)

In [None]:
colors = list(plt.get_cmap('tab20').colors)
clusters = list(dict_clust_lab.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dict_clust_lab.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})
    j+=1

In [None]:
linkage[linkage[:,2] > 6,:]

In [None]:
cl_02 = 2
indices_02 = [int(i) for i in dict_clust_lab[cl_02]]

dist_mat_02 = np.empty((len(indices_02),len(indices_02)))
for i, l0 in enumerate(indices_02):
    s0 = dict_index_specnorm[l0]
    for j, l1 in enumerate(indices_02):
        s1 = dict_index_specnorm[l1]
        if j > i:
            dist_mat_02[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
linkage_02 = hierarchy.linkage(dist_mat_02, method='average')

In [None]:
fig, ax = ip.general_plot(dims=(10,5))
dn = hierarchy.dendrogram(linkage_02)

In [None]:
t = 0.16

clust02 = hierarchy.fcluster(linkage_02, t=t, criterion='distance')
dict_clust02_lab = defaultdict(list)
for cl, l in zip(clust02, indices_02):
    dict_clust02_lab[cl].append(l)

len(dict_clust02_lab)

In [None]:
colors = list(plt.get_cmap('tab20').colors)
clusters = list(dict_clust02_lab.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dict_clust02_lab.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})
    j+=1

Subcluster large cluster

In [None]:
# Get subsetted distance matrix
cl_03 = 1
indices_03 = [int(i) for i in dict_clust_lab[cl_03]]

dist_mat_03 = np.empty((len(indices_03),len(indices_03)))
for i, l0 in tqdm(enumerate(indices_03)):
    s0 = dict_index_specnorm[l0]
    for j, l1 in enumerate(indices_03):
        s1 = dict_index_specnorm[l1]
        if j > i:
            dist_mat_03[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
# Build linkage matrix
linkage_03 = hierarchy.linkage(dist_mat_03, method='average')

In [None]:
# show dendogram
fig, ax = ip.general_plot(dims=(10,5))
dn = hierarchy.dendrogram(linkage_03)

In [None]:
# Pick clustering level
t = 6.5

clust03 = hierarchy.fcluster(linkage_03, t=t, criterion='distance')
dict_clust03_lab = defaultdict(list)
for cl, l in zip(clust03, indices_03):
    dict_clust03_lab[cl].append(l)

len(dict_clust03_lab)

In [None]:
colors = list(plt.get_cmap('tab20').colors)
clusters = list(dict_clust03_lab.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dict_clust03_lab.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})
    j+=1

Try to subset a different way

In [None]:
print([len(v) for k, v in dict_clust_lab.items()])
print(np.unique(clust))

In [None]:
bool_c0 = clust == 1
t = 6.5

indices_c0 = np.array(indices)[bool_c0]
clust_c0_all = hierarchy.fcluster(linkage, t=t, criterion='distance')
clust_c0 = clust_c0_all[bool_c0]
dict_clust0_lab = defaultdict(list)
for cl, l in zip(clust_c0,indices_c0):
    dict_clust0_lab[cl].append(l)

len(dict_clust0_lab)

In [None]:
dcl = dict_clust0_lab
clust_ = clust_c0

colors = list(plt.get_cmap('tab20').colors)
clusters = list(dcl.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dcl.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})

In [None]:
np.unique(clust_c0_all, return_counts=True)

In [None]:
bool_c1 = clust_c0_all == 1
t = 5.5

indices_c1 = np.array(indices)[bool_c1]
clust_c1_all = hierarchy.fcluster(linkage, t=t, criterion='distance')
clust_c1 = clust_c1_all[bool_c1]
dict_clust1_lab = defaultdict(list)
for cl, l in zip(clust_c1,indices_c1):
    dict_clust1_lab[cl].append(l)

len(dict_clust1_lab)

In [None]:
dcl = dict_clust1_lab
clust_ = clust_c1

colors = list(plt.get_cmap('tab20').colors)
clusters = list(dcl.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dcl.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})

In [None]:
np.unique(clust_c1_all, return_counts=True)

In [None]:
bool_c2 = clust_c1_all == 1
t = 4.5

indices_c2 = np.array(indices)[bool_c2]
clust_c2_all = hierarchy.fcluster(linkage, t=t, criterion='distance')
clust_c2 = clust_c2_all[bool_c2]
dict_clust2_lab = defaultdict(list)
for cl, l in zip(clust_c2,indices_c2):
    dict_clust2_lab[cl].append(l)

len(dict_clust2_lab)
len(clust_c2)

In [None]:
dcl = dict_clust2_lab
clust_ = clust_c2

colors = list(plt.get_cmap('tab20').colors)
clusters = list(dcl.keys())
while len(colors) < len(clusters):
    colors += colors
dict_clust_rgb = dict(zip(clusters, colors))

fig, ax = ip.general_plot(dims=dims)
j=0
for cl, labels in dcl.items():
    # if j < 5:
    # labels = dict_clust_lab[cl]
    spec_cl = []
    for l in labels:
        spec_cl.append(dict_index_specnorm[l])
    spec_cl = np.vstack(spec_cl) 
    spec_pix_gem = np.mean(spec_cl, axis=0)[None,:]  
    spec_pix_gem /=np.sum(spec_pix_gem)  
    col = dict_clust_rgb[cl]
    fsi.plot_cell_spectra(ax, spec_pix_gem, {'lw':1,'alpha':1,'color':col})

Try umap visualization

In [None]:
import umap

fit = umap.UMAP(metric='precomputed')
u = fit.fit_transform(dist_mat)

In [None]:
plt.scatter(u[:,0], u[:,1], c='k',alpha=0.1)

In [None]:
# distance matrix 
dict_index_maxnorm = {}
dict_index_ml = {}
i = 0
for m, dict_lab_spec in tqdm(dict_m_seg_spectra.items()):
    # labels = np.unique(list(dict_lab_spec.keys()))
    for l, spec in dict_lab_spec.items():
        # spec = dict_lab_spec[l]
        spec_mean = np.mean(spec, axis=0)
        spec_norm = spec_mean / np.max(spec_mean)
        dict_index_maxnorm[i] = spec_norm
        dict_index_ml[i] = [m,l]  # Associate the distance matrix index with a z image and seg object 
        i += 1

    

# dist_mat = np.empty((len(dict_index_specnorm),len(dict_index_specnorm)))
# for i, s0 in tqdm(dict_index_specnorm.items()):
#     for j, s1 in dict_index_specnorm.items():
#         if j > i:
#             dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
dist_mat = np.empty((len(dict_index_maxnorm),len(dict_index_maxnorm)))
for i, s0 in tqdm(dict_index_maxnorm.items()):
    for j, s1 in dict_index_maxnorm.items():
        if j > i:
            dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
fit = umap.UMAP(metric='precomputed')
u = fit.fit_transform(dist_mat)

In [None]:
plt.scatter(u[:,0], u[:,1], c='k',alpha=0.1)

In [None]:
# distance matrix 
dict_index_spec = {}
dict_index_ml = {}
i = 0
for m, dict_lab_spec in tqdm(dict_m_seg_spectra.items()):
    # labels = np.unique(list(dict_lab_spec.keys()))
    for l, spec in dict_lab_spec.items():
        # spec = dict_lab_spec[l]
        spec_mean = np.mean(spec, axis=0)
        spec_norm = spec_mean
        dict_index_spec[i] = spec_norm
        dict_index_ml[i] = [m,l]  # Associate the distance matrix index with a z image and seg object 
        i += 1

    

# dist_mat = np.empty((len(dict_index_specnorm),len(dict_index_specnorm)))
# for i, s0 in tqdm(dict_index_specnorm.items()):
#     for j, s1 in dict_index_specnorm.items():
#         if j > i:
#             dist_mat[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
dist_mat_spec = np.empty((len(dict_index_spec),len(dict_index_spec)))
for i, s0 in tqdm(dict_index_spec.items()):
    for j, s1 in dict_index_spec.items():
        if j > i:
            dist_mat_spec[i,j] = fhc.euclid_dist_cumul_spec(s0,s1)

In [None]:
fit = umap.UMAP(metric='precomputed')
u = fit.fit_transform(dist_mat_spec)

In [None]:
plt.scatter(u[:,0], u[:,1], c='k',alpha=0.1)

In [None]:
dist_mat = np.empty((len(dict_index_maxnorm),len(dict_index_maxnorm)))
for i, s0 in tqdm(dict_index_maxnorm.items()):
    for j, s1 in dict_index_maxnorm.items():
        if j > i:
            dist_mat[i,j] = fhc.channel_cosine_intensity_5b_v2(s0,s1)

In [None]:
fit = umap.UMAP(metric='precomputed')
u = fit.fit_transform(dist_mat)

In [None]:
plt.scatter(u[:,0], u[:,1], c='k',alpha=0.1)