In [169]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
from tqdm import tqdm
from time import time
import matplotlib.pyplot as plt
import yaml
import re
import aicspylibczi as aplc
from sklearn.neighbors import NearestNeighbors
import javabridge
import bioformats
import xml.etree.ElementTree as ET
from collections import defaultdict
import aicspylibczi as aplc
from scipy import stats
from cv2 import resize, INTER_CUBIC, INTER_NEAREST
from scipy.spatial.distance import squareform, pdist, cdist
from sklearn.cluster import AgglomerativeClustering
import umap

ModuleNotFoundError: No module named 'm2stitch'

In [None]:
cluster = ""
workdir = "/workdir/bmg224/harvard_dental/manuscript/code"
os.chdir(cluster + workdir)
os.getcwd()

In [None]:
%load_ext autoreload
%autoreload 2
functions_path = '/workdir/bmg224/manuscripts/mgefish/code/functions'

sys.path.append(cluster + functions_path)

import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi

In [None]:
dirs = [
    "../../imaging/2022_12_16_harvardwelch/data",
    "../../imaging/2023_02_08_hsdm/data",
    "../../imaging/2023_02_18_hsdm/data",
    "/fs/cbsuvlaminck2/workdir/Data/bmg224/2023/brc_imaging/2023_10_16_hsdm",
]

In [None]:
dict_dir_group_czi = {}
for d in dirs:
    fns = glob.glob(d + "/*.czi")
    fns_base = [os.path.split(f)[1] for f in fns]
    group_names = [
        re.sub("(?<=fov_\d{2})[a-zA-Z0-9\-\_]+.czi", "", s) for s in fns_base
    ]
    # print('HERE-->', shifts_fns)
    group_names = np.sort(np.unique(group_names))
    m_size = group_names.shape[0]
    dict_group_czifns_all = defaultdict(list)
    for g in group_names:
        for s in fns:
            if g in s:
                dict_group_czifns_all[g].append(s)
    dict_dir_group_czi[d] = {g: sorted(s) for g, s in dict_group_czifns_all.items()}
    print(dict_dir_group_czi[d])

In [None]:
n_images = 0
for d, dict_group_czi in dict_dir_group_czi.items():
    for gr, _ in dict_group_czi.items():
        n_images += 1
n_images

In [None]:
dict_dir_group_czi_filt = defaultdict(dict)
for d, dict_group_czi in dict_dir_group_czi.items():
    for gr, czi_fns in dict_group_czi.items():
        czi_fns_filt = []
        for fn in czi_fns:
            if ("stitch" not in fn) and ("Stitch" not in fn):
                czi_fns_filt.append(fn)
        dict_dir_group_czi_filt[d][gr] = czi_fns_filt

In [54]:
def replace_outlier_shifts(sh_i):
    if len(sh_i) > 2:
        q = np.quantile(sh_i, [0.25, 0.5, 0.75], axis=0)
        iqr = q[2] - q[0]
        ol_plus = q[2] + 1.5 * iqr
        ol_minus = q[0] - 1.5 * iqr
        shifts_red = []
        inds_replace = []
        for k, s in enumerate(sh_i):
            bool_std = any(s > ol_plus) or (any(s < ol_minus))
            bool_z = all(s == np.array([0, 0]))
            if bool_std or bool_z:
                inds_replace.append(k)
            else:
                shifts_red.append(s)
        if inds_replace:
            sh_mean = np.median(shifts_red, axis=0).astype(int)
            for k in inds_replace:
                # print('Replaced', sh_i[k,:], 'with', sh_mean)
                sh_i[k, :] = sh_mean
    return sh_i

In [44]:
dict_dir_date = {
    "../../imaging/2022_12_16_harvardwelch/data": "2022_12_16",
    "../../imaging/2023_02_08_hsdm/data": "2023_02_08",
    "../../imaging/2023_02_18_hsdm/data": "2023_02_18",
    "/workdir/Data/bmg224/2023/brc_imaging/2023_10_16_hsdm": "2023_10_16",
}

In [63]:
max_shift = 500
gauss = 3
diff_gauss = (2, 3)
bg_smoothing = 5
n_clust_bg = 10
top_n_clust_bg = 9
imin=2
dpi=500

out_dir = "../outputs/segmentation_2024_03_07"
out_fmt_seg = out_dir + "/{d}/{sn}/segs/{sn}_M_{m}"
out_fmt_plot = out_dir + "/{d}/{sn}/plots/{sn}_M_{m}"
seg_fmt = out_fmt_seg + "_seg.npy"
props_fmt = out_fmt_seg + "_props.csv"
plot_fmt = out_fmt_plot + "_seg_plot.png"
rgb_fmt = out_fmt_plot + "_rgb_plot.png"
ncells = 0

for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]
    for j, (sn, czi_fns) in enumerate(dict_group_czi.items()):
        # if (i == 0) and (j == 0):
        # if j == 0:
        # if sn == '2022_12_16_harvardwelch_patient_1_tooth_31_aspect_ML_depth_supra_fov_02':
        print(sn)
        # Get number of tiles or scenes
        M, mtype = fsi.get_ntiles(czi_fns[0])
        # Check if the segs have already been produced
        segs_done = []
        for m in range(M):
            rgb_fn = rgb_fmt.format(d=date, sn=sn, m=m)
            sd = True if os.path.exists(rgb_fn) else False
            segs_done.append(sd)
        if not all(segs_done):
            # Get the resolutions
            resolutions = [fsi.get_resolution(fn) for fn in czi_fns]
            # Get the lasers
            lasers = [fsi.get_laser(fn) for fn in czi_fns]
            # Get shifts
            shifts = []
            czi_fns = [fn for fn, l in zip(czi_fns, lasers) if l != 405]
            resolutions = [r for r, l in zip(resolutions, lasers) if l != 405]
            lasers = [l for l in lasers if l != 405]
            for m in range(M):
                raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
                raws = [fsi.reshape_aics_image(r) for r in raws]
                # some images have different pixel resolution, correct that
                raws = fsi.match_resolutions_and_size(raws, resolutions)
                image_max_norm = [fsi.max_norm(r) for r in raws]
                sh = fsi._get_shift_vectors(image_max_norm)
                # print(sh)
                shifts.append(sh)
            # Some of the shifts are clearly wrong, fix those
            sh_arr = np.array(shifts)
            for k in range(1, len(lasers)):
                sh_i = sh_arr[:, k, :]
                # print("Shifts", lasers[k], ":")
                # print(sh_i)
                # address large deviatinos from typical
                sh_arr[:, k, :] = replace_outlier_shifts(sh_i)
            # Now shift the raw images
            for m in range(M):
                rgb_fn = rgb_fmt.format(d=date, sn=sn, m=m)
                if not os.path.exists(rgb_fn):
                    raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
                    raws = [fsi.reshape_aics_image(r) for r in raws]
                    if all([r.shape[0] > 0 for r in raws]):
                        # some images have different pixel resolution, correct that
                        raws = fsi.match_resolutions_and_size(raws, resolutions)
                        raws_shift = fsi._shift_images(
                            raws, sh_arr[m, :, :], max_shift=max_shift
                        )
                        stack = np.dstack(raws_shift)
                        stack_sum = np.sum(stack, axis=2)
                        pre = sf.pre_process(stack_sum, gauss=gauss, diff_gauss=diff_gauss)
                        mask = sf.get_background_mask(
                            stack_sum,
                            bg_smoothing=bg_smoothing,
                            n_clust_bg=n_clust_bg,
                            top_n_clust_bg=top_n_clust_bg,
                        )
                        seg = sf.segment(pre, mask)
                        props = sf.measure_regionprops(seg, stack_sum)
                        spec = fsi.get_cell_average_spectra(seg, stack)
                        props = props.merge(
                            pd.DataFrame(spec), left_index=True, right_index=True
                        )
                        ncells += props.shape[0]

                        seg_fn = seg_fmt.format(d=date, sn=sn, m=m)
                        props_fn = props_fmt.format(d=date, sn=sn, m=m)
                        plot_fn = plot_fmt.format(d=date, sn=sn, m=m)
                        rgb_fn = rgb_fmt.format(d=date, sn=sn, m=m)

                        for f in [seg_fn, props_fn, plot_fn, rgb_fn]:
                            odir = os.path.split(f)[0]
                            if not os.path.exists(odir):
                                os.makedirs(odir)
                                print("Made dir:", odir)

                        np.save(seg_fn, seg)
                        print("Wrote:", seg_fn)
                        props.to_csv(props_fn, index=False)
                        print("Wrote:", props_fn)

                        # ip.plot_image(stack_sum, cmap="inferno", im_inches=10)
                        fig, ax, _ = ip.plot_image(ip.seg2rgb(seg), im_inches=imin)
                        plt.figure(fig)
                        ip.save_fig(plot_fn, dpi=dpi, bbox_inches=0)
                        plt.close()
                        print("Wrote:", plot_fn)

                        rgb = np.dstack([fsi.max_norm(r, type='sum') for r in raws_shift])
                        rgb = rgb[:,:,:3]
                        fig, ax, _ = ip.plot_image(rgb, im_inches=imin)
                        plt.figure(fig)
                        ip.save_fig(rgb_fn, dpi=dpi, bbox_inches=0)
                        plt.close()
                        print("Wrote:", rgb_fn)

                        # ip.plot_image(
                        #     stack_sum[900:1100, 900:1100], cmap="inferno", im_inches=10
                        # )
                        # ip.plot_image(ip.seg2rgb(seg)[900:1100, 900:1100, :], im_inches=10)

2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01
2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_02
2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_03
2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01
2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_02
2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_03
2022_12_16_harvardwelch_patient_18_tooth_2_aspect_MB_depth__fov_01
2022_12_16_harvardwelch_patient_19_tooth_15_aspect_MF_depth_sub_fov_01
2022_12_16_harvardwelch_patient_19_tooth_30_aspect_MB_depth_sub_fov_01
2022_12_16_harvardwelch_patient_1_tooth_31_aspect_ML_depth_supra_fov_01
2022_12_16_harvardwelch_patient_1_tooth_31_aspect_ML_depth_supra_fov_02
2022_12_16_harvardwelch_patient_1_tooth_31_aspect_ML_depth_supra_fov_03
2022_12_16_harvardwelch_patient_9_tooth_15_aspect_MB_depth_supra_fov_01
2022_12_16_harvardwelch_patient_9_tooth_15_aspect_MB_depth_supra_fov_02
20

: 

In [93]:
ncells

984005

## Cluster spectra by image

In [92]:
n_clust = 25
spec_dims = (10,5)
umap_dims=(5,5)
s_u=1
alpha_u=0.25
cmap_u='Spectral'
colors = plt.get_cmap(cmap_u)(np.linspace(0,1,n_clust))
dict_clust_col = dict(zip(np.arange(n_clust), colors))

out_fmt_clust = out_dir + "/{d}/{sn}/cluster/{sn}"
umap_fmt = out_fmt_clust + "_umap.png"
spec_fmt = out_fmt_clust + "_spec_clust_{cl}.png"
clust_fmt = out_fmt_clust + "_clusters.npy"
mlab_fmt = out_fmt_clust + "_dict_index_mlab.yaml"


for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]
    for j, (sn, _) in enumerate(dict_group_czi.items()):
        if (i == 0) and (j == 0):
            # Setup dirs 
            umap_fn = umap_fmt.format(d=date, sn=sn)
            clust_fn = clust_fmt.format(d=date, sn=sn)
            mlab_fn = mlab_fmt.format(d=date, sn=sn)
            spec_fmt2 = spec_fmt.format(d=date, sn=sn, cl='{cl}')
            for fn in [umap_fn, clust_fn, mlab_fn, spec_fmt2]:
                d = os.path.split(fn)[0]
                if not os.path.exists(d):
                    os.makedirs(d)
                    print('Made dir:', d)

            # Get props filenames
            props_glob = props_fmt.format(d=date, sn=sn, m='*')
            props_fns = sorted(glob.glob(props_glob))
            # Load props
            spec_arr = []
            ind_plus = 0
            dict_ind_mlab = defaultdict(dict)
            for fn in props_fns[:2]:
                m = int(re.search('(?<=_M_)\d+',fn)[0])
                p = pd.read_csv(fn)
                # construct spectra array
                cols_spec = np.arange(57).astype(str).tolist()
                spec_arr.append(p[cols_spec].values)
                # make a dictionary to map spec array index to tile and cell label
                for ind, mlab in enumerate(p.label):
                    dict_ind_mlab[ind+ind_plus] = {'m':m, 'mlab':mlab}
                ind_plus += p.shape[0]
            with open(mlab_fn, 'w') as f:
                yaml.dump(dict_ind_mlab, f, default_flow_style=False)

            # Get spectra distances
            spec_arr = np.vstack(spec_arr)
            dist_mat_cond = pdist(spec_arr, fhc.channel_cosine_intensity_allonev2)
            # Make squareform
            dist_mat = squareform(dist_mat_cond)
            # Agglomerative cluster
            agg = AgglomerativeClustering(n_clusters=n_clust, affinity='precomputed', linkage='complete')
            agg.fit(dist_mat)
            clust_agg = agg.labels_
            np.save(clust_fn, clust_agg)

            # Plot umap
            fig, ax = ip.general_plot(dims=umap_dims, col='w')
            fit = umap.UMAP(metric='precomputed', n_neighbors=100, min_dist=0.1).fit(dist_mat)
            u = fit.embedding_
            ax.scatter(u[:,0], u[:,1], c=clust_agg, s=s_u, alpha=alpha_u, cmap=cmap_u)
            ax.set_aspect('equal')
            ip.save_fig(umap_fn)
            plt.close()

            # Plot spectra 
            for c in np.unique(clust_agg):
                # print('Cluster:', c)
                bool_c = clust_agg == c
                spec_sub = spec_arr[bool_c,:]

                # spec_sub_meansub = spec_sub - specs_med
                # spec_sub_meansub[spec_sub_meansub < 0] = 0

                fig, ax = ip.general_plot(dims=spec_dims, col='w')
                color = dict_clust_col[c]
                # fsi.plot_cell_spectra(ax, spec_sub, {'lw':2,'alpha':0.2,'color':'r'})
                fsi.plot_cell_spectra(ax, spec_sub, {'lw':2,'alpha':0.2,'color':color})

                ylim = ax.get_ylim()

                xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
                for x in xs:
                    ax.plot([x,x], ylim, color=(0.5,0.5,0.5), lw=0.5)
                
                ip.save_fig(spec_fmt2.format(cl=c))
                plt.close()



Made dir: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_10_tooth_8_aspect_MB_depth_supra_fov_01/cluster


  warn("using precomputed metric; inverse_transform will be unavailable")


Cluster: 0
Cluster: 1
Cluster: 2
Cluster: 3
Cluster: 4
Cluster: 5
Cluster: 6
Cluster: 7
Cluster: 8
Cluster: 9
Cluster: 10
Cluster: 11
Cluster: 12
Cluster: 13
Cluster: 14
Cluster: 15
Cluster: 16
Cluster: 17
Cluster: 18
Cluster: 19
Cluster: 20
Cluster: 21
Cluster: 22
Cluster: 23
Cluster: 24


## Classify clusters

In [95]:
bmg224_dir = '../../..'
dict_date_pdfn = {
    '2022_12_16': [
        bmg224_dir 
        + '/manuscripts/mgefish/data/HiPRFISH_probe_design' 
        + '/welch2016_5b_no_633_channel.csv',
        '5bit_no633'
    ],
    '2023_02_08': [
        bmg224_dir 
        + '/manuscripts/mgefish/data/HiPRFISH_probe_design' 
        + '/welch2016_5b_no_633_channel.csv',
        '5bit_no633'
    ],
    '2023_02_18': [
        bmg224_dir 
        + '/manuscripts/mgefish/data/HiPRFISH_probe_design'
        + '/welch2016_5b_no_633_channel.csv',
        '5bit_no633'
    ],
    '2023_10_16': [
        bmg224_dir 
        + '/harvard_dental/pick_distant_barcodes/2023_08_01_order'
        + '/welch2016_7b_distant_v2.csv',
        '7bit_no405'
    ],
    '2023_10_18': [
        bmg224_dir 
        + '/harvard_dental/pick_distant_barcodes/2023_08_01_order'
        + '/welch2016_7b_distant_v2.csv',
        '7bit_no405'
    ],
}

In [96]:
ref_dir = (
    bmg224_dir 
    + "/manuscripts/mgefish/data/unused/fig_5/HiPRFISH_reference_spectra"
)

In [142]:
out_fmt_classif = out_dir + "/{d}/{sn}/classif"
spec_classif_fmt = out_fmt_classif + "/spectra_plots/{sn}_spec_clust_{cl}.png"
classif_fmt = out_fmt_classif + "/{sn}_dict_cluster_barcode.yaml"

for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]

    # Load reference spectra
    pdfn, bc_type = dict_date_pdfn[date]
    probe_design = pd.read_csv(pdfn)
    barcodes = probe_design['code']
    sci_names = probe_design['sci_name']
    dict_bc_sciname = dict(zip(barcodes, sci_names))
    barcodes = np.unique(barcodes)
    sci_names = [dict_bc_sciname[bc] for bc in barcodes]
    ref_spec = fhc.get_reference_spectra(barcodes, bc_type, ref_dir)

    # Get nn classifier reference dict
    # dict_refind_bc = {}  # Map reference array index to bc
    # ind_start = 0
    # for r, bc in zip(ref_spec, barcodes):
    #     nref = len(r)
    #     inds = np.arange(ind_start, ind_start + nref)
    #     for i in inds:
    #         dict_refind_bc[i] = bc
    #     ind_start += nref

    # Get classifier
    ref_spec_med = [np.median(r, axis=0) for r in ref_spec]
    ref_spec_arr = np.vstack(ref_spec_med)
    ref_spec_dist_mat_cond = pdist(
        ref_spec_arr, fhc.channel_cosine_intensity_allonev2
    )
    ref_spec_dist_mat = squareform(ref_spec_dist_mat_cond)
    nbrs = NearestNeighbors(n_neighbors=1, metric='precomputed')
    nn_ref = nbrs.fit(ref_spec_dist_mat)

    # iterate through images
    for j, (sn, _) in enumerate(dict_group_czi.items()):
        if (i == 0) and (j == 3):
            print(sn)
            # Setup dirs 
            classif_fn = classif_fmt.format(d=date, sn=sn)
            spec_classif_fmt2 = spec_classif_fmt.format(d=date, sn=sn, cl='{cl}')
            for fn in [classif_fn, spec_classif_fmt2]:
                d = os.path.split(fn)[0]
                if not os.path.exists(d):
                    os.makedirs(d)
                    print('Made dir:', d)
            # umap_fn = umap_fmt.format(d=date, sn=sn)
            clust_fn = clust_fmt.format(d=date, sn=sn)
            mlab_fn = mlab_fmt.format(d=date, sn=sn)
            # spec_fmt2 = spec_fmt.format(d=date, sn=sn, cl='{cl}')

            # Get props filenames
            props_glob = props_fmt.format(d=date, sn=sn, m='*')
            props_fns = sorted(glob.glob(props_glob))
            # Load props
            spec_arr = []
            for fn in props_fns:
                p = pd.read_csv(fn)
                # construct spectra array
                cols_spec = np.arange(57).astype(str).tolist()
                spec_arr.append(p[cols_spec].values)
            spec_arr = np.vstack(spec_arr)

            # Load clustering
            clust_agg = np.load(clust_fn)
            cl_unq = np.unique(clust_agg)

            # Iterate through clusters and classify
            dict_cl_bc = {}
            for cl in cl_unq:
                spec = spec_arr[clust_agg == cl]
                spec_mean = np.median(spec, axis=0)[None, :]
                # if cl == 3:
                #     s = spec[np.random.choice(len(spec), size=1000, replace=False)]
                #     fig, ax = ip.general_plot(dims=spec_dims, col='w')
                #     fsi.plot_cell_spectra(ax, s, {"lw": 0.2, "alpha": 0.1, "color": "r"})                    
                #     fsi.plot_cell_spectra(ax, spec_mean, {"lw": 1, "alpha": 1, "color": "w"})                    
                #     ylim = ax.get_ylim()

                #     xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
                #     for x in xs:
                #         ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5)
                # Get distances
                cl_dist_matrix = cdist(
                    spec_mean, 
                    ref_spec_arr, 
                    metric=fhc.channel_cosine_intensity_allonev2
                )
                # Classify each spectrum
                nn_dists, nn_inds = nn_ref.kneighbors(cl_dist_matrix)
                rs_ind = nn_inds[0][0]
                bc = barcodes[rs_ind]
                sciname = sci_names[rs_ind]
                dict_cl_bc[cl] = bc

                # print('Cluster:',cl, 'Barcode:', bc, 'Sciname:', sciname)
                # bc = dict_refind_bc[nn_inds[0]]

                # Plot classif vs mean cluster spec
                fig, ax = ip.general_plot(dims=spec_dims, col='w')
                rs = ref_spec_med[rs_ind][None,:]
                rs = rs / np.sum(rs)
                sm = (spec_mean / np.sum(spec_mean))
                fsi.plot_cell_spectra(ax, rs, {"lw": 1, "alpha": 1, "color": "w"})                    
                fsi.plot_cell_spectra(ax, sm, {"lw": 1, "alpha": 1, "color": "r"})                    
                ylim = ax.get_ylim()
                xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
                for x in xs:
                    ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5)
                ax.set_title(sciname + ': ' + str(bc), color='w')
                spec_classif_fn = spec_classif_fmt2.format(cl=cl)
                ip.save_fig(spec_classif_fn)
                plt.close()

            # Write classif dict
            with open(classif_fn, 'w') as f:
                yaml.dump(dict_cl_bc, f, default_flow_style=False)


            # spec_dims = (10, 5)

            # for s, name, bc in zip(ref_spec, sci_names, barcodes):
            #     fig, ax = ip.general_plot(dims=spec_dims, col='w')
            #     fsi.plot_cell_spectra(ax, s, {"lw": 0.2, "alpha": 0.1, "color": "r"})
            #     ax.set_title(name + ': ' + str(bc), color='w')

            #     ylim = ax.get_ylim()

            #     xs = [3, 7, 10, 24, 27, 29, 31, 33, 43, 47]
            #     for x in xs:
            #         ax.plot([x, x], ylim, color=(0.5, 0.5, 0.5), lw=0.5)



2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01
Made dir: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01/classif
Made dir: ../outputs/segmentation_2024_03_07/2022_12_16/2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01/classif/spectra_plots


## Tile stitching (failed)

In [172]:
out_fmt_register = out_dir + "/{date}/{sn}/registration"
reg_fmt = out_fmt_register + '/{sn}_registered_{m}.npy'
tile_info_fmt = out_fmt_register + '/{sn}_tile_info.csv'

for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]

    for j, (sn, czi_fns) in enumerate(dict_group_czi.items()):
        if (i == 0) and (j == 3):
            print(sn)
            M, mtype = fsi.get_ntiles(czi_fns[0])
            # Get the resolutions
            resolutions = [fsi.get_resolution(fn) for fn in czi_fns]
            # Get the lasers
            lasers = [fsi.get_laser(fn) for fn in czi_fns]
            # Get shifts
            shifts = []
            czi_fns = [fn for fn, l in zip(czi_fns, lasers) if l != 405]
            resolutions = [r for r, l in zip(resolutions, lasers) if l != 405]
            lasers = [l for l in lasers if l != 405]
            for m in range(M):
                raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
                raws = [fsi.reshape_aics_image(r) for r in raws]
                # some images have different pixel resolution, correct that
                raws = fsi.match_resolutions_and_size(raws, resolutions)
                image_max_norm = [fsi.max_norm(r) for r in raws]
                sh = fsi._get_shift_vectors(image_max_norm)
                # print(sh)
                shifts.append(sh)
            # Some of the shifts are clearly wrong, fix those
            sh_arr = np.array(shifts)
            for k in range(1, len(lasers)):
                sh_i = sh_arr[:, k, :]
                # print("Shifts", lasers[k], ":")
                # print(sh_i)
                # address large deviatinos from typical
                sh_arr[:, k, :] = replace_outlier_shifts(sh_i)
            # Now shift the raw images
            stack_sums = []
            for m in range(M):
                rgb_fn = rgb_fmt.format(d=date, sn=sn, m=m)
                raws = [fsi.load_raw(fn, m, mtype) for fn in czi_fns]
                raws = [fsi.reshape_aics_image(r) for r in raws]
                # some images have different pixel resolution, correct that
                raws = fsi.match_resolutions_and_size(raws, resolutions)
                raws_shift = fsi._shift_images(
                    raws, sh_arr[m, :, :], max_shift=max_shift
                )
                stack = np.dstack(raws_shift)
                reg_fn = reg_fmt.format(date=date, sn=sn, m=m)
                d = os.path.split(reg_fn)[0]
                if not os.path.exists(d):
                    os.makedirs(d)
                    print('Made dir:', d)
                np.save(reg_fn, stack)

            # Save tile info
            ncols = int(fsi.get_metadata_value(czi_fns[0], 'TilesX')[0])
            nrows = int(fsi.get_metadata_value(czi_fns[0], 'TilesY')[0])
            # cols = np.tile(np.arange(ncols), nrows)
            # rows = np.repeat(np.arange(nrows), ncols)
            tile_info_fn = tile_info_fmt.format(date=date, sn=sn)
            d = os.path.split(tile_info_fn)[0]
            if not os.path.exists(d):
                os.makedirs(d)
                print('Made dir:', d)
            pd.DataFrame({'rows':[nrows],'cols':[ncols]}).to_csv(tile_info_fn, index=False)



            



2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01


In [167]:
ncols = int(fsi.get_metadata_value(czi_fns[0], 'TilesX')[0])
nrows = int(fsi.get_metadata_value(czi_fns[0], 'TilesY')[0])
cols = np.tile(np.arange(ncols), nrows)
rows = np.repeat(np.arange(nrows), ncols)
result_df, _ = m2stitch.stitch_images(stack_sums, rows, cols, row_col_transpose=False)

print(result_df["y_pos"])
# the absolute y (second last dim.) positions of the tiles
print(result_df["x_pos"])
# the absolute x (last dim.) positions of the tiles

# stitching example
result_df["y_pos2"] = result_df["y_pos"] - result_df["y_pos"].min()
result_df["x_pos2"] = result_df["x_pos"] - result_df["x_pos"].min()

size_y = stack_sums.shape[1]
size_x = stack_sums.shape[2]

stitched_image_size = (
    result_df["y_pos2"].max() + size_y,
    result_df["x_pos2"].max() + size_x,
)
stitched_image = np.zeros_like(stack_sums, shape=stitched_image_size)
for i, row in result_df.iterrows():
    stitched_image[
        row["y_pos2"] : row["y_pos2"] + size_y,
        row["x_pos2"] : row["x_pos2"] + size_x,
    ] = stack_sums[i]


NameError: name 'm2stitch' is not defined

In [168]:
print(fsi.get_metadata_value(czi_fns[0], 'TileAcquisitionOverlap'))
print(fsi.get_metadata_value(czi_fns[0], 'TilesX'))
print(fsi.get_metadata_value(czi_fns[0], 'TilesY'))

['0.050000000000000003']
['4']
['4']


## Adjacency matrix

Plot colors

In [230]:
sciname_list = [
    'Corynebacterium',
    'Actinomyces',
    'Rothia',
    'Capnocytophaga',
    'Prevotella',
    'Porphyromonas',
    'Streptococcus',
    'Gemella',
    'Veillonella',
    'Selenomonas',
    'Lautropia',
    'Neisseriaceae',
    'Pasteurellaceae',
    'Campylobacter',
    'Fusobacterium',
    'Leptotrichia',
    'Treponema',
    'TM7'
]

colors = plt.get_cmap('tab20').colors
# colors = [c + (1,) for c in colors]

dict_sciname_color = dict(zip(sciname_list, colors))
dict_sciname_color['Neisseria'] = dict_sciname_color['Neisseriaceae']
dict_sciname_color['Saccharibacteria'] = dict_sciname_color['TM7']


Cell coordinates and identity

In [241]:
plot_dir = out_dir + "/{d}/{sn}/plots"
classif_plot_fmt = out_fmt_plot + '_classif.png'
classif_plot_all_fmt = plot_dir + '/{sn}_classif_alltiles.png'
centroid_sciname_fmt = out_fmt_classif + '/{sn}_centroid_sciname.csv'

dpi=500
im_inches=2
radius_um = 5

for i, (d, dict_group_czi) in enumerate(dict_dir_group_czi_filt.items()):
    date = dict_dir_date[d]
    # Get bc sciname dict
    pdfn, bc_type = dict_date_pdfn[date]
    probe_design = pd.read_csv(pdfn)
    barcodes = probe_design['code']
    sci_names = probe_design['sci_name']
    dict_bc_sciname = dict(zip(barcodes, sci_names))
    for j, (sn, czi_fns) in enumerate(dict_group_czi.items()):
        if (i == 0) and (j == 3):
            print(sn)
            # filenames
            mlab_fn = mlab_fmt.format(d=date, sn=sn)
            clust_fn = clust_fmt.format(d=date, sn=sn)
            classif_fn = classif_fmt.format(d=date, sn=sn)

            # props
            prop_glob = props_fmt.format(d=date, sn=sn, m='*')
            prop_fns = glob.glob(prop_glob)
            Ms = [int(re.search('(?<=_M_)\d+',fn)[0]) for fn in props_fns]
            prop_fns = [x for _, x in sorted(zip(Ms, prop_fns))]

            # registered images
            reg_glob = reg_fmt.format(date=date, sn=sn, m='*')
            reg_fns = glob.glob(reg_glob)
            Ms = [int(re.search('(?<=_registered_)\d+',fn)[0]) for fn in reg_fns]
            reg_fns = [x for _, x in sorted(zip(Ms, reg_fns))]

            # segs
            seg_glob = seg_fmt.format(d=date, sn=sn, m='*')
            seg_fns = glob.glob(seg_glob)
            Ms = [int(re.search('(?<=_M_)\d+',fn)[0]) for fn in seg_fns]
            seg_fns = [x for _, x in sorted(zip(Ms, seg_fns))]

            # Load mlab dict
            with open(mlab_fn, 'r') as f:
                dict_m_lab_ind = yaml.unsafe_load(f)
            # Load clusters
            clust_agg = np.load(clust_fn)
            cl_unq = np.unique(clust_agg)
            # Load classif
            with open(classif_fn, 'r') as f:
                dict_cl_bc = yaml.unsafe_load(f)

            # Get upper left corners for each tile
            ncols = int(fsi.get_metadata_value(czi_fns[0], 'TilesX')[0])
            nrows = int(fsi.get_metadata_value(czi_fns[0], 'TilesY')[0])
            overl = float(fsi.get_metadata_value(
                czi_fns[0], 'TileAcquisitionOverlap'
            )[0])
            cols = np.tile(np.arange(ncols), nrows)
            rows = np.repeat(np.arange(nrows), ncols)
            M, mtype = fsi.get_ntiles(czi_fns[0])
            seg = np.load(seg_fns[0])
            shp = np.array(seg.shape[:2])
            im_r, im_c = shp - shp * overl
            ul_corners = []
            for m in range(M):
                c, r = cols[m], rows[m]
                ulc = np.array([r * im_r, c * im_c])
                ul_corners.append(ulc)

            # Plot classified tiled image
            res_umpix = fsi.get_resolution(czi_fns[0])
            tile_shp = (
                (nrows - 1)*int(im_r) + shp[0], 
                (ncols - 1)*int(im_c) + shp[1]
            )
            classif_tile = np.zeros((tile_shp[0], tile_shp[1],len(colors[0])))
            sum_tile = np.zeros((tile_shp[0], tile_shp[1]))
            # Get cell absolute locations
            dict_ind_centroid_sciname = {}
            ul_corners = np.array(ul_corners)
            r_lims = np.unique(ul_corners[:,0])
            c_lims = np.unique(ul_corners[:,1])
            for m in range(M):
                # Get image corner values
                c, r = cols[m], rows[m]
                ulc = ul_corners[m]
                # Get limits to remove cells from overlap
                r_lim, c_lim = [1e15]*2
                if r < np.max(rows):
                    r_lim = r_lims[r+1]
                if c < np.max(cols):
                    c_lim = c_lims[c+1]

                # Adjust cell locations 
                prop = pd.read_csv(prop_fns[m])
                # Centroid
                centroids = np.array([eval(c) for c in prop['centroid']])
                centroids += ulc
                bool_clim = centroids[:,1] < c_lim
                bool_rlim = centroids[:,0] < r_lim
                prop['centroid_adj'] = centroids.tolist()
                # Bbox
                bboxes = np.array([eval(b) for b in prop['bbox']])
                bboxes += np.tile(ulc, 2).astype(int)
                prop['bbox_adj'] = bboxes.tolist()

                # Filter based on location
                prop_filt = prop[bool_rlim*bool_clim]

                # Plot classif
                seg = np.load(seg_fns[m])
                dict_lab_col = {}
                dict_lab_bbox = {}
                for i, row in prop_filt.iterrows():
                    lab = row['label']
                    bbox = row['bbox']
                    centroid = row['centroid_adj']
                    ind = dict_m_lab_ind[m][lab]
                    cl = clust_agg[ind]
                    bc = dict_cl_bc[cl]
                    sciname = dict_bc_sciname[bc]
                    color = dict_sciname_color[sciname]
                    dict_lab_col[lab] = color
                    dict_ind_centroid_sciname[ind] = [centroid, sciname]
                    dict_lab_bbox[lab] = bbox
                classif_rgb = sf.seg_2_rgb(seg, dict_lab_col, dict_lab_bbox)
                # Save plot
                fig, ax, cbar = ip.plot_image(
                    classif_rgb, 
                    im_inches=im_inches, 
                    scalebar_resolution=res_umpix
                )
                classif_plot_fn = classif_plot_fmt.format(d=date, sn=sn, m=m)
                d = os.path.split(classif_plot_fn)[0]
                if not os.path.exists(d):
                    os.makedirs(d)
                    print('Made dir:', d)
                plt.figure(fig)
                ip.save_fig(classif_plot_fn, dpi=dpi, bbox_inches=0)   
                plt.close()

                # Write to all tiles image
                cr_shp = classif_rgb.shape[:2]
                ulc = [int(c) for c in ulc]
                print(ulc, cr_shp)
                classif_tile[ulc[0]:ulc[0] + cr_shp[0], ulc[1]:ulc[1] + cr_shp[1], :] = classif_rgb

                # # Plot sum 
                # reg = np.load(reg_fns[m])
                # reg_sum = np.sum(reg, axis=2)
                # sum_tile[ulc[0]:ulc[0] + cr_shp[0], ulc[1]:ulc[1] + cr_shp[1]] = reg_sum
            # Plot all tiles together
            fig, ax, cbar = ip.plot_image(
                classif_tile, 
                im_inches=im_inches*np.max(rows), 
                scalebar_resolution=res_umpix
            )
            classif_plot_all_fn = classif_plot_all_fmt.format(d=date, sn=sn)
            d = os.path.split(classif_plot_all_fn)[0]
            if not os.path.exists(d):
                os.makedirs(d)
                print('Made dir:', d)
            plt.figure(fig)
            ip.save_fig(classif_plot_all_fn, dpi=dpi, bbox_inches=0)
            plt.close()

            # Save overall coords
            cent_sci = list(dict_ind_centroid_sciname.values())
            clust_inds = list(dict_ind_centroid_sciname.keys())
            coords = [cs[0] for cs in cent_sci]
            scinames = [cs[1] for cs in cent_sci]
            # radius_pix = radius_um / res_umpix
            # neigh = NearestNeighbors(radius=radius_pix)
            # nbrs = neigh.fit(coords)
            # nn_inds, nn_dists = nbrs.radius_neighbors(coords)
            centroid_sciname_fn = centroid_sciname_fmt.format(d=date, sn=sn)
            d = os.path.split(centroid_sciname_fn)[0]
            if not os.path.exists(d):
                os.makedirs(d)
                print('Made dir:', d)
            pd.DataFrame({
                'clust_ind': clust_inds,
                'coord': coords,
                'sciname': scinames,
            }).to_csv(centroid_sciname_fn)

                    




            
# Construct adjacency matrix hao method
# Construct adjacency matrix voronoi method

2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01
[0, 0] (2000, 2000)
[0, 1900] (2000, 2000)
[0, 3800] (2000, 2000)
[1900, 0] (2000, 2000)
[1900, 1900] (2000, 2000)
[1900, 3800] (2000, 2000)
[3800, 0] (2000, 2000)
[3800, 1900] (2000, 2000)
[3800, 3800] (2000, 2000)


Adjacency matrix

In [None]:
date = "2022_12_16"
sn = "2022_12_16_harvardwelch_patient_14_tooth_14_aspect_MB_depth_sub_fov_01"

centroid_sciname_fn = centroid_sciname_fmt.format(d=date, sn=sn)
centroid_sciname = pd.read_csv(centroid_sciname_fn)

