# Classify a hiprfish image with hierarchical clustering
## Setup

In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import re
from collections import defaultdict
import aicspylibczi as aplc
from skimage.registration import phase_cross_correlation
from tqdm import tqdm
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
import umap
import math
from sklearn.cluster import AgglomerativeClustering


In [None]:
cluster = ''
workdir = '/workdir/bmg224/manuscripts/mgefish/code/bmg_plasmids_imaging/agglomerative_clustering'
os.chdir(cluster + workdir)
os.getcwd()

In [None]:
config_fn = 'config_240107.yaml' # relative path to config file from workdir

with open(config_fn, 'r') as f:
    config = yaml.safe_load(f)

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append(config['pipeline_path'] + '/' + config['functions_path'])
import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi



In [None]:
# Get filenames from directories
raw_dir = config["data_dir"] + "/*.czi"
# raw_dir = config["data_dir"] + "/*" + config["laser_regex"]
fns = glob.glob(raw_dir)
fns_base = [os.path.split(f)[1] for f in fns]
group_names = [re.sub("_2024\w+.czi", "", s) for s in fns_base]
# print('HERE-->', shifts_fns)
group_names = np.sort(np.unique(group_names))
m_size = group_names.shape[0]
dict_group_czifns_all = defaultdict(list)
for g in group_names:
    for s in fns:
        if g in s:
            dict_group_czifns_all[g].append(s)
dict_group_czifns_all = {g: sorted(s) for g, s in dict_group_czifns_all.items()}
dict_group_czifns_all

In [None]:
dict_group_czifns = {k: [v[0], v[1], v[2]] for k, v in dict_group_czifns_all.items() if len(v) > 1}
dict_group_czifns

In [None]:
sn = "2024_01_07_newplasmidredo2reimage_slide_7_fov_03"

## Shift images

In [None]:
from cv2 import resize, INTER_CUBIC, INTER_NEAREST

def center_image(im, dims, ul_corner):
    shp = im.shape
    if not all([dims[i] == shp[i] for i in range(len(dims))]):
        shp_new = dims if len(shp) == 2 else dims + (shp[2],)
        temp = np.zeros(shp_new)
        br_corner = np.array(ul_corner) + np.array(shp[:2])
        temp[ul_corner[0]:br_corner[0], ul_corner[1]:br_corner[1]] = im
        im = temp
    return im

def resize_hipr(im, hipr_res, mega_res, dims='none', out_fn=False, ul_corner=(0,0)):
    # im = np.load(in_fn)
    factor_resize = hipr_res / mega_res
    hipr_resize = resize(
            im,
            None,
            fx = factor_resize,
            fy = factor_resize,
            interpolation = INTER_NEAREST
            )
    if isinstance(dims, str): dims = hipr_resize.shape
    hipr_resize = center_image(hipr_resize, dims, ul_corner)
    # if out_fn: np.save(out_fn, hipr_resize)
    return hipr_resize

def reshape_aics_image(m_img):
    '''
    Given an AICS image with just XY and CHannel,
    REshape into shape (X,Y,C)
    '''
    img = np.squeeze(m_img)
    img = np.transpose(img, (1,2,0))
    return img

def correct_633_sizing(raws, czifns):
    # Get resolution
    res_um_pix = []
    for fn in czifns:
        czi = aplc.CziFile(fn)
        for n in czi.meta.iter():
            if 'Scaling' in n.tag:
                if 'X' in n.tag:
                    resolution = float(n.text)
        res_um_pix.append(resolution * 10**6)
    res_633 = res_um_pix[3]
    res_other = res_um_pix[0]
    print(res_633, '-->', res_other)
    # REsize 633
    # raws_resize = raws.copy()
    resize_633 = resize_hipr(raws[3], res_633, res_other)
    resize_shape = resize_633.shape[:2]
    # other_shape = np.array(raws[0].shape[:2])
    # corner = (np.array(resize_633.shape[:2]) - np.array(other_shape)) // 2
    raws_resize = []
    for i in range(3):
        ri = raws[i]
        sp = ri.shape[2]
        c = (np.array(resize_shape) - np.array(ri.shape[:2])) // 2
        r = np.zeros(resize_shape + (sp,))
        r[c[0]:c[0] + ri.shape[0], c[1]:c[1] + ri.shape[1],:] = ri
        raws_resize.append(r)
    raws_resize.append(resize_633)
    # resize_633 = resize_633[corner[0]:corner[1]+other_shape[0], corner[1]:corner[1]+other_shape[1],:]
    # raws_resize[3] = resize_633
    return(raws_resize, c) 


def _get_shift_vectors(image_sum):
    # Find shift vectors
    shift_vectors = [
            phase_cross_correlation(
                    np.log(image_sum[0]+1), np.log(image_sum[i]+1)
                    )[0]
            for i in range(1,len(image_sum))
            ]
    shift_vectors.insert(0, np.asarray([0.0,0.0]))
    return shift_vectors


def max_norm(raw, c=['min','max']):
    im = np.max(raw, axis=2)
    mn = np.min(im) if c[0] == 'min' else c[0]
    mx = np.max(im) if c[1] == 'max' else c[1]
    im = np.clip(im, mn, mx)
    return (im - mn) / (mx - mn)

In [None]:
czifns = dict_group_czifns[sn]
dimshape = aplc.CziFile(czifns[0]).get_dims_shape()[0]
dimshape

In [None]:
M = 4
# M = dimshape['S'][1]

colors = [[0,0.5,0],[0.5,0,0],[0,0,0.5]]
clips_rgb = [[0,5000],[0,10000],[0,5000],[0,2500]]
stacks = []
im_inches = 10
# for m in [2]:
for m in range(M):
    print("\n\nTile:",m)
    raws = []
    for fn, las in zip(czifns, config['lasers']):
        if las not in config['rgb']['exclude_lasers']:
            czi = aplc.CziFile(fn)
            # if M:
            im, sh = czi.read_image(M=m)

            im = reshape_aics_image(im)
            print(im.shape)
            raws.append(im)
    
    raws_max_norm = [max_norm(r, c) for r, c in zip(raws, clips_rgb)]
    shift_vectors = _get_shift_vectors(raws_max_norm)
    print(shift_vectors)
    # shift_vectors_1_3 = _get_shift_vectors(raws_max_norm[:3])
    
    # raws_resize, c = correct_633_sizing(raws, czifns)
    # raws_shift_max_norm = [max_norm(r, c) for r, c in zip(raws_resize, clips_rgb)]
    # shp = raws_max_norm[0].shape
    # raws_max_norm[3] = raws_shift_max_norm[3][:shp[0],:shp[1]]
    # shift_vectors = _get_shift_vectors(raws_max_norm)
    # shift_vectors[3] = shift_vectors[3] + c
    # max_shift = config['max_shift']
    # raws_shift = fsi._shift_images(raws_resize, shift_vectors, max_shift=max_shift)
    max_shift = config['max_shift']
    raws_shift = fsi._shift_images(raws, shift_vectors, max_shift=max_shift)
        
    # shift_vectors_4 = _get_shift_vectors(raws_shift_max_norm[2:4])
    # shift_vectors = shift_vectors_1_3 + [shift_vectors_4[1]]

    # raws_shift = fsi._shift_images(raws_resize, shift_vectors, max_shift=max_shift)
    im_r = max_norm(raws_shift[0], clips_rgb[0])
    stack_max = [im_r]
    true_points = np.argwhere(im_r)
    top_left = true_points.min(axis=0)
    bottom_right = true_points.max(axis=0)
    for i in range(1,len(raws_shift)):
        im_rgb = np.zeros(im_r.shape + (3,))
        im_rgb[:,:,0] = im_r
        im_g = max_norm(raws_shift[i], clips_rgb[i])
        stack_max.append(im_g)
        im_rgb[:,:,1] = im_g
        print(i)
        ip.plot_image(im_rgb, im_inches=im_inches)
        plt.show()    
        plt.close()
        true_points = np.argwhere(im_g)
        tl = true_points.min(axis=0)
        top_left = np.max(np.vstack([top_left, tl]), axis=0)
        br = true_points.max(axis=0)
        bottom_right = np.min(np.vstack([bottom_right, br]), axis=0)
    raws_chan_col = []
    for c, im in zip(colors, stack_max):
        raws_chan_col.append(im[...,None] * np.array(c)[None,None,:])    
    rgb = np.zeros_like(raws_chan_col[0])
    for im in raws_chan_col:
        rgb += im
    rgb_trim = rgb[
        top_left[0]:bottom_right[0]+1,
        top_left[1]:bottom_right[1]+1
        ]
    ip.plot_image(rgb_trim, im_inches=im_inches)
    plt.show()
    plt.close()
    stacks.append(np.dstack(raws_shift)[
        top_left[0]:bottom_right[0]+1,
        top_left[1]:bottom_right[1]+1
        ])




In [None]:
for n in czi.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            resolution = float(n.text)
resolution

### Segment

In [None]:
mxs = []
lasers = ['488','514','561']
chan_inds = [0,23,43,57]
for i in range(len(lasers)):
    print('Laser',lasers[i])
    for stack in stacks:
        mx = np.max(stack[:,:,chan_inds[i]:chan_inds[i+1]],axis=2).ravel()
        mxs += mx.tolist()

    mxs_sub = np.random.choice(mxs,100000)
    mxs_sub = np.sort(mxs_sub)


    rnd = 10000

    fig, ax = ip.general_plot(dims=(10,5))
    ax.scatter(np.arange(len(mxs_sub)), mxs_sub, s=1)
    ylim = int(math.ceil(ax.get_ylim()[1] / rnd)) * rnd
    ax.set_ylim(0,ylim)
    ax.set_yticks(np.arange(0, ylim, rnd//16))
    ax.grid(axis='y')
    plt.show()

In [None]:
mask_threshs = [300,300,300]


masks = []
for stack in stacks:
    masks_chan = []
    for i in range(len(lasers)):
        mx = np.max(stack[:,:,chan_inds[i]:chan_inds[i+1]],axis=2)
        masks_chan.append(mx > mask_threshs[i])
    mask = np.zeros_like(masks_chan[0])
    for m in masks_chan:
        mask += m
    masks.append(mask)


In [None]:
# Set param

z = [500, 1000, 500, 1000]
imin = 10
gauss = 4
diff_gauss = (0,)
mask_thresh = 700

i = 0
stack = stacks[i]


# mask = masks[i]

im_max = np.max(stack, axis=2)
im_sum = np.sum(stack, axis=2)
# mask = im_max > mask_thresh

ip.plot_image(im_max[z[0] : z[1], z[2] : z[3]], cmap="inferno", im_inches=imin)

pre_max = sf.pre_process(im_max, gauss=gauss, diff_gauss=diff_gauss)
ip.plot_image(pre_max[z[0] : z[1], z[2] : z[3]], cmap="inferno", im_inches=imin)

pre = sf.pre_process(im_sum, gauss=gauss, diff_gauss=diff_gauss)
ip.plot_image(
    pre[z[0] : z[1], z[2] : z[3]], cmap="inferno", im_inches=imin, clims=(3000, 10000)
)

mask = pre_max > mask_thresh
fig, ax, cbar = ip.plot_image(
    (pre * mask)[z[0] : z[1], z[2] : z[3]],
    cmap="inferno",
    im_inches=imin,
    clims=(3000, 10000),
)
ax.imshow(np.dstack([0.5 * (~mask)] * 4)[z[0] : z[1], z[2] : z[3]])

ip.plot_image(pre, cmap="inferno", im_inches=imin)

fig, ax, cbar = ip.plot_image(
    (pre * mask), cmap="inferno", im_inches=imin, clims=(3000, 10000)
)
ax.imshow(np.dstack([0.5 * (~mask)] * 4))

seg = sf.segment(
    pre_max[z[0] : z[1], z[2] : z[3]], background_mask=mask[z[0] : z[1], z[2] : z[3]]
)
seg_zoom_rgb = ip.seg2rgb(seg)
ip.plot_image(seg_zoom_rgb, im_inches=imin)
# ip.plot_image(seg_zoom_rgb[z[0]:z[1],z[2]:z[3]], im_inches=imin)

fig, ax, cbar = ip.plot_image(
    (pre_max[z[0] : z[1], z[2] : z[3]] * (seg > 0)),
    cmap="inferno",
    im_inches=imin,
)
ax.imshow(np.dstack([0.5 * (~(seg > 0))] * 4))

In [None]:
segs = []
specs = []
props = []
for stack, mask in tqdm(zip(stacks, masks)):
    im_max = np.max(stack, axis=2)
    im_sum = np.sum(stack, axis=2)
    # ip.plot_image(im_max[z[0]:z[1],z[2]:z[3]], cmap='inferno', im_inches=imin)
    pre = sf.pre_process(im_max, gauss=gauss, diff_gauss=diff_gauss)
    # ip.plot_image(pre[z[0]:z[1],z[2]:z[3]], cmap='inferno', im_inches=imin)
    # mask = im_max > mask_thresh
    # ip.plot_image((pre), cmap='inferno', im_inches=imin)
    # ip.plot_image((pre*mask), cmap='inferno', im_inches=imin)
    seg = sf.segment(pre, background_mask=mask)
    segs.append(seg)
    prop = sf.measure_regionprops(seg, raw=im_sum)
    props.append(prop)
    dict_lab_spec = {}
    for i, row in prop.iterrows():
        b = row.bbox
        l = row.label
        b = eval(b) if isinstance(b, str) else b
        r_sub = stack[b[0]:b[2],b[1]:b[3],:]
        m_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        dict_lab_spec[l] = r_sub[m_sub]
    dict_lab_specmean = {l:np.mean(s, axis=0) for l, s in dict_lab_spec.items()}
    specs.append(dict_lab_specmean)


Filter intensity

In [None]:
for i in range(len(lasers)):
    mxs = []
    print('Laser',lasers[i])
    for dict_lab_specmean in specs:
        s_ch = np.array(list(dict_lab_specmean.values()))[:,chan_inds[i]:chan_inds[i+1]]
        mx = np.max(s_ch, axis=1)
        # mx = np.max(stack[:,:,chan_inds[i]:chan_inds[i+1]],axis=2).ravel()
        mxs += mx.tolist()
    # mxs_sub = np.random.choice(mxs,100000)
    mxs = np.sort(mxs)


    rnd = 1000

    fig, ax = ip.general_plot(dims=(10,5))
    ax.scatter(np.arange(len(mxs)), mxs, s=1)
    ylim = int(math.ceil(ax.get_ylim()[1] / rnd)) * rnd
    ax.set_ylim(0,ylim)
    ax.set_yticks(np.arange(0, ylim, rnd//4))
    ax.grid(axis='y')
    plt.show()

In [None]:
seg_threshs = [0,0,0]
specs_filt = []
for dict_lab_specmean in specs:
    dict_lab_specfilt = {}
    for l, s in dict_lab_specmean.items():
        mxs_ch = [np.max(s[chan_inds[i]:chan_inds[i+1]]) for i in range(len(lasers))]
        mxs_ch_bool = [mxs_ch[i] > seg_threshs[i] for i in range(len(lasers))]
        if any(mxs_ch_bool):
            dict_lab_specfilt[l] = s
    specs_filt.append(dict_lab_specfilt)

print([len(s) for s in specs])
print([len(s) for s in specs_filt])

In [None]:
ims_rgb = []
for m in range(M):
    im_sum = np.sum(stacks[m], axis=2)

    mx = np.max(im_sum)*0.35
    # mx = 60000
    mn = np.min(im_sum)
    im_norm = np.clip(im_sum,mn,mx)
    im_norm = (im_norm - mn) / (mx - mn)
    im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
    im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
    ims_rgb.append(im_rgb)
    fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for m in range(M):
    dict_lab_spec = specs[m]
    labels = list(dict_lab_spec.keys())
    seg = segs[m]
    prop = props[m]
    im_rgb = ims_rgb[m].copy()
    for i, row in prop.iterrows():
        l = row.label
        if l in labels:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
            im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            color = np.array([0,0,1,0.5])
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

    fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:

for m in range(M):
    # im_sum = np.sum(stacks[m], axis=2)

    # mx = np.max(im_sum)*0.35
    # # mx = 60000
    # mn = np.min(im_sum)
    # im_norm = np.clip(im_sum,mn,mx)
    # im_norm = (im_norm - mn) / (mx - mn)
    # im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
    # im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])

    dict_lab_specfilt = specs_filt[m]
    labels = list(dict_lab_specfilt.keys())
    seg = segs[m]
    prop = props[m]
    im_rgb = ims_rgb[m].copy()
    for i, row in prop.iterrows():
        l = row.label
        if l in labels:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
            im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            color = np.array([0,0,1,0.5])
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

    fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

## Save

In [None]:
output_dir = config['output_dir'] + '/' + sn
# output_dir = output_dir_fmt.format(sample_name=sn)
if not os.path.exists(output_dir): 
    os.makedirs(output_dir)
    print("made dir:",output_dir)


In [None]:
stack_dir = output_dir + '/stacks'
props_dir = output_dir + '/props'
segs_dir = output_dir + '/segs'
spec_dir = output_dir + '/spectra'
# clust_dir = output_dir + '/clust'

for dir_ in [spec_dir, segs_dir, props_dir, stack_dir]:
    if not os.path.exists(dir_):
        os.makedirs(dir_)
        print("made dir:",dir_)


In [None]:
for m in range(M):
    bn = sn + '_M_' + str(m)
    stack_fn = stack_dir + '/' + bn + '_stack.npy'
    props_fn = props_dir + '/' + bn + '_props.csv'
    seg_fn = segs_dir + '/' + bn + '_seg.npy'
    spec_fn = spec_dir + '/' + bn + '_spec.yaml'
    # clust_fn = clust_dir + '/' + bn + '_clust.yaml'

    np.save(stack_fn, stacks[m])
    print('Wrote:',stack_fn)
    props[m].to_csv(props_fn, index=False)
    print('Wrote:',props_fn)
    np.save(seg_fn, segs[m])
    print('Wrote:',seg_fn)
    with open(spec_fn, 'w') as f:
        yaml.dump(specs[m], f)
    print('Wrote:',spec_fn)
    # with open(clust_fn, 'w') as f:
    #     yaml.dump(dict_lab_clust, f)
    # print('Wrote:',clust_fn)


In [None]:
# rnd = 10000

# fig, ax = ip.general_plot(dims=(10,5))
# ax.scatter(np.arange(len(mxs)), mxs, s=1)
# ylim = int(math.ceil(ax.get_ylim()[1] / rnd)) * rnd
# ax.set_ylim(0,ylim)
# ax.set_yticks(np.arange(0, ylim, rnd//4))
# ax.grid(axis='y')


In [None]:
# spec_thresh = 2000
# specs_filt = []
# for dict_lab_specmean in specs:
#     dict_lab_specfilt = {}
#     for l, s in dict_lab_specmean.items():
#         if np.max(s) > spec_thresh:
#             dict_lab_specfilt[l] = s
#     specs_filt.append(dict_lab_specfilt)

## Average spectrum

In [None]:
spec_dims = (10, 5)

specs_arr = np.vstack([np.vstack([v for v in s.values()]) for s in specs_filt])
specs_arr.shape


In [None]:
fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr, {'lw':0.1,'alpha':0.1,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_arr_med = np.median(specs_arr, axis=0)

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_med[None,:], {'lw':1,'alpha':1,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_arr_mean = np.mean(specs_arr, axis=0)

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_mean[None,:], {'lw':1,'alpha':1,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_arr_std = np.std(specs_arr, axis=0)

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_std[None,:], {'lw':1,'alpha':1,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_arr_sub = specs_arr_med - specs_arr_std

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_sub[None,:], {'lw':1,'alpha':1,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_arr_bgfilt = specs_arr - specs_arr_med
specs_arr_bgfilt[specs_arr_bgfilt < 0] = 0

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_bgfilt, {'lw':0.1,'alpha':0.51,'color':'r'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
specs_bgfilt_mx = np.max(specs_arr_bgfilt, axis=1)
specs_bgfilt_mx = np.sort(specs_bgfilt_mx)
plt.scatter(np.arange(specs_bgfilt_mx.shape[0]), specs_bgfilt_mx)

In [None]:
len(specs_filt[m])

In [None]:
specs_corr = []
for s in specs_filt:
    s_ = {}
    for k, v in s.items():
        v_ = v - specs_arr_med
        v_[v_ < 0] = 0
        s_[k] = v_
    specs_corr.append(s_)


In [None]:
len(specs_corr[m])

## Clustering auto

In [None]:
n_clust = 20

# Setup output dirs
output_dir = '../outputs/agglomerative_cluster/' + sn
# output_dir = output_dir_fmt.format(sample_name=sn)
if not os.path.exists(output_dir): 
    os.makedirs(output_dir)
    print("made dir:",output_dir)

stack_dir = output_dir + '/stacks'
props_dir = output_dir + '/props'
segs_dir = output_dir + '/segs'
spec_dir = output_dir + '/spectra'
clust_dir = output_dir + '/clust'

for dir_ in [clust_dir, spec_dir, segs_dir, props_dir, stack_dir]:
    if not os.path.exists(dir_):
        os.makedirs(dir_)
        print("made dir:",dir_)


for m in range(M):
    # Get derivative of spectra
    dict_lab_specmean = specs_filt[m]
    dict_lab_slope = {l:np.diff(s) for l, s in dict_lab_specmean.items()}


    # Build distance matrix between all cell spectra
    n_cells = len(dict_lab_slope)

    dist_mat = np.zeros((n_cells,n_cells))
    for i, s_i in tqdm(enumerate(dict_lab_slope.values())):
        s_i_sub = s_i
        for j, s_j in enumerate(dict_lab_slope.values()):
            # if j > i:
            s_j_sub = s_j
            dist_mat[i,j] = fhc.channel_cosine_intensity_allonev2(s_i_sub, s_j_sub)

    # Cluster
    condensed_dist_mat = squareform(dist_mat)
    linkage = hierarchy.linkage(condensed_dist_mat, method='complete')
    labels = list(dict_lab_slope.keys())
    agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')

    agg.fit(dist_mat)

    clust_all = agg.labels_

    dict_lab_clust = dict(zip(labels,clust_all))

    # Save files
    bn = sn + '_M_' + str(m)
    stack_fn = stack_dir + '/' + bn + '_stack.npy'
    props_fn = props_dir + '/' + bn + '_props.csv'
    seg_fn = segs_dir + '/' + bn + '_seg.npy'
    spec_fn = spec_dir + '/' + bn + '_spec.yaml'
    clust_fn = clust_dir + '/' + bn + '_clust.yaml'

    np.save(stack_fn, stacks[m])
    print('Wrote:',stack_fn)
    props[m].to_csv(props_fn, index=False)
    print('Wrote:',props_fn)
    np.save(seg_fn, segs[m])
    print('Wrote:',seg_fn)
    with open(spec_fn, 'w') as f:
        yaml.dump(specs[m], f)
    print('Wrote:',spec_fn)
    with open(clust_fn, 'w') as f:
        yaml.dump(dict_lab_clust, f)
    print('Wrote:',clust_fn)



### Clustering Manual

In [None]:
m = 0

In [None]:
im_sum = np.sum(stacks[m], axis=2)

mx = np.max(im_sum)*0.35
# mx = 60000
mn = np.min(im_sum)
im_norm = np.clip(im_sum,mn,mx)
im_norm = (im_norm - mn) / (mx - mn)
im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
np.max(list(specs[m].values()))

In [None]:
# n_cells = 0
# slopes = []
# dict_ind_lab = {}
# ind = 0
# for m, dict_lab_specmean in enumerate(specs):
#     s_ = []
#     for l,s in dict_lab_specmean.items():
#         s_.append(np.diff(s))
#         dict_ind_lab[ind] = [m,l]
#         ind += 1   
#     slopes += s_ 
#     # n_cells += len(dict_lab_specmean)
# n_cells = len(dict_ind_lab)
# print(n_cells)
dict_lab_specmean = specs_corr[m]
# dict_lab_specmean = specs_filt[m]
dict_lab_slope = {l:np.diff(s) for l, s in dict_lab_specmean.items()}



In [None]:
len(specs_corr[m])

In [None]:
# dist_mat = np.zeros((n_cells,n_cells))

# for i, s_i in tqdm(enumerate(slopes)):
#     s_i_sub = s_i
#     for j, s_j in enumerate(slopes):
#         if j > i:
#             s_j_sub = s_j
#             dist_mat[i,j] = fhc.channel_cosine_intensity_allonev2(s_i_sub, s_j_sub)
# dist_mat_t = dist_mat.copy().T
# for i in range(dist_mat.shape[0]):
#     dist_mat_t[i,i] = 0
# dist_mat += dist_mat_t
# ip.plot_image(dist_mat, cmap='inferno',im_inches=imin)

n_cells = len(dict_lab_slope)

dist_mat = np.zeros((n_cells,n_cells))
for i, s_i in tqdm(enumerate(dict_lab_slope.values())):
    s_i_sub = s_i
    for j, s_j in enumerate(dict_lab_slope.values()):
        # if j > i:
        s_j_sub = s_j
        dist_mat[i,j] = fhc.channel_cosine_intensity_allonev2(s_i_sub, s_j_sub)
ip.plot_image(dist_mat, cmap='inferno',im_inches=imin)

In [None]:
condensed_dist_mat = squareform(dist_mat)
linkage = hierarchy.linkage(condensed_dist_mat, method='complete')

In [None]:
labels = list(dict_lab_slope.keys())
len(labels)

In [None]:

fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode=None)
# dn = hierarchy.dendrogram(linkage, labels=inds, truncate_mode='lastp')
# ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
spec_arr_order = []
for l in dn['ivl']:
    s = dict_lab_specmean[l]
    spec_arr_order.append(s)
spec_arr_order = np.vstack(spec_arr_order)



fig = plt.figure(figsize=(15,5))
axs = {
    0: fig.add_axes([0.1,0.1,0.1,0.8]),
    1: fig.add_axes([0.2,0.1,0.7,0.8])
}
hierarchy.dendrogram(linkage, ax=axs[0], labels=labels, orientation='left', no_labels=True, truncate_mode=None)
axs[1].imshow(np.flip(spec_arr_order, axis=0), cmap='inferno', aspect='auto')
axs[1].axis('off')
axs[0].axis('off')
plt.show()
plt.close()

In [None]:
fit = umap.UMAP(metric='precomputed')
u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], alpha=0.25)
plt.show()
plt.close()

In [None]:
from sklearn.cluster import AgglomerativeClustering

n_clust = 4
agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')

agg.fit(dist_mat)

clust_agg = agg.labels_



In [None]:

plt.scatter(u[:,0], u[:,1], c=clust_agg, alpha=0.25, cmap='tab10')
plt.show()
plt.close()

In [None]:
clust_all = clust_agg


In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab10')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
# im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
# im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
im_rgb = ims_rgb[m]

seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
spec_dims = (10,5)

for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':1,'alpha':0.2,'color':color})
    # ax.set_ylim(0,12500)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

### Subset clusters

In [None]:
t = {}
clust = {}
i = 0
t[i] = 0.065

clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
clust_unique = np.unique(clust[i])
n_clust = len(clust_unique)

clust_all = clust[i]
labels = list(dict_lab_slope.keys())

dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.unique(clust_all), colors))

im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
seg = segs[m]
prop = props[m]
for _, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    # if len(labels_sub) > 100:
    print("Cluster:",c) 
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = np.array(dict_clust_col[c])
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.5,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()


In [None]:
sub_cl = {}
i = 1
i_old = 0
sub_cl[i] = 39
t[i] = 0.050

clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_change = np.unique(clust[i][bool_subclust])

clust_all_temp = clust_all.copy()
clmx = np.max(clust_all_temp)
clust_i_rename = clust_all_temp.copy()
dict_clust_change = defaultdict(dict)
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    dict_clust_change[i][clmx] = c
    clust_i_rename[bool_c] = clmx
clust_all_temp[bool_subclust] = clust_i_rename[bool_subclust]


clust_unique = np.unique(clust_all_temp)
n_clust = len(clust_unique)

labels = list(dict_lab_slope.keys())

dict_lab_clust = dict(zip(labels,clust_all_temp))
colors = plt.get_cmap('tab10')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.unique(clust_all_temp), colors))

im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
seg = segs[m]
prop = props[m]
for _, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)
plt.show()
plt.close()

spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust_all_temp == c]
    # if len(labels_sub) > 100:
    print("Cluster:",c) 
    try:
        print("Changed from new fcluster:",dict_clust_change[i][c]) 
    except:
        pass
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = np.array(dict_clust_col[c])
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.5,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()


In [None]:
clust_all = clust_all_temp

In [None]:
i = 2
i_old = 0
sub_cl[i] = 3
t[i] = 0.21

clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_change = np.unique(clust[i][bool_subclust])

clust_all_temp = clust_all.copy()
clmx = np.max(clust_all_temp)
clust_i_rename = clust_all_temp.copy()
dict_clust_change = defaultdict(dict)
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    dict_clust_change[i][clmx] = c
    clust_i_rename[bool_c] = clmx
clust_all_temp[bool_subclust] = clust_i_rename[bool_subclust]


clust_unique = np.unique(clust_all_temp)
n_clust = len(clust_unique)

labels = list(dict_lab_slope.keys())

dict_lab_clust = dict(zip(labels,clust_all_temp))
colors = plt.get_cmap('tab10')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.unique(clust_all_temp), colors))

im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
seg = segs[m]
prop = props[m]
for _, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)
plt.show()
plt.close()

spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust_all_temp == c]
    # if len(labels_sub) > 100:
    print("Cluster:",c) 
    try:
        print("Changed from new fcluster:",dict_clust_change[i][c]) 
    except:
        pass
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = np.array(dict_clust_col[c])
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()


In [None]:
clust_all = clust_all_temp

In [None]:
i = 3
i_old = 1
sub_cl[i] = 6
t[i] = 0.14

clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_change = np.unique(clust[i][bool_subclust])

clust_all_temp = clust_all.copy()
clmx = np.max(clust_all_temp)
clust_i_rename = clust_all_temp.copy()
dict_clust_change = defaultdict(dict)
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    dict_clust_change[i][clmx] = c
    clust_i_rename[bool_c] = clmx
clust_all_temp[bool_subclust] = clust_i_rename[bool_subclust]


clust_unique = np.unique(clust_all_temp)
n_clust = len(clust_unique)

labels = list(dict_lab_slope.keys())

dict_lab_clust = dict(zip(labels,clust_all_temp))
colors = plt.get_cmap('tab10')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.unique(clust_all_temp), colors))

im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])
seg = segs[m]
prop = props[m]
for _, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)
plt.show()
plt.close()

spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust_all_temp == c]
    # if len(labels_sub) > 100:
    print("Cluster:",c) 
    try:
        print("Changed from new fcluster:",dict_clust_change[i][c]) 
    except:
        pass
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = np.array(dict_clust_col[c])
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.25,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()


In [None]:
clust_change = []

clmx = np.max(clust_all)
clust_i_rename = clust_all.copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

clust_all[bool_subclust] = clust_i_rename[bool_subclust]

### Remove clusters

In [None]:
ignore = [9]
clust_trim = clust_all.copy()
for c in ignore:
    bool_c = clust_all == c
    clust_trim[bool_c] = 0

In [None]:
dict_lab_clust = dict(zip(labels,clust_trim))
# colors = plt.get_cmap('tab10').colors
# dict_clust_col = dict(zip(np.unique(clust_trim)[1:], colors))

In [None]:
im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
im_rgb = np.dstack([im_rgb, np.ones(im_rgb.shape[:2])])

for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        cl = dict_lab_clust[l]
        if cl:
            color = np.array(dict_clust_col[cl])
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
            im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]        
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_trim):
    if c:
        print('Cluster:', c)
        bool_c = clust_trim == c
        labels_sub = np.array(labels)[bool_c]
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        color = dict_clust_col[c]
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

## Save

In [None]:
output_dir = config['output_dir'] + '/' + sn
# output_dir = output_dir_fmt.format(sample_name=sn)
if not os.path.exists(output_dir): 
    os.makedirs(output_dir)
    print("made dir:",output_dir)


In [None]:
# stack_dir = output_dir + '/stacks'
# props_dir = output_dir + '/props'
# segs_dir = output_dir + '/segs'
# spec_dir = output_dir + '/spectra'
clust_dir = output_dir + '/clust'

for dir_ in [clust_dir]:
    if not os.path.exists(dir_):
        os.makedirs(dir_)
        print("made dir:",dir_)


In [None]:
bn = sn + '_M_' + str(m)
# stack_fn = stack_dir + '/' + bn + '_stack.npy'
# props_fn = props_dir + '/' + bn + '_props.csv'a
# seg_fn = segs_dir + '/' + bn + '_seg.npy'
# spec_fn = spec_dir + '/' + bn + '_spec.yaml'
clust_fn = clust_dir + '/' + bn + '_clust.yaml'

# np.save(stack_fn, stacks[m])
# print('Wrote:',stack_fn)
# props[m].to_csv(props_fn, index=False)
# print('Wrote:',props_fn)
# np.save(seg_fn, segs[m])
# print('Wrote:',seg_fn)
# with open(spec_fn, 'w') as f:
#     yaml.dump(specs[m], f)
# print('Wrote:',spec_fn)
with open(clust_fn, 'w') as f:
    yaml.dump(dict_lab_clust, f)
print('Wrote:',clust_fn)


## Failed: HDBSCAN clustering

In [None]:
import hdbscan
hdb = hdbscan.HDBSCAN(min_cluster_size=5, metric='precomputed')

hdb.fit(dist_mat)

clust_hdb = hdb.labels_

plt.scatter(u[:,0], u[:,1], c=clust_hdb, alpha=0.25)
plt.show()
plt.close()

In [None]:
import hdbscan
hdb = hdbscan.HDBSCAN(min_cluster_size=10)

hdb.fit(u)

clust_hdb = hdb.labels_



In [None]:
u_ = u[clust_hdb > -1,:]
u_i = u[clust_hdb == -1,:]


plt.scatter(u_i[:,0], u_i[:,1], color=[0.5,0.5,0.5], alpha=0.25)
plt.scatter(u_[:,0], u_[:,1], c=clust_hdb[clust_hdb > -1], alpha=0.25, cmap='Spectral')
plt.show()
plt.close()

In [None]:
np.unique(clust_hdb).shape

In [None]:
clust_all = clust_hdb
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

## Moved to classify_clusters_02

In [None]:
dict_cl_spec = {}
for c in np.unique(clust_all):
    # print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    dict_cl_spec[c] = spec_sub

In [None]:
dict_cl_specmean = {cl: np.mean(s, axis=0) for cl, s in dict_cl_spec.items()}
dict_cl_specnorm = {cl: s / np.sum(s) for cl, s in dict_cl_specmean.items()}

In [None]:
n_clust = len(dict_cl_specnorm)
n_chan = dict_cl_spec[c].shape[1]
diff_mat = np.zeros((n_clust,n_clust,n_chan))
for i, s_i in enumerate(dict_cl_specnorm.values()):
    for j, s_j in enumerate(dict_cl_specnorm.values()):
        diff_mat[i,j,:] = s_i - s_j

In [None]:
def sum_normalize_ref(ref_spec):
    ref_sum_norm = []
    for r in ref_spec:
        r_ = r - np.min(r, axis=1)[:,None]
        ref_sum_norm.append(r_ / np.sum(r_, axis=1)[:,None])
    return [np.mean(r, axis=0) for r in ref_sum_norm]

def get_reference_spectra(barcodes, bc_len):
    ref_dir = '/fs/cbsuvlaminck2/workdir/bmg224/manuscripts/mgefish/data/unused/fig_5/HiPRFISH_reference_spectra'
    fmt = '08_18_2018_enc_{}_avgint.csv'
    # if bc_len == 5:
    barcodes_str = [str(bc).zfill(7) for bc in barcodes]
    # barcodes_str = [str(bc).zfill(7) for bc in barcodes]
    barcodes_10bit = [bc[0] + '0' + bc[1:4] + '00' + bc[4:] for bc in barcodes_str]
    # barcodes_10bit = [bc[0] + '0' + bc[1:4] + '00' + bc[4:] for bc in barcodes_str]
    barcodes_b10 = [int(str(bc),2) for bc in barcodes_10bit]
    st = 32
    en = 32 + 63
    ref_avgint_cols = [i for i in range(st,en)]

    ref_spec = []
    for bc in barcodes_b10:
        fn = cluster + '/' + ref_dir + '/'+ fmt.format(bc)
        ref = pd.read_csv(fn, header=None)
        ref = ref[ref_avgint_cols].values
        ref_spec.append(ref)
    return ref_spec

# Get reference spectra
probe_design_dir = '/fs/cbsuvlaminck2/workdir/bmg224/manuscripts/mgefish/data/HiPRFISH_probe_design'
probe_design_fn = probe_design_dir + '/welch2016_7b_distant.csv'
probe_design = pd.read_csv(probe_design_fn)
barcodes = probe_design['code'].unique()
barcode_length = len(str(np.max(barcodes)))
ref_spec = get_reference_spectra(barcodes, barcode_length)
sci_names = [probe_design.loc[probe_design['code'] == bc,'sci_name'].unique()[0] 
            for bc in barcodes]
weights_sum_norm = sum_normalize_ref(ref_spec)

In [None]:
from sklearn.cluster import OPTICS
opt = OPTICS()

opt.fit(u)

clust_opt = opt.labels_

plt.scatter(u[:,0], u[:,1], c=clust_opt, alpha=0.25)
plt.show()
plt.close()

## Manual cluster selecrion

In [None]:
t = {}
clust = {}


In [None]:
i = 0

In [None]:
# pick a threshold for clustering
t[i] = 0.042

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
clust_unique = np.unique(clust[i])
clust_unique

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
# Plot Spectra from each group
# choice = 200

# spec_dims = (10,5)
# for c in clust_unique:
#     print("Cluster:",c)
#     inds_sub = np.array(inds)[clust[i] == c]
#     spec_sub = []
#     inds_sub_choice = np.random.choice(inds_sub, size=choice, replace=False)
#     for ind in inds_sub:
#         m,l = dict_ind_lab[ind]
#         s = specs[m][l]
#         spec_sub.append(s)
#     spec_sub = np.vstack(spec_sub)
#     fig, ax = ip.general_plot(dims=spec_dims)
#     fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
#     ax.set_ylim(0,2**16)
#     plt.plot()
#     plt.show()
# project groups onto Image and compare with raw image

spec_dims = (10,5)
for c in clust_unique:
    print("Cluster:",c)
    labels_sub = np.array(labels)[clust[i] == c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
clust_all = clust[i]
np.unique(clust_all)

In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='Spectral')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab10').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
sub_cl = {}

In [None]:
i = 1
i_old = 0
sub_cl[i] = 1

In [None]:
# pick a threshold for clustering
t[i] = 0.04


In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    if len(labels_sub) > 100:
    print("Cluster:",c) 
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
clust_change = [1,2]

clmx = np.max(clust_all)
clust_i_rename = clust_all.copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='tab10')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab10').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
i = 2
i_old = 1
sub_cl[i] = 2

In [None]:
# pick a threshold for clustering
t[i] = 0.0225

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    # if len(labels_sub) > 10:
    print("Cluster:",c) 
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
clust_change = [58,59]

clmx = np.max(clust_all)
clust_i_rename = clust_all.copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='tab20')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab10').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
i = 3
i_old = 1
sub_cl[i] = 1

In [None]:
# pick a threshold for clustering
t[i] = 0.02525

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    if len(labels_sub) > 10:
        print("Cluster:",c) 
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

In [None]:
clust_change = [4,7,8,29,30]

clmx = np.max(clust_all)
clust_i_rename = clust_all.copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='Spectral')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
i = 4
i_old = 3
sub_cl[i] = 30

In [None]:
# pick a threshold for clustering
t[i] = 0.02125

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    if len(labels_sub) > 10:
        print("Cluster:",c) 
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

In [None]:
clust_change = [59]

clmx = np.max(clust_all)
clust_i_rename = clust[i_old].copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='tab20')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
i = 5
i_old = 3
sub_cl[i] = 29

In [None]:
# pick a threshold for clustering
t[i] = 0.019

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    if len(labels_sub) > 10:
        print("Cluster:",c) 
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

In [None]:
clust_change = [10,11]

clmx = np.max(clust_all)
clust_i_rename = clust[i_old].copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='Spectral')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
i = 6
i_old = 5
sub_cl[i] = 10

In [None]:
# pick a threshold for clustering
t[i] = 0.0195

In [None]:
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode='lastp')
ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:
# cluster linkage into three groups
clust[i] = hierarchy.fcluster(linkage, t=t[i], criterion='distance')
bool_subclust = clust[i_old] == sub_cl[i]
clust_unique = np.unique(clust[i][bool_subclust])
clust_unique

In [None]:
spec_dims = (10,5)
for c in clust_unique:
    labels_sub = np.array(labels)[clust[i] == c]
    if len(labels_sub) > 10:
        print("Cluster:",c) 
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.1,'color':'r'})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

In [None]:
clust_change = [12]

clmx = np.max(clust_all)
clust_i_rename = clust[i_old].copy()
for c in clust_change:
    bool_c = clust[i] == c
    clmx += 1
    clust_i_rename[bool_c] = clmx

np.unique(clust_i_rename[bool_subclust])

In [None]:
clust_all[bool_subclust] = clust_i_rename[bool_subclust]
np.unique(clust_all)


In [None]:
# fit = umap.UMAP(metric='precomputed')
# u = fit.fit_transform(dist_mat)
plt.scatter(u[:,0], u[:,1], c=clust_all, alpha=0.25, cmap='Spectral')
plt.show()
plt.close()

In [None]:
dict_lab_clust = dict(zip(labels,clust_all))
colors = plt.get_cmap('tab20').colors
dict_clust_col = dict(zip(np.unique(clust_all), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

Merge cluster

In [None]:
for c in np.unique(clust_all):
    print('Cluster:', c)
    bool_c = clust_all == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.25,'color':'r'})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
# Merge certain cluster
merge = {}

In [None]:
clust_merge = clust_all.copy()
for c, c_lst in merge.items():
    for c_ in c_lst:
        bool_c = clust_all == c_
        clust_merge[bool_c] = c

In [None]:
for c in np.unique(clust_merge):
    print('Cluster:', c)
    bool_c = clust_merge == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':0.25,'color':'r'})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

## Project to the image

In [None]:
im_sum = np.sum(stacks[m], axis=2)

mx = np.max(im_sum)
mn = np.min(im_sum)
im_norm = (im_sum - mn) / (mx - mn)
im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
dict_lab_clust = dict(zip(labels,clust_merge))
colors = plt.get_cmap('tab10').colors
dict_clust_col = dict(zip(np.unique(clust_merge), colors))

In [None]:
seg = segs[m]
prop = props[m]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
        im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        cl = dict_lab_clust[l]
        color = np.array(dict_clust_col[cl])
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_merge):
    print('Cluster:', c)
    bool_c = clust_merge == c
    labels_sub = np.array(labels)[bool_c]
    spec_sub = []
    for l in labels_sub:
        s = dict_lab_specmean[l]
        spec_sub.append(s)
    spec_sub = np.vstack(spec_sub)
    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
    ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
ignore = [2]
clust_trim = clust_merge.copy()
for c in ignore:
    bool_c = clust_merge == c
    clust_trim[bool_c] = 0

In [None]:
dict_lab_clust = dict(zip(labels,clust_trim))
colors = plt.get_cmap('tab10').colors
dict_clust_col = dict(zip(np.unique(clust_trim)[1:], colors))

In [None]:
im_rgb = im_norm[:,:,None] * np.array([1,1,1])[None,:]
for i, row in prop.iterrows():
    l = row.label
    if l in labels:
        b = row.bbox
        cl = dict_lab_clust[l]
        if cl:
            color = np.array(dict_clust_col[cl])
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = im_rgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = seg[b[0]:b[2],b[1]:b[3]] == l
            im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]        
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            im_rgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
ip.plot_image(im_rgb, im_inches=10)

In [None]:
for c in np.unique(clust_trim):
    if c:
        print('Cluster:', c)
        bool_c = clust_trim == c
        labels_sub = np.array(labels)[bool_c]
        spec_sub = []
        for l in labels_sub:
            s = dict_lab_specmean[l]
            spec_sub.append(s)
        spec_sub = np.vstack(spec_sub)
        fig, ax = ip.general_plot(dims=spec_dims)
        color = dict_clust_col[c]
        fsi.plot_cell_spectra(ax, spec_sub, {'lw':0.5,'alpha':1,'color':color})
        ax.set_ylim(0,2**16)
        plt.plot()
        plt.show()

## Save

In [None]:
output_dir = '../../outputs/agglomerative_cluster/' + sn
# output_dir = output_dir_fmt.format(sample_name=sn)
if not os.path.exists(output_dir): 
    os.makedirs(output_dir)
    print("made dir:",output_dir)


In [None]:
stack_dir = output_dir + '/stacks'
props_dir = output_dir + '/props'
segs_dir = output_dir + '/segs'
spec_dir = output_dir + '/spectra'
clust_dir = output_dir + '/clust'

for dir_ in [clust_dir, spec_dir, segs_dir, props_dir, stack_dir]:
    if not os.path.exists(dir_):
        os.makedirs(dir_)
        print("made dir:",dir_)


In [None]:
bn = sn + '_M_' + str(m)
stack_fn = stack_dir + '/' + bn + '_stack.npy'
props_fn = props_dir + '/' + bn + '_props.csv'
seg_fn = segs_dir + '/' + bn + '_seg.npy'
spec_fn = spec_dir + '/' + bn + '_spec.yaml'
clust_fn = clust_dir + '/' + bn + '_clust.yaml'

np.save(stack_fn, stacks[m])
print('Wrote:',stack_fn)
props[m].to_csv(props_fn, index=False)
print('Wrote:',props_fn)
np.save(seg_fn, segs[m])
print('Wrote:',seg_fn)
with open(spec_fn, 'w') as f:
    yaml.dump(specs[m], f)
print('Wrote:',spec_fn)
with open(clust_fn, 'w') as f:
    yaml.dump(dict_lab_clust, f)
print('Wrote:',clust_fn)


In [None]:
a = 1