In [None]:
import glob
import sys
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import re
import joblib
from sklearn.cluster import AgglomerativeClustering
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from collections import defaultdict
import numba as nb
import javabridge
import bioformats
import aicspylibczi as aplc
import skimage.restoration as skr


In [None]:
gc.collect()


Move to workdir

In [None]:
# Absolute path
cluster = '/fs/cbsuvlaminck2/'
project_workdir = cluster + '/workdir/bmg224/manuscripts/mgefish/code/harvard_plasmids_imaging/2023_09_12_sapp05_spades08'

os.chdir(project_workdir)
os.getcwd()  # Make sure you're in the right directory

Load config file

In [None]:
config_fn = 'config_matrix_classify.yaml' # relative path to config file from workdir

with open(config_fn, 'r') as f:
    config = yaml.safe_load(f)

Special imports

In [None]:
%load_ext autoreload
%autoreload 2

sys.path.append(cluster + config['pipeline_path'] + '/' + config['functions_path'])
import fn_general_use as fgu
import image_plots as ip
import segmentation_func as sf
import fn_hiprfish_classifier as fhc
import fn_spectral_images as fsi



## Preprocessing
### Pick a sample

In [None]:
input_table = pd.read_csv(config['input_table_fn'])
input_table

In [None]:
sn_i = 6
# sample_names = input_table['sample_name'].values
# sn = sample_names[sn_i]
sn, mt, af = input_table.iloc[sn_i, :].values
sn

## Plot RGB with MGE overlay

In [None]:
# Hiprfish Resolution 
# Get raw data filename
raw_fmt = config['data_dir'] + '/' + config['raw_fmt']
raw_fn = raw_fmt.format(sample_name=sn, laser='488')
# load  metadata
# find resolution 
czi = aplc.CziFile(raw_fn)
for n in czi.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            resolution = float(n.text)
hipr_res_um_pix = resolution * 10**6
# number of z stacks
dims_shape = czi.get_dims_shape()[0]
z_size = dims_shape['Z'][1]


In [None]:
# Airyscan Resolution 
# Get raw data filename
raw_fmt = config['data_dir'] + '/' + sn + '_mode_airy_Airyscan Processing.czi'
raw_mge_fn = raw_fmt.format(sample_name=sn, laser='488')
# load  metadata
# find resolution 
czi_mge = aplc.CziFile(raw_mge_fn)
for n in czi_mge.meta.iter():
    if 'Scaling' in n.tag:
        if 'X' in n.tag:
            resolution = float(n.text)
mge_res_um_pix = resolution * 10**6
# number of z stacks
dims_shape = czi_mge.get_dims_shape()[0]
z_size = dims_shape['Z'][1]


In [None]:
from cv2 import resize, INTER_CUBIC, INTER_NEAREST

def center_image(im, dims, ul_corner):
    shp = im.shape
    if not all([dims[i] == shp[i] for i in range(len(dims))]):
        shp_new = dims if len(shp) == 2 else dims + (shp[2],)
        temp = np.zeros(shp_new)
        br_corner = np.array(ul_corner) + np.array(shp[:2])
        temp[ul_corner[0]:br_corner[0], ul_corner[1]:br_corner[1]] = im
        im = temp
    return im

def resize_hipr(im, hipr_res, mega_res, dims='none', out_fn=False, ul_corner=(0,0)):
    # im = np.load(in_fn)
    factor_resize = hipr_res / mega_res
    hipr_resize = resize(
            im,
            None,
            fx = factor_resize,
            fy = factor_resize,
            interpolation = INTER_NEAREST
            )
    if isinstance(dims, str): dims = hipr_resize.shape
    hipr_resize = center_image(hipr_resize, dims, ul_corner)
    # if out_fn: np.save(out_fn, hipr_resize)
    return hipr_resize

def shift_mega(im):
    '''Globally define:  mega_shift_vector, max_shift, dims, ul_corner'''
    # im = np.load(in_fn)
    if len(im.shape) == 2: im = im[...,None]
    im = center_image(im, dims, ul_corner)
    return fsi._shift_images([im], mega_shift_vector, max_shift=max_shift)

def get_mega_props(seg, raw_shift, ch_):
    print(raw_shift.shape)
    raw = raw_shift[:,:,ch_]
    seg = seg.astype(np.int64)[:,:,0]
    return sf.measure_regionprops(seg, raw=raw)

def reshape_aics_image(m_img):
    '''
    Given an AICS image with just XY and CHannel,
    REshape into shape (X,Y,C)
    '''
    img = np.squeeze(m_img)
    img = np.transpose(img, (1,2,0))
    return img

In [None]:
for sn_i in range(7):
    sn, mt, af = input_table.iloc[sn_i, :].values

    print(sn)

    # Hiprfish Resolution 
    # Get raw data filename
    raw_fmt = config['data_dir'] + '/' + config['raw_fmt']
    raw_fn = raw_fmt.format(sample_name=sn, laser='488')
    # load  metadata
    # find resolution 
    czi = aplc.CziFile(raw_fn)
    for n in czi.meta.iter():
        if 'Scaling' in n.tag:
            if 'X' in n.tag:
                resolution = float(n.text)
    hipr_res_um_pix = resolution * 10**6
    # number of z stacks
    dims_shape = czi.get_dims_shape()[0]
    z_size = dims_shape['Z'][1]

    # Airyscan Resolution 
    # Get raw data filename
    raw_fmt = config['data_dir'] + '/' + sn + '_mode_airy_Airyscan Processing.czi'
    raw_mge_fn = raw_fmt.format(sample_name=sn, laser='488')
    # load  metadata
    # find resolution 
    czi_mge = aplc.CziFile(raw_mge_fn)
    for n in czi_mge.meta.iter():
        if 'Scaling' in n.tag:
            if 'X' in n.tag:
                resolution = float(n.text)
    mge_res_um_pix = resolution * 10**6
    # number of z stacks
    dims_shape = czi_mge.get_dims_shape()[0]
    z_size = dims_shape['Z'][1]


    # M = 2
    M = z_size
    mge_cell_chan = 0

    im_inches=10
    gauss=2

    spot_clims = (50,150)

    out_dir = '../../../outputs/harvard_plasmids_imaging/2023_09_12_sapp05_spades08/rgb_mge_overlays/' + sn
    if not os.path.exists(out_dir): os.makedirs(out_dir)

    from cv2 import resize, INTER_CUBIC, INTER_NEAREST

    dict_m_seg_spectra = {}
    # iterate through z 
    for m in range(M):
        print(m)

        # Load registered image
        shift_fmt = config['output_dir'] + '/' + config['shift_dir']
        af_str = re.sub('\.','_',str(af))
        shift_dir = shift_fmt.format(sample_name=sn, maskthresh=mt, allfluor=af_str)
        raw_fns = sorted(glob.glob(shift_dir + '/' + sn + '_M_' + str(m) + '_*'))
        raws = [np.load(f) for f in raw_fns]

        # pre process
        def get_smooth(m_raws_shift, sigma):
            raws_smooth = []
            for im in m_raws_shift:
                im_smooth = np.empty(im.shape)
                for i in range(im.shape[2]):
                    im_smooth[:,:,i] = sf.pre_process(im[:,:,i], gauss=gauss)
                    # im_smooth[:,:,i] = gaussian_filter(im[:,:,i], sigma=sigma)
                raws_smooth.append(im_smooth)
            return raws_smooth
        raws_smooth = get_smooth(raws, config['sigma'])
        stack_pre = np.dstack(raws_smooth)
        stack_pre_sum = np.sum(stack_pre, axis=2)
        stack_pre_sum_zoom = stack_pre_sum
        # ip.plot_image(np.sum(stack, axis=2),cmap='inferno', im_inches=im_inches)
        # plt.show()
        # plt.close()
        # ip.plot_image(stack_pre_sum_zoom, cmap='inferno', im_inches=im_inches)    
        # plt.show()
        # plt.close()

        # Plot RGB
        clips = [(0.05,0.4),(0.05,0.4),(0.05,0.4)]
        rgb = [np.max(r, axis=2) for r in raws_smooth]
        rgb_smooth_adj_lst = []
        for r, clip in zip(rgb, clips):
            mx = np.max(rgb)
            mn = np.min(rgb)
            r_norm = (r - mn) / (mx - mn)
            r_adj = np.clip(r_norm, clip[0], clip[1])
            r_adj = (r_adj - clip[0]) / (clip[1] - clip[0])
            rgb_smooth_adj_lst.append(r_adj)
        rgb_smooth_adj = np.dstack(rgb_smooth_adj_lst)
        # for r in rgb_smooth_adj_lst:
            # ip.plot_image(r, cmap='inferno', im_inches=im_inches)
        # ip.plot_image(rgb_smooth_adj, im_inches=im_inches)
        # plt.show()
        # plt.close()

        # Get rescaled hiprfish image
        hipr_rgb_resize = resize_hipr(
                rgb_smooth_adj, hipr_res_um_pix, mge_res_um_pix
                )
        hipr_sum = np.sum(np.dstack(raws), axis=2)
        hipr_sum_resize = resize_hipr(
            hipr_sum, hipr_res_um_pix, mge_res_um_pix
            )
        
        # Load mge image 
        raw_mge, _ = czi_mge.read_image(Z=m)
        raw_mge = reshape_aics_image(raw_mge)
        mega_cell = raw_mge[:,:,mge_cell_chan]
        # ip.plot_image(mega_cell, cmap='inferno')
        # plt.show()
        # plt.close

        # Which is the smaller image?
        mshp = mega_cell.shape[:2]
        hshp = hipr_sum_resize.shape[:2]
        im_list = [mega_cell, hipr_sum_resize]
        i_sml = np.argmin([mshp[0],hshp[0]])
        i_lrg = np.argmax([mshp[0],hshp[0]])
        sml = im_list[i_sml]
        lrg = im_list[i_lrg]
        # Get half the difference between sizes
        shp_dff = np.abs(np.array(hshp) - np.array(mshp)) // 2
        # Shift the smaller image so that it sits at the center of the larger image
        sml_shift_shape = lrg.shape[:2]
        if len(sml.shape) > 2:
            sml_shift += (sml.shape[2],)
        sml_shift = np.zeros(sml_shift_shape)
        corn_ind = np.array(shp_dff) + np.array(sml.shape[:2])
        sml_shift[shp_dff[0]:corn_ind[0], shp_dff[1]:corn_ind[1]] = sml
        # reassign mega and hipr image var names
        im_shift_list = [0,0]
        im_shift_list[i_sml] = sml_shift
        im_shift_list[i_lrg] = lrg
        mega_shift = im_shift_list[0]
        hipr_shift = im_shift_list[1]
        # Get the shift vectors for the mega image
        image_list = [hipr_shift, mega_shift]
        shift_vectors = fsi._get_shift_vectors(image_list)

        # Shift mge 
        max_shift = 500
        mega_shift_vector = [shift_vectors[1]]
        dims = lrg.shape
        ul_corner = shp_dff
        # run the shift function
        raw_shift = shift_mega(raw_mge)
        # cs_shifts = shift_mega(mega_cs)
        # ss_shifts = shift_mega(mega_ss)

        # show raw spot on top of classif
        ul = 0.15
        ll = 0.1
        spot_raw = raw_shift[0][:,:,1].copy()
        # ip.plot_image(spot_raw, cmap='inferno')
        # plt.show()
        # plt.close()
        spot_raw = np.clip(spot_raw, spot_clims[0],spot_clims[1])
        spot_raw = (spot_raw - spot_clims[0]) / (spot_clims[1] - spot_clims[0])
        # spot_raw -= np.min(spot_raw)
        # spot_raw /= np.max(spot_raw)
        # spot_raw[spot_raw > ul] = ul
        # spot_raw[spot_raw < ll] = 0
        # spot_raw /= ul
        spot_raw_overlay = np.zeros(spot_raw.shape + (4,))
        spot_raw_overlay[:,:,0] = spot_raw
        spot_raw_overlay[:,:,2] = spot_raw
        spot_raw_overlay[:,:,3] = 1
        # ip.plot_image(spot_raw_overlay)
        # plt.show()
        # plt.close()
        spot_raw_overlay[:,:,3] = spot_raw
        fig, ax, cbar = ip.plot_image(
                hipr_rgb_resize, scalebar_resolution=mge_res_um_pix, im_inches=im_inches
                )
        ax.imshow(spot_raw_overlay)
        out_bn = out_dir + '/' + sn + '_M_' + str(m)
        ip.save_png_pdf(out_bn)
        # plt.show()
        # plt.close()
