In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
from tqdm import tqdm
from time import time
import matplotlib.pyplot as plt
import yaml
import re
import aicspylibczi as aplc
from sklearn.neighbors import NearestNeighbors
import javabridge
import bioformats
import xml.etree.ElementTree as ET
from collections import defaultdict
import aicspylibczi as aplc
from scipy import stats
from cv2 import resize, INTER_CUBIC, INTER_NEAREST

In [None]:
cluster = ""
workdir = "/workdir/bmg224/harvard_dental/manuscript/code"
os.chdir(cluster + workdir)
os.getcwd()

In [None]:
%load_ext autoreload
%autoreload 2
functions_path = '/workdir/bmg224/manuscripts/mgefish/code/functions'

sys.path.append(cluster + functions_path)

import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi

In [None]:
dirs = [
    "../../imaging/2022_12_16_harvardwelch/data",
    "../../imaging/2023_02_08_hsdm/data",
    "../../imaging/2023_02_18_hsdm/data",
    "/fs/cbsuvlaminck2/workdir/Data/bmg224/2023/brc_imaging/2023_10_16_hsdm",
]

In [None]:
dict_dir_group_czi = {}
for d in dirs:
    fns = glob.glob(d + "/*.czi")
    fns_base = [os.path.split(f)[1] for f in fns]
    group_names = [
        re.sub("(?<=fov_\d{2})[a-zA-Z0-9\-\_]+.czi", "", s) for s in fns_base
    ]
    # print('HERE-->', shifts_fns)
    group_names = np.sort(np.unique(group_names))
    m_size = group_names.shape[0]
    dict_group_czifns_all = defaultdict(list)
    for g in group_names:
        for s in fns:
            if g in s:
                dict_group_czifns_all[g].append(s)
    dict_dir_group_czi[d] = {g: sorted(s) for g, s in dict_group_czifns_all.items()}
    print(dict_dir_group_czi[d])

In [None]:
n_images = 0
for d, dict_group_czi in dict_dir_group_czi.items():
    for gr, _ in dict_group_czi.items():
        n_images += 1
n_images

In [None]:
dict_dir_group_czi_filt = defaultdict(dict)
for d, dict_group_czi in dict_dir_group_czi.items():
    for gr, czi_fns in dict_group_czi.items():
        czi_fns_filt = []
        for fn in czi_fns:
            if ("stitch" not in fn) and ("Stitch" not in fn):
                czi_fns_filt.append(fn)
        dict_dir_group_czi_filt[d][gr] = czi_fns_filt

In [None]:
def replace_outlier_shifts(sh_i):
    q = np.quantile(sh_i, [0.25, 0.5, 0.75], axis=0)
    iqr = q[2] - q[0]
    ol_plus = q[2] + 1.5 * iqr
    ol_minus = q[0] - 1.5 * iqr
    shifts_red = []
    inds_replace = []
    for k, s in enumerate(sh_i):
        bool_std = any(s > ol_plus) or (any(s < ol_minus))
        bool_z = all(s == np.array([0, 0]))
        if bool_std or bool_z:
            inds_replace.append(k)
        else:
            shifts_red.append(s)
    if inds_replace:
        sh_mean = np.median(shifts_red, axis=0).astype(int)
        for k in inds_replace:
            # print('Replaced', sh_i[k,:], 'with', sh_mean)
            sh_i[k, :] = sh_mean
    return sh_i

In [36]:
dict_dir_date = {
    "../../imaging/2022_12_16_harvardwelch/data": "2022_12_16",
    "../../imaging/2023_02_08_hsdm/data": "2023_02_08",
    "../../imaging/2023_02_18_hsdm/data": "2023_02_18",
    "/workdir/Data/bmg224/2023/brc_imaging/2023_10_16_hsdm": "2023_10_16",
}

In [39]:
max_shift = 500
gauss = 3
diff_gauss = (2, 3)
bg_smoothing = 5
n_clust_bg = 4
top_n_clust_bg = 3
imin=2
dpi=500

out_dir = "../outputs/segmentation_2024_03_07"
out_fmt_seg = out_dir + "/{d}/{sn}/segs/{sn}_M_{m}"
out_fmt_plot = out_dir + "/{d}/{sn}/plots/{sn}_M_{m}"
seg_fmt = out_fmt_seg + "_seg.npy"
props_fmt = out_fmt_seg + "_props.csv"
plot_fmt = out_fmt_plot + "_seg_plot.png"
rgb_fmt = out_fmt_plot + "_rgb_plot.png"

for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]
    for j, (sn, czi_fns) in enumerate(dict_group_czi.items()):
        # if (i == 0) and (j == 0):
        # if i == n:
        print(sn)
        # Get number of tiles or scenes
        M, mtype = fsi.get_ntiles(czi_fns[0])
        # Get the resolutions
        resolutions = [fsi.get_resolution(fn) for fn in czi_fns]
        # Get the lasers
        lasers = [fsi.get_laser(fn) for fn in czi_fns]
        # Get shifts
        shifts = []
        for m in range(M):
            raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
            raws = [fsi.reshape_aics_image(r) for r in raws]
            # some images have different pixel resolution, correct that
            raws = fsi.match_resolutions_and_size(raws, resolutions)
            image_max_norm = [fsi.max_norm(r) for r in raws]
            sh = fsi._get_shift_vectors(image_max_norm)
            # print(sh)
            shifts.append(sh)
        # Some of the shifts are clearly wrong, fix those
        sh_arr = np.array(shifts)
        for k in range(1, len(lasers)):
            sh_i = sh_arr[:, k, :]
            print("Shifts", lasers[k], ":")
            print(sh_i)
            # address large deviatinos from typical
            sh_arr[:, k, :] = replace_outlier_shifts(sh_i)
        # Now shift the raw images
        for m in range(M):
            # if m == 1:
            raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
            raws = [fsi.reshape_aics_image(r) for r in raws]
            # some images have different pixel resolution, correct that
            raws = fsi.match_resolutions_and_size(raws, resolutions)
            raws_shift = fsi._shift_images(
                raws, sh_arr[m, :, :], max_shift=max_shift
            )
            stack = np.dstack(raws_shift)
            stack_sum = np.sum(stack, axis=2)
            pre = sf.pre_process(stack_sum, gauss=gauss, diff_gauss=diff_gauss)
            mask = sf.get_background_mask(
                stack_sum,
                bg_smoothing=bg_smoothing,
                n_clust_bg=n_clust_bg,
                top_n_clust_bg=top_n_clust_bg,
            )
            seg = sf.segment(pre, mask)
            props = sf.measure_regionprops(seg, stack_sum)
            spec = fsi.get_cell_average_spectra(seg, stack)
            props = props.merge(
                pd.DataFrame(spec), left_index=True, right_index=True
            )

            seg_fn = seg_fmt.format(d=date, sn=sn, m=m)
            props_fn = props_fmt.format(d=date, sn=sn, m=m)
            plot_fn = plot_fmt.format(d=date, sn=sn, m=m)
            rgb_fn = rgb_fmt.format(d=date, sn=sn, m=m)

            for f in [seg_fn, props_fn, plot_fn, rgb_fn]:
                odir = os.path.split(f)[0]
                if not os.path.exists(odir):
                    os.makedirs(odir)
                    print("Made dir:", odir)

            np.save(seg_fn, seg)
            print("Wrote:", seg_fn)
            props.to_csv(props_fn, index=False)
            print("Wrote:", props_fn)

            # ip.plot_image(stack_sum, cmap="inferno", im_inches=10)
            fig, ax, _ = ip.plot_image(ip.seg2rgb(seg), im_inches=imin)
            plt.figure(fig)
            ip.save_fig(plot_fn, dpi=dpi, bbox_inches=0)
            plt.close()
            print("Wrote:", plot_fn)

            rgb = np.dstack([fsi.max_norm(r, type='sum') for r in raws_shift])
            rgb = rgb[:,:,:3]
            fig, ax, _ = ip.plot_image(rgb, im_inches=imin)
            plt.figure(fig)
            ip.save_fig(rgb_fn, dpi=dpi, bbox_inches=0)
            plt.close()
            print("Wrote:", rgb_fn)

            # ip.plot_image(
            #     stack_sum[900:1100, 900:1100], cmap="inferno", im_inches=10
            # )
            # ip.plot_image(ip.seg2rgb(seg)[900:1100, 900:1100, :], im_inches=10)

2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01
0 0 0
1 4 -7
2 4 -7
Wrote: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01/segs/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01_M_1_seg.npy
Wrote: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01/segs/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01_M_1_props.csv
Wrote: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01/plots/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01_M_1_seg_plot.png
Wrote: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01/plots/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01_M_1_rgb_plot.png
