# Adjust sizing on HIPRFISH and overaly with MGE fish image 
## Setup

In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import re
from collections import defaultdict
import aicspylibczi as aplc
from skimage.registration import phase_cross_correlation
from tqdm import tqdm
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
import umap
import math
from sklearn.cluster import AgglomerativeClustering
from cv2 import resize, INTER_NEAREST
from sklearn.neighbors import NearestNeighbors
from scipy import stats
from scipy.spatial.distance import squareform, pdist



In [None]:
cluster = '/fs/cbsuvlaminck2/'
workdir = '/workdir/bmg224/manuscripts/mgefish/code/bmg_plasmids_imaging/agglomerative_clustering'
os.chdir(cluster + workdir)
os.getcwd()

In [None]:
config_fn = "config_240107.yaml"  # relative path to config file from workdir

with open(config_fn, "r") as f:
    config = yaml.safe_load(f)

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append(cluster + config['pipeline_path'] + '/' + config['functions_path'])
import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi



In [None]:
# Get filenames from directories
raw_dir = cluster + config["data_dir"] + "/*[y0-9].czi"
# raw_dir = config["data_dir"] + "/*" + config["laser_regex"]
fns = glob.glob(raw_dir)
fns_base = [os.path.split(f)[1] for f in fns]
group_names = [re.sub("_2024\w+.czi", "", s) for s in fns_base]
# print('HERE-->', shifts_fns)
group_names = np.sort(np.unique(group_names))
m_size = group_names.shape[0]
dict_group_czifns_all = defaultdict(list)
for g in group_names:
    for s in fns:
        if g in s:
            dict_group_czifns_all[g].append(s)
dict_group_czifns_all = {g: sorted(s) for g, s in dict_group_czifns_all.items()}
dict_group_czifns_all

In [None]:
dict_group_czifns = {k: [v[0], v[1], v[2]] for k, v in dict_group_czifns_all.items() if len(v) > 1}
dict_group_czifns

In [None]:
sn = "2024_01_07_newplasmidredo2reimage_slide_7_fov_02"

## Load data

In [None]:
output_dir = config['output_dir'] + '/' + sn

stack_dir = output_dir + '/stacks'
props_dir = output_dir + '/props'
segs_dir = output_dir + '/segs'
spec_dir = output_dir + '/spectra'
clust_dir = output_dir + '/clust'

In [None]:
bn = sn + '_M_{}'

stack_fn = stack_dir + '/' + bn + '_stack.npy'
props_fn = props_dir + '/' + bn + '_props.csv'
seg_fn = segs_dir + '/' + bn + '_seg.npy'
spec_fn = spec_dir + '/' + bn + '_spec.yaml'
clust_fn = clust_dir + '/' + bn + '_clust.yaml'

In [None]:
stack_fns = glob.glob(stack_fn.format('*'))
stack_fns

In [None]:
Ms_hipr = [re.findall('(?<=_M_)\d+', f)[0] for f in stack_fns]
print(Ms_hipr)

In [None]:
stacks = [np.load(fn) for fn in stack_fns]

In [None]:
stacks = [x for _, x in sorted(zip(Ms_hipr, stacks))]

In [None]:
imin = 10
stacks_sum = [np.sum(s, axis=2) for s in stacks]

# stacks_sum_alt = stacks_sum[:3] + [np.zeros_like(stacks_sum[0])] + stacks_sum[3:]
# ip.subplot_square_images(stacks_sum, (3,3), im_inches=imin, clims=clims)

# ip.plot_image(stacks_sum[0], cmap='inferno', im_inches=imin)
clims = [(3000,20000)]*4
ip.subplot_square_images(stacks_sum, (2,2), im_inches=imin, clims=clims)

In [None]:
mge_raw_fn = dict_group_czifns_all[sn][4]
mge_raw_fn

In [None]:
# mge_data_fn = '/workdir/Data/bmg224/2023/brc_imaging/2023_11_22_newplasmid/2023_11_22_newplasmid_sample_bmg_fov_05_2023_11_22__07_19_27_airy.czi'
# czi_mge = aplc.CziFile(mge_data_fn)
czi_mge = aplc.CziFile(mge_raw_fn)
czi_mge.get_dims_shape()[0]

In [None]:
M_mge = 16

In [None]:
def reshape_aics_image(m_img):
    '''
    Given an AICS image with just XY and CHannel,
    REshape into shape (X,Y,C)
    '''
    img = np.squeeze(m_img)
    img = np.transpose(img, (1,2,0))
    return img

In [None]:
raws_mge = []
for m_s in range(M_mge):
    im, sh = czi_mge.read_image(M=m_s)
    im = reshape_aics_image(im)
    raws_mge.append(im)

In [None]:
raws_mge_cell = [im[:,:,0] for im in raws_mge]
clims = [(100,1000)]*16
ip.subplot_square_images(raws_mge_cell, (4,4), clims=clims)

In [None]:
raws_mge_spot = [im[:,:,1] for im in raws_mge]
clims = [(50,300)]*16
ip.subplot_square_images(raws_mge_spot, (4,4), clims=(clims))

## Resize HiPRFISH image

In [None]:
hipr_raw_fn = dict_group_czifns_all[sn][0]
czi_hipr = aplc.CziFile(hipr_raw_fn)


for n in czi_mge.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            res_mge = float(n.text)
print('MGE m/pix',res_mge)

for n in czi_hipr.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            res_hipr = float(n.text)
print('HiPR m/pix',res_hipr)


In [None]:
def resize_hipr(im, hipr_res, mega_res, dims='none', out_fn=False, ul_corner=(0,0)):
    # im = np.load(in_fn)
    factor_resize = hipr_res / mega_res
    hipr_resize = resize(
            im,
            None,
            fx = factor_resize,
            fy = factor_resize,
            interpolation = INTER_NEAREST
            )
    if isinstance(dims, str): dims = hipr_resize.shape
    hipr_resize = center_image(hipr_resize, dims, ul_corner)
    # if out_fn: np.save(out_fn, hipr_resize)
    return hipr_resize

def center_image(im, dims, ul_corner):
    shp = im.shape
    if not all([dims[i] == shp[i] for i in range(len(dims))]):
        shp_new = dims if len(shp) == 2 else dims + (shp[2],)
        temp = np.zeros(shp_new)
        br_corner = np.array(ul_corner) + np.array(shp[:2])
        temp[ul_corner[0]:br_corner[0], ul_corner[1]:br_corner[1]] = im
        im = temp
    return im

In [None]:
# seg_fns = glob.glob(seg_fn.format('*'))
seg_fns = [seg_fn.format(msh) for msh in Ms_hipr]
segs = [np.load(fn) for fn in seg_fns]
segs = [x for _, x in sorted(zip(Ms_hipr, segs))]

stacks_max = [np.max(s, axis=2) for s in stacks]

In [None]:
hipr_res = []
hipr_maxs_res = []
hipr_sums_res = []
hipr_segs_res = []
for full, mx, sm, seg in zip(stacks, stacks_max, stacks_sum, segs):
    hipr_res.append(resize_hipr(full, res_hipr, res_mge))
    hipr_maxs_res.append(resize_hipr(mx, res_hipr, res_mge))
    hipr_sums_res.append(resize_hipr(sm, res_hipr, res_mge))
    hipr_segs_res.append(resize_hipr(seg, res_hipr, res_mge))

## Shift MGE image

In [None]:
mge_shifts_dir = output_dir + '/mge_shifts'
if not os.path.exists(mge_shifts_dir): 
    os.makedirs(mge_shifts_dir)
    print('Made dir:',mge_shifts_dir)

In [None]:
mge_shift_fmt = mge_shifts_dir + "/" + bn + "_mge_shift.npy"

edge = 500

mge_m_list = [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]

# [0,1,4,5], [2,3,6,7],
# [8,9,12,13],[10,11,14,15]
# [14,15,20,21],[16,17,22,23],
# [24,25,30,31],[26,27,32,33],[28,29,34,35]

raws_mge_shift = []
for m_h, (m_hstr, mge_ms) in enumerate(zip(Ms_hipr, mge_m_list)):
    hipr_max_res = hipr_sums_res[m_h]
    hipr_max_res_sm = sf.pre_process(hipr_max_res, gauss=5)
    # hipr_max_res = hipr_maxs_res[m_h]
    rms_shape = tuple([s + 2 * edge for s in hipr_max_res.shape])
    # print(rms_shape)
    raw_mge_shift = np.zeros(rms_shape + (2,))

    rms = raws_mge_cell[0].shape
    hrs = hipr_max_res.shape
    corners = [
        (0, 0),
        (0, hrs[1] - rms[1]),
        (hrs[0] - rms[0], 0),
        (hrs[0] - rms[0], hrs[1] - rms[1]),
    ]

    # mge_ms = mge_m_list[m_h]

    for i, (m_m, hipr_ul) in enumerate(zip(reversed(mge_ms), reversed(corners))):
        # Get images
        raw_mge_cell = raws_mge_cell[m_m]
        raw_mge_cell_sm = sf.pre_process(raw_mge_cell, gauss=5)
        raw_mge_m = raws_mge[m_m]
        # Pick quadrant for hipr
        hipr_ul = np.array(hipr_ul)
        hipr_lr = hipr_ul + raw_mge_cell.shape
        # Get same size MGE and hipr
        hipr_max_res_quad = hipr_max_res_sm[
            hipr_ul[0] : hipr_lr[0], hipr_ul[1] : hipr_lr[1]
        ]
        image_list = [hipr_max_res_quad, raw_mge_cell_sm]
        # Register mge with hipr
        shift_vectors = fsi._get_shift_vectors(image_list)
        shifts = shift_vectors[1]
        # if m_h == 0 and i == 1:  # 2024_01_07_newplasmidredo2reimage_slide_7_fov_03 
        #     print(shifts)
        #     shifts[0] = shifts[0] - 140
        #     shifts[1] = shifts[1] - 50 
        # Add mge based on shifts to full size image
        ul = (hipr_ul + shifts + edge).astype(int)
        lr = (ul + raw_mge_cell.shape).astype(int)
        raw_mge_shift[ul[0] : lr[0], ul[1] : lr[1], :] = raw_mge_m
    raws_mge_shift.append(raw_mge_shift)
    out_fn = mge_shift_fmt.format(m_hstr)
    # np.save(out_fn, raw_mge_shift)
    # print('Wrote:', out_fn)

In [None]:
# edge_out_fn = mge_shifts_dir + '/' + bn.format('edgepixels') + '_mge_shift.txt'
# with open(edge_out_fn, 'w') as f:
#     f.write(str(edge))
# print('Wrote:',edge_out_fn)

In [None]:
def norm(im, c=['min','max']):
    mn = np.min(im) if c[0] == 'min' else c[0]
    mx = np.max(im) if c[1] == 'max' else c[1]
    im = np.clip(im, mn, mx)
    return (im - mn) / (mx - mn)

In [None]:
imin=30

mge_shifts_plot_bn = mge_shifts_dir + '/' + bn + '_mge_shift_plot'

clips = [(0,20000),(0,1500)]

for m_h, (m_hstr, mge_ms) in enumerate(zip(Ms_hipr, mge_m_list)):
    raw_mge_shift = raws_mge_shift[m_h]

    hipr_sum_res = hipr_sums_res[m_h]
    hipr_sum_res_edge = np.zeros(raw_mge_shift.shape[:2])
    hsr_shape = hipr_sum_res.shape
    hipr_sum_res_edge[edge:edge+hsr_shape[0],edge:edge+hsr_shape[1]] = hipr_sum_res
    

    im_r = norm(hipr_sum_res_edge, clips[0])
    im_g = norm(raw_mge_shift[:,:,0], clips[1])

    im_rgb = np.zeros(im_r.shape + (3,))
    im_rgb[:,:,0] = im_r
    im_rgb[:,:,1] = im_g
    ip.plot_image(im_rgb, im_inches=imin)
    
    # out_bn = mge_shifts_plot_bn.format(m_hstr)
    # ip.save_png_pdf(out_bn)
    # print('Wrote:',out_bn + '.png')

    plt.show()    
    plt.close()

### Shift scratch paper

In [None]:
raw_mge = raws_mge[0]
mega_cell = raws_mge_cell[0]
hipr_sum_resize = hipr_maxs_res[0]
hipr_view = hipr_sums_res[0]

In [None]:
ip.plot_image(hipr_view, cmap='inferno',im_inches=imin)

In [None]:
ip.plot_image(mega_cell, cmap='inferno',im_inches=imin)


In [None]:
# Which is the smaller image?
mshp = mega_cell.shape[:2]
hshp = hipr_sum_resize.shape[:2]
im_list = [mega_cell, hipr_sum_resize]
i_sml = np.argmin([mshp[0],hshp[0]])
i_lrg = np.argmax([mshp[0],hshp[0]])
sml = im_list[i_sml]
lrg = im_list[i_lrg]
# Get half the difference between sizes
shp_dff = np.abs(np.array(hshp) - np.array(mshp)) // 2
# Shift the smaller image so that it sits at the center of the larger image
sml_shift_shape = lrg.shape[:2]
if len(sml.shape) > 2:
    sml_shift_shape += (sml.shape[2],)
sml_shift = np.zeros(sml_shift_shape)
corn_ind = np.array(shp_dff) + np.array(sml.shape[:2])
sml_shift[shp_dff[0]:corn_ind[0], shp_dff[1]:corn_ind[1]] = sml
# reassign mega and hipr image var names
im_shift_list = [0,0]
im_shift_list[i_sml] = sml_shift
im_shift_list[i_lrg] = lrg
mega_shift = im_shift_list[0]
hipr_shift = im_shift_list[1]
# Get the shift vectors for the mega image
image_list = [hipr_shift, mega_shift]
shift_vectors = fsi._get_shift_vectors(image_list)
shift_vectors

In [None]:
def shift_mega(im):
    '''Globally define:  mega_shift_vector, max_shift, dims, ul_corner'''
    # im = np.load(in_fn)
    if len(im.shape) == 2: im = im[...,None]
    im = center_image(im, dims, ul_corner)
    return fsi._shift_images([im], mega_shift_vector, max_shift=max_shift)

In [None]:
# Shift mge 
max_shift = 2000
mega_shift_vector = [shift_vectors[1]]
dims = lrg.shape
ul_corner = shp_dff
# run the shift function
raw_shift = shift_mega(raw_mge)[0]

Show the overlay

In [None]:
def norm(im, c=['min','max']):
    mn = np.min(im) if c[0] == 'min' else c[0]
    mx = np.max(im) if c[1] == 'max' else c[1]
    im = np.clip(im, mn, mx)
    return (im - mn) / (mx - mn)

In [None]:
clips = [(0,20000),(0,1500)]

im_r = norm(hipr_view, clips[0])
im_g = norm(raw_shift[:,:,0], clips[1])

im_rgb = np.zeros(im_r.shape + (3,))
im_rgb[:,:,0] = im_r
im_rgb[:,:,1] = im_g
ip.plot_image(im_rgb, im_inches=imin)
plt.show()    
plt.close()


Try shift again with subset hipr

In [None]:
m_h = 0
hipr_max_res = hipr_maxs_res[m_h]
edge = 500
rms_shape = tuple([s + 2*edge for s in hipr_max_res.shape])
raw_mge_shift = np.zeros(rms_shape + (2,))
raw_mge_shift.shape

In [None]:
m_m = 0
# Get images
raw_mge_cell = raws_mge_cell[m_m]
raw_mge_m = raws_mge[m_m]
# Pick quadrant for hipr
hipr_ul = np.array([0,0])
hipr_lr = hipr_ul + raw_mge_cell.shape
# Get same size MGE and hipr
hipr_max_res_quad = hipr_max_res[hipr_ul[0]:hipr_lr[0],hipr_ul[1]:hipr_lr[1]]
image_list = [hipr_max_res_quad, raw_mge_cell]
# Register mge with hipr
shift_vectors = fsi._get_shift_vectors(image_list)
shifts = shift_vectors[1]
# Add mge based on shifts to full size image
# ul = (hipr_ul + shifts).astype(int)
# lr = (ul + raw_mge_cell.shape).astype(int)
ul = (hipr_ul + shifts + edge).astype(int)
lr = (ul + raw_mge_cell.shape).astype(int)
raw_mge_shift[ul[0]:lr[0],ul[1]:lr[1],:] = raw_mge_m

In [None]:
hipr_sum_res = hipr_sums_res[m_h]
hipr_sum_res_edge = np.zeros(rms_shape)
hipr_sum_res_edge[edge:edge+hrs[0],edge:edge+hrs[1]] = hipr_sum_res

In [None]:
clips = [(0,20000),(0,1500)]


im_r = norm(hipr_sum_res_edge, clips[0])
im_g = norm(raw_mge_shift[:,:,0], clips[1])

im_rgb = np.zeros(im_r.shape + (3,))
im_rgb[:,:,0] = im_r
im_rgb[:,:,1] = im_g
ip.plot_image(im_rgb, im_inches=imin)
plt.show()    
plt.close()

In [None]:
m_m = 0
M_ms = [0,1,2,3]

raw_mge_shift = np.zeros(rms_shape + (2,))
rms = raws_mge_cell[0].shape
hrs = hipr_max_res.shape
corners = [(0,0),(0,hrs[1]-rms[1]),(hrs[0]-rms[0],0),(hrs[0]-rms[0],hrs[1]-rms[1])]
for m_m, hipr_ul in zip(M_ms, corners):
    # Get images
    raw_mge_cell = raws_mge_cell[m_m]
    raw_mge_m = raws_mge[m_m]
    # Pick quadrant for hipr
    hipr_ul = np.array(hipr_ul)
    hipr_lr = hipr_ul + raw_mge_cell.shape
    # Get same size MGE and hipr
    hipr_max_res_quad = hipr_max_res[hipr_ul[0]:hipr_lr[0],hipr_ul[1]:hipr_lr[1]]
    image_list = [hipr_max_res_quad, raw_mge_cell]
    # Register mge with hipr
    shift_vectors = fsi._get_shift_vectors(image_list)
    shifts = shift_vectors[1]
    # Add mge based on shifts to full size image
    ul = (hipr_ul + shifts + edge).astype(int)
    lr = (ul + raw_mge_cell.shape).astype(int)
    raw_mge_shift[ul[0]:lr[0],ul[1]:lr[1],:] = raw_mge_m

In [None]:
clips = [(0,20000),(0,1500)]

im_r = norm(hipr_sum_res_edge, clips[0])
im_g = norm(raw_mge_shift[:,:,0], clips[1])

im_rgb = np.zeros(im_r.shape + (3,))
im_rgb[:,:,0] = im_r
im_rgb[:,:,1] = im_g
ip.plot_image(im_rgb, im_inches=imin)
plt.show()    
plt.close()

In [None]:
ul

In [None]:
mge_shifts_dir = output_dir + '/mge_shifts'
if not os.path.exists(mge_shifts_dir): 
    os.makedirs(mge_shifts_dir)
    print('Made dir:',mge_shifts_dir)

In [None]:
mge_shift_fmt = mge_shifts_dir + '/' + bn + '_mge_shift.npy'

edge = 500

mge_m_list = [
    [0,1,2,3]
]

raws_mge_shift = []
for m_h in range(M_hipr):
    hipr_max_res = hipr_maxs_res[m_h]
    rms_shape = tuple([s + 2*edge for s in hipr_max_res.shape])
    # print(rms_shape)
    raw_mge_shift = np.zeros(rms_shape + (2,))

    rms = raws_mge_cell[0].shape
    hrs = hipr_max_res.shape
    corners = [(0,0),(0,hrs[1]-rms[1]),(hrs[0]-rms[0],0),(hrs[0]-rms[0],hrs[1]-rms[1])]

    mge_ms = mge_m_list[m_h]

    for m_m, hipr_ul in zip(mge_ms, corners):
        # Get images
        raw_mge_cell = raws_mge_cell[m_m]
        raw_mge_m = raws_mge[m_m]
        # Pick quadrant for hipr
        hipr_ul = np.array(hipr_ul)
        hipr_lr = hipr_ul + raw_mge_cell.shape
        # Get same size MGE and hipr
        hipr_max_res_quad = hipr_max_res[hipr_ul[0]:hipr_lr[0],hipr_ul[1]:hipr_lr[1]]
        image_list = [hipr_max_res_quad, raw_mge_cell]
        # Register mge with hipr
        shift_vectors = fsi._get_shift_vectors(image_list)
        shifts = shift_vectors[1]
        # Add mge based on shifts to full size image
        ul = (hipr_ul + shifts + edge).astype(int)
        lr = (ul + raw_mge_cell.shape).astype(int)
        raw_mge_shift[ul[0]:lr[0],ul[1]:lr[1],:] = raw_mge_m
    raws_mge_shift.append(raw_mge_shift)
    out_fn = mge_shift_fmt.format(m_h)
    np.save(out_fn, raw_mge_shift)
    print('Wrote:', out_fn)



In [None]:
edge_out_fn = mge_shifts_dir + '/' + bn.format('edgepixels') + '_mge_shift.txt'
with open(edge_out_fn, 'w') as f:
    f.write(str(edge))
print('Wrote:',edge_out_fn)

In [None]:


mge_shifts_plot_bn = mge_shifts_dir + '/' + bn + '_mge_shift_plot'

clips = [(0,20000),(0,1500)]

for m_h in range(M_hipr):
    raw_mge_shift = raws_mge_shift[m_h]

    hipr_sum_res = hipr_sums_res[m_h]
    hipr_sum_res_edge = np.zeros(raw_mge_shift.shape[:2])
    hsr_shape = hipr_sum_res.shape
    hipr_sum_res_edge[edge:edge+hsr_shape[0],edge:edge+hsr_shape[1]] = hipr_sum_res
    

    im_r = norm(hipr_sum_res_edge, clips[0])
    im_g = norm(raw_mge_shift[:,:,0], clips[1])

    im_rgb = np.zeros(im_r.shape + (3,))
    im_rgb[:,:,0] = im_r
    im_rgb[:,:,1] = im_g
    ip.plot_image(im_rgb, im_inches=imin)
    out_bn = mge_shifts_plot_bn.format(m_h)
    # ip.save_png_pdf(out_bn)
    # print('Wrote:',out_bn + '.png')
    plt.show()    
    plt.close()

## Get MGE spots

In [None]:
# pick image
m_h = 2
imin = 30

clim_mge = (100, 300)

im = raws_mge_shift[m_h][:, :, 1]
ip.plot_image(im, cmap="inferno", im_inches=imin, clims=clim_mge)

In [None]:
# pre-process
gauss=2

im_pre = sf.pre_process(im, log=False, denoise=0, gauss=gauss, diff_gauss=(0,))
# check pre-processing
ip.plot_image(im_pre, cmap="inferno", im_inches=imin, clims=clim_mge)

In [None]:
# Get mask
bg_thresh = 100

im_mask = sf.get_background_mask(
    im_pre, bg_smoothing=0, n_clust_bg=3, top_n_clust_bg=1, bg_threshold=bg_thresh
)
# check mask
fig, ax, cbar = ip.plot_image(
    (im_pre * im_mask), cmap="inferno", im_inches=imin, clims=clim_mge
)
ax.imshow(np.dstack([0.5 * (~im_mask)] * 4))
# segment
# Check segmentation
# Save segmentation

Remove large objects and low intensity objects

In [None]:
# segment
# im_seg = sf.segment(
#     im_pre,
#     background_mask = im_mask
#     )

im_seg = sf.label(im_mask)
# im_seg = segment_no_lne(im_pre, im_mask)

In [None]:
# Check segmentation
seg_zoom_rgb = ip.seg2rgb(im_seg)
ip.plot_image(seg_zoom_rgb, im_inches=imin)
# Save segmentation 

In [None]:
# Get spot properties
prop = sf.measure_regionprops(im_seg, raw=im)
prop.shape

In [None]:
int_thresh = 100

fig, ax = ip.general_plot(dims=(10,5))
x = np.arange(prop.shape[0])
y = prop.max_intensity.sort_values()
ax.scatter(x,y, s=1)
ax.set_title('Spot Max Intensity (a.u.)')
ax.plot([0,x.shape[0]], [int_thresh]*2, 'k')

In [None]:
bool_int = prop.max_intensity.values > int_thresh
print(len(bool_int))
print(sum(bool_int))

In [None]:
area_thresh = 3000
fig, ax = ip.general_plot(dims=(10,5))
x = np.arange(prop.shape[0])
y = prop.area.sort_values()
ax.scatter(x,y, s=1)
ax.set_title('Spot Area (pixels)')
ax.plot([0,x.shape[0]], [area_thresh]*2, 'k')

In [None]:
bool_area = prop.area.values < area_thresh
print(len(bool_area))
print(sum(bool_area))

In [None]:
ecc_thresh = 0.9
fig, ax = ip.general_plot(dims=(10,5))
x = np.arange(prop.shape[0])
y = prop.eccentricity.sort_values()
ax.scatter(x,y, s=1)
ax.set_title('Spot eccentricity (pixels)')
ax.plot([0,x.shape[0]], [ecc_thresh]*2, 'k')

In [None]:
bool_ecc = prop.eccentricity.values < ecc_thresh
print(len(bool_ecc))
print(sum(bool_ecc))

In [None]:
# bools = np.array([True for _ in bool_area])
bools = bool_area
# bools = bool_area * bool_int
# bools = bool_area * bool_int * bool_ecc

In [None]:
# Filtered spots
seg_zoom_rgb_filt = seg_zoom_rgb.copy()

for i, row in prop[bools].iterrows():
    l = row.label
    b = row.bbox
    b = eval(b) if isinstance(b, str) else b
    rgb_sub = seg_zoom_rgb_filt[b[0] : b[2], b[1] : b[3]]
    seg_sub = im_seg[b[0] : b[2], b[1] : b[3]] == l
    # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
    # color = np.array([0,0,1,0.5])
    # color = np.array(dict_bc_col[cl] + (1,))
    # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
    rgb_noncell = rgb_sub * np.dstack([~seg_sub] * rgb_sub.shape[2])
    # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
    # rgb_cell = seg_sub[:,:,None] * color[None,:]
    seg_zoom_rgb_filt[b[0] : b[2], b[1] : b[3], :] = rgb_noncell

ip.plot_image(seg_zoom_rgb_filt, im_inches=imin)

In [None]:
# Unfiltered spots
seg_zoom_rgb_areafilt = seg_zoom_rgb.copy()
for i, row in prop[~bools].iterrows():
    l = row.label
    b = row.bbox
    b = eval(b) if isinstance(b, str) else b
    rgb_sub = seg_zoom_rgb_areafilt[b[0]:b[2],b[1]:b[3]]
    seg_sub = im_seg[b[0]:b[2],b[1]:b[3]] == l
    # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
    # color = np.array([0,0,1,0.5])
    # color = np.array(dict_bc_col[cl] + (1,))
    # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
    rgb_noncell = rgb_sub * np.dstack([~seg_sub]*rgb_sub.shape[2])
    # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
    # rgb_cell = seg_sub[:,:,None] * color[None,:]
    seg_zoom_rgb_areafilt[b[0]:b[2],b[1]:b[3],:] = rgb_noncell
    
ip.plot_image(seg_zoom_rgb_areafilt, im_inches=imin)


In [None]:
# Filtered spots

seg_mask_filt = np.sum(seg_zoom_rgb_filt, axis=2) > 0
# seg_mask_areafilt = np.sum(seg_zoom_rgb_areafilt, axis=2) > 0
fig, ax, cbar = ip.plot_image(im*seg_mask_filt, im_inches=imin, clims=clim_mge)
ax.imshow(np.dstack([0.5*(~seg_mask_filt)]*4))


In [None]:
# Unfiltered spots

seg_mask_areafilt = np.sum(seg_zoom_rgb_areafilt, axis=2) > 0
# seg_mask_areafilt = np.sum(seg_zoom_rgb_areafilt, axis=2) > 0
fig, ax, cbar = ip.plot_image(im * seg_mask_areafilt, im_inches=imin, clims=clim_mge)
ax.imshow(np.dstack([0.5 * (~seg_mask_areafilt)] * 4))

Segment with watershed

In [None]:
# more aggressiv mask
im_mask2 = im_pre > 125

# check mask
fig, ax, cbar = ip.plot_image(
    (im_pre * im_mask2 * seg_mask_areafilt),
    cmap="inferno",
    im_inches=imin,
    clims=clim_mge,
)
ax.imshow(np.dstack([0.5 * ~(im_mask2 * seg_mask_areafilt)] * 4))

In [None]:
def segment_no_lne(image, mask):
    seeds = sf.label(sf.peak_local_max(image, min_distance=1, indices=False))
    watershed_input = -image*mask
    seg = sf.watershed(watershed_input, seeds, mask=mask, watershed_line=True)
    return sf.label(seg)

In [None]:
# segment
# im_seg = sf.segment(
#     im_pre,
#     background_mask = im_mask
#     )

mask_fin = seg_mask_areafilt*im_mask2
im_seg2 = sf.label(mask_fin)
# im_seg2 = segment_no_lne(im_pre, mask_fin)

In [None]:
# Check segmentation
seg2_zoom_rgb = ip.seg2rgb(im_seg2)
ip.plot_image(seg2_zoom_rgb, im_inches=imin)

In [None]:
# Get spot properties
prop2 = sf.measure_regionprops(im_seg2, raw=im)
prop.shape

In [None]:
area_thresh2 = 600
fig, ax = ip.general_plot(dims=(10,5))
x = np.arange(prop2.shape[0])
y = prop2.area.sort_values()
ax.scatter(x,y, s=1)
ax.set_title('Spot Area (pixels)')
ax.plot([0,x.shape[0]], [area_thresh2]*2, 'k')

In [None]:
bool_area2 = prop2.area.values < area_thresh2
print(len(bool_area2))
print(sum(bool_area2))

In [None]:
# bools = np.array([True for _ in bool_area])
bools2 = bool_area2
# bools = bool_area * bool_int
# bools = bool_area * bool_int * bool_ecc

In [None]:
# Filtered spots
seg_zoom_rgb_filt2 = seg2_zoom_rgb.copy()

for i, row in prop2[bools2].iterrows():
    l = row.label
    b = row.bbox
    b = eval(b) if isinstance(b, str) else b
    rgb_sub = seg_zoom_rgb_filt2[b[0] : b[2], b[1] : b[3]]
    seg_sub = im_seg2[b[0] : b[2], b[1] : b[3]] == l
    # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
    # color = np.array([0,0,1,0.5])
    # color = np.array(dict_bc_col[cl] + (1,))
    # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
    rgb_noncell = rgb_sub * np.dstack([~seg_sub] * rgb_sub.shape[2])
    # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
    # rgb_cell = seg_sub[:,:,None] * color[None,:]
    seg_zoom_rgb_filt2[b[0] : b[2], b[1] : b[3], :] = rgb_noncell

ip.plot_image(seg_zoom_rgb_filt2, im_inches=imin)

In [None]:
# Unfiltered spots
seg_zoom_rgb_areafilt2 = seg2_zoom_rgb.copy()
for i, row in prop2[~bools2].iterrows():
    l = row.label
    b = row.bbox
    b = eval(b) if isinstance(b, str) else b
    rgb_sub = seg_zoom_rgb_areafilt2[b[0]:b[2],b[1]:b[3]]
    seg_sub = im_seg2[b[0]:b[2],b[1]:b[3]] == l
    # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
    # color = np.array([0,0,1,0.5])
    # color = np.array(dict_bc_col[cl] + (1,))
    # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
    rgb_noncell = rgb_sub * np.dstack([~seg_sub]*rgb_sub.shape[2])
    # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
    # rgb_cell = seg_sub[:,:,None] * color[None,:]
    seg_zoom_rgb_areafilt2[b[0]:b[2],b[1]:b[3],:] = rgb_noncell
    
ip.plot_image(seg_zoom_rgb_areafilt2, im_inches=imin)


In [None]:
# Filtered spots

seg_mask_filt2 = np.sum(seg_zoom_rgb_filt2, axis=2) > 0
# seg_mask_areafilt = np.sum(seg_zoom_rgb_areafilt, axis=2) > 0
fig, ax, cbar = ip.plot_image(im*seg_mask_filt2, im_inches=imin, clims=clim_mge)
ax.imshow(np.dstack([0.5*(~seg_mask_filt2)]*4))


In [None]:
# Unfiltered spots

seg_mask_areafilt2 = np.sum(seg_zoom_rgb_areafilt2, axis=2) > 0
# seg_mask_areafilt = np.sum(seg_zoom_rgb_areafilt, axis=2) > 0
fig, ax, cbar = ip.plot_image(im * seg_mask_areafilt2, im_inches=imin, clims=clim_mge)
ax.imshow(np.dstack([0.5 * (~seg_mask_areafilt2)] * 4))

In [None]:
mask_fin2 = mask_fin * seg_mask_areafilt2

### Pick spots that look like psfs
Visualize peak local maxima

In [None]:
pre_masked = im_pre * mask_fin2
plm_inds = sf.peak_local_max(pre_masked)
plm_inds.shape

In [None]:
rng = 20

color='r'
alpha=0.1
width=3

fig, ax = ip.general_plot(dims=(10,10), col='w')
for r, c in plm_inds:
    vals_r = im_pre[r-rng:r+rng, c]
    vals_c = im_pre[r, c-rng:c+rng]
    for v in [vals_r, vals_c]:
        ax.plot(np.arange(rng*2), v, color=color, alpha=alpha, lw=width)



Measure the difference betweeen peak vs valley in a circle around each spot

In [None]:
def create_ring_array(size, inner_radius, outer_radius):
    # Create an empty boolean array
    array = np.zeros((size, size), dtype=bool)

    # Calculate the center of the array
    center = size // 2

    # Create a grid of coordinates
    x, y = np.ogrid[:size, :size]

    # Calculate the distance from each point to the center
    distance = np.sqrt((x - center) ** 2 + (y - center) ** 2)

    # Set True values for points within the ring
    array[(distance >= inner_radius) & (distance <= outer_radius)] = True

    return array


size = 25 # Must be odd
inner_radius = 10
outer_radius = 12
radius_mask = create_ring_array(size, inner_radius, outer_radius)
plt.imshow(radius_mask)

In [None]:
plm_slope_means = []
plm_slope_mins = []
plm_int = []
for r, c in plm_inds:
    plm = im_pre[r, c]
    im_sub = im_pre[
        r - outer_radius : r + outer_radius+1, c - outer_radius : c + outer_radius+1
    ]
    vals_ring = im_sub[radius_mask]
    slopes = plm - vals_ring
    plm_slope_means.append(np.mean(slopes))
    plm_slope_mins.append(np.min(slopes))
    plm_int.append(plm)


In [None]:
thesh_slope = 75
thresh_int = 205

# plm_slope_means_sort = np.sort(plm_slope_means)

fig, ax = ip.general_plot(dims=(10, 5))
ax.scatter(plm_int, plm_slope_means)
xlims = ax.get_xlim()
ylims = ax.get_ylim()
ax.plot(xlims, [thesh_slope] * 2, "k")
ax.set_ylabel('Radius mean slope')
ax.plot([thresh_int] * 2, ylims, "r")
ax.set_xlabel('Spot intensity')

In [None]:
thesh_slope = 20
thresh_int = 150


fig, ax = ip.general_plot(dims=(10, 5))
ax.scatter(plm_int, plm_slope_mins)
xlims = ax.get_xlim()
ylims = ax.get_ylim()
ax.plot(xlims, [thesh_slope] * 2, "k")
ax.plot([thresh_int] * 2, ylims, "r")
ax.set_ylabel('Radius min slope')
ax.set_xlabel('Spot intensity')

In [None]:
plm_slope_mins_sort = np.sort(plm_slope_mins)
fig, ax = ip.general_plot(dims=(10, 5))
ax.scatter(np.arange(len(plm_slope_mins_sort)), plm_slope_mins_sort)

In [None]:
def line_int_slope(x, p, m):
    return (x - p[0])*m + p[1]

p = (150,0)
m=20/50

fig, ax = ip.general_plot(dims=(10, 5))
ax.scatter(plm_int, plm_slope_mins)
xlims = ax.get_xlim()
ylims = ax.get_ylim()
ys = [line_int_slope(x, p, m) for x in xlims]
ax.plot(xlims, ys, 'k')
ax.set_ylabel('Radius min slope')
ax.set_xlabel('Spot intensity')

In [None]:
bool_line = [y > line_int_slope(x, p, m) for x, y in zip(plm_int, plm_slope_mins)]

Filter segmentation by psf-ness

In [None]:
bool_slope = np.array(plm_slope_mins) > thesh_slope
bool_int = np.array(plm_int) > thresh_int
bools_plm = bool_slope & bool_int
# bools_plm = bool_int
# bools_plm = bool_line

plm_inds_filt = plm_inds[bools_plm]

seeds_filt = np.zeros_like(im_pre)
for i, (r, c) in enumerate(plm_inds_filt):
    seeds_filt[r, c] = i

In [None]:
watershed_input = -im_pre
seg = sf.watershed(watershed_input, seeds_filt, mask=mask_fin2, watershed_line=True)

In [None]:

# check mask
fig, ax, cbar = ip.plot_image(
    (im_pre * (seg > 0)),
    cmap="inferno",
    im_inches=imin,
    clims=(100,300),
)
ax.imshow(np.dstack([0.5 * ~(seg > 0)] * 4))

In [None]:
size=5

# check mask
fig, ax, cbar = ip.plot_image(
    (im_pre),
    cmap="inferno",
    im_inches=imin,
    clims=clim_mge,
)

ax.scatter(plm_inds_filt[:, 1].squeeze(), plm_inds_filt[:,0].squeeze(), color='g', s=size)

In [None]:
# Get spot properties
prop = sf.measure_regionprops(seg, raw=im)
prop.shape

Remove spots outside of cells

In [None]:
# Get pixels for random simulation
im_cell = raws_mge_shift[m_h][:,:,0]
ip.plot_image(im_cell, cmap='inferno', im_inches=imin)


In [None]:
mask_cell = im_cell > 125
fig, ax, cbar = ip.plot_image((im_cell*mask_cell), cmap='inferno', im_inches=imin)
ax.imshow(np.dstack([0.5*(~mask_cell)]*4))

In [None]:
bool_incell = []
for c in prop.centroid.values:
    bool_incell.append(mask_cell[int(c[0]),int(c[1])])
print(len(bool_incell))
print(sum(bool_incell))

Show final spots on cells

In [None]:
# Raw cell
res_mge_umpix = res_mge * 10**6

fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# ax.imshow(raw_mge_shift_spot_rgb)

In [None]:
mge_clims = (100,200)

raw_mge_shift_spot = raws_mge_shift[m_h][:, :, 1]

raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, mge_clims)
raw_mge_shift_spot_rgb = np.dstack(
    [
        raw_mge_shift_spot_norm,
        np.zeros_like(raw_mge_shift_spot_norm),
        raw_mge_shift_spot_norm,
        raw_mge_shift_spot_norm,
    ]
)

s = (seg > 0) * 1
seg_mge_shift_spot_rgb = np.dstack(
    [
        s,
        np.zeros_like(s),
        s,
        s
    ]
).astype(float)

In [None]:
# Raw spots
res_mge_umpix = res_mge * 10**6

fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
ax.imshow(raw_mge_shift_spot_rgb)

In [None]:
# seg spots
save = False

imin=30

res_mge_umpix = res_mge * 10**6
dpi = np.max(im_cell.shape) // imin

fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=imin,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# ax.imshow(seg_mge_shift_spot_rgb)
ax.scatter(plm_inds_filt[:, 1].squeeze(), plm_inds_filt[:,0].squeeze(), color='m', s=30)
if save:
    mge_overlay_dir = output_dir + "/mge_overlay/clust_merge"
    if not os.path.exists(mge_overlay_dir):
        os.makedirs(mge_overlay_dir)
        print("Made dir:", mge_overlay_dir)
    plt.figure(fig)
    out_bn = mge_overlay_dir + "/" + bn.format(m_h) + "_overlay_mge"
    ip.save_png_pdf(out_bn, dpi=dpi)
    print("Wrote:", out_bn + ".png")

## Measure spatial association with spectra clusters

Get resized hipr properties

In [None]:
def add_edge(hipr_sum_res, edge):
    hsr_shape = hipr_sum_res.shape
    hipr_sum_res_edge = np.zeros(np.array(hsr_shape) + 2*edge)
    hipr_sum_res_edge[edge:edge+hsr_shape[0],edge:edge+hsr_shape[1]] = hipr_sum_res
    return hipr_sum_res_edge

In [None]:
m_h

In [None]:
hipr_sum_res = hipr_sums_res[m_h]
hipr_sum_res_edge = add_edge(hipr_sum_res, edge)

hipr_seg_res = hipr_segs_res[m_h]
hipr_seg_res_edge = add_edge(hipr_seg_res, edge)
hipr_seg_res_edge = hipr_seg_res_edge.astype(int)

hipr_prop_res = sf.measure_regionprops(hipr_seg_res_edge, raw=hipr_sum_res_edge)
hipr_prop_res.columns

In [None]:
hipr_seg_res_edge_rgb = ip.seg2rgb(hipr_seg_res_edge)
fig, ax, _ = ip.plot_image(hipr_seg_res_edge_rgb, im_inches=imin)
raw_mge_shift_spot = raws_mge_shift[m_h][:,:,1]
raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, (0,250))
raw_mge_shift_spot_rgb = np.dstack([raw_mge_shift_spot_norm]*4)
ax.imshow(raw_mge_shift_spot_rgb)

In [None]:
# # Load cell props
# cell_props = pd.read_csv(props_fn.format(m_h))
# cell_props.columns

Load spectral clusters

In [None]:
# load high resolution clusters
clust_dir_hires = clust_dir + '/hi_res'
clust_sn_fn = clust_dir_hires + '/' + sn + '_M_' + str(m_h) + '_clust_hi_res.yaml'
with open(clust_sn_fn, 'r') as f:
    dict_lab_clust = yaml.unsafe_load(f)
print(len(dict_lab_clust))
clusters_unq = np.unique(list(dict_lab_clust.values()))
clusters_unq


In [None]:
# # Load spectral clusters
# with open(clust_fn.format(3), 'r') as f:
#     dict_lab_clust = yaml.unsafe_load(f)
# clusters_unq = np.unique(list(dict_lab_clust.values()))
# clusters_unq

In [None]:
# # classif svc
# classif_dir = config['output_dir'] + '/classif_svc'
# classif_fn = classif_dir + '/dict_sn_m_label_classif.yaml'

# with open(classif_fn, 'r') as f:
#     dict_sn_m_lab_cl = yaml.unsafe_load(f)

In [None]:
# dict_lab_clust = dict_sn_m_lab_cl[sn][str(m_h)]
# clusters_unq = np.unique(list(dict_lab_clust.values()))
# clusters_unq

Create a dictionary with cell coords for each cluster


In [None]:
hipr_prop_res.shape

In [None]:
# Create a dictionary with cell coords for each cluster
dict_clust_coords = defaultdict(list)
for l, c in hipr_prop_res[['label','centroid']].values:
    cl = dict_lab_clust[l]
    c = eval(c) if isinstance(c, str) else c
    dict_clust_coords[cl].append(list(c))
[len(v) for v in dict_clust_coords.values()]

In [None]:
# Get spot coordinates
# spot_coords = prop.centroid.values[bool_area]
spot_coords = prop.centroid.values[bool_incell]
# spot_coords = prop.centroid.values[bool_area * bool_incell]
spot_coords = [list(s) for s in spot_coords]
print(prop.shape[0])
print(len(spot_coords))

Get dictionary of cluster nearest neighbor distances

In [None]:
# Get dictionary of cluster nearest neighbor distances
n_neighbors=1

dict_cl_dists = {}
for cl in clusters_unq:
    reseg_coords = dict_clust_coords[cl]
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
    dists, _ = nbrs.kneighbors(spot_coords)
    dict_cl_dists[cl] = dists


simulate random spots

In [None]:
dict_clust_lab = defaultdict(list)
for lab, cl in dict_lab_clust.items():
    dict_clust_lab[cl].append(lab)

In [None]:
l = hipr_prop_res.label.values
c = [list(coord) for coord in hipr_prop_res.centroid.values]
# c = [list(eval(coord)) for coord in hipr_prop_res.centroid.values]
dict_lab_coord = dict(zip(l, c))

In [None]:
# simulate random spots
n=1000

cell_coords_tup = hipr_prop_res.loc[:,'centroid'].values
cell_coords = np.array([list(c) for c in cell_coords_tup])
# cell_coords = np.array([list(eval(c)) for c in cell_coords_tup])
pix_coords = np.argwhere(mask_cell)

dict_cl_dists_sim = defaultdict(list)
for i in tqdm(range(n)):
    # Randomize spot locations
    i_sim = np.random.randint(
            0, pix_coords.shape[0], size=len(spot_coords)
            )
    sim_spot_coords = pix_coords[i_sim]
    # # Randomize cell labels
    # c_sim = np.random.randint(
    #         0, hipr_reseg_props.shape[0], size=hipr_reseg_props.shape[0]
    #         )
    # bc_sim = np.array([x for _, x in sorted(zip(c_sim, cell_bc))])
    for cl in clusters_unq:
        # Get cell coords for taxon
        labels = dict_clust_lab[cl]
        # bool_bc = np.array(cell_bc) == cl
        # bool_bc = bc_sim == bc
        # tax_centroid = cell_coords[bool_bc,:]
        tax_centroid = [dict_lab_coord[l] for l in labels]
        # Get nearest neighbor cell distance for each spot
        nbrs = NearestNeighbors(n_neighbors=1).fit(tax_centroid)
        dists, _ = nbrs.kneighbors(sim_spot_coords)
        dict_cl_dists_sim[cl].append(dists)
        


In [None]:
# Get fraction of spots associated in measured and simulation
r_um = 0.5
res_mge_umpix = res_mge * 10**6

meas_vals = []
sim_vals = []
for cl in clusters_unq:
    # Get simulated fraction within radius of cell
    sim_dists = dict_cl_dists_sim[cl]
    sim_dists_um = np.array(sim_dists) * res_mge_umpix
    bool_sim_rad = sim_dists_um < r_um
    sim_rad_counts = np.sum(bool_sim_rad, axis=1)
    sim_rad_frac = sim_rad_counts
    sim_vals.append(sim_rad_frac)
    # Get measured fraction
    dists_um = dict_cl_dists[cl] * res_mge_umpix
    dists_um.shape
    bool_rad = dists_um < r_um
    rad_counts = np.sum(bool_rad)
    rad_frac = rad_counts
    meas_vals.append(rad_frac)

sim_vals = np.array(sim_vals)[:,:,0]
sim_frac = sim_vals / len(spot_coords)
meas_vals = np.array(meas_vals)
meas_frac = meas_vals / len(spot_coords)



Get color dict

In [None]:
# Rearrange colors
col_list = list(plt.get_cmap('tab20').colors)
col_1 = [col_list[i] for i in np.arange(0,20,2)]
del col_1[7]
col_2 = [col_list[i] for i in np.arange(1,20,2)]
del col_2[7]
col_list_re = col_1 + col_2 + [(0,1,0), (1,1,0)]
barlist = plt.bar(np.arange(20), np.ones(20))
for b,c in zip(barlist, col_list_re):
    b.set_color(c)

In [None]:
# # Get scinames dict
# probe_design_fn = config['probe_design_dir'] + '/' + config['probe_design_filename']
# probe_design = pd.read_csv(probe_design_fn)
# barcodes_pd = probe_design['code'].unique()
# sci_names_pd = [probe_design.loc[probe_design['code'] == bc,'sci_name'].unique()[0] 
#             for bc in barcodes_pd]
# barcodes_pdstr = [str(bc).zfill(5) for bc in barcodes_pd]
# dict_bc_sciname = dict(zip(barcodes_pdstr, sci_names_pd))
# dict_sciname_bc = dict(zip(sci_names_pd, barcodes_pdstr))

In [None]:
# # Count all barcodes
# dict_cl_counts = defaultdict(int)
# for m, dlc in dict_sn_m_lab_cl[sn].items():
#     clusts, counts = np.unique(list(dlc.values()), return_counts=True)
#     for cl, cnt in zip(clusts,counts):
#         dict_cl_counts[cl] += cnt
# dict_cl_counts

In [None]:
# # Sort barcodes
# bcs = list(dict_cl_counts.keys())
# counts = list(dict_cl_counts.values())
# barcodes_countsort = [bc for _, bc in sorted(zip(counts, bcs))]

In [None]:
# # Make dict and plot 
# dict_bc_col = dict(zip(barcodes_countsort, col_list_re))
# sciname_countsort = [dict_bc_sciname[bc] for bc in barcodes_countsort]
# ip.taxon_legend(sciname_countsort, col_list_re)

In [None]:
# dict_bc_col = dict(zip(clusters_unq, plt.get_cmap('tab20').colors))
dict_bc_col = dict(zip(clusters_unq, col_list_re))

In [None]:
mu = np.mean(sim_vals, axis=1)
sig = np.std(sim_vals, axis=1)
sim_z = (sim_vals - mu[:,None]) / sig[:,None]
meas_z = (meas_vals - mu) / sig

Plot association

In [None]:
# Plot z score number of spots associated with group
# dims=[5,2]
dims=[2.5,1]
xlab_rotation=45
pval_rotation=60
marker='.'
marker_size=10
text_dist=0.1
ft=7
ylimadj = 0.1
true_frac_llim = 0
line_col = 'k'
box_line_col = (0.5,0.5,0.5)
box_col = 'w'
yticklength=2

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
# Plot simulation
boxplot = ax.boxplot(
        sim_z.T, patch_artist=True, showfliers=False,
        boxprops=dict(facecolor=box_col, color=box_line_col),
        capprops=dict(color=box_line_col),
        whiskerprops=dict(color=box_line_col),
        medianprops=dict(color=box_line_col),
      )
# for m in boxplot['medians']:
#     m.set_color(line_col)
# for b in boxplot['boxes']:
#     b.set_edgecolor(line_col)
#     b.set_facecolor(box_col)
# Plot measured value
ys = []
xlab = []
x = 1
for i, cl in enumerate(clusters_unq):
# for i, bc_tax in zip(ind_order, barcodes_int_order):
    # sci_name = dict_bc_sciname[bc_tax]
    # sci_name = dict_bc_sciname[cl]
    # xlab.append(sci_name)
    xlab.append(int(cl))
    # try:
    #     color = col_dict[sci_name]
    # except:
    #     continue
    color = dict_bc_col[cl]
    true_frac = meas_z[i]
    # true_frac = true_count / n_cells
    _ = ax.plot(x, true_frac, marker=marker, ms=marker_size, color=color)
    # Plot p value
    sim_vals_i = sim_vals[i,:]
    # sim_vals = sim_arr[:,i,h] / n_cells
    sim_mean = np.mean(sim_vals)
    if true_frac > sim_mean:
        # number of simulations with value greater than observed
        r_ = sum(sim_vals_i > true_frac)
    else:
        # number of simulations with value less than observed
        r_ = sum(sim_vals_i < true_frac)
    # P value
    p_ = r_ / n
    # Get text location
    q1,q3 = np.quantile(sim_vals, [0.25,0.75])
    q4 = q3 + 1.5 * (q3 - q1)
    # y_m = np.max(sim_vals)
    # y = y_m if y_m > true_frac else true_frac
    y = q4 if q4 > true_frac else true_frac
    y += text_dist
    ys.append(y)
    # if true_frac < true_frac_llim:
    #     t = ''
    # elif (p_ > 0.05):
    #     t = ''
    # elif (p_ > 0.001) and (p_ <= 0.05):
    #     t = str("p=" + str(p_))
    # else:
    #     t = str("p<0.001")
    # _ = ax.text(x, y, t, fontsize=ft, ha='left',va='bottom', rotation=pval_rotation, rotation_mode='anchor',
    #         color=line_col)
    x+=1
# ax.set_xticklabels([], rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
ax.set_xticklabels(xlab, rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
# ax.tick_params(axis='x',direction='out')
# ax.set_xticks([])
ax.tick_params(axis='y', length=yticklength)
# ax.set_yticks(ticks=[-10,0,10,20], labels=[])
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['right'].set_color('none')

# ylims = ax.get_ylim()
# ax.set_ylim(ylims[0], np.max(ys) + ylimadj)
mge_assoc_dir = output_dir + '/mge_association'
if not os.path.exists(mge_assoc_dir): 
    os.makedirs(mge_assoc_dir)
    print('Made dir:',mge_assoc_dir)

# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_zscore_association_0_5um'
# ip.save_png_pdf(out_bn)



In [None]:


ft=6
line_col = 'k'
width=0.4
dims=[2.5,0.6]
yticklength=2

sci_name_order = clusters_unq
# sci_name_order = [dict_bc_sciname[bc] for bc in barcodes_int_order]
color_order = [dict_bc_col[sc] for sc in sci_name_order]

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(meas_frac.shape[0]),
        meas_frac,
        width=width,
        color=color_order,
        edgecolor=line_col
        )

ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.set_xticks([])
# ax.set_yticks(ticks=[0,0.2,0.4], labels=[])
ax.tick_params(axis='y', length=yticklength)

mge_assoc_dir = output_dir + '/mge_association'
if not os.path.exists(mge_assoc_dir): 
    os.makedirs(mge_assoc_dir)
    print('Made dir:',mge_assoc_dir)


# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_bar_seg_nn_frac_spot_association_0_5um'
# ip.save_png_pdf(out_bn)

Get fraction of cells in each cluster assoicated with spots

In [None]:
n_neighbors=1

dict_cl_sdists = {}
for cl in clusters_unq:
    reseg_coords = dict_clust_coords[cl]
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(spot_coords)
    dists, _ = nbrs.kneighbors(reseg_coords)
    dict_cl_sdists[cl] = dists

In [None]:
# Get fraction of spots associated in measured and simulation

cl_nnearspots = []
cl_fracnearspots = []
for cl in clusters_unq:
    # Get n cells in cluster with nearby spots
    dists_um = dict_cl_sdists[cl] * res_mge_umpix
    dists_um.shape
    bool_rad = dists_um < r_um
    rad_counts = np.sum(bool_rad)
    rad_frac = rad_counts / len(dists_um)
    cl_nnearspots.append(rad_counts)
    cl_fracnearspots.append(rad_frac)

cl_nnearspots = np.array(cl_nnearspots)
cl_fracnearspots = np.array(cl_fracnearspots)

In [None]:
# Frac taxon assoc with spot
dims=[2.5,0.6]
yticklength=2
ft=6
line_col = 'k'
width=0.4


fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(cl_fracnearspots.shape[0]),
        cl_fracnearspots,
        width=width,
        color=color_order,
        edgecolor=line_col
        )
ax.set_xticks([])
# ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(axis='y', length=yticklength)

mge_assoc_dir = output_dir + '/mge_association'
if not os.path.exists(mge_assoc_dir): 
    os.makedirs(mge_assoc_dir)
    print('Made dir:',mge_assoc_dir)

# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_bar_seg_nn_frac_cell_association_0_5um'
# ip.save_png_pdf(out_bn)

## Evaluate strongly associated spectra

Extract spectra

In [None]:
spec_fn.format(m_h)

In [None]:
# Load file
with open(spec_fn.format(m_h), 'r') as f:
    dict_lab_spec = yaml.unsafe_load(f)
len(dict_lab_spec)

In [None]:
# Get spectra for clusters
dict_cl_spec = defaultdict(list)
dict_cl_lab = defaultdict(list)
for l, spec in dict_lab_spec.items():
    cl = dict_lab_clust[l]
    dict_cl_spec[cl].append(spec)
    dict_cl_lab[cl].append(l)
[len(v) for v in dict_cl_spec.values()]

Plot spectra

In [None]:
spec_dims = (10,5)

In [None]:
# cl = dict_sciname_bc['Selenomonas']
cl_toplot = [1,4,8,11,12,13,14,15,17]
for cl in cl_toplot:
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec[cl])

    fig, ax = ip.general_plot(dims=spec_dims, col='w')
    fsi.plot_cell_spectra(ax, specs_arr, {'lw':1,'alpha':0.2,'color':dict_bc_col[cl]})
    plt.show()
    plt.close()
# ax.set_ylim(0,2**16)
# plt.plot()
# plt.show()

# mge_assoc_dir = output_dir + '/mge_association'
# if not os.path.exists(mge_assoc_dir): 
#     os.makedirs(mge_assoc_dir)
#     print('Made dir:',mge_assoc_dir)
# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_clusts_5_spec'
# ip.save_png_pdf(out_bn)

In [None]:
# cl = dict_sciname_bc['Lautropia']
cl=3

specs_arr = np.array(dict_cl_spec[cl])

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr, {'lw':1,'alpha':0.2,'color':dict_bc_col[cl]})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x,x], ylims,'k')
# ax.set_ylim(0,2**16)
# plt.plot()
# plt.show()
mge_assoc_dir = output_dir + '/mge_association'
if not os.path.exists(mge_assoc_dir): 
    os.makedirs(mge_assoc_dir)
    print('Made dir:',mge_assoc_dir)

out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_clusts_3_spec'
# ip.save_png_pdf(out_bn)

Group clusters by peaks

In [None]:
cl = 3
specs_arr = np.array(dict_cl_spec[cl])
specs_meansub = specs_arr - specs_mean
specs_meansub[specs_meansub < 0] = 0


In [None]:

bool_0 = (specs_arr[:, 33] - specs_arr[:, 29]) > 0

In [None]:
specs_arr_0 = specs_arr[bool_0, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_0, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_1 = ~bool_0

specs_arr_1 = specs_arr[bool_1, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_1, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_2_ = ((specs_arr[:, 29] - specs_arr[:, 30]) > 0) | ((specs_arr[:, 28] - specs_arr[:, 30]) > 0)

In [None]:
bool_2 = bool_0 * bool_2_

specs_arr_2 = specs_arr[bool_2, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_2, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_3 = bool_0 * ~bool_2_

specs_arr_3 = specs_arr[bool_3, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_3, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

Add clusters if you like them

In [None]:
bools = [bool_0, bool_1, bool_2, bool_3]
clnew_ = np.zeros(specs_arr.shape[0])
for i, b in enumerate(bools):
    clnew_[b] = i + 1
clnew_

In [None]:
clnew = clnew_ + np.max(list(dict_lab_clust.values()))
clnew

In [None]:
for l, cl in zip(dict_cl_lab[cl], clnew):
    dict_lab_clust[l] = cl

print(np.unique(list(dict_lab_clust.values())))

Next cluster

In [None]:
cl = 2
specs_arr = np.array(dict_cl_spec[cl])
specs_meansub = specs_arr - specs_mean
specs_meansub[specs_meansub < 0] = 0


In [None]:

bool_0 = (specs_arr[:, 33] - specs_arr[:, 29]) > 0

In [None]:
specs_arr_0 = specs_arr[bool_0, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_0, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_1 = ~bool_0

specs_arr_1 = specs_arr[bool_1, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_1, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_2_ = ((specs_arr[:, 29] - specs_arr[:, 30]) > 0) | ((specs_arr[:, 29] - specs_arr[:, 31]) > 0)

In [None]:
bool_2 = bool_0 * bool_2_

specs_arr_2 = specs_arr[bool_2, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_2, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

In [None]:
bool_3 = bool_0 * ~bool_2_

specs_arr_3 = specs_arr[bool_3, :]

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_3, {"lw": 1, "alpha": 0.2, "color": "r"})

# plot a line
xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
ylims = ax.get_ylim()
for x in xs:
    ax.plot([x, x], ylims, "k")

Add clusters if you like them

In [None]:
bools = [bool_0, bool_1]
clnew_ = np.zeros(specs_arr.shape[0])
for i, b in enumerate(bools):
    clnew_[b] = i + 1
np.unique(clnew_)

In [None]:
clnew = clnew_ + np.max(list(dict_lab_clust.values()))
np.unique(clnew)

In [None]:
for l, cl in zip(dict_cl_lab[cl], clnew):
    dict_lab_clust[l] = cl

print(np.unique(list(dict_lab_clust.values())))

### Redo assoc with new clusters

In [None]:
clusters_unq = np.unique(list(dict_lab_clust.values()))
clusters_unq

In [None]:
# Create a dictionary with cell coords for each cluster
dict_clust_coords = defaultdict(list)
for l, c in hipr_prop_res[["label", "centroid"]].values:
    cl = dict_lab_clust[l]
    c = eval(c) if isinstance(c, str) else c
    dict_clust_coords[cl].append(list(c))
[len(v) for v in dict_clust_coords.values()]

In [None]:
# Get dictionary of cluster nearest neighbor distances
n_neighbors = 1

dict_cl_dists = {}
for cl in clusters_unq:
    reseg_coords = dict_clust_coords[cl]
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
    dists, _ = nbrs.kneighbors(spot_coords)
    dict_cl_dists[cl] = dists

In [None]:
dict_clust_lab = defaultdict(list)
for lab, cl in dict_lab_clust.items():
    dict_clust_lab[cl].append(lab)

In [None]:
# simulate random spots
n = 1000

cell_coords_tup = hipr_prop_res.loc[:, "centroid"].values
cell_coords = np.array([list(c) for c in cell_coords_tup])
# cell_coords = np.array([list(eval(c)) for c in cell_coords_tup])
pix_coords = np.argwhere(mask_cell)

dict_cl_dists_sim = defaultdict(list)
for i in tqdm(range(n)):
    # Randomize spot locations
    i_sim = np.random.randint(0, pix_coords.shape[0], size=len(spot_coords))
    sim_spot_coords = pix_coords[i_sim]
    # # Randomize cell labels
    # c_sim = np.random.randint(
    #         0, hipr_reseg_props.shape[0], size=hipr_reseg_props.shape[0]
    #         )
    # bc_sim = np.array([x for _, x in sorted(zip(c_sim, cell_bc))])
    for cl in clusters_unq:
        # Get cell coords for taxon
        labels = dict_clust_lab[cl]
        # bool_bc = np.array(cell_bc) == cl
        # bool_bc = bc_sim == bc
        # tax_centroid = cell_coords[bool_bc,:]
        tax_centroid = [dict_lab_coord[l] for l in labels]
        # Get nearest neighbor cell distance for each spot
        nbrs = NearestNeighbors(n_neighbors=1).fit(tax_centroid)
        dists, _ = nbrs.kneighbors(sim_spot_coords)
        dict_cl_dists_sim[cl].append(dists)

In [None]:
# Get fraction of spots associated in measured and simulation
r_um = 0.5
res_mge_umpix = res_mge * 10**6

meas_vals = []
sim_vals = []
for cl in clusters_unq:
    # Get simulated fraction within radius of cell
    sim_dists = dict_cl_dists_sim[cl]
    sim_dists_um = np.array(sim_dists) * res_mge_umpix
    bool_sim_rad = sim_dists_um < r_um
    if np.sum(bool_sim_rad) > 0:
        sim_rad_counts = np.sum(bool_sim_rad, axis=1)
        sim_rad_frac = sim_rad_counts
    else:
        sim_rad_frac = 0
    sim_vals.append(sim_rad_frac)
    # Get measured fraction
    dists_um = dict_cl_dists[cl] * res_mge_umpix
    dists_um.shape
    bool_rad = dists_um < r_um
    rad_counts = np.sum(bool_rad)
    rad_frac = rad_counts
    meas_vals.append(rad_frac)

sim_vals = np.array(sim_vals)[:, :, 0]
sim_frac = sim_vals / len(spot_coords)
meas_vals = np.array(meas_vals)
meas_frac = meas_vals / len(spot_coords)




In [None]:
dict_bc_col

In [None]:
dict_bc_col = dict(zip(clusters_unq, col_list_re))

In [None]:
mu = np.mean(sim_vals, axis=1)
sig = np.std(sim_vals, axis=1)
sim_z = (sim_vals - mu[:, None]) / sig[:, None]
meas_z = (meas_vals - mu) / sig

In [None]:
# Plot z score number of spots associated with group
# dims=[5,2]
dims = [2.5, 1]
xlab_rotation = 45
pval_rotation = 60
marker = "."
marker_size = 10
text_dist = 0.1
ft = 7
ylimadj = 0.1
true_frac_llim = 0
line_col = "k"
box_line_col = (0.5, 0.5, 0.5)
box_col = "w"
yticklength = 2

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
# Plot simulation
boxplot = ax.boxplot(
    sim_z.T,
    patch_artist=True,
    showfliers=False,
    boxprops=dict(facecolor=box_col, color=box_line_col),
    capprops=dict(color=box_line_col),
    whiskerprops=dict(color=box_line_col),
    medianprops=dict(color=box_line_col),
)
# for m in boxplot['medians']:
#     m.set_color(line_col)
# for b in boxplot['boxes']:
#     b.set_edgecolor(line_col)
#     b.set_facecolor(box_col)
col_dict
# Plot measured value
ys = []
xlab = []
x = 1
for i, cl in enumerate(clusters_unq):
    # for i, bc_tax in zip(ind_order, barcodes_int_order):
    # sci_name = dict_bc_sciname[bc_tax]
    # sci_name = dict_bc_sciname[cl]
    xlab.append(cl)
    # try:
    #     color = col_dict[sci_name]
    # except:
    #     continue
    color = dict_bc_col[cl]
    true_frac = meas_z[i]
    # true_frac = true_count / n_cells
    _ = ax.plot(x, true_frac, marker=marker, ms=marker_size, color=color)
    # Plot p value
    sim_vals_i = sim_vals[i, :]
    # sim_vals = sim_arr[:,i,h] / n_cells
    sim_mean = np.mean(sim_vals)
    if true_frac > sim_mean:
        # number of simulations with value greater than observed
        r_ = sum(sim_vals_i > true_frac)
    else:
        # number of simulations with value less than observed
        r_ = sum(sim_vals_i < true_frac)
    # P value
    p_ = r_ / n
    # Get text location
    q1, q3 = np.quantile(sim_vals, [0.25, 0.75])
    q4 = q3 + 1.5 * (q3 - q1)
    # y_m = np.max(sim_vals)
    # y = y_m if y_m > true_frac else true_frac
    y = q4 if q4 > true_frac else true_frac
    y += text_dist
    ys.append(y)
    # if true_frac < true_frac_llim:
    #     t = ''
    # elif (p_ > 0.05):
    #     t = ''
    # elif (p_ > 0.001) and (p_ <= 0.05):
    #     t = str("p=" + str(p_))
    # else:
    #     t = str("p<0.001")
    # _ = ax.text(x, y, t, fontsize=ft, ha='left',va='bottom', rotation=pval_rotation, rotation_mode='anchor',
    #         color=line_col)
    x += 1
# ax.set_xticklabels([], rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
ax.set_xticklabels(xlab, rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
# ax.tick_params(axis='x',direction='out')
# ax.set_xticks([])
# ax.tick_params(axis="y", length=yticklength)
# ax.set_yticks(ticks=[-10,0,10,20], labels=[])
ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")

# ylims = ax.get_ylim()
# ax.set_ylim(ylims[0], np.max(ys) + ylimadj)
mge_assoc_dir = output_dir + "/mge_association"
if not os.path.exists(mge_assoc_dir):
    os.makedirs(mge_assoc_dir)
    print("Made dir:", mge_assoc_dir)

# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_zscore_association_0_5um'
# ip.save_png_pdf(out_bn)

In [None]:
a = 1

In [None]:
# Get spectra for clusters
dict_cl_spec = defaultdict(list)
dict_cl_lab = defaultdict(list)
for l, spec in dict_lab_spec.items():
    cl = dict_lab_clust[l]
    dict_cl_spec[cl].append(spec)
    dict_cl_lab[cl].append(l)

In [None]:
# cl = dict_sciname_bc['Selenomonas']
# cl_toplot=[12,13,14]

for cl in clusters_unq:
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec[cl])

    specs_meansub = specs_arr - specs_mean
    specs_meansub[specs_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims)
    fsi.plot_cell_spectra(ax, specs_meansub, {"lw": 1, "alpha": 0.2, "color": dict_bc_col[cl]})
    # plot a line
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    ylims = ax.get_ylim()
    for x in xs:
        ax.plot([x, x], ylims, "k")
    plt.show()
    plt.close()


# ax.set_ylim(0,2**16)
# plt.plot()
# plt.show()

# mge_assoc_dir = output_dir + "/mge_association"
# if not os.path.exists(mge_assoc_dir):
#     os.makedirs(mge_assoc_dir)
#     print("Made dir:", mge_assoc_dir)

# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_clusts_5_spec'
# ip.save_png_pdf(out_bn)

### Subcluster target clusters

In [None]:

def cluster_spectra_agg(spec_cl, n_clust):
    dist_mat_cond = pdist(spec_cl, fhc.channel_cosine_intensity_allonev2)
    dist_mat = squareform(dist_mat_cond)
    linkage = hierarchy.linkage(dist_mat_cond, method='complete')
    agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')
    agg.fit(dist_mat)
    return agg.labels_


In [None]:
cl = 2
n_clust = 4
spec_cl = np.vstack(dict_cl_spec[cl])

clust_agg_ = cluster_spectra_agg(spec_cl, n_clust)

In [None]:
# n_clust = np.max(clust_agg) + 1
cmap='tab10'
colors = plt.get_cmap(cmap)(np.linspace(0,1,n_clust))
dict_clust_col_ = dict(zip(np.arange(n_clust), colors))

In [None]:
spec_cl.shape

In [None]:
spec_dims = (10, 5)


for c in np.unique(clust_agg_):
    print("Cluster:", c)
    bool_c = clust_agg_ == c
    spec_sub = spec_cl[bool_c, :]

    # spec_sub_meansub = spec_sub - specs_med
    # spec_sub_meansub[spec_sub_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col_[c]
    fsi.plot_cell_spectra(ax, spec_sub, {"lw": 1, "alpha": 0.1, "color": color})
    # ax.set_ylim(0,12500)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

If you like the subclustering, add it to the cluster results

In [None]:
clnew = clust_agg_ + np.max(list(dict_lab_clust.values()))

for l, cl in zip(dict_cl_lab[cl], clnew):
    dict_lab_clust[l] = cl

print(np.unique(list(dict_lab_clust.values())))

In [None]:
spec_cl_1 = dict_cl_spec[1]
bool_prevotella = spec_cl_1[:,]

## Plot associated clusters on image

In [None]:
# MGE image
mge_clims = (50,400)
raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, mge_clims)
raw_mge_shift_spot_rgb = np.dstack([
    raw_mge_shift_spot_norm,
    np.zeros_like(raw_mge_shift_spot_norm),
    raw_mge_shift_spot_norm,
    raw_mge_shift_spot_norm
])

In [None]:
clusters_toplot = clusters_unq

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in clusters_toplot:
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = np.array(dict_bc_col[cl] + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)

# plt.figure(fig)
# out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_0_overlay'
# ip.save_png_pdf(out_bn)

Color segmentation cluster by spectral intensity

In [None]:
print(np.max([np.max(s) for s in dict_lab_spec.values()]))
print(np.min([np.max(s) for s in dict_lab_spec.values()]))
print(np.mean([np.max(s) for s in dict_lab_spec.values()]))

In [None]:
clip = 500

hipr_seg_res_clustrgb_int = np.zeros(hipr_seg_res_edge.shape + (4,))
clusters_toplot = [2,3]
for cl in clusters_toplot:
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb_int[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = np.array(dict_bc_col[cl] + (1,))
            mx_int = np.max(dict_lab_spec[l])
            mx_int = 1 if mx_int > clip else mx_int / clip
            color = color * mx_int
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb_int[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=imin, clims=(0,2000))
ax.imshow(hipr_seg_res_clustrgb_int)
ax.imshow(raw_mge_shift_spot_rgb)
plt.figure(fig)
out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_3_4_int_overlay'

indicate which cells have low intensity

In [None]:
thresh_int = 250

hipr_seg_res_clustrgb_thr = np.zeros(hipr_seg_res_edge.shape + (4,))
clusters_toplot = [3,5]
for cl in clusters_toplot:
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb_thr[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = np.array(col_dict[cl] + (1,))
            mx_int = np.max(dict_lab_spec[l])
            if mx_int < thresh_int:
                color = np.array([1,1,0,1])
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb_thr[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=imin, clims=(0,2000))
ax.imshow(hipr_seg_res_clustrgb_thr)
ax.imshow(raw_mge_shift_spot_rgb)
plt.figure(fig)
out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_3_4_thr_250_overlay'
# ip.save_png_pdf(out_bn)


Group clusters

In [None]:
hipr_seg_res_clustrgb_sing = np.zeros(hipr_seg_res_edge.shape + (4,))
clusters_toplot = [3,5]
for cl in clusters_toplot:
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb_sing[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            color = np.array(plt.get_cmap('tab10').colors[0] + (1,))
            # color = np.array(col_dict[cl] + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb_sing[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=imin, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb_sing)
ax.imshow(raw_mge_shift_spot_rgb)
plt.figure(fig)
out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_3_4_singlecolor_overlay'
# ip.save_png_pdf(out_bn)


## Look for spectra near spots

Plot fraction of spots given changing radius

In [None]:
# Nearest neighor of all spots
n_neighbors=1 
reseg_coords = [list(c) for c in hipr_prop_res.centroid.values]
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
dists, inds = nbrs.kneighbors(spot_coords)

In [None]:
# fraction with close cells
nspots = len(dists)
nclose = np.sum(dists <= (r_um/res_mge_umpix))
frac = nclose/nspots
print(nspots)
print(nclose)
print(frac)

In [None]:
fracs = []
nspots = len(dists)
rs = np.linspace(0,2,20)
for r in rs:
    nclose = np.sum(dists <= (r/res_mge_umpix))
    fracs.append(nclose/nspots)
fig, ax = ip.general_plot(dims=(10,5))
ax.plot(rs, fracs)

Get nearest single neighbor spectra

In [None]:
labels = hipr_prop_res.label.values
labels_nn = labels[inds.squeeze()]
specs_nn = np.vstack([dict_lab_spec[l] for l in labels_nn])
specs_nn.shape

Plot spectra

In [None]:
fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_nn, {'lw':1,'alpha':0.1,'color':'r'})

Split spectra by intensity

In [None]:
spec_fns = {}
bn_ = '{}_M_{}'
for sn_ in dict_group_czifns.keys():

    output_dir_ = config['output_dir'] + '/' + sn_

    spec_dir_ = output_dir_ + '/spectra'
    bn_m = bn_.format(sn_,'*')

    spec_fn_ = spec_dir_ + '/' + bn_m + '_spec.yaml'

    spec_fns[sn_] = glob.glob(spec_fn_)
spec_fns

In [None]:
Ms = {sn_:[re.findall('(?<=_M_)\d+', f)[0] for f in fns] for sn_, fns in spec_fns.items()}
Ms

In [None]:

## GENERALLY DONT RERUN ##

# specs_arr = []
# dict_sn_m_idx_lab = defaultdict(dict)
# for sn_, s_fns in spec_fns.items():
#     # dict_sn_sni[sn] = sn_i
#     print(sn_)
#     ms = Ms[sn_]
#     for s_fn, m_ in zip(s_fns, ms):
#         # print(s_fn)
#         with open(s_fn, 'r') as f:
#             dict_lab_spec_ = yaml.unsafe_load(f)
#         print(len(dict_lab_spec_))
#         labels = list(dict_lab_spec_.keys())
#         idx = np.arange(len(dict_lab_spec_)) + len(specs_arr)
#         dict_sn_m_idx_lab[sn_][m_] = dict(zip(idx,labels))
#         specs_arr += [s[None,:] for s in dict_lab_spec_.values()]


In [None]:
specs_arr = [s[:,:57] for s in specs_arr]
specs_arr = np.vstack(specs_arr)
specs_arr.shape

In [None]:
spec_dims = (10,5)

n = 10000  # subsample spectra
idx = np.random.choice(np.arange(specs_arr.shape[0]), size=n, replace=False)
specs_arr_rnd = specs_arr[idx,:]

specs_mean = np.mean(specs_arr, axis=0)
specs_med = np.median(specs_arr, axis=0)
# specs_std = np.std(specs_arr, axis=0)
specs_mode = stats.mode(specs_arr, axis=0)[0].squeeze()

fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_arr_rnd, {'lw':1,'alpha':0.1,'color':'r'})
fsi.plot_cell_spectra(ax, specs_mean[None,:], {'lw':2,'alpha':1,'color':'k'})
fsi.plot_cell_spectra(ax, specs_med[None,:], {'lw':2,'alpha':1,'color':'b'})
fsi.plot_cell_spectra(ax, specs_mode[None,:], {'lw':2,'alpha':1,'color':'g'})
# ax.set_ylim(0,2**16)
plt.plot()
plt.show()

In [None]:
thresh_int = 75

specs_nn_meansub = specs_nn - specs_mean
bool_int_all = specs_nn_meansub > thresh_int
bool_int_counts = np.sum(bool_int_all, axis=1)
bool_int = bool_int_counts > 1
print(specs_nn_meansub.shape[0])
print(sum(bool_int))

Plot pre filtered cells

In [None]:
clusters_toplot = [2,3]

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
# for cl in clusters_toplot:
labels_sub = labels_nn
for i, row in hipr_prop_res.iterrows():
    l = row.label
    if l in labels_sub:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
        # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        # color = np.array([0,0,1,0.5])
        # color = np.array(dict_bc_col[cl] + (1,))
        color = np.array(col_list_re[2] + (1,))
        # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)

Plot post filtered cells

In [None]:
clusters_toplot = [2,3]

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
# for cl in clusters_toplot:
labels_sub = labels_nn[bool_int]
for i, row in hipr_prop_res.iterrows():
    l = row.label
    if l in labels_sub:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
        seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
        # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        # color = np.array([0,0,1,0.5])
        # color = np.array(dict_bc_col[cl] + (1,))
        color = np.array(col_list_re[2] + (1,))
        # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
        rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:,:,None] * color[None,:]
        hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)

Get nearest n neighbors

In [None]:
# Nearest neighor of all spots
n_neighbors=3
reseg_coords = [list(c) for c in hipr_prop_res.centroid.values]
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
dists, inds = nbrs.kneighbors(spot_coords)

In [None]:
labels = hipr_prop_res.label.values
labels_nn = labels[inds.ravel()]
specs_nn = np.vstack([dict_lab_spec[l] for l in labels_nn])
specs_nn.shape

In [None]:
fig, ax = ip.general_plot(dims=spec_dims)
fsi.plot_cell_spectra(ax, specs_nn, {'lw':1,'alpha':0.1,'color':'r'})

Cluster nearest neighbor spectra

In [None]:
dist_mat_cond = pdist(specs_nn, fhc.channel_cosine_intensity_allonev2)

In [None]:
dist_mat = squareform(dist_mat_cond)


In [None]:
linkage = hierarchy.linkage(dist_mat_cond, method='complete')


In [None]:
labels = labels_nn
fig, ax = ip.general_plot(dims=(15,15))
dn = hierarchy.dendrogram(linkage, labels=labels, truncate_mode=None)
# dn = hierarchy.dendrogram(linkage, labels=inds, truncate_mode='lastp')
# ax.axhline(t[i])
ylims = [round(l,2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0],ylims[1],0.005))
ax.grid(axis='y')
plt.show()
plt.close()

In [None]:

spec_arr_order = []
for l in dn['ivl']:
    s = dict_lab_spec[l]
    spec_arr_order.append(s)
    
spec_arr_order = np.vstack(spec_arr_order)



fig = plt.figure(figsize=(15,5))
axs = {
    0: fig.add_axes([0.1,0.1,0.1,0.8]),
    1: fig.add_axes([0.2,0.1,0.7,0.8])
}
hierarchy.dendrogram(linkage, ax=axs[0], labels=labels, orientation='left', no_labels=True, truncate_mode=None)
axs[1].imshow(np.flip(spec_arr_order, axis=0), cmap='inferno', aspect='auto')
axs[1].axis('off')
axs[0].axis('off')
plt.show()
plt.close()

In [None]:
fit = umap.UMAP(metric='precomputed', n_neighbors=100, min_dist=0.1).fit(dist_mat)
u = fit.embedding_
plt.scatter(u[:,0], u[:,1], alpha=0.5)
plt.show()
plt.close()

In [None]:
n_clust = 6
agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')

agg.fit(dist_mat)

clust_agg = agg.labels_

In [None]:
plt.scatter(u[:,0], u[:,1], c=clust_agg, alpha=0.1, cmap='tab10')
plt.show()
plt.close()

In [None]:
colors = plt.get_cmap('tab10')(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.arange(n_clust), colors))

In [None]:
dict_lab_nnclust = dict(zip(labels_nn, clust_agg))

In [None]:
spec_arr_order = []
cluster_bar = []
for l in dn['ivl']:
    s = dict_lab_spec[l]
    spec_arr_order.append(s)
    cl = dict_lab_nnclust[l]
    c = list(dict_clust_col[cl])
    cluster_bar.append(c)

spec_arr_order = np.vstack(spec_arr_order)

cluster_bar = np.array(cluster_bar).reshape(len(cluster_bar),1,len(cluster_bar[0]))
cluster_bar.shape

fig = plt.figure(figsize=(15,5))
axs = {
    0: fig.add_axes([0.1,0.1,0.08,0.8]),
    1: fig.add_axes([0.18,0.1,0.02,0.8]),
    2: fig.add_axes([0.2,0.1,0.7,0.8])
}
hierarchy.dendrogram(linkage, ax=axs[0], labels=labels, orientation='left', no_labels=True, truncate_mode=None)
axs[1].imshow(np.flip(cluster_bar, axis=0), aspect='auto')
axs[2].imshow(np.flip(spec_arr_order, axis=0), cmap='inferno', aspect='auto')
axs[0].axis('off')
axs[1].axis('off')
axs[2].axis('off')
plt.show()
plt.close()

In [None]:
spec_dims = (10,5)


for c in np.unique(clust_agg):
    print('Cluster:', c)
    bool_c = clust_agg == c
    spec_sub = specs_nn[bool_c,:]

    # spec_sub_meansub = spec_sub - specs_med
    # spec_sub_meansub[spec_sub_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':1,'alpha':0.2,'color':color})
    ax.set_ylim(0,1000)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in np.unique(clust_agg):
    labels_sub = labels_nn[clust_agg == cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = dict_clust_col[cl]
            # color = np.array(dict_clust_col[cl] + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix, axes_off=False)

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix, axes_off=False)
ax.imshow(raw_mge_shift_spot_rgb)

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)

Look at specra relative to the mean

In [None]:
spec_dims = (10,5)


for c in np.unique(clust_agg):
    print('Cluster:', c)
    bool_c = clust_agg == c
    spec_sub = specs_nn[bool_c,:]

    spec_sub_meansub = spec_sub - specs_med
    spec_sub_meansub[spec_sub_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims)
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub_meansub, {'lw':1,'alpha':0.1,'color':color})
    ax.set_ylim(0,600)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

Zoom in on assocs

In [None]:
c=[2900,1750]
d=[1000,1000]

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.set_ylim([c[0]+d[0], c[0]])
ax.set_xlim([c[1], c[1]+d[1]])

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(raw_mge_shift_spot_rgb)
ax.set_ylim([c[0]+d[0], c[0]])
# ax.set_ylim([c[0], c[0]+d[0]])
ax.set_xlim([c[1], c[1]+d[1]])

In [None]:
fig, ax, cbar = ip.plot_image(im_cell, cmap='gray', im_inches=30, clims=(0,2000), scalebar_resolution=res_mge_umpix)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)
ax.set_ylim([c[0]+d[0], c[0]])
# ax.set_ylim([c[0], c[0]+d[0]])
ax.set_xlim([c[1], c[1]+d[1]])

## Get nearest clusters to spots

Get resized hipr properties

In [None]:
def add_edge(hipr_sum_res, edge):
    hsr_shape = hipr_sum_res.shape
    hipr_sum_res_edge = np.zeros(np.array(hsr_shape) + 2*edge)
    hipr_sum_res_edge[edge:edge+hsr_shape[0],edge:edge+hsr_shape[1]] = hipr_sum_res
    return hipr_sum_res_edge

In [None]:
# hipr_res_m = hipr_res[m_h]
# hipr_res_m_edge = add_edge(hipr_res_m, edge)

hipr_sum_res = hipr_sums_res[m_h]
hipr_sum_res_edge = add_edge(hipr_sum_res, edge)

hipr_seg_res = hipr_segs_res[m_h]
hipr_seg_res_edge = add_edge(hipr_seg_res, edge)
hipr_seg_res_edge = hipr_seg_res_edge.astype(int)

hipr_prop_res = sf.measure_regionprops(hipr_seg_res_edge, raw=hipr_sum_res_edge)
hipr_prop_res.columns

In [None]:
hipr_seg_res_edge_rgb = ip.seg2rgb(hipr_seg_res_edge)
fig, ax, _ = ip.plot_image(hipr_seg_res_edge_rgb, im_inches=imin)
raw_mge_shift_spot = raws_mge_shift[m_h][:,:,1]
raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, (0,150))
raw_mge_shift_spot_rgb = np.dstack([raw_mge_shift_spot_norm]*4)
ax.imshow(raw_mge_shift_spot_rgb)

Nearest n neighbors

In [None]:
# Get spot coordinates
# spot_coords = prop.centroid.values[bool_area]
# spot_coords = prop.centroid.values[bool_incell]
spot_coords = prop.centroid.values
# spot_coords = prop.centroid.values[bool_area * bool_incell]
spot_coords = [list(s) for s in spot_coords]
print(prop.shape[0])
print(len(spot_coords))

In [None]:
# Nearest neighor of all spots
n_neighbors=1
reseg_coords = [list(c) for c in hipr_prop_res.centroid.values]
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
dists, inds = nbrs.kneighbors(spot_coords)

Get clusters

In [None]:
# load merged clusters
clust_dir_hires = clust_dir + '/merge'
clust_sn_fn = clust_dir_hires + '/' + sn + '_M_' + str(m_h) + '_clust_merge.yaml'
print(clust_sn_fn)
with open(clust_sn_fn, 'r') as f:
    dict_lab_clust = yaml.unsafe_load(f)
print(len(dict_lab_clust))
clusters_unq = np.unique(list(dict_lab_clust.values()))
clusters_unq

Get cluster id of neighbors, merge repeat clusters for a given spot


In [None]:
labels = hipr_prop_res.label.values
labels_nn = [labels[i] for i in inds]
clusts_nn = [[dict_lab_clust[l] for l in ls] for ls in labels_nn]
clusts_nn_unq = [np.unique(cl) for cl in clusts_nn]


Counts for each cluster

In [None]:
# Get counts for clusters 
clusts_nn_unq_flat = [cl for cln in clusts_nn_unq for cl in cln]
cl_nn_unq, cl_nn_counts = np.unique(clusts_nn_unq_flat, return_counts=True)

print(cl_nn_counts)
print(sum(cl_nn_counts))
print(cl_nn_unq)

In [None]:
# cl_nn_counts_sort = [cl for _, cl in sorted(zip(cl_nn_unq, cl_nn_counts))]
# cl_nn_unq_sort = sorted(cl_nn_unq)
# print(cl_nn_counts_sort)
# print(cl_nn_unq_sort)

In [None]:
dict_cl_nncounts = dict(zip(cl_nn_unq, cl_nn_counts))
meas_counts_arr = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_unq):
    try:
        meas_counts_arr[i] = dict_cl_nncounts[cl]
    except:
        pass
print(meas_counts_arr)
print(sum(meas_counts_arr))
print(len(spot_coords))

Simulate random spots and count nearest neighbor values

In [None]:
n = 1000

cell_coords_tup = hipr_prop_res.loc[:, "centroid"].values
cell_coords = np.array([list(c) for c in cell_coords_tup])
# cell_coords = np.array([list(eval(c)) for c in cell_coords_tup])
pix_coords = np.argwhere(mask_cell)

dict_cl_dists_sim = defaultdict(list)
sim_counts_arr = np.zeros((n, len(clusters_unq)))
for i in tqdm(range(n)):
    # Randomize spot locations
    i_sim = np.random.randint(0, pix_coords.shape[0], size=len(spot_coords))
    sim_spot_coords = pix_coords[i_sim]
    # Get nearest neighbors for each spot
    dists_sim, inds_sim = nbrs.kneighbors(sim_spot_coords)
    # Get cluster id of neighbors, merge repeated clusters ids for a given spot
    labels_nn_sim = [labels[i] for i in inds_sim]
    clusts_nn_sim = [[dict_lab_clust[l] for l in ls] for ls in labels_nn_sim]
    clusts_nn_unq_sim = [np.unique(cl) for cl in clusts_nn_sim]
    # Get counts for clusters
    clusts_nn_unq_flat_sim = [cl for cln in clusts_nn_unq_sim for cl in cln]
    cl_nn_unq_sim, cl_nn_counts_sim = np.unique(clusts_nn_unq_flat_sim, return_counts=True)
    # Add counts to array
    dict_cl_nncounts_sim = dict(zip(cl_nn_unq_sim, cl_nn_counts_sim))
    for j, cl in enumerate(clusters_unq):
        try:
            sim_counts_arr[i, j] = dict_cl_nncounts_sim[cl]
        except:
            pass    


    

Get Z scores 

In [None]:
mu = np.mean(sim_counts_arr, axis=0).squeeze()
sig = np.std(sim_counts_arr, axis=0).squeeze()
sim_z = (sim_counts_arr - mu) / sig
meas_z = (meas_counts_arr - mu) / sig

Sort by z scores

In [None]:
clusters_sort = [cl for _, cl in sorted(zip(meas_z, clusters_unq), reverse=True)]
meas_counts_arr_sort = [cl for _, cl in sorted(zip(meas_z, meas_counts_arr), reverse=True)]
meas_z_sort = sorted(meas_z, reverse=True)

In [None]:
dict_cl_simz = dict(zip(clusters_unq, sim_z.T))
sim_z_sort = np.hstack([dict_cl_simz[cl][:,None] for cl in clusters_sort])
sim_z_sort.shape

### Classify spectra
Get spectra

In [None]:
# Get spectra
with open(spec_fn.format(m_h), 'r') as f:
    dict_lab_spec = yaml.unsafe_load(f)
len(dict_lab_spec)

Get spectra for clusters

In [None]:
# Get spectra for clusters
dict_cl_spec = defaultdict(list)
dict_cl_lab = defaultdict(list)
for l, spec in dict_lab_spec.items():
    cl = dict_lab_clust[l]
    dict_cl_spec[cl].append(spec)
    dict_cl_lab[cl].append(l)

Plot spectra with varying y axis

In [None]:
# spec_dims = [2.5, 1.25]
spec_dims = [10, 5]

ft = 6


for cl in list(clusters_sort):
    specs_arr = np.array(dict_cl_spec[cl])

    fig, ax = ip.general_plot(dims=spec_dims, col="w", ft=ft)

    fsi.plot_cell_spectra(
        ax, specs_arr, {"lw": 1, "alpha": 0.2, "color": 'r'}
    )

    # ax.set_ylim(0,2000)

    ylim = ax.get_ylim()
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5, alpha=0.5)

    # ax.set_title('Cluster: ' + str(cl), color='w', ft=ft)
    # ax.set_xticks([])
    # ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    # ax.tick_params(axis='y', length=yticklength)

    # if save:
    #     mge_assoc_dir = output_dir + "/cluster_spectra" + "/M_" + str(m_h)
    #     if not os.path.exists(mge_assoc_dir):
    #         os.makedirs(mge_assoc_dir)
    #         print("Made dir:", mge_assoc_dir)
    #     out_bn = mge_assoc_dir + "/" + bn.format(m_h) + "_spec_cluster_" + str(cl)
    #     ip.save_png_pdf(out_bn)
    #     print("Wrote:", out_bn + ".png")

    print(cl)
    plt.show()
    plt.close()

Plot spectra with constant y axis

In [None]:
spec_dims = [10,5]
# spec_dims = [2.5, 1.25]
ft = 6
ylim = (0,1000)


for cl in list(clusters_sort):
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec[cl])

    fig, ax = ip.general_plot(dims=spec_dims, col="w", ft=ft)

    fsi.plot_cell_spectra(
        ax, specs_arr, {"lw": 1, "alpha": 0.2, "color": 'r'}
    )

    ax.set_ylim(ylim[0], ylim[1])

    # ylim = ax.get_ylim()
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5, alpha=0.5)

    # ax.set_title('Cluster: ' + str(cl), color='w', ft=ft)
    # ax.set_xticks([])
    # ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    # ax.tick_params(axis='y', length=yticklength)



    plt.show()
    plt.close()

In [None]:
# dict_cl_sciname = {
#     20: 'Streptococcus',
#     16: 'Corynebacterium',
#     11: 'Actinomyces',
#     23: 'Lautropia',
#     24: 'Fusobacterium',
#     22: 'unclassified'
# } # 2023_07_01_slide_7_fov_02 tile 1

dict_cl_sciname = {
    23: 'Streptococcus',
    20: 'Prevotella',
    11: 'Leptotrichia',
    1: 'Actinomyces',
    22: 'Gemella',
    21: 'Treponema'
} # 2024_01_10_bmgshort_slide_bmg_fov_01_M_2

Get color dict

In [None]:
# Rearrange colors
col_list = list(plt.get_cmap('tab20').colors)
col_1 = [col_list[i] for i in np.arange(0,20,2)]
remove_inds_1 = []
# remove_inds_1 = [6,7]
col_1 = [c for i,c in enumerate(col_1) if i not in remove_inds_1]
col_2 = [col_list[i] for i in np.arange(1,20,2)]
remove_inds_2 = []
# remove_inds_2 = [6,7]
col_2 = [c for i,c in enumerate(col_2) if i not in remove_inds_2]
# del col_2[7]
col_list_re = col_1 + col_2 + [(0,1,0), (1,1,0)]
ln = len(col_list_re)
barlist = plt.bar(np.arange(ln), np.ones(ln))
for b,c in zip(barlist, col_list_re):
    b.set_color(c)
_ = plt.xticks(np.arange(ln))

In [None]:
dict_sciname_cind = {
    'Streptococcus': 1,
    'Prevotella': 9,
    'Corynebacterium': 2,
    'Leptotrichia': 8,
    'Actinomyces': 2,
    'Lautropia': 3,
    'Fusobacterium': 14,
    'Gemella': 3,
    'Treponema': 5,
    'unclassified': 7,
}
# col_index_order = [0, 9, 1, 8, 2, 3, 4, 7, 7]

dict_sciname_col = {sciname:col_list_re[i] for sciname, i in dict_sciname_cind.items()}
# for sc, i in zip(sciname_order, col_index_order):
#     dict_sciname_col[sc] = col_list_re[i]

In [None]:
# n_assoc = 3
# cind_first_two = [1,0,2]

# cl_first_two = clusters_sort[:n_assoc]
# col_first_two = [col_list_re[i] for i in cind_first_two]
# dict_bc_col = dict(zip(cl_first_two, col_first_two))
# print(cl_first_two)

# cl_unq, cl_counts = np.unique(list(dict_lab_clust.values()), return_counts=True)
# cl_unq_ = [cl for cl in cl_unq if cl not in cl_first_two]
# cl_counts_ = [cnt for cl, cnt in zip(cl_unq, cl_counts) if cl not in cl_first_two]
# print(cl_unq_)
# print(cl_counts_)

# cl_next_few = [cl for _, cl in sorted(zip(cl_counts_, cl_unq_), reverse=True)]
# print(cl_next_few)
# # cl_next_few = clusters_sort[n_assoc:]

# # cind_next_few = [6, 1, 7, 5, 4, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
# cind_next_few = [5, 6, 7, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
# col_next_few = [col_list_re[i] for i in cind_next_few]
# for cl, col in zip(cl_next_few, col_next_few):
#     dict_bc_col[cl] = col

# cl_color_order = cl_first_two + cl_next_few


# ln = len(cl_color_order)
# barlist = plt.bar(np.arange(ln), np.ones(ln))
# for b, cl in zip(barlist, cl_color_order):
#     c = dict_bc_col[cl]
#     b.set_color(c)
# _ = plt.xticks(ticks=np.arange(ln), labels=cl_color_order)

# dict_bc_col = dict(zip(clusters_sort, col_list))
dict_bc_col = dict(zip(clusters_sort, col_list_re))


Plot assoc

In [None]:
# Plot z score number of spots associated with group




save=False




# dims=[5,2]
dims = [1.5, 1]
xlab_rotation = 45
pval_rotation = 60
marker = "."
marker_size = 10
text_dist = 0.1
# ft = 12
ft = 6
ylimadj = 0.1
true_frac_llim = 0
line_col = "k"
box_line_col = (0.5, 0.5, 0.5)
box_col = "w" if line_col == 'k' else 'k'
yticklength = 2

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
# Plot simulation
boxplot = ax.boxplot(
    sim_z_sort,
    patch_artist=True,
    showfliers=False,
    boxprops=dict(facecolor=box_col, color=box_line_col),
    capprops=dict(color=box_line_col),
    whiskerprops=dict(color=box_line_col),
    medianprops=dict(color=box_line_col),
)
# for m in boxplot['medians']:
#     m.set_color(line_col)
# for b in boxplot['boxes']:
#     b.set_edgecolor(line_col)
#     b.set_facecolor(box_col)

# Plot measured value
ys = []
xlab = []
cols = []
x = 1
for i, cl in enumerate(clusters_sort):
    # for i, bc_tax in zip(ind_order, barcodes_int_order):
    # sci_name = dict_bc_sciname[bc_tax]
    # sci_name = dict_bc_sciname[cl]
    sci_name = dict_cl_sciname[cl]
    color = dict_sciname_col[sci_name]
    # color = dict_bc_col[cl]
    # xlab.append(sci_name)

    # try:
    #     color = col_dict[sci_name]
    # except:
    #     continue

    cols.append(color)
    true_frac = meas_z_sort[i]
    print(sci_name, ': ', true_frac)
    # true_frac = true_count / n_cells
    _ = ax.plot(x, true_frac, marker=marker, ms=marker_size, color=color)
    # # Plot p value
    # sim_vals_i = sim_counts_arr[:, i]
    # # sim_vals = sim_arr[:,i,h] / n_cells
    # sim_mean = np.mean(sim_vals)
    # if true_frac > sim_mean:
    #     # number of simulations with value greater than observed
    #     r_ = sum(sim_vals_i > true_frac)
    # else:
    #     # number of simulations with value less than observed
    #     r_ = sum(sim_vals_i < true_frac)
    # # P value
    # p_ = r_ / n
    # # Get text location
    # q1, q3 = np.quantile(sim_vals, [0.25, 0.75])
    # q4 = q3 + 1.5 * (q3 - q1)
    # # y_m = np.max(sim_vals)
    # # y = y_m if y_m > true_frac else true_frac
    # y = q4 if q4 > true_frac else true_frac
    # y += text_dist
    # ys.append(y)
    # if true_frac < true_frac_llim:
    #     t = ''
    # elif (p_ > 0.05):
    #     t = ''
    # elif (p_ > 0.001) and (p_ <= 0.05):
    #     t = str("p=" + str(p_))
    # else:
    #     t = str("p<0.001")
    # _ = ax.text(x, y, t, fontsize=ft, ha='left',va='bottom', rotation=pval_rotation, rotation_mode='anchor',
    #         color=line_col)
    x += 1
ax.set_xticklabels([], rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
# ax.set_xticklabels(xlab, rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
# ax.tick_params(axis='x',direction='out')
ax.set_xticks([])
# ax.tick_params(axis="y", length=yticklength)
ax.set_yticks(ticks=[-5,0,5,10], labels=[])
ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")

# ylims = ax.get_ylim()
# ax.set_ylim(ylims[0], np.max(ys) + ylimadj)

if save:
    mge_assoc_dir = output_dir + "/mge_association/nearest_neighbor"
    if not os.path.exists(mge_assoc_dir):
        os.makedirs(mge_assoc_dir)
        print("Made dir:", mge_assoc_dir)
    out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_zscore_association'
    ip.save_png_pdf(out_bn)
    print('Wrote:', out_bn + '.png')


Plot spot fraction 

In [None]:
meas_spotfrac_arr = np.array(meas_counts_arr_sort) / len(spot_coords)

# ft=12
# ft=6
line_col = 'k'
width=0.4
# dims=[5,2]
dims=[1,0.6]
yticklength=1

# sci_name_order = [dict_bc_sciname[bc] for bc in barcodes_int_order]
sci_name_order = [dict_cl_sciname[cl] for cl in clusters_sort]
color_order = [dict_sciname_col[sc] for sc in sci_name_order]
# color_order = [dict_bc_col[sc] for sc in clusters_sort]

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(meas_spotfrac_arr.shape[0]),
        meas_spotfrac_arr,
        width=width,
        color=color_order,
        edgecolor=line_col
        )
print(meas_spotfrac_arr)
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.set_xticks([])
ax.set_yticks(ticks=[0,0.2,0.4,0.6], labels=[])
ax.tick_params(axis='y', length=yticklength)

if save:
    mge_assoc_dir = output_dir + "/mge_association/nearest_neighbor"
    if not os.path.exists(mge_assoc_dir):
        os.makedirs(mge_assoc_dir)
        print("Made dir:", mge_assoc_dir)
    out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_fracspot'
    ip.save_png_pdf(out_bn)
    print('Wrote:', out_bn + '.png')

# mge_assoc_dir = output_dir + '/mge_association'
# if not os.path.exists(mge_assoc_dir): 
#     os.makedirs(mge_assoc_dir)
#     print('Made dir:',mge_assoc_dir)
# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_bar_seg_nn_frac_spot_association_0_5um'
# ip.save_png_pdf(out_bn)

Plot fraction cluster assoc with spot

In [None]:
# Get cluster counts without repeateing the same label
labels_nn_flat = [l for ls in labels_nn for l in ls]
labels_nn_unq = np.unique(labels_nn_flat)
clusts_labels_nn_unq = [dict_lab_clust[l] for l in labels_nn_unq]
cl_nn_unq2, cl_nn_counts2 = np.unique(clusts_labels_nn_unq, return_counts=True)
dict_cl_nncounts2 = dict(zip(cl_nn_unq2, cl_nn_counts2))
meas_counts_arr2 = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_sort):
    try:
        meas_counts_arr2[i] = dict_cl_nncounts2[cl]
    except:
        pass
print(meas_counts_arr2)

In [None]:
# Get counts for each cluster
cl_all, cl_counts = np.unique(list(dict_lab_clust.values()), return_counts=True)
dict_cl_counts_all = dict(zip(cl_all, cl_counts))
all_counts_arr = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_sort):
    try:
        all_counts_arr[i] = dict_cl_counts_all[cl]
    except:
        pass
print(all_counts_arr)

In [None]:
# Frac taxon assoc with spot
cl_fracnearspots = meas_counts_arr2 / all_counts_arr

# dims=[5,2]
dims=[1.5,0.6]
yticklength=2
# ft=12
# ft=6
line_col = 'k'
width=0.4


fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(cl_fracnearspots.shape[0]),
        cl_fracnearspots,
        width=width,
        color=color_order,
        edgecolor=line_col
        )
print(cl_fracnearspots)
ax.set_xticks([])
# ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(axis='y', length=yticklength)

if save:
    mge_assoc_dir = output_dir + "/mge_association/nearest_neighbor"
    if not os.path.exists(mge_assoc_dir): 
        os.makedirs(mge_assoc_dir)
        print('Made dir:',mge_assoc_dir)
    out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_fraccell'
    ip.save_png_pdf(out_bn)
    print('Wrote:', out_bn + '.png')


### Plot spectra
Get spectra

In [None]:
# Get spectra
with open(spec_fn.format(m_h), 'r') as f:
    dict_lab_spec = yaml.unsafe_load(f)
len(dict_lab_spec)

Get spectra for clusters

In [None]:
# Get spectra for clusters
dict_cl_spec = defaultdict(list)
dict_cl_lab = defaultdict(list)
for l, spec in dict_lab_spec.items():
    cl = dict_lab_clust[l]
    dict_cl_spec[cl].append(spec)
    dict_cl_lab[cl].append(l)

Plot spectra with varying y axis

In [None]:
spec_dims = [2.5, 1.25]
# spec_dims = [10, 5]

ft = 6


for cl in list(clusters_sort):
    specs_arr = np.array(dict_cl_spec[cl])

    sci_name = dict_cl_sciname[cl]
    color = dict_sciname_col[sci_name]
    # color = dict_bc_col[cl]

    fig, ax = ip.general_plot(dims=spec_dims, col="w", ft=ft)

    fsi.plot_cell_spectra(
        ax, specs_arr, {"lw": 1, "alpha": 0.2, "color": color}
    )

    # ax.set_ylim(0,2000)

    ylim = ax.get_ylim()
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5, alpha=0.5)

    # ax.set_title('Cluster: ' + str(cl), color='w', ft=ft)
    # ax.set_xticks([])
    # ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    # ax.tick_params(axis='y', length=yticklength)

    if save:
        mge_assoc_dir = output_dir + "/cluster_spectra" + "/M_" + str(m_h)
        if not os.path.exists(mge_assoc_dir):
            os.makedirs(mge_assoc_dir)
            print("Made dir:", mge_assoc_dir)
        out_bn = mge_assoc_dir + "/" + bn.format(m_h) + "_spec_cluster_" + str(cl)
        ip.save_png_pdf(out_bn)
        print("Wrote:", out_bn + ".png")

    print(cl)
    plt.show()
    plt.close()

Plot spectra with constant y axis

In [None]:
# spec_dims = [10,5]
spec_dims = [2.5, 1.25]
ft = 6
ylim = (0,1000)


for cl in list(clusters_sort):
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec[cl])
    
    sci_name = dict_cl_sciname[cl]
    color = dict_sciname_col[sci_name]
    # color = dict_bc_col[cl]

    fig, ax = ip.general_plot(dims=spec_dims, col="w", ft=ft)

    fsi.plot_cell_spectra(
        ax, specs_arr, {"lw": 1, "alpha": 0.2, "color": color}
    )

    ax.set_ylim(ylim[0], ylim[1])

    # ylim = ax.get_ylim()
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5, alpha=0.5)

    # ax.set_title('Cluster: ' + str(cl), color='w', ft=ft)
    # ax.set_xticks([])
    # ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")
    # ax.tick_params(axis='y', length=yticklength)

    if save:
        mge_assoc_dir = output_dir + '/cluster_spectra' + '/M_' + str(m_h)
        if not os.path.exists(mge_assoc_dir):
            os.makedirs(mge_assoc_dir)
            print('Made dir:',mge_assoc_dir)
        out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_spec_cluster_' + str(cl) + '_std_y'
        ip.save_png_pdf(out_bn)
        print('Wrote:', out_bn + '.png')

    plt.show()
    plt.close()

Plot clusters on overlay

In [None]:
# MGE image
# mge_clims = clim_mge
mge_clims = (100, 125)

raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, mge_clims)
raw_mge_shift_spot_rgb = np.dstack([
    raw_mge_shift_spot_norm,
    np.zeros_like(raw_mge_shift_spot_norm),
    raw_mge_shift_spot_norm,
    raw_mge_shift_spot_norm
])    

In [None]:
dict_clust_lab = defaultdict(list)
for lab, cl in dict_lab_clust.items():
    dict_clust_lab[cl].append(lab)

In [None]:
clusters_toplot = clusters_sort
clusters_notoplot = []

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in clusters_toplot:
    if cl not in clusters_notoplot:
        labels_sub = dict_clust_lab[cl]
        for i, row in hipr_prop_res.iterrows():
            l = row.label
            if l in labels_sub:
                b = row.bbox
                b = eval(b) if isinstance(b, str) else b
                rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
                seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
                # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
                # color = np.array([0,0,1,0.5])
                sci_name = dict_cl_sciname[cl]
                col = dict_sciname_col[sci_name]
                # col = dict_bc_col[cl]
                color = np.array(col + (1,))
                # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
                rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
                # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
                rgb_cell = seg_sub[:,:,None] * color[None,:]
                hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
imin=10

res_mge_umpix = res_mge * 10**6
dpi = np.max(im_cell.shape) // imin

dpi

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)

# fig, ax, cbar = ip.plot_image(
#     np.zeros_like(im_cell),
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )

ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)
# ax.imshow(seg_mge_shift_spot_rgb)

if save:
    mge_overlay_dir = output_dir + "/mge_overlay/clust_merge"
    if not os.path.exists(mge_overlay_dir):
        os.makedirs(mge_overlay_dir)
        print("Made dir:", mge_overlay_dir)
    plt.figure(fig)
    out_bn = mge_overlay_dir + "/" + bn.format(m_h) + "_overlay_mergeclusts_mge"
    ip.save_png_pdf(out_bn, dpi=dpi)
    print("Wrote:", out_bn + ".png")

Plot assoc clusters on overlay

In [None]:
index_toplot = [0,1]
clusters_toplot = [clusters_sort[i] for i in index_toplot]

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in clusters_toplot:
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
            seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            sci_name = dict_cl_sciname[cl]
            col = dict_sciname_col[sci_name]
            color = np.array(col + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:,:,None] * color[None,:]
            hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell

In [None]:
# fig, ax, cbar = ip.plot_image(
    #     im_cell,
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)
# ax.imshow(seg_mge_shift_spot_rgb)
if save:
    mge_overlay_dir = output_dir + '/mge_overlay/clust_merge'
    if not os.path.exists(mge_overlay_dir):
        os.makedirs(mge_overlay_dir)
        print('Made dir:',mge_overlay_dir)
    plt.figure(fig)
    out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_overlay_mergeclusts_mge_assoc'
    ip.save_png_pdf(out_bn, dpi=dpi)
    print('Wrote:', out_bn + '.png')

In [None]:
# FOR ROHIT #

# from skimage.measure import regionprops
# import matplotlib.pyplot as plt


# def seg_2_rgb(seg, dict_lab_col, dict_lab_bbox):
#     """
#     Convert segmentation numpy array to RGB image

#     seg - 2d array(int or str) - objects are adjacent pixels with the same value
#     dict_lab_col - dictionary(int or str -> list or tup) - map objects in seg to an RGB or RGBA color
#     dict_lab_bbox - dictinoary(int or str -> list(int) or tup(int) or str(comma separated integers)) -
#         map objects in seg to a bounding box
#         bounding box has format (row min, column min, row max, column max)
#     """
#     rgb_shape = len(dict_lab_col[list(dict_lab_col.keys())[0]])  # Determine RGB or RGBA
#     rgb = np.zeros(seg.shape + (rgb_shape,))  # initiate empty image
#     for l, b in dict_lab_bbox.items():
#         b = eval(b) if isinstance(b, str) else b  # convert to list if string
#         rgb_sub = rgb[b[0] : b[2], b[1] : b[3]]  # extract current RGB bbox
#         seg_sub = seg[b[0] : b[2], b[1] : b[3]] == l  # Extract object
#         color = np.array(dict_lab_col[l])  # Get the object's new color
#         rgb_nonobj = rgb_sub * np.dstack(
#             [~seg_sub] * rgb_shape
#         )  # keep everything already existing in the rgb
#         rgb_obj = seg_sub[:, :, None] * color[None, :]  # recolor the object
#         rgb[b[0] : b[2], b[1] : b[3], :] = (
#             rgb_nonobj + rgb_obj
#         )  # write the recolored object to the rgb
#     return rgb


# def measure_regionprops(seg, raw=None):
#     '''
#     Measure the properties of segmented objects in a numpy array

#     seg - 2d array(int) - objects are adjacent pixels with the same value
#     raw - 2d array(int or float) - grayscale image used to measure intensity values in seg objects
#     '''
#     if isinstance(raw, type(None)):
#         raw = np.zeros(seg.shape)
#     sp_ = regionprops(seg, intensity_image=raw)
#     properties = [
#         "label",
#         "centroid",
#         "area",
#         "max_intensity",
#         "mean_intensity",
#         "min_intensity",
#         "bbox",
#         "major_axis_length",
#         "minor_axis_length",
#         "orientation",
#         "eccentricity",
#         "perimeter",
#     ]
#     df = pd.DataFrame([])
#     for p in properties:
#         df[p] = [s[p] for s in sp_]
#     for j in range(2):
#         df["centroid-" + str(j)] = [r["centroid"][j] for i, r in df.iterrows()]
#     for j in range(4):
#         df["bbox-" + str(j)] = [r["bbox"][j] for i, r in df.iterrows()]
#     # regions = regionprops_table(seg, intensity_image = raw,
#     #                             properties=['label','centroid','area','max_intensity',
#     #                             'mean_intensity','min_intensity', 'bbox',
#     #                             'major_axis_length', 'minor_axis_length',
#     #                             'orientation','eccentricity','perimeter'])
#     # return pd.DataFrame(regions)
#     return df



# #########
# # EXAMPLE
# #########
# seg = np.zeros((100, 100), dtype=int)
# seg[40:60, 40:60] = 1

# props = measure_regionprops(seg)
# dict_lab_bbox = dict(zip(props.label.values, props.bbox.values))

# dict_lab_col = {1: (1, 0, 1)}

# rgb = seg_2_rgb(seg, dict_lab_col, dict_lab_bbox)

# plt.imshow(seg, cmap="gray")
# plt.show()
# plt.close()
# plt.imshow(rgb)
# print(rgb.shape)

## Get nearest spectra to spots
Get spectra

In [None]:
labels_nn = hipr_prop_res.label.values[inds.squeeze()]
spec_nn = np.array([dict_lab_spec[l] for l in labels_nn])

Plot spectra together

In [None]:
fig, ax = ip.general_plot(dims=spec_dims, col="w", ft=ft)

fsi.plot_cell_spectra(
    ax, spec_nn, {"lw": 1, "alpha": 0.2, "color": 'r'}
)

Project nearest neighbors on image

In [None]:
col_list_re[0]

In [None]:
hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for i, row in hipr_prop_res.iterrows():
    l = row.label
    if l in labels_nn:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3]]
        seg_sub = hipr_seg_res_edge[b[0] : b[2], b[1] : b[3]] == l
        # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        # color = np.array([0,0,1,0.5])
        color = np.array(col_list_re[0] + (1,))
        # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
        rgb_noncell = rgb_sub * np.dstack([~seg_sub] * 4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:, :, None] * color[None, :]
        hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3], :] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)

Cluster spectra

In [None]:
dist_mat_cond = pdist(spec_nn, fhc.channel_cosine_intensity_allonev2)

In [None]:
dist_mat = squareform(dist_mat_cond)

In [None]:
linkage = hierarchy.linkage(dist_mat_cond, method="complete")

In [None]:
labels = list(dict_lab_spec.keys())

fig, ax = ip.general_plot(dims=(15, 15))
dn = hierarchy.dendrogram(linkage, truncate_mode=None)
# dn = hierarchy.dendrogram(linkage, labels=inds, truncate_mode='lastp')
# ax.axhline(t[i])
ylims = [round(l, 2) for l in ax.get_ylim()]
ax.set_yticks(np.arange(ylims[0], ylims[1], 0.1))
ax.grid(axis="y")
plt.show()
plt.close()

In [None]:
fit = umap.UMAP(metric="precomputed", n_neighbors=100, min_dist=0.1).fit(dist_mat)
u = fit.embedding_
plt.scatter(u[:, 0], u[:, 1], alpha=0.05)
plt.show()
plt.close()

In [None]:
n_clust = 4
agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')

agg.fit(dist_mat)

clust_agg = agg.labels_

In [None]:
# n_clust = np.max(clust_agg) + 1
dict_clust_col = dict(zip(np.arange(n_clust), col_list_re))
clust_cols = [dict_clust_col[cl] for cl in clust_agg]

In [None]:
plt.scatter(u[:, 0], u[:, 1], c=clust_cols, alpha=0.2)
plt.show()
plt.close()

Plot cluster spectra

In [None]:
spec_dims = (10,5)


for c in np.unique(clust_agg):
    print('Cluster:', c)
    bool_c = clust_agg == c
    spec_sub = spec_nn[bool_c,:]

    # spec_sub_meansub = spec_sub - specs_med
    # spec_sub_meansub[spec_sub_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims, col='w')
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':2,'alpha':0.2,'color':color})

    ylim = ax.get_ylim()

    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)

    # ax.set_ylim(0,1000)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

Project clusters on image

In [None]:
dict_lab_clust = dict(zip(labels_nn, clust_agg))

In [None]:
dict_clust_lab = defaultdict(list)
for lab, cl in dict_lab_clust.items():
    dict_clust_lab[cl].append(lab)

In [None]:

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in np.unique(clust_agg):
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3]]
            seg_sub = hipr_seg_res_edge[b[0] : b[2], b[1] : b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = np.array(dict_clust_col[cl] + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell = rgb_sub * np.dstack([~seg_sub] * 4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:, :, None] * color[None, :]
            hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3], :] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# fig, ax, cbar = ip.plot_image(
#     im_cell,
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)



Manually cluster spectra

In [None]:
frac = 0.9
frac_1 = 0.4

clust_man = []
for s in spec_nn:
    if s[47] < frac*s[31]:
        clust_man.append(0)
    elif s[43] > frac_1*s[47]:
        clust_man.append(1)
    else:
        clust_man.append(2)


In [None]:
dict_clust_col = dict(zip(np.arange(len(np.unique(clust_man))), col_list_re))
clust_cols = [dict_clust_col[cl] for cl in clust_man]

In [None]:
plt.scatter(u[:, 0], u[:, 1], c=clust_cols, alpha=0.2)
plt.show()
plt.close()

In [None]:
spec_dims = (10, 5)


for c in np.unique(clust_man):
    print("Cluster:", c)
    bool_c = clust_man == c
    spec_sub = spec_nn[bool_c, :]

    # spec_sub_meansub = spec_sub - specs_med
    # spec_sub_meansub[spec_sub_meansub < 0] = 0

    fig, ax = ip.general_plot(dims=spec_dims, col="w")
    color = dict_clust_col[c]
    fsi.plot_cell_spectra(ax, spec_sub, {"lw": 2, "alpha": 0.2, "color": color})

    ylim = ax.get_ylim()

    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5)

    # ax.set_ylim(0,1000)
    # ax.set_ylim(0,2**16)
    plt.plot()
    plt.show()

In [None]:
dict_lab_clust = dict(zip(labels_nn, clust_man))

In [None]:
dict_clust_lab = defaultdict(list)
for lab, cl in dict_lab_clust.items():
    dict_clust_lab[cl].append(lab)

In [None]:

hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for cl in np.unique(clust_man):
    labels_sub = dict_clust_lab[cl]
    for i, row in hipr_prop_res.iterrows():
        l = row.label
        if l in labels_sub:
            b = row.bbox
            b = eval(b) if isinstance(b, str) else b
            rgb_sub = hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3]]
            seg_sub = hipr_seg_res_edge[b[0] : b[2], b[1] : b[3]] == l
            # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
            # color = np.array([0,0,1,0.5])
            color = np.array(dict_clust_col[cl] + (1,))
            # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
            rgb_noncell = rgb_sub * np.dstack([~seg_sub] * 4)
            # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
            rgb_cell = seg_sub[:, :, None] * color[None, :]
            hipr_seg_res_clustrgb[b[0] : b[2], b[1] : b[3], :] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# fig, ax, cbar = ip.plot_image(
#     im_cell,
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )
ax.imshow(hipr_seg_res_clustrgb)
ax.imshow(raw_mge_shift_spot_rgb)



Filter spots by shape params

In [None]:
a = 1

In [None]:
thresh_ecc = 0.8

fig, ax = ip.general_plot(dims=(10,5))
ax.scatter(np.arange(prop.shape[0]), prop.eccentricity.sort_values().values)
xlims = ax.get_xlim()
ax.plot(xlims, [thresh_ecc]*2, color='k')

In [None]:
spot_lab_sub = prop.loc[prop.eccentricity < thresh_ecc, 'label'].values.squeeze()
spot_lab_sub.shape

Project unfiltered spots onto image

In [None]:
mega_seg_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for i, row in prop.iterrows():
    l = row.label
    # if l in spot_lab_sub:
    b = row.bbox
    b = eval(b) if isinstance(b, str) else b
    rgb_sub = mega_seg_clustrgb[b[0] : b[2], b[1] : b[3]]
    seg_sub = im_seg[b[0] : b[2], b[1] : b[3]] == l
    # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
    # color = np.array([0,0,1,0.5])
    color = np.array((1, 0, 1, 1))
    # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
    rgb_noncell = rgb_sub * np.dstack([~seg_sub] * 4)
    # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
    rgb_cell = seg_sub[:, :, None] * color[None, :]
    mega_seg_clustrgb[b[0] : b[2], b[1] : b[3], :] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# fig, ax, cbar = ip.plot_image(
#     im_cell,
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )
ax.imshow(mega_seg_clustrgb)
# ax.imshow(raw_mge_shift_spot_rgb)

Project filtered spots onto image

In [None]:
mega_seg_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
# scinames_toplot = ['Selenomonas', 'Lautropia']
# clusters_toplot = [dict_sciname_bc[sc] for sc in scinames_toplot]
for i, row in prop.iterrows():
    l = row.label
    if l in spot_lab_sub:
        b = row.bbox
        b = eval(b) if isinstance(b, str) else b
        rgb_sub = mega_seg_clustrgb[b[0] : b[2], b[1] : b[3]]
        seg_sub = im_seg[b[0] : b[2], b[1] : b[3]] == l
        # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
        # color = np.array([0,0,1,0.5])
        color = np.array((1,0,1,1))
        # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
        rgb_noncell = rgb_sub * np.dstack([~seg_sub] * 4)
        # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
        rgb_cell = seg_sub[:, :, None] * color[None, :]
        mega_seg_clustrgb[b[0] : b[2], b[1] : b[3], :] = rgb_noncell + rgb_cell

In [None]:
fig, ax, cbar = ip.plot_image(
    im_cell,
    cmap="gray",
    im_inches=30,
    clims=(0, 2000),
    scalebar_resolution=res_mge_umpix,
)
# fig, ax, cbar = ip.plot_image(
#     im_cell,
#     cmap="gray",
#     im_inches=imin,
#     clims=(0, 2000),
#     scalebar_resolution=res_mge_umpix,
# )
ax.imshow(mega_seg_clustrgb)
# ax.imshow(raw_mge_shift_spot_rgb)

## Collect nearby pixels for each spot and inspect spectra

In [None]:
# shp = hipr_sum_res_edge.shape
# pix_inds = [[i,j] for i in range(shp[0]) for j in range(shp[1])]
# pix_inds[:10]

In [None]:
hipr_res_m = hipr_res[m_h]
shp = hipr_res_m.shape
shp_edge = [s + 2 * edge for s in shp[:2]] + [shp[2]]
hipr_res_m_edge = np.zeros(shp_edge)
hipr_res_m_edge[edge : edge + shp[0], edge : edge + shp[1], :] = hipr_res_m
hipr_res_m_edge.shape

In [None]:
hipr_res_m_edge_gauss = np.dstack([
    sf.ndi.gaussian_filter(hipr_res_m_edge[:,:,i], sigma=3) 
    for i in range(hipr_res_m_edge.shape[2])
])


In [None]:
rng = 5

specs_near_pix = []
for s in spot_coords:
    r, c = [int(s_) for s_ in s]
    spec = hipr_res_m_edge_gauss[r - rng : r + rng, c - rng : c + rng, :]
    spec = spec[np.ones((rng * 2, rng * 2), dtype=bool)]
    specs_near_pix.append(spec)
    


In [None]:

for spec_sub in specs_near_pix[:10]:
    fig, ax = ip.general_plot(dims=spec_dims, col='w')
    fsi.plot_cell_spectra(ax, spec_sub, {'lw':1,'alpha':0.1,'color':'r'})

    ylim = ax.get_ylim()

    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)

In [None]:
sum_thresh = 18000

specs_near_pix_sums = []
for s in np.vstack(specs_near_pix):
    specs_near_pix_sums.append(np.sum(s))

specs_near_pix_sums = np.sort(specs_near_pix_sums)
fig, ax = ip.general_plot()
ax.scatter(np.arange(len(specs_near_pix_sums)), specs_near_pix_sums)
xlims = ax.get_xlim()
ax.plot(xlims, [sum_thresh]*2, 'k')

In [None]:
specs_near_pix_high = []
for sp in specs_near_pix:
    sp_ = []
    for s in sp:
        if np.sum(s) > sum_thresh:
            sp_.append(s)
    specs_near_pix_high.append(sp_)


In [None]:
for sp in specs_near_pix_high:
    if sp:
        fig, ax = ip.general_plot(dims=spec_dims, col='w')
        fsi.plot_cell_spectra(ax, np.array(sp), {'lw':2,'alpha':0.1,'color':'r'})

        ylim = ax.get_ylim()

        xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
        for x in xs:
            ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)
        plt.show()
        plt.close()

In [None]:
specs_near_pix_high = []
for sp in specs_near_pix:
    sp_ = []
    for s in sp:
        if np.sum(s) > sum_thresh:
            s_ = s
        else:
            s_ = []
        sp_.append(s_)
        
    specs_near_pix_high.append(sp_)

In [None]:
for sc, sp in zip(spot_coords, spex_near_pix_high):
    k = 0
    for i in range(rng*2):
        i_ = sc[0] - rng + i
        for j in range(rng*2):
            j_ = sc[1] - rng + j
            if sp[k]:
                
            k += 1


## For each spot, look at shape and spectrum of nearest cell


In [None]:
b

In [None]:
adj = 20
clim = (50, 300)
ymax = 1000

for i, row in prop[bool_area * bool_incell].iterrows():
    # Get bounding box
    b = row.centroid
    b = eval(b) if isinstance(b, str) else b
    b = [int(b_) for b_ in b]
    # extract image
    raw_sub = raw_mge_shift_spot[b[0] - adj : b[0] + adj, b[1] - adj : b[1] + adj]
    # plot image
    ip.plot_image(raw_sub, cmap="inferno", axes_off=False, clims=clim)
    plt.show()
    plt.close()
    # plot nearest spectrum
    spec = spec_nn[i]
    fig, ax = ip.general_plot(dims=(10, 5), col="w", ft=12)
    ax.set_ylim(0, ymax)
    ylim = ax.get_ylim()
    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5)

    fsi.plot_cell_spectra(ax, spec[None, :], {"lw": 1, "alpha": 1, "color": "r"})
    plt.show()
    plt.close()

### Merge clusters

In [None]:
np.max(list(dict_lab_clust.values()))

In [None]:
dict_lab_clust_new = dict_lab_clust.copy()

In [None]:
clusters_tomerge = [[8,11,13,14,15], [2,0,9,3,1,6,12],[10,4,17], [7]]
labels_cl_new = []
for clusts in clusters_tomerge:
    clnew = np.max(list(dict_lab_clust_new.values())) + 1
    for cl in clusts:
        for l in dict_clust_lab[cl]:
            dict_lab_clust_new[l] = clnew


In [None]:
dict_cl_spec_new = defaultdict(list)
for l, cl in dict_lab_clust_new.items():
    spec = dict_lab_spec[l]
    dict_cl_spec_new[cl].append(spec)

Plot spectra

In [None]:
clusters_unq_new = list(dict_cl_spec_new.keys())

In [None]:
dict_bc_col_new = dict(zip(clusters_unq_new, col_list_re))

In [None]:
for cl in list(dict_cl_spec_new.keys()):
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec_new[cl])

    fig, ax = ip.general_plot(dims=spec_dims, col='w')
    fsi.plot_cell_spectra(ax, specs_arr, {'lw':1,'alpha':0.2,'color':dict_bc_col_new[cl]})

    # ax.set_ylim(0,2000)
    ylim = ax.get_ylim()

    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)

    plt.show()
    plt.close()

### Redo spatial association

Get measured associatoin

In [None]:
# Nearest neighor of all spots
n_neighbors=1
reseg_coords = [list(c) for c in hipr_prop_res.centroid.values]
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
dists, inds = nbrs.kneighbors(spot_coords)
# Get cluster id of neighbors, merge repeat clusters for a given spot
labels = hipr_prop_res.label.values
labels_nn = [labels[i] for i in inds]
clusts_nn = [[dict_lab_clust[l] for l in ls] for ls in labels_nn]
clusts_nn_unq = [np.unique(cl) for cl in clusts_nn]
# Get counts for clusters 
clusts_nn_unq_flat = [cl for cln in clusts_nn_unq for cl in cln]
cl_nn_unq, cl_nn_counts = np.unique(clusts_nn_unq_flat, return_counts=True)
# Write to array
dict_cl_nncounts = dict(zip(cl_nn_unq, cl_nn_counts))
meas_counts_arr = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_unq):
    try:
        meas_counts_arr[i] = dict_cl_nncounts[cl]
    except:
        pass

Simulate random spots and count nearest neighbor values

In [None]:
n = 1000

cell_coords_tup = hipr_prop_res.loc[:, "centroid"].values
cell_coords = np.array([list(c) for c in cell_coords_tup])
# cell_coords = np.array([list(eval(c)) for c in cell_coords_tup])
pix_coords = np.argwhere(mask_cell)

dict_cl_dists_sim = defaultdict(list)
sim_counts_arr = np.zeros((n, len(clusters_unq)))
for i in tqdm(range(n)):
    # Randomize spot locations
    i_sim = np.random.randint(0, pix_coords.shape[0], size=len(spot_coords))
    sim_spot_coords = pix_coords[i_sim]
    # Get nearest neighbors for each spot
    dists_sim, inds_sim = nbrs.kneighbors(sim_spot_coords)
    # Get cluster id of neighbors, merge repeated clusters ids for a given spot
    labels_nn_sim = [labels[i] for i in inds_sim]
    clusts_nn_sim = [[dict_lab_clust[l] for l in ls] for ls in labels_nn_sim]
    clusts_nn_unq_sim = [np.unique(cl) for cl in clusts_nn_sim]
    # Get counts for clusters
    clusts_nn_unq_flat_sim = [cl for cln in clusts_nn_unq_sim for cl in cln]
    cl_nn_unq_sim, cl_nn_counts_sim = np.unique(clusts_nn_unq_flat_sim, return_counts=True)
    # Add counts to array
    dict_cl_nncounts_sim = dict(zip(cl_nn_unq_sim, cl_nn_counts_sim))
    for j, cl in enumerate(clusters_unq):
        try:
            sim_counts_arr[i, j] = dict_cl_nncounts_sim[cl]
        except:
            pass    


    

Plot Z scores 

In [None]:
mu = np.mean(sim_counts_arr, axis=0).squeeze()
sig = np.std(sim_counts_arr, axis=0).squeeze()
sim_z = (sim_counts_arr - mu) / sig
meas_z = (meas_counts_arr - mu) / sig

In [None]:
# Plot z score number of spots associated with group
dims=[5,2]
# dims = [2.5, 1]
xlab_rotation = 45
pval_rotation = 60
marker = "."
marker_size = 10
text_dist = 0.1
ft = 12
# ft = 7
ylimadj = 0.1
true_frac_llim = 0
line_col = "k"
box_line_col = (0.5, 0.5, 0.5)
box_col = "w"
yticklength = 2

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
# Plot simulation
boxplot = ax.boxplot(
    sim_z,
    patch_artist=True,
    showfliers=False,
    boxprops=dict(facecolor=box_col, color=box_line_col),
    capprops=dict(color=box_line_col),
    whiskerprops=dict(color=box_line_col),
    medianprops=dict(color=box_line_col),
)
# for m in boxplot['medians']:
#     m.set_color(line_col)
# for b in boxplot['boxes']:
#     b.set_edgecolor(line_col)
#     b.set_facecolor(box_col)

# Plot measured value
ys = []
xlab = []
x = 1
for i, cl in enumerate(clusters_unq):
    # for i, bc_tax in zip(ind_order, barcodes_int_order):
    # sci_name = dict_bc_sciname[bc_tax]
    # sci_name = dict_bc_sciname[cl]
    xlab.append(cl)
    # try:
    #     color = col_dict[sci_name]
    # except:
    #     continue
    color = dict_bc_col[cl]
    true_frac = meas_z[i]
    # true_frac = true_count / n_cells
    _ = ax.plot(x, true_frac, marker=marker, ms=marker_size, color=color)
    # # Plot p value
    # sim_vals_i = sim_counts_arr[:, i]
    # # sim_vals = sim_arr[:,i,h] / n_cells
    # sim_mean = np.mean(sim_vals)
    # if true_frac > sim_mean:
    #     # number of simulations with value greater than observed
    #     r_ = sum(sim_vals_i > true_frac)
    # else:
    #     # number of simulations with value less than observed
    #     r_ = sum(sim_vals_i < true_frac)
    # # P value
    # p_ = r_ / n
    # # Get text location
    # q1, q3 = np.quantile(sim_vals, [0.25, 0.75])
    # q4 = q3 + 1.5 * (q3 - q1)
    # # y_m = np.max(sim_vals)
    # # y = y_m if y_m > true_frac else true_frac
    # y = q4 if q4 > true_frac else true_frac
    # y += text_dist
    # ys.append(y)
    # if true_frac < true_frac_llim:
    #     t = ''
    # elif (p_ > 0.05):
    #     t = ''
    # elif (p_ > 0.001) and (p_ <= 0.05):
    #     t = str("p=" + str(p_))
    # else:
    #     t = str("p<0.001")
    # _ = ax.text(x, y, t, fontsize=ft, ha='left',va='bottom', rotation=pval_rotation, rotation_mode='anchor',
    #         color=line_col)
    x += 1
# ax.set_xticklabels([], rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
ax.set_xticklabels(xlab, rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
# ax.tick_params(axis='x',direction='out')
# ax.set_xticks([])
# ax.tick_params(axis="y", length=yticklength)
# ax.set_yticks(ticks=[-10,0,10,20], labels=[])
ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")

# ylims = ax.get_ylim()
# ax.set_ylim(ylims[0], np.max(ys) + ylimadj)

# mge_assoc_dir = output_dir + "/mge_association"
# if not os.path.exists(mge_assoc_dir):
#     os.makedirs(mge_assoc_dir)
#     print("Made dir:", mge_assoc_dir)
# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_zscore_association_0_5um'
# ip.save_png_pdf(out_bn)

In [None]:
clusters_toplot = [4,10,11,13,14,15]

Plot spot fraction 

In [None]:
meas_spotfrac_arr = meas_counts_arr / len(spot_coords)

ft=12
# ft=6
line_col = 'k'
width=0.4
dims=[5,2]
# dims=[2.5,0.6]
yticklength=2

sci_name_order = clusters_unq
# sci_name_order = [dict_bc_sciname[bc] for bc in barcodes_int_order]
color_order = [dict_bc_col[sc] for sc in sci_name_order]

fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(meas_spotfrac_arr.shape[0]),
        meas_spotfrac_arr,
        width=width,
        color=color_order,
        edgecolor=line_col
        )

ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.set_xticks([])
# ax.set_yticks(ticks=[0,0.2,0.4], labels=[])
ax.tick_params(axis='y', length=yticklength)

# mge_assoc_dir = output_dir + '/mge_association'
# if not os.path.exists(mge_assoc_dir): 
#     os.makedirs(mge_assoc_dir)
#     print('Made dir:',mge_assoc_dir)
# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_bar_seg_nn_frac_spot_association_0_5um'
# ip.save_png_pdf(out_bn)

Plot fraction cluster assoc with spot

In [None]:
# Get cluster counts without repeateing the same label
labels_nn_flat = [l for ls in labels_nn for l in ls]
labels_nn_unq = np.unique(labels_nn_flat)
clusts_labels_nn_unq = [dict_lab_clust[l] for l in labels_nn_unq]
cl_nn_unq2, cl_nn_counts2 = np.unique(clusts_labels_nn_unq, return_counts=True)
dict_cl_nncounts2 = dict(zip(cl_nn_unq2, cl_nn_counts2))
meas_counts_arr2 = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_unq):
    try:
        meas_counts_arr2[i] = dict_cl_nncounts2[cl]
    except:
        pass
print(meas_counts_arr2)

In [None]:
# Get counts for each cluster
cl_all, cl_counts = np.unique(list(dict_lab_clust.values()), return_counts=True)
dict_cl_counts_all = dict(zip(cl_all, cl_counts))
all_counts_arr = np.zeros(len(clusters_unq))
for i, cl in enumerate(clusters_unq):
    try:
        all_counts_arr[i] = dict_cl_counts_all[cl]
    except:
        pass
print(all_counts_arr)

In [None]:
# Frac taxon assoc with spot
cl_fracnearspots = meas_counts_arr2 / all_counts_arr

dims=[5,2]
# dims=[2.5,0.6]
yticklength=2
ft=12
# ft=6
line_col = 'k'
width=0.4


fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
ax.bar(
        np.arange(cl_fracnearspots.shape[0]),
        cl_fracnearspots,
        width=width,
        color=color_order,
        edgecolor=line_col
        )
ax.set_xticks([])
# ax.set_yticks(ticks=[0,0.3,0.6], labels=[])
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(axis='y', length=yticklength)

# mge_assoc_dir = output_dir + '/mge_association'
# if not os.path.exists(mge_assoc_dir): 
#     os.makedirs(mge_assoc_dir)
#     print('Made dir:',mge_assoc_dir)
# out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_bar_seg_nn_frac_cell_association_0_5um'
# ip.save_png_pdf(out_bn)

Plot spectra

In [None]:
for cl in clusters_toplot:
    print('Cluster:',cl)
    specs_arr = np.array(dict_cl_spec[cl])

    fig, ax = ip.general_plot(dims=spec_dims, col='w')
    fsi.plot_cell_spectra(ax, specs_arr, {'lw':1,'alpha':0.2,'color':dict_bc_col[cl]})

    # ax.set_ylim(0,2000)
    ylim = ax.get_ylim()

    xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
    for x in xs:
        ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)

    plt.show()
    plt.close()

Plot clusters on overlay

Plot assoc clusters on overlay