In [None]:
import random
import glob
import json
from pathlib import Path
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import matplotlib.pyplot as plt
import h5py as h5
from skimage.exposure import rescale_intensity

# 1) Load images from H5 + normalize intensity

### Definitions

In [None]:
def normalize(arr, min_, max_):
    '''
    non-clipping linear normalization, min_ will be set to 0, max_ to 1 in the output 
    '''
    return (arr - min_) / (max_ - min_)

def load_all_details(in_file, do_normalize=True, percentiles=(2.5, 99.5)):
    
    imgs = []
    percentiles_ = []
    names = []
    
    with h5.File(in_file, 'r') as fd:
        details = [k for k in fd['experiment'].keys() if 'detail' in k]

        for detail_name in details:
            data = fd['experiment/{}/0/0'.format(detail_name)][...].astype(np.float32)
            
            percentiles_img = np.percentile(data, percentiles)
            
            # rescale intensity            
            if do_normalize:
                data = normalize(data, *percentiles_img)
            
            imgs.append(data.squeeze())
            percentiles_.append(percentiles_img) # save percentiles in raw intensity (even if we )
            names.append(detail_name) # save dataset name
            
    return imgs, percentiles_, names

def get_quantiles_per_replicate(file_to_imgs_map, percentiles=(2.5, 99.5)):
    replicate_to_imgs = defaultdict(list)
    for f, imgs in file_to_imgs_map.items():
        replicate_to_imgs[Path(f).parent].extend(imgs)
    
    res_quantiles = {}
    for rep, imgs in replicate_to_imgs.items():
        flat_vals = np.concatenate([img.flat for img in imgs])
        res_quantiles[rep] = np.percentile(flat_vals, percentiles)
    
    return res_quantiles


def normalize_loaded_per_replicate(loaded, percentiles):
    quantiles_per_rep = get_quantiles_per_replicate({k: imgs for k, (imgs, _, _) in loaded.items()}, percentiles)
    
    res = {}
    for i, (f, (imgs, p_raw, names)) in enumerate(loaded.items()):
        perc = tuple(quantiles_per_rep[Path(f).parent])
        imgs_rescaled = [normalize(img, *perc) for img in imgs]
        res[f] = imgs_rescaled, p_raw, names
        
        print(f'({i+1}/{len(loaded)}): {f}')

    return res, quantiles_per_rep

## get sorted list of hdf5 files
NOTE: the folder structure should be ```.../biological_replicate_id(date_condition)/technical_replicate_id/random_file_hash.h5```

In [None]:
in_files = glob.glob('/scratch/hoerl/auto_sir_dna_comp/*/*/*.h5')
in_files.sort()
in_files

In [None]:
# quick check that all files have the same directory depth
np.unique([len(Path(in_file).parents) for in_file in in_files], return_counts=True)

### Normalization options

In [None]:
# percentiles for normalization
percentiles=(2.5, 99.8)
# whether to normalize per replicate (True) or per image (False)
normalize_per_replicate = True

## Load images

In [None]:
# single-threaded version:
# loaded = {f: load_all_details(f) for f in in_files}

loaded = {}
with ThreadPoolExecutor() as tpe:
    futures = [tpe.submit(load_all_details, f, not normalize_per_replicate, percentiles) for f in in_files]
    for i, (f,future) in enumerate(zip(in_files, futures)):
        loaded[f] = future.result()    
        print(f'({i+1}/{len(futures)}): {f}')

## Normalize per replicate
If we have not normalized per image, normalize per replicate now

In [None]:
if normalize_per_replicate:
    loaded, quantiles_per_rep = normalize_loaded_per_replicate(loaded, percentiles)

### Plot examples of images per file

In [None]:
def plot_img_grid(imgs, **kwargs):
    
    imgs = imgs.copy()
    random.shuffle(imgs)
    
    fig, axs = plt.subplots(**kwargs)

    for ax, img in zip(axs.flat, imgs):
        ax.imshow(np.clip(img.squeeze(), 0, 1), cmap='gray')
        ax.axis('off')

In [None]:
for k, v in sorted(loaded.items()):
    print(k)
    plot_img_grid(v[0], ncols=4, nrows=1, figsize=(12,4))
    plt.show()

# 2) Simple segmentation via threshold

In [None]:
from skimage.filters import threshold_otsu, threshold_li, gaussian
from skimage.transform import rescale
from skimage.morphology import remove_small_holes, remove_small_objects
import tqdm

def blur_and_segment(img, blur_sigma=16, max_hole_size=512, min_object_size=512):

    img = img.squeeze()

    g_ = gaussian(img, blur_sigma)
    
    # clip and convert to 8-bit
    # otherwise, li thresholding kept running veeery long for a few images
    g_ = (np.clip(g_, 0, 1) * 255).astype(np.uint8)
    
    mask = g_ > threshold_li(g_)

    # a bit of binary cleaning
    mask = remove_small_objects(mask, min_object_size)
    mask = remove_small_holes(mask, max_hole_size)

    return img, mask

### Segment multithreaded

In [None]:
seg_sigma = 12
max_hole_size = 5000
min_object_size = 5000

segs = {}
with ThreadPoolExecutor() as tpe:
    for i, (k, v) in enumerate(sorted(loaded.items())):
        futures = [tpe.submit(blur_and_segment, vi, seg_sigma, max_hole_size, min_object_size) for vi in v[0]]
        segs[k] = [f.result() for f in futures]
        print(f'({i+1}/{len(loaded)}): {k}')

### Visualize segmentation on example images

In [None]:
from skimage.color import gray2rgb, label2rgb

def plot_seg_grid(imgs, **kwargs):
    
    imgs = imgs.copy()
    random.shuffle(imgs)
    
    fig, axs = plt.subplots(**kwargs)
    for ax, (img, seg) in zip(axs.flat, imgs):
        
#         img = np.clip(img, 0, 1)
        # rescale to 0,1 for better visibility
        img = rescale_intensity(img, out_range=(0,1))
        
        # use skimage label draw tools
        lab_img = label2rgb(seg*1, gray2rgb(img), bg_label=0)
        ax.imshow(lab_img)
        ax.axis('off')

In [None]:
for k, v in sorted(segs.items()):
    print(k)
    plot_seg_grid(v, ncols=4, nrows=1, figsize=(12,4))
    plt.show()

# 3) Extract texture features

In [None]:
from skimage.feature import local_binary_pattern
from skimage.feature import greycomatrix, greycoprops
from collections.abc import Iterable

def get_glcm_features(img, mask, props, distances, angles, blur_sigma=None):
    
    # do a bit of blur anyway to reduce color quantization effects on GLCM
    if blur_sigma is None:
        blur_sigma = 0.5

    # clip and make uint8 here
    img = (np.clip(gaussian(img, blur_sigma), 0, 1) * 255).astype(np.uint8)
    
    # make input for masked GLCM:
    # 1) set bg to zero
    # 2) set everything else +1 
    # (NB: should not be necessary if mask comes from threshold, but let's keep it anyway)
    img_for_masked_glcm = img.copy().astype(np.uint16)
    img_for_masked_glcm[~ mask] = 0
    img_for_masked_glcm[mask] += 1 

    # get glcm, but ignore first row & column (co-ocurrence with 0 := background)
    glcm = greycomatrix(img_for_masked_glcm, distances, angles, 257)
    glcm = glcm[1:,1:]
    
    return np.stack([greycoprops(glcm, prop=p) for p in props])

def get_lbp_histogram(img, mask, blur_sigma=None, Rs=3, P=32):

    # ensure Rs is iterable even if we only have a single R
    if not isinstance(Rs, Iterable):
        Rs = [Rs]
    
    # do a bit of blur anyway to reduce color quantization effects on GLCM
    if blur_sigma is None:
        blur_sigma = 0.5
        
    # clip and make uint8 here
    img = (np.clip(gaussian(img, blur_sigma), 0, 1) * 255).astype(np.uint8)
    
    probs = []
    for R in Rs:
        lbp = local_binary_pattern(img, P, R, method='uniform')
        probs_i, bins = np.histogram(lbp[mask], bins=np.arange(P+3), density=True)
        probs.append(probs_i)

    return np.stack(probs)

### Feature extraction parameters

In [None]:
# set to None to not blur images, code for sigma estimation below
blur_sigma = None
# blur_sigma = 5.2

# GLCM feature options
props = ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']

# distances = [2, 4, 7, 12, 16]
distances = [2, 4, 8, 16, 32, 64]
angles = [0, np.pi/2]

# LBP feature options
add_lbps = False
Rs = [2,4,6]
P = 32

### Calculate GLCM features for all images

In [None]:
glcms = {}
with ThreadPoolExecutor() as tpe:
    for i,(k, v) in enumerate(sorted(segs.items())):
        futures = [tpe.submit(get_glcm_features, i, m, props, distances, angles, blur_sigma) for (i,m) in v]
        glcms[k] = [f.result() for f in futures]
        print(f'({i+1}/{len(loaded)}): {k}')

### Optional: calculate LBP features

Did not imporove results much, can be skipped

In [None]:
# only calculate if we want them
if add_lbps:
    lbps = {}
    with ThreadPoolExecutor() as tpe:
        for k, v in sorted(segs.items()):
            futures = [tpe.submit(get_lbp_histogram, i, m, blur_sigma, Rs, P) for (i,m) in v]
            lbps[k] = [f.result() for f in futures]
            print(k)

# 4) Some simple other features + save as table

In [None]:
from scipy.ndimage import maximum_filter1d

def get_simple_features(img, mask, blur_sigma=None):
    
    # do a bit of blur anyway to reduce color quantization effects on GLCM
    if blur_sigma is None:
        blur_sigma = 0.5

    img = gaussian(img, blur_sigma)
        
    return np.mean(img[mask]), np.std(img[mask]), np.sum(mask), img.shape[0], img.shape[1]

def get_num_blank_lines(img, normalization_vals=(0,1), thresh=1.0, max_filter_size=5):
    
    # un-normalize image that has been inensity-scaled to (0,1) based on original values
    # assumes no clipping has happened
    # use the default values (0,1) to use the image as-is
    low_val, high_val = normalization_vals
    norm_range = high_val - low_val
    img_unnormalized = img * norm_range + low_val
    
    # get mean along rows and cols
    # max-filter to minimize effect of single blank lines
    # (the artifacts we observe are usually many consecutive blank lines due to detector shutdown or scanning outside the FOV)
    row_means = img_unnormalized.mean(axis=1)
    col_means = img_unnormalized.mean(axis=0)
    row_means = maximum_filter1d(row_means, max_filter_size)
    col_means = maximum_filter1d(col_means, max_filter_size)
    
    # result: numer of rows, cols with mean intensity under thresh (blank)
    return (row_means < thresh).sum(), (col_means < thresh).sum() 
  

### Combine the multiple dicts + calculate simple features

In [None]:
from itertools import product
from collections import defaultdict

# column names for GLCM feats
props_names = list(map(lambda x: '_'.join(x), product(props, map(str, distances), map(str, map(np.rad2deg, angles)))))

df_dict = defaultdict(list)

for i, (k,v) in enumerate(glcms.items()):
    for vi in v:
        df_dict['filename'].append(k)
        for value, prop_name in zip(np.array(vi).flat, props_names):
            df_dict[prop_name].append(value)
    
    if add_lbps:
        for lbp in lbps[k]:
            for value, (R, p) in zip(lbp.flat, product(Rs, np.arange(P+2, dtype=int))):
                df_dict[f'LBP_R{R}_{p}'] = value
    
    for (img, mask) in segs[k]:
        mu, sig, area, height, width = get_simple_features(img, mask, blur_sigma)
        df_dict['intensity_mu'].append(mu)
        df_dict['intensity_sigma'].append(sig)
        df_dict['mask_area'].append(area)
        df_dict['img_height'].append(height)
        df_dict['img_width'].append(width)
    
    imgs, percs, names = loaded[k]
    
    if normalize_per_replicate:
        percs_rep_low, percs_rep_high = quantiles_per_rep[Path(k).parent]
    
    for (per_low, per_high), name, img in zip(percs, names, imgs):
        df_dict['dataset_name'].append(name)
        df_dict['perc_low'].append(per_low if not normalize_per_replicate else percs_rep_low)
        df_dict['perc_high'].append(per_high if not normalize_per_replicate else percs_rep_high)
        df_dict['perc_low_image'].append(per_low)
        df_dict['perc_high_image'].append(per_high)
        
        blank_rows, blank_cols = get_num_blank_lines(img, (per_low, per_high) if not normalize_per_replicate else (percs_rep_low, percs_rep_high))
        df_dict['num_blank_rows'].append(blank_rows)
        df_dict['num_blank_cols'].append(blank_cols)
        
    print(f'({i+1}/{len(glcms)}): {k}')

### Save to CSV

In [None]:
import pandas as pd

df = pd.DataFrame.from_dict(df_dict)

In [None]:
# OPTIONALLY: load existing table and append (e.g. when adding new replicates)

append_to_existing_df = False
existing_df = '/scratch/hoerl/auto_sir_dna_comp/20211111_glcm_all_lithreshold_smallblur.csv'

if append_to_existing_df:
    df = pd.read_csv(existing_df).append(df, ignore_index=True)

In [None]:
df.to_csv('/scratch/hoerl/auto_sir_dna_comp/20220829_glcm-long_all_replicatenorm.csv', index=False)

# OPTIONAL: save pngs for manual sorting, confocal blur simulations

## Save a random sample of images for manual sorting

In [None]:
from itertools import repeat, chain, zip_longest
from skimage.io import imsave
import os

sorting_out_dir = '/scratch/hoerl/auto_sir_dna_comp/sorting20210316'

# flat list of (image, dataset_name, hdf5_filename)
imgsplusnames = list(chain(*[zip(v[0], v[2], repeat(k)) for k,v in sorted(loaded.items())]))
random.shuffle(imgsplusnames)

if not os.path.exists(sorting_out_dir):
    os.makedirs(sorting_out_dir)

for img, name, file in imgsplusnames[:250]:
    outname = (os.path.split(file)[1].replace('.h5', '') + '_' + name + '.png')
    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    # NOTE: uncomment this line to actually save
#     imsave(os.path.join(sorting_out_dir, outname), img)

## Check sorting

In the following cells, we load a feature table with an ```classification_auto```-column (as produced in ```subset_selection2.ipynb```) select a random sample from the images we loaded above and save them as .png for manual inspection.

In [None]:
df_with_cls = pd.read_csv('/scratch/hoerl/auto_sir_dna_comp/20220816_glcm_all_imagenorm_withcls.csv')
df_with_cls.head()

In [None]:
n_to_sample = 250

dfs_good = df_with_cls[df_with_cls.classification_auto == 'good'].sample(n_to_sample)
dfs_bad = df_with_cls[df_with_cls.classification_auto == 'bad'].sample(n_to_sample)

In [None]:
from itertools import repeat, chain, zip_longest
from skimage.io import imsave
import os

imgsplusnames = list(chain(*[zip(v[0], v[2], repeat(k)) for k,v in sorted(loaded.items())]))

good_out_dir = '/scratch/hoerl/auto_sir_dna_comp/sorting20211115_resorted_val/good'
if not os.path.exists(good_out_dir):
    os.makedirs(good_out_dir)

for ridx, r in dfs_good.iterrows():
    img = next((i for (i, n, f) in imgsplusnames if n==r['dataset_name'] and f==r['filename']))
    outname = (os.path.split(r['filename'])[1].replace('.h5', '') + '_' + r['dataset_name'] + '.png')
    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    imsave(os.path.join(good_out_dir, outname), img)
    
bad_out_dir = '/scratch/hoerl/auto_sir_dna_comp/sorting20211115_resorted_val/bad'
if not os.path.exists(bad_out_dir):
    os.makedirs(bad_out_dir)

for ridx, r in dfs_bad.iterrows():
    img = next((i for (i, n, f) in imgsplusnames if n==r['dataset_name'] and f==r['filename']))
    outname = (os.path.split(r['filename'])[1].replace('.h5', '') + '_' + r['dataset_name'] + '.png')
    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    imsave(os.path.join(bad_out_dir, outname), img)

# get pixel sizes, calculate blur sigma for simulated confocal

In [None]:
# print pixel sizes for first STED image in all input files

for file in in_files:
    with h5.File(file, 'r') as fd:
        details = [k for k in fd['experiment'].keys() if 'detail' in k]
        
        if len(details) == 0:
            continue
        
        first_detail = details[0]

        dataset = fd['experiment/{}/0'.format(first_detail)]
        meta = json.loads(dataset.attrs['measurement_meta'])
        print(file, float(meta['ExpControl']['scan']['range']['x']['psz']))

In [None]:
# assumed FWHMS STED 50nm, conf 250nm
sted_fwhm = 5e-8
conf_fwhm = 2.5e-7
pixel_size = 2e-8

def fwhm_to_sigma(f):
    return f / (2 * np.sqrt(2 * np.log(2)))
    
# correct for already existing STED FWHM
blur_sigma = np.sqrt(fwhm_to_sigma(conf_fwhm)**2 - fwhm_to_sigma(sted_fwhm)**2)

# additional blur to get conf resolution
blur_sigma = blur_sigma / pixel_size
blur_sigma

# Plot an example of STED vs. simulated confocal

In [None]:
img_plot = next(iter(loaded.values()))[0]

fig, axs = plt.subplots(ncols=2, figsize=(16,8))
axs[0].imshow(img_plot[0].squeeze(), cmap='gray')
axs[0].axis('off')
axs[0].set_title('STED raw')
axs[1].imshow(gaussian(img_plot[0].squeeze(), 5.2), cmap='gray')
axs[1].axis('off')
axs[1].set_title('simulated confocal')