# Measure association between MGE spots and spectral groups 
## Setup

In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import re
from collections import defaultdict
import aicspylibczi as aplc
from skimage.registration import phase_cross_correlation
from tqdm import tqdm
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
import umap
import math
from sklearn.cluster import AgglomerativeClustering
from cv2 import resize, INTER_NEAREST
sfrom skimage.restoration import richardson_lucy



In [None]:
cluster = ''
workdir = '/workdir/bmg224/manuscripts/mgefish/code/bmg_plasmids_imaging/agglomerative_clustering'
os.chdir(cluster + workdir)
os.getcwd()

In [None]:
config_fn = 'config_matrix_classify.yaml' # relative path to config file from workdir

with open(config_fn, 'r') as f:
    config = yaml.safe_load(f)

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append(config['pipeline_path'] + '/' + config['functions_path'])
import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi



In [None]:
# Get filenames from directories
raw_dir = config["data_dir"] + "/*[y0-9].czi"
# raw_dir = config["data_dir"] + "/*" + config["laser_regex"]
fns = glob.glob(raw_dir)
fns_base = [os.path.split(f)[1] for f in fns]
group_names = [re.sub("_2023\w+.czi", "", s) for s in fns_base]
# print('HERE-->', shifts_fns)
group_names = np.sort(np.unique(group_names))
m_size = group_names.shape[0]
dict_group_czifns_all = defaultdict(list)
for g in group_names:
    for s in fns:
        if g in s:
            dict_group_czifns_all[g].append(s)
dict_group_czifns_all = {g: sorted(s) for g, s in dict_group_czifns_all.items()}
dict_group_czifns_all

In [None]:
dict_group_czifns = {k: [v[0], v[1], v[2]] for k, v in dict_group_czifns_all.items()}
dict_group_czifns

In [None]:
ecoli_raw_dir = config["data_dir"] + "/*ing.czi"

ecoli_fns = glob.glob(ecoli_raw_dir)
ecoli_fns


## Get PSF
Load e coli image

In [None]:
czi_ecoli = aplc.CziFile(ecoli_fns[1])
czi_ecoli.get_dims_shape()

In [None]:
def reshape_aics_image(m_img):
    '''
    Given an AICS image with just XY and CHannel,
    REshape into shape (X,Y,C)
    '''
    img = np.squeeze(m_img)
    img = np.transpose(img, (1,2,0))
    return img

In [None]:
nMs_e = 4
raws_ec = []
for m_s in range(nMs_e):
    im, sh = czi_ecoli.read_image(M=m_s)
    im = reshape_aics_image(im)
    raws_ec.append(im)
[r.shape for r in raws_ec]

Threshold

In [None]:
# Plot spots
subpl=(2,2)
clims=[(50,500)]*nMs_e
imin=20

ims = []
for m_h in range(nMs_e):
    ims.append(raws_ec[m_h][:,:,1])
ip.subplot_square_images(ims, subpl, clims=clims, im_inches=imin)

In [None]:
# Get mask
ims_mask = [sf.get_background_mask(
    im,
    bg_smoothing=0,
    n_clust_bg=4,
    top_n_clust_bg=1,
    bg_threshold=200
    ) for im in ims]

# check mask
plot_ims = [im_pre*im_mask for im_pre, im_mask in zip(ims, ims_mask)]
fig, axs, cbars = ip.subplot_square_images(plot_ims, subpl, clims=clims, im_inches=imin)
for ax, im_mask in zip(fig.axes, ims_mask):
    ax.imshow(np.dstack([0.5*(~im_mask)]*4))

Get peaks

In [None]:
plms = []
for m in range(nMs_e):
    im = ims[m]
    mask = ims_mask[m]
    plms.append(sf.peak_local_max(im*mask, min_distance=1, indices=True))
[len(plm) for plm in plms]

In [None]:
plm_int = []
for m in range(nMs_e):
    im = ims[m]
    plm = plms[m]
    plm_int += [im[p[0],p[1]] for p in plm]
plm_int_sort = np.sort(plm_int)
fig, ax = ip.general_plot(dims=(10,5))
ax.scatter(np.arange(len(plm_int_sort)), plm_int_sort)

In [None]:
rng = 18
plm_lines = []
plm_areas = []
for m in range(nMs_e):
    im = ims[m]
    plm = plms[m]
    for p in plm:
        if (p[0]-rng > 0) and (p[0]+rng < im.shape[0]):
            if (p[1]-rng > 0) and (p[1]+rng < im.shape[1]):
                plm_lines.append(im[p[0]-rng:p[0]+rng+1, p[1]])
                plm_areas.append(im[p[0]-rng:p[0]+rng+1, p[1]-rng:p[1]+rng+1])
plm_lines = np.vstack(plm_lines)
plm_areas = np.dstack(plm_areas)
plm_lines.shape

Filter peaks

In [None]:
filt_rad = 14
filt_edgeint = 110

fig, ax = ip.general_plot(dims=(10,5))
fsi.plot_cell_spectra(ax, plm_lines, {'lw':1,'alpha':0.1,'color':'r'})
ax.plot([rng+1-filt_rad]*2,[0,500],'k')
ax.plot([rng+1+filt_rad]*2,[0,500],'k')
ax.plot([0,rng*2+1],[filt_edgeint]*2,'k')

plt.show()
plt.close()

In [None]:
plm_ar_line = plm_areas[rng,:].T
plm_ar_line.shape

In [None]:

fig, ax = ip.general_plot(dims=(10,5))
fsi.plot_cell_spectra(ax, plm_ar_line, {'lw':1,'alpha':0.1,'color':'r'})

ax.plot([rng+1-filt_rad]*2,[0,500],'k')
ax.plot([rng+1+filt_rad]*2,[0,500],'k')
ax.plot([0,rng*2+1],[filt_edgeint]*2,'k')

In [None]:
def create_circle_array(size, radius):
    # Create an empty square array
    array = np.zeros((size, size), dtype=bool)

    # Calculate the center of the array
    center = (size - 1) / 2

    # Iterate through each element in the array
    for i in range(size):
        for j in range(size):
            # Calculate the distance from the center to the current point
            distance = np.sqrt((i - center)**2 + (j - center)**2)

            # If the distance is less than or equal to the radius, set the value to True
            if distance <= radius:
                array[i, j] = True

    return array

# Example usage: create a 5x5 array with a circle of radius 2
result = create_circle_array(5, 2)
print(result)

In [None]:
# Get values in a ring around the peaks
bool_circ = ~create_circle_array(2*rng+1, filt_rad)
plt.imshow(bool_circ)
plm_circ_vals = plm_areas[bool_circ,:]
plm_circ_vals.shape

In [None]:
bool_plm = np.max(plm_circ_vals, axis=0) < filt_edgeint
plm_lines_filt = plm_lines[bool_plm,:]
plm_areas_filt = plm_areas[:,:,bool_plm]
print(plm_lines_filt.shape)
print(plm_areas_filt.shape)

In [None]:
fig, ax = ip.general_plot(dims=(10,5))
fsi.plot_cell_spectra(ax, plm_lines_filt, {'lw':1,'alpha':0.1,'color':'r'})

ax.plot([rng+1-filt_rad]*2,[0,500],'k')
ax.plot([rng+1+filt_rad]*2,[0,500],'k')
ax.plot([0,rng*2+1],[filt_edgeint]*2,'k')

Get Psf

In [None]:
plm_lines_mean = np.mean(plm_lines_filt, axis=0)
fig, ax = ip.general_plot(dims=(10,5))
fsi.plot_cell_spectra(ax, plm_lines_filt, {'lw':1,'alpha':0.1,'color':'r'})
fsi.plot_cell_spectra(ax, plm_lines_mean[None,:], {'lw':1,'alpha':1,'color':'k'})

In [None]:
psf = np.mean(plm_areas_filt, axis=2)
mx = np.max(psf)
mn = np.min(psf)
psf_norm = (psf - mn) / (mx - mn) 

ip.plot_image(psf, cmap='inferno', im_inches=10)

Test PSF


In [None]:
im_test = ims[0][:500,:500]
ip.plot_image(im_test, cmap='inferno', im_inches=10)

In [None]:
# ims_norm = []
# for im in ims:
#     mx = np.max(im)
#     mn = np.min(im)
#     ims_norm.append((im - mn) / (mx - mn))
# ims_deconv = []
# for im in tqdm(ims):
#     ims_deconv.append(richardson_lucy(im, psf))

mx = np.max(im_test)
mn = np.min(im_test)
im_test_norm = (im_test - mn) / (mx - mn)

im_test_deconv = richardson_lucy(im_test_norm, psf_norm)


In [None]:
# ims_deconv_edge = [im[100:-100,100:-100] for im in ims_deconv]

# [[np.min(im),np.max(im), np.mean(im), np.std(im)] for im in ims_deconv_edge]
np.min(im_test_deconv),np.max(im_test_deconv), np.mean(im_test_deconv), np.std(im_test_deconv)

In [None]:
clims=['min','max']


ip.plot_image(im_test_deconv, clims=clims, cmap='inferno', im_inches=10)
# ip.plot_image(ims_deconv_edge[0], clims=clims, cmap='inferno',im_inches=imin)
# fig, ax, cbar = ip.subplot_square_images(ims_deconv_edge, subpl, clims=clims, im_inches=imin)

## Load data

Pick an image

In [None]:
sn = "2023_11_22_newplasmid_sample_bmg_fov_04"

In [None]:
output_dir = config['output_dir'] + '/' + sn

stack_dir = output_dir + '/stacks'
props_dir = output_dir + '/props'
segs_dir = output_dir + '/segs'
spec_dir = output_dir + '/spectra'
clust_dir = output_dir + '/clust'
mge_shifts_dir = output_dir + '/mge_shifts'


In [None]:
bn = sn + '_M_{}'
stack_fn = stack_dir + '/' + bn + '_stack.npy'
props_fn = props_dir + '/' + bn + '_props.csv'
seg_fn = segs_dir + '/' + bn + '_seg.npy'
spec_fn = spec_dir + '/' + bn + '_spec.yaml'
clust_fn = clust_dir + '/' + bn + '_clust.yaml'
mge_shift_fn = mge_shifts_dir + '/' + bn + '_mge_shift.npy'


In [None]:
mge_shift_fns = glob.glob(mge_shift_fn.format('*'))
mge_shift_fns

In [None]:
Ms_hipr = [re.findall('(?<=_M_)\d+', f)[0] for f in mge_shift_fns]
print(Ms_hipr)
nMs = len(Ms_hipr)

In [None]:
raws_mge_shift = [np.load(fn) for fn in mge_shift_fns]

In [None]:
raws_mge_shift = [x for _, x in sorted(zip(Ms_hipr, raws_mge_shift))]

## Get MGE spots

In [None]:
# Plot spots
subpl=(2,2)
clims=(50,500)
# clims=[(50,500)]*nMs
imin=30

ims = []
for m_h in range(nMs):
    im = raws_mge_shift[m_h][:,:,1]
    ims.append(im)
    ip.plot_image(im, clims=clims, im_inches=imin)
# ip.subplot_square_images(ims, subpl, clims=clims, im_inches=imin)
# ip.plot_image(im, cmap='inferno', im_inches=imin, clims=(0,750))

In [None]:
# # pre-process
# ims_pre = [sf.pre_process(
#     im,
#     log=False,
#     denoise=0,
#     gauss=0,
#     diff_gauss=(0,)
#     ) for im in ims]
# # check pre-processing
# ip.subplot_square_images(ims_pre, subpl, clims=clims, im_inches=imin)

ims_pre = []
for im in ims:
    im_pre = sf.pre_process(
        im,
        log=False,
        denoise=0,
        gauss=0,
        diff_gauss=(0,)
    )
    ims_pre.append(im_pre)
    ip.plot_image(im_pre, clims=clims, im_inches=imin)

# ip.plot_image(im_pre, cmap='inferno', im_inches=imin, clims=(0,750))


In [None]:
ims_mask = []
for im in ims:
    im_mask = sf.get_background_mask(
        im,
        bg_smoothing=0,
        n_clust_bg=4,
        top_n_clust_bg=1,
        bg_threshold=200
        )
    ims_mask.append(im_mask)
    fig, ax, _ = ip.plot_image(im*im_mask, clims=clims, im_inches=imin)
    ax.imshow(np.dstack([0.5*(~im_mask)]*4))


# # Get mask
# ims_mask = [sf.get_background_mask(
#     im,
#     bg_smoothing=0,
#     n_clust_bg=4,
#     top_n_clust_bg=1,
#     bg_threshold=200
#     ) for im in ims]

# # check mask
# plot_ims = [im_pre*im_mask for im_pre, im_mask in zip(ims_pre, ims_mask)]
# fig, axs, cbars = ip.subplot_square_images(plot_ims, subpl, clims=clims, im_inches=imin)
# for ax, im_mask in zip(fig.axes, ims_mask):
#     ax.imshow(np.dstack([0.5*(~im_mask)]*4))

# # fig, ax, cbar = ip.plot_image((im_pre*im_mask), cmap='inferno', im_inches=imin, clims=(0,750))
# # ax.imshow(np.dstack([0.5*(~im_mask)]*4))
# # segment
# # Check segmentation
# # Save segmentation 

In [None]:
def segment(image, mask):
    seeds = sf.label(sf.peak_local_max(image, min_distance=1, indices=False))
    watershed_input = -image*mask
    seg = sf.watershed(watershed_input, seeds, mask=mask, watershed_line=True)
    return sf.label(seg)

# segment
ims_seg = [segment(
    im_pre, im_mask
    ) for im_pre, im_mask in zip(ims_pre, ims_mask)]

In [None]:
# Check segmentation
segs_rgb = []
for im_seg in ims_seg:
    seg_rgb = ip.seg2rgb(im_seg)
    segs_rgb.append(seg_rgb)
    ip.plot_image(seg_rgb, im_inches=imin)
# segs_zoom_rgb = [ip.seg2rgb(im_seg) for im_seg in ims_seg]
# fig, axs, cbars = ip.subplot_square_images(segs_zoom_rgb, subpl, im_inches=imin)


# ip.plot_image(seg_zoom_rgb, im_inches=imin)
# Save segmentation 

In [None]:
# Get spot properties
props = [sf.measure_regionprops(im_seg, raw=im) for im_seg, im in zip(ims_seg, ims)]
props[0].columns

In [None]:
y = [prop.max_intensity.values for prop in props]
y = np.sort(np.hstack(y))
x = np.arange(y.shape[0])

In [None]:
int_thresh = 100

fig, ax = ip.general_plot(dims=(10,5))
ax.scatter(x,y, s=1)
ax.set_title('Spot Max Intensity (a.u.)')
ax.plot([0,x.shape[0]], [int_thresh]*2, 'k')

In [None]:
y = [prop.area.values for prop in props]
y = np.sort(np.hstack(y))
x = np.arange(y.shape[0])

In [None]:
area_thresh = 200
fig, ax = ip.general_plot(dims=(10,5))
ax.scatter(x,y, s=1)
ax.set_title('Spot Area (pixels)')
ax.plot([0,x.shape[0]], [area_thresh]*2, 'k')

In [None]:
bools_area = [prop.area.values < area_thresh for prop in props]
bools_int = [prop.max_intensity.values > int_thresh for prop in props]
for ba, bi in zip(bools_area, bools_int):
    print(len(ba))
    print(sum(ba*bi))
    print('--')

In [None]:
# Get pixels for random simulation
ims_cell = [raws_mge_shift[m_h][:,:,0] for m_h in range(nMs)]

fig, axs, cbars = ip.subplot_square_images(ims_cell, subpl, clims=[clims]*nMs, im_inches=imin)
# ip.plot_image(im, cmap='inferno', im_inches=imin, clims=(0,2000))


In [None]:
masks_cell = [im_cell > 235 for im_cell in ims_cell]

plot_ims = [im_cell*mask_cell for im_cell, mask_cell in zip(ims_cell, masks_cell)]

fig, axs, cbars = ip.subplot_square_images(plot_ims, subpl, clims=[clims]*nMs, im_inches=imin)
for ax, im_mask in zip(fig.axes, masks_cell):
    ax.imshow(np.dstack([0.5*(~im_mask)]*4))

# fig, ax, cbar = ip.plot_image((im_cell*mask_cell), cmap='inferno', im_inches=imin, clims=(0,2000))
# ax.imshow(np.dstack([0.5*(~mask_cell)]*4))

In [None]:
bools_incell = []
for prop, mask_cell in zip(props, masks_cell):
    bool_incell = []
    for c in prop.centroid.values:
        bool_incell.append(mask_cell[int(c[0]),int(c[1])])
    print(len(bool_incell))
    print(sum(bool_incell))
    print('--')
    bools_incell.append(np.array(bool_incell))

## Measure spatial association with spectra clusters

Get resized hipr properties

In [None]:
def add_edge(hipr_sum_res, edge):
    hsr_shape = hipr_sum_res.shape
    hipr_sum_res_edge = np.zeros(np.array(hsr_shape) + 2*edge)
    hipr_sum_res_edge[edge:edge+hsr_shape[0],edge:edge+hsr_shape[1]] = hipr_sum_res
    return hipr_sum_res_edge

def resize_hipr(im, hipr_res, mega_res, dims='none', out_fn=False, ul_corner=(0,0)):
    # im = np.load(in_fn)
    factor_resize = hipr_res / mega_res
    hipr_resize = resize(
            im,
            None,
            fx = factor_resize,
            fy = factor_resize,
            interpolation = INTER_NEAREST
            )
    if isinstance(dims, str): dims = hipr_resize.shape
    hipr_resize = center_image(hipr_resize, dims, ul_corner)
    # if out_fn: np.save(out_fn, hipr_resize)
    return hipr_resize

def center_image(im, dims, ul_corner):
    shp = im.shape
    if not all([dims[i] == shp[i] for i in range(len(dims))]):
        shp_new = dims if len(shp) == 2 else dims + (shp[2],)
        temp = np.zeros(shp_new)
        br_corner = np.array(ul_corner) + np.array(shp[:2])
        temp[ul_corner[0]:br_corner[0], ul_corner[1]:br_corner[1]] = im
        im = temp
    return im

def norm(im, c=['min','max']):
    mn = np.min(im) if c[0] == 'min' else c[0]
    mx = np.max(im) if c[1] == 'max' else c[1]
    im = np.clip(im, mn, mx)
    return (im - mn) / (mx - mn)

In [None]:
hipr_raw_fn = dict_group_czifns_all[sn][0]
czi_hipr = aplc.CziFile(hipr_raw_fn)

mge_raw_fn = dict_group_czifns_all[sn][4]
czi_mge = aplc.CziFile(mge_raw_fn)


for n in czi_mge.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            res_mge = float(n.text)
print('MGE m/pix',res_mge)

for n in czi_hipr.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            res_hipr = float(n.text)
print('HiPR m/pix',res_hipr)

In [None]:
edge_out_fn = mge_shifts_dir + '/' + bn.format('edgepixels') + '_mge_shift.txt'
edge_out_fn

In [None]:
with open(edge_out_fn, 'r') as f:
    edge = f.read()
edge = int(edge)
edge

In [None]:
stack_fns = glob.glob(stack_fn.format('*'))
stacks = [np.load(fn) for fn in stack_fns]
stacks = [x for _, x in sorted(zip(Ms_hipr, stacks))]
stacks_max = [np.max(s, axis=2) for s in stacks]
stacks_sum = [np.sum(s, axis=2) for s in stacks]

seg_fns = glob.glob(seg_fn.format('*'))
seg = [np.load(fn) for fn in seg_fns]
Ms_hipr_seg = [re.findall('(?<=_M_)\d+', f)[0] for f in seg_fns]
segs = [x for _, x in sorted(zip(Ms_hipr_seg, seg))]

hipr_maxs_res = []
hipr_sums_res = []
hipr_segs_res = []
for mx, sm, seg in zip(stacks_max, stacks_sum, segs):
    hipr_maxs_res.append(resize_hipr(
            mx, res_hipr, res_mge
            ))
    hipr_sums_res.append(resize_hipr(
            sm, res_hipr, res_mge
            ))
    hipr_segs_res.append(resize_hipr(
            seg, res_hipr, res_mge
            ))

In [None]:
len(hipr_sums_res)

In [None]:
hipr_sums_res_edge = []
hipr_segs_res_edge = []
hipr_props_res = []

for m_h in range(nMs):
    hipr_sum_res = hipr_sums_res[m_h]
    hipr_sum_res_edge = add_edge(hipr_sum_res, edge)
    hipr_sums_res_edge.append(hipr_sum_res_edge)

    hipr_seg_res = hipr_segs_res[m_h]
    hipr_seg_res_edge = add_edge(hipr_seg_res, edge)
    hipr_seg_res_edge = hipr_seg_res_edge.astype(int)
    hipr_segs_res_edge.append(hipr_seg_res_edge)

    hipr_prop_res = sf.measure_regionprops(hipr_seg_res_edge, raw=hipr_sum_res_edge)
    hipr_props_res.append(hipr_prop_res)
hipr_props_res[0].columns

In [None]:
for m_h in range(nMs):
    hipr_seg_res_edge = hipr_segs_res_edge[m_h]
    hipr_seg_res_edge_rgb = ip.seg2rgb(hipr_seg_res_edge)
    fig, ax, _ = ip.plot_image(hipr_seg_res_edge_rgb, im_inches=imin)
    raw_mge_shift_spot = raws_mge_shift[m_h][:,:,1]
    raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, (0,250))
    raw_mge_shift_spot_rgb = np.dstack([raw_mge_shift_spot_norm]*4)
    ax.imshow(raw_mge_shift_spot_rgb)

Load spectral clusters

In [None]:
# dict_clusters_unq = {}
# dicts_lab_clust = {}
# for m_h in range(nMs):
#     # Load spectral clusters
#     with open(clust_fn.format(m_h), 'r') as f:
#         dicts_lab_clust[m_h] = yaml.unsafe_load(f)
#     dict_clusters_unq[m_h] = np.unique(list(dicts_lab_clust[m_h].values()))

In [None]:
# classif svc
classif_dir = config['output_dir'] + '/classif_svc'
classif_fn = classif_dir + '/dict_sn_m_label_classif.yaml'

with open(classif_fn, 'r') as f:
    dict_sn_m_lab_cl = yaml.unsafe_load(f)

In [None]:
dicts_lab_clust = dict_sn_m_lab_cl[sn]

Create a dictionary with cell coords for each cluster


In [None]:
dicts_clust_coords = []
for m_h in range(nMs):
    # dict_lab_clust = dicts_lab_clust[m_h]
    dict_lab_clust = dicts_lab_clust[str(m_h)]
    hipr_prop_res = hipr_props_res[m_h]
    # Create a dictionary with cell coords for each cluster
    dict_clust_coords = defaultdict(list)
    for l, c in hipr_prop_res[['label','centroid']].values:
        cl = dict_lab_clust[l]
        c = eval(c) if isinstance(c, str) else c
        dict_clust_coords[cl].append(list(c))
    dicts_clust_coords.append(dict_clust_coords)

In [None]:
spots_coords = []
for m_h in range(nMs):
    prop = props[m_h]
    bool_area = bools_area[m_h]
    bool_incell = bools_incell[m_h]
    # bool_int = bools_int[m_h]
    # Get spot coordinates
    # spot_coords = prop.centroid.values[bool_area]

    
    spot_coords = prop.centroid.values
    # spot_coords = prop.centroid.values[bool_area * bool_incell]
    # spot_coords = prop.centroid.values[bool_area * bool_incell * bool_int]
    print(spot_coords.shape)
    spot_coords = [list(s) for s in spot_coords]
    spots_coords.append(spots_coords)

Get dictionary of cluster nearest neighbor distances

In [None]:
# Get dictionary of cluster nearest neighbor distances
n_neighbors=1
dicts_cl_dists = defaultdict(dict)
for m_h in range(nMs):
    dict_clust_coords = dicts_clust_coords[m_h]
    for cl in dict_clusters_unq[m_h]:
        reseg_coords = dict_clust_coords[cl]
        nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(reseg_coords)
        dists, _ = nbrs.kneighbors(spot_coords)
        dicts_cl_dists[m_h][cl] = dists

simulate random spots

In [None]:
dicts_clust_lab = defaultdict(lambda: defaultdict(list))
for m_h in range(nMs):
    dict_lab_clust = dicts_lab_clust[m_h]
    # dict_lab_clust = dicts_lab_clust[str(m_h)]
    for lab, cl in dict_lab_clust.items():
        dicts_clust_lab[m_h][cl].append(lab)

In [None]:
dicts_lab_coord = {}
for m_h in range(nMs):
    hipr_prop_res = hipr_props_res[m_h]
    l = hipr_prop_res.label.values
    c = [list(coord) for coord in hipr_prop_res.centroid.values]
    # c = [list(eval(coord)) for coord in hipr_prop_res.centroid.values]
    dicts_lab_coord[m_h] = dict(zip(l, c))

In [None]:
# simulate random spots
n=1000

dicts_cl_dists_sim = defaultdict(lambda: defaultdict(list))
for m_h in range(nMs):
    hipr_prop_res = hipr_props_res[m_h]
    mask_cell = masks_cell[m_h]
    dict_clust_lab = dicts_clust_lab[m_h]
    dict_lab_coord = dicts_lab_coord[m_h]
    
    cell_coords_tup = hipr_prop_res.loc[:,'centroid'].values
    cell_coords = np.array([list(c) for c in cell_coords_tup])
    # cell_coords = np.array([list(eval(c)) for c in cell_coords_tup])
    pix_coords = np.argwhere(mask_cell)

    # clusters_unq = np.unique(list(dict_clust_coords.keys()))
    dict_cl_dists_sim = defaultdict(list)
    for i in tqdm(range(n)):
        # Randomize spot locations
        i_sim = np.random.randint(
                0, pix_coords.shape[0], size=len(spot_coords)
                )
        sim_spot_coords = pix_coords[i_sim]
        # # Randomize cell labels
        # c_sim = np.random.randint(
        #         0, hipr_reseg_props.shape[0], size=hipr_reseg_props.shape[0]
        #         )
        # bc_sim = np.array([x for _, x in sorted(zip(c_sim, cell_bc))])
        for cl in dict_clusters_unq[m_h]:
            # Get cell coords for taxon
            labels = dict_clust_lab[cl]
            # bool_bc = np.array(cell_bc) == cl
            # bool_bc = bc_sim == bc
            # tax_centroid = cell_coords[bool_bc,:]
            tax_centroid = [dict_lab_coord[l] for l in labels]
            # Get nearest neighbor cell distance for each spot
            nbrs = NearestNeighbors(n_neighbors=1).fit(tax_centroid)
            dists, _ = nbrs.kneighbors(sim_spot_coords)
            dicts_cl_dists_sim[m_h][cl].append(dists)
            


Measure associations

In [None]:
# Get fraction of spots associated in measured and simulation
r_um = 0.5
res_mge_umpix = res_mge * 10**6

dict_meas_vals = defaultdict(list)
dict_sim_vals = defaultdict(list)
dict_meas_frac = {}
dict_sim_frac = {}

for m_h in range(nMs):
    for cl in dict_clusters_unq[m_h]:
        # Get simulated fraction within radius of cell
        sim_dists = dicts_cl_dists_sim[m_h][cl]
        sim_dists_um = np.array(sim_dists) * res_mge_umpix
        bool_sim_rad = sim_dists_um < r_um
        sim_rad_counts = np.sum(bool_sim_rad, axis=1)
        sim_rad_frac = sim_rad_counts
        dict_sim_vals[m_h].append(sim_rad_frac)
        # Get measured fraction
        dists_um = dicts_cl_dists[m_h][cl] * res_mge_umpix
        dists_um.shape
        bool_rad = dists_um < r_um
        rad_counts = np.sum(bool_rad)
        rad_frac = rad_counts
        dict_meas_vals[m_h].append(rad_frac)

    dict_sim_vals[m_h] = np.array(dict_sim_vals[m_h])[:,:,0]
    dict_sim_frac[m_h] = dict_sim_vals[m_h] / len(spot_coords)
    dict_meas_vals[m_h] = np.array(dict_meas_vals[m_h])
    dict_meas_frac[m_h] = dict_meas_vals[m_h] / len(spot_coords)



Get color dict

In [None]:
# # Count all barcodes
# dict_cl_counts = defaultdict(int)
# for sn, dmlc in dict_sn_m_lab_cl.items():
#     for m, dlc in dmlc.items():
#         clusts, counts = np.unique(list(dlc.values()), return_counts=True)
#         for cl, cnt in zip(clusts,counts):
#             dict_cl_counts[cl] += cnt
# dict_cl_counts

In [None]:
# # Sort barcodes
# bcs = list(dict_cl_counts.keys())
# counts = list(dict_cl_counts.values())
# barcodes_countsort = [bc for _, bc in sorted(zip(counts, bcs))]

In [None]:
# Rearrange colors
col_list = list(plt.get_cmap('tab20').colors)
col_1 = [col_list[i] for i in np.arange(0,20,2)]
del col_1[7]
col_2 = [col_list[i] for i in np.arange(1,20,2)]
del col_2[7]
col_list_re = col_1 + col_2 + [(1,1,0), (0,1,0)]
barlist = plt.bar(np.arange(20), np.ones(20))
for b,c in zip(barlist, col_list_re):
    b.set_color(c)



In [None]:
# # Get scinames dict
# probe_design_fn = config['probe_design_dir'] + '/' + config['probe_design_filename']
# probe_design = pd.read_csv(probe_design_fn)
# barcodes_pd = probe_design['code'].unique()
# sci_names_pd = [probe_design.loc[probe_design['code'] == bc,'sci_name'].unique()[0] 
#             for bc in barcodes_pd]
# barcodes_pdstr = [str(bc).zfill(5) for bc in barcodes_pd]
# dict_bc_sciname = dict(zip(barcodes_pdstr, sci_names_pd))


In [None]:
# # Make dict and plot 
# dict_bc_col = dict(zip(barcodes_countsort, col_list_re))
# sciname_countsort = [dict_bc_sciname[bc] for bc in barcodes_countsort]
# ip.taxon_legend(sciname_countsort, col_list_re)

Plot z-score values

In [None]:
# Get z-scores
dict_sim_z = {}
dict_meas_z = {}
for m_h in range(nMs):
    mu = np.mean(dict_sim_vals[m_h], axis=1)
    sig = np.std(dict_sim_vals[m_h], axis=1)
    dict_sim_z[m_h] = (dict_sim_vals[m_h] - mu[:,None]) / sig[:,None]
    dict_meas_z[m_h] = (dict_meas_vals[m_h] - mu) / sig

In [None]:
# Plot z score number of spots associated with group
dims=[2.5,1]
xlab_rotation=45
pval_rotation=60
marker='.'
marker_size=10
text_dist=0.1
ft=7
ylimadj = 0.1
true_frac_llim = 0
line_col = 'k'
box_line_col = (0.5,0.5,0.5)
box_col = 'w'
yticklength=2

for m_h in range(nMs):
    fig, ax = ip.general_plot(dims=dims, ft=ft, col=line_col)
    # Plot simulation
    boxplot = ax.boxplot(
            dict_sim_z[m_h].T, patch_artist=True, showfliers=False,
            boxprops=dict(facecolor=box_col, color=box_line_col),
            capprops=dict(color=box_line_col),
            whiskerprops=dict(color=box_line_col),
            medianprops=dict(color=box_line_col),
        )
    # for m in boxplot['medians']:
    #     m.set_color(line_col)
    # for b in boxplot['boxes']:
    #     b.set_edgecolor(line_col)
    #     b.set_facecolor(box_col)

    # Plot measured value
    ys = []
    xlab = []
    x = 1
    # clusters_unq = np.unique(list(dicts_clust_coords[m_h].keys()))
    for i, cl in enumerate(dict_clusters_unq[m_h]):
    # for i, bc_tax in zip(ind_order, barcodes_int_order):
        # sci_name = dict_bc_sciname[cl]
        xlab.append(sci_name)
        # try:
        #     color = col_dict[sci_name]
        # except:
        #     continue
        # color = dict_bc_col[cl]
        color = col_list_re[i]
        true_frac = dict_meas_z[m_h][i]
        # true_frac = true_count / n_cells
        _ = ax.plot(x, true_frac, marker=marker, ms=marker_size, color=color)
        # Plot p value
        sim_vals_i = dict_sim_vals[m_h][i,:]
        # sim_vals = sim_arr[:,i,h] / n_cells
        sim_mean = np.mean(dict_sim_vals[m_h])
        if true_frac > sim_mean:
            # number of simulations with value greater than observed
            r_ = sum(sim_vals_i > true_frac)
        else:
            # number of simulations with value less than observed
            r_ = sum(sim_vals_i < true_frac)
        # P value
        p_ = r_ / n
        # Get text location
        q1,q3 = np.quantile(dict_sim_vals[m_h], [0.25,0.75])
        q4 = q3 + 1.5 * (q3 - q1)
        # y_m = np.max(sim_vals)
        # y = y_m if y_m > true_frac else true_frac
        y = q4 if q4 > true_frac else true_frac
        y += text_dist
        ys.append(y)
        # if true_frac < true_frac_llim:
        #     t = ''
        # elif (p_ > 0.05):
        #     t = ''
        # elif (p_ > 0.001) and (p_ <= 0.05):
        #     t = str("p=" + str(p_))
        # else:
        #     t = str("p<0.001")
        # _ = ax.text(x, y, t, fontsize=ft, ha='left',va='bottom', rotation=pval_rotation, rotation_mode='anchor',
        #         color=line_col)
        x+=1
    # ax.set_xticklabels([], rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
    # ax.set_xticklabels(xlab, rotation=xlab_rotation, ha='right', va='top', rotation_mode='anchor')
    # ax.tick_params(axis='x',direction='out')
    ax.set_xticks([])
    ax.tick_params(axis='y', length=yticklength)
    # ax.set_yticks(ticks=[-10,0,10,20], labels=[])
    ax.spines['top'].set_color('none')
    ax.spines['bottom'].set_color('none')
    ax.spines['right'].set_color('none')

    # ylims = ax.get_ylim()
    # ax.set_ylim(ylims[0], np.max(ys) + ylimadj)
    # mge_assoc_dir = output_dir + '/mge_association'
    # if not os.path.exists(mge_assoc_dir): 
    #     os.makedirs(mge_assoc_dir)
    #     print('Made dir:',mge_assoc_dir)

    # out_bn = mge_assoc_dir + '/' + bn.format(m_h) + '_seg_nn_zscore_association_0_5um'
    # ip.save_png_pdf(out_bn)



## Overaly spots on classif

In [None]:
raws_mge_shift_spot_rgb = {}
for m_h in range(nMs):
    raw_mge_shift_spot = raws_mge_shift[m_h][:,:,1]

    raw_mge_shift_spot_norm = norm(raw_mge_shift_spot, (50,250))
    raw_mge_shift_spot_rgb = np.dstack([
        raw_mge_shift_spot_norm,
        np.zeros_like(raw_mge_shift_spot_norm),
        raw_mge_shift_spot_norm,
        raw_mge_shift_spot_norm
    ])
    raws_mge_shift_spot_rgb[m_h] = raw_mge_shift_spot_rgb

In [None]:
hipr_segs_res_clustrgb = {}
for m_h in range(nMs):
    hipr_seg_res_edge = hipr_segs_res_edge[m_h]
    hipr_prop_res = hipr_props_res[m_h]
    
    # clusters_unq = np.unique(list(dicts_clust_coords[m_h].keys()))

    hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
    clusters_toplot = [3,5]
    for cl in dict_clusters_unq[m_h]:
        labels_sub = dicts_clust_lab[m_h][cl]
        for i, row in hipr_prop_res.iterrows():
            l = row.label
            if l in labels_sub:
                b = row.bbox
                b = eval(b) if isinstance(b, str) else b
                rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
                seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
                # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
                # color = np.array([0,0,1,0.5])
                color = np.array(dict_bc_col[cl] + (1,))
                # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
                rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
                # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
                rgb_cell = seg_sub[:,:,None] * color[None,:]
                hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell
    hipr_segs_res_clustrgb[m_h] = hipr_seg_res_clustrgb

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

for m_h in range(nMs):
    fig, ax, cbar = ip.plot_image(ims_cell[m_h], cmap='gray', im_inches=imin, clims=(0,2000))
    ax.imshow(hipr_segs_res_clustrgb[m_h])
    ax.imshow(raws_mge_shift_spot_rgb[m_h])

# plt.figure(fig)
# out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_3_4_overlay'
# ip.save_png_pdf(out_bn)

Show only prevotella

In [None]:
dict_sciname_bc = dict(zip(list(dict_bc_sciname.values()),list(dict_bc_sciname.keys())))
dict_sciname_bc

In [None]:
target_genera = ['Prevotella']

hipr_segs_res_clustrgb = {}
for m_h in range(nMs):
    hipr_seg_res_edge = hipr_segs_res_edge[m_h]
    hipr_prop_res = hipr_props_res[m_h]
    
    target_clusters = [dict_sciname_bc[sc] for sc in target_genera]
    # clusters_unq = np.unique(list(dicts_clust_coords[m_h].keys()))

    hipr_seg_res_clustrgb = np.zeros(hipr_seg_res_edge.shape + (4,))
    clusters_toplot = [3,5]
    for cl in target_clusters:
        labels_sub = dicts_clust_lab[m_h][cl]
        for i, row in hipr_prop_res.iterrows():
            l = row.label
            if l in labels_sub:
                b = row.bbox
                b = eval(b) if isinstance(b, str) else b
                rgb_sub = hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3]]
                seg_sub = hipr_seg_res_edge[b[0]:b[2],b[1]:b[3]] == l
                # im_sub = im_norm[b[0]:b[2],b[1]:b[3]]
                # color = np.array([0,0,1,0.5])
                color = np.array(dict_bc_col[cl] + (1,))
                # rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*3 + np.ones_like(seg_sub))
                rgb_noncell =  rgb_sub * np.dstack([~seg_sub]*4)
                # rgb_cell = (im_sub * seg_sub)[:,:,None] * color[None,:]
                rgb_cell = seg_sub[:,:,None] * color[None,:]
                hipr_seg_res_clustrgb[b[0]:b[2],b[1]:b[3],:] = rgb_noncell + rgb_cell
    hipr_segs_res_clustrgb[m_h] = hipr_seg_res_clustrgb

In [None]:
mge_overlay_dir = output_dir + '/mge_overlay'
if not os.path.exists(mge_overlay_dir): 
    os.makedirs(mge_overlay_dir)
    print('Made dir:',mge_overlay_dir)

for m_h in range(nMs):
    fig, ax, cbar = ip.plot_image(ims_cell[m_h], cmap='gray', im_inches=imin, clims=(0,2000))
    ax.imshow(hipr_segs_res_clustrgb[m_h])
    ax.imshow(raws_mge_shift_spot_rgb[m_h])

# plt.figure(fig)
# out_bn = mge_overlay_dir + '/' + bn.format(m_h) + '_clusts_3_4_overlay'
# ip.save_png_pdf(out_bn)