# Extraction of confocal overview tiles

With the code in this notebook, we can extract the individual overview tiles + their location metadate from HDF5 files generated by the autoSTED pipeline.

We can do two things with them:

* Save individual tiles + locations ("Tile Configuration") for stitching in Fiji/BigStitcher

* Stitch maximum projections based just on stage coordinates and save
    * can be saved as raw intensities in TIFF
    * or as PNG, optionally with bounding boxes around the areas 

## Imports / Defs

In [None]:
import json
import glob
import os
from pathlib import Path
from itertools import count

import seaborn as sns
import numpy as np
from tifffile import imwrite
import h5py as h5
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
def zeropad(n: int, out_len=3):
    """
    Convert int to zero-padded str with at least out_len characters
    """
    s = str(n)
    return '0' * (out_len - len(s)) + s


def save_stitched_projections(h5file, out_file, show_boxes=True, save_raw_intensities=False):
    
    # 1. collect images, attribute dicts and dataset names from h5

    imgs = []
    attrs = []
    name = []

    with h5.File(h5file, 'r') as fd:
        for k in fd['experiment'].keys():
            name.append(k)
            attrs.append(json.loads(fd[f'experiment/{k}/0'].attrs['measurement_meta']))
            imgs.append(fd[f'experiment/{k}/0/0'][...].squeeze())
            
    # 2. get XY coordinates, FOVs and pixel sizes from attribute dicts

    coords = []
    lens = []
    pszs = []

    for attr in attrs:
        # coords: g_off + off (coarse) + g_off + off (fine)
        # should be okay for x,y
        coords.append([
            attr['ExpControl']['scan']['range']['coarse_x']['g_off'] + attr['ExpControl']['scan']['range']['coarse_x']['off'] +
            attr['ExpControl']['scan']['range']['x']['g_off'] + attr['ExpControl']['scan']['range']['x']['off'],
            attr['ExpControl']['scan']['range']['coarse_y']['g_off'] + attr['ExpControl']['scan']['range']['coarse_y']['off'] +
            attr['ExpControl']['scan']['range']['y']['g_off'] + attr['ExpControl']['scan']['range']['y']['off']
        ])
        lens.append([attr['ExpControl']['scan']['range']['x']['len'], attr['ExpControl']['scan']['range']['y']['len']])
        pszs.append([attr['ExpControl']['scan']['range']['x']['psz'], attr['ExpControl']['scan']['range']['y']['psz']])
    
    # 3. fusion
    coords = np.array(coords)
    mins = coords[['detail' not in n for n in name]] - np.array(lens)[['detail' not in n for n in name]] / 2
    maxs = coords[['detail' not in n for n in name]] + np.array(lens)[['detail' not in n for n in name]] / 2

    # min of all overviews
    min_ = np.min(mins, axis=0)

    # allocate fused, weights
    # NOTE: we flip xy here!
    # TODO: do not hardcode pixelsize
    shape = np.ceil((np.max(maxs, axis=0) - np.min(mins, axis=0)) / 2.5e-7).astype(int)
    fused = np.zeros(shape[::-1])
    weights = np.zeros(shape[::-1])

    # fuse max projections of overviews
    for img, n, mi, ma in list(zip(imgs, name, coords - np.array(lens)/2, coords + np.array(lens)/2))[:]:
        if 'detail' in n:
            continue

        fused[tuple((slice(int(np.round(mii)), int(np.round(mai))) for mii, mai in zip(((mi- min_)/2.5e-7)[::-1], ((ma- min_)/2.5e-7)[::-1])))] += img.max(axis=0)
        weights[tuple((slice(int(np.round(mii)), int(np.round(mai))) for mii, mai in zip(((mi- min_)/2.5e-7)[::-1], ((ma- min_)/2.5e-7)[::-1])))] += 1

    # normalize weights
    fused[weights!=0] /= weights[weights!=0]

    # Option A: do mpl plot, save as PNG
    if not save_raw_intensities:
        plt.figure(figsize=(20,20))
        plt.imshow(fused, cmap='hot')

        # boxes around details
        if show_boxes:
            for img, n, mi, ma, l in list(zip(imgs, name, coords - np.array(lens)/2, coords + np.array(lens)/2, lens))[:]:
                if 'detail' not in n:
                    continue

                rec = Rectangle((mi-min_)/2.5e-7, *(np.array(l)/2.5e-7), color='limegreen', fill=None)
                plt.gca().add_artist(rec)

        plt.axis('off')
        plt.savefig(out_file, dpi=np.max(shape)/20)
    
    # Option B: save raw intensities as TIFF
    else:
        fused = fused.astype(np.int16)
        imwrite(out_file.replace('.png', '.tiff'), fused)


def save_ov_tiles(h5file, outdir):

    # 1. collect images, attribute dicts and dataset names from h5
    imgs = []
    attrs = []
    names = []

    with h5.File(h5file, 'r') as fd:
        for k in fd['experiment'].keys():
            names.append(k)
            attrs.append(json.loads(fd[f'experiment/{k}/0'].attrs['measurement_meta']))
            imgs.append(fd[f'experiment/{k}/0/0'][...].squeeze())

    # 2. get coordinates, FOVs and pixel sizes from attribute dicts
    coords = []
    lens = []
    pszs = []

    for attr in attrs:
        # coords: g_off + off (coarse) + g_off + off (fine)
        # should be okay for x,y
        coords.append([
            attr['ExpControl']['scan']['range']['coarse_x']['g_off'] + attr['ExpControl']['scan']['range']['coarse_x']['off'] +
            attr['ExpControl']['scan']['range']['x']['g_off'] + attr['ExpControl']['scan']['range']['x']['off'],
            attr['ExpControl']['scan']['range']['coarse_y']['g_off'] + attr['ExpControl']['scan']['range']['coarse_y']['off'] +
            attr['ExpControl']['scan']['range']['y']['g_off'] + attr['ExpControl']['scan']['range']['y']['off']
        ])
        lens.append([attr['ExpControl']['scan']['range']['x']['len'], attr['ExpControl']['scan']['range']['y']['len']])
        pszs.append([attr['ExpControl']['scan']['range']['x']['psz'], attr['ExpControl']['scan']['range']['y']['psz']])
        
    coords = np.array(coords)

    # 3. save tiles + tileConfiguration (location) text file
    ctr = count()
    tile_config_lines = ['dim=3']
    for img, psz, coord, name in zip(imgs, pszs, coords, names):

        # skip detail images
        if 'detail' in name or img.ndim != 3:
            continue

        # save as imageJ-compatible stack
        outfile = outdir / f'tile_{zeropad(next(ctr))}.tif'
        try:
            imwrite(outfile, img.astype(np.uint16), imagej=True, metadata={'axes':'ZYX'})
        
        # FIXME: debug except, should not happen anymore?
        except ValueError:
            print(img.shape, name, h5file)
            return

        # make tile config for Fiji Grid/Collection Stitching
        tile_config_lines.append('{};;({},{},0.0)'.format(outfile.name, *(coord/psz)))

    with (outdir / 'tile_config.txt').open('w') as tc_file:
        tc_file.write('\n'.join(tile_config_lines))

## Get files to process

In [8]:
basepath = '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED'
h5s = glob.glob(os.path.join(basepath, '**/*.h5'), recursive=True)

h5s

['/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-08_rep3/7e2945e314196b8c7998afbb09526fd1.h5',
 '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-05_rep1/57a7d020522f784dca80e052b20eab36.h5',
 '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-07_rep2/0ccf6936085070c4a9f2301009af738c.h5',
 '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-07_rep2/109f896964d0b8763f0e4b9c9f1a235c.h5',
 '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-07_rep2/ef9b8dbc30c16f3fdd0613c6966814b7.h5',
 '/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/20210826_IMR90_6d_ICM_young/26.08.21_rep2/c128

## A: Save individual tiles    

In [None]:
out_base_path = '/scratch/hoerl/20230507_imr90_stitching'

outdirs = []

for h5file in h5s:
    outdir = Path(out_base_path) / os.path.relpath(h5file, basepath)
    outdir = outdir.parent / outdir.stem
    outdir.mkdir(parents=True, exist_ok=True)
    outdirs.append(outdir)

# apply to all h5 files
import tqdm
for h5file, outdir in tqdm.tqdm(zip(h5s, outdirs)):
    save_ov_tiles(h5file, outdir)

## B: Save stitched max projections

In [None]:
out_root = '/scratch/hoerl/auto_sir_dna_comp/stitched_ov_raw'

for h5_file in h5s:
    
    # NB: check the glob depth above!
    outfile = os.path.join(out_root, '/'.join(h5_file.rsplit('/', 3)[1:]) + '.png')
    out_folder, _ = os.path.split(outfile)
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    
    try:
        save_stitched_projections(h5_file, outfile, False, True)
    except Exception as e:
        print(f'error on file {h5_file}: {e}')
    print(h5_file)
    
print('done.')

# Testing code on individual file

In [9]:
# get a single h5 file path

h5file = h5s[0]
h5file

'/data/cooperation_data/ArgyrisPapantonis-nuclear_architecture/Simona_Nasiscionyte/STED/2020705_IMR90_young_untreated/2020-07-08_rep3/7e2945e314196b8c7998afbb09526fd1.h5'

In [None]:
# collect images, attribute dicts and dataset names from h5

imgs = []
attrs = []
names = []

with h5.File(h5file, 'r') as fd:
    for k in fd['experiment'].keys():
        names.append(k)
        attrs.append(json.loads(fd[f'experiment/{k}/0'].attrs['measurement_meta']))
        imgs.append(fd[f'experiment/{k}/0/0'][...].squeeze())

In [None]:
# get coordinates, FOVs and pixel sizes from attribute dicts

coords = []
lens = []
pszs = []

for attr in attrs:
    # coords: g_off + off (coarse) + g_off + off (fine)
    # should be okay for x,y
    coords.append([
        attr['ExpControl']['scan']['range']['coarse_x']['g_off'] + attr['ExpControl']['scan']['range']['coarse_x']['off'] +
        attr['ExpControl']['scan']['range']['x']['g_off'] + attr['ExpControl']['scan']['range']['x']['off'],
        attr['ExpControl']['scan']['range']['coarse_y']['g_off'] + attr['ExpControl']['scan']['range']['coarse_y']['off'] +
        attr['ExpControl']['scan']['range']['y']['g_off'] + attr['ExpControl']['scan']['range']['y']['off']
    ])
    lens.append([attr['ExpControl']['scan']['range']['x']['len'], attr['ExpControl']['scan']['range']['y']['len']])
    pszs.append([attr['ExpControl']['scan']['range']['x']['psz'], attr['ExpControl']['scan']['range']['y']['psz']])

In [None]:
# plot coords
coords = np.array(coords)
sns.scatterplot(x=coords.T[0], y=coords.T[1], hue = ['detail' if 'detail' in n else 'overview' for n in names])

### Fused plot with rectangles around details

In [None]:
# np.unique((np.array(lens) / np.array(pszs)).astype(int), axis=0)
mins = coords[['detail' not in n for n in names]] - np.array(lens)[['detail' not in n for n in names]] / 2
maxs = coords[['detail' not in n for n in names]] + np.array(lens)[['detail' not in n for n in names]] / 2

# min of all overviews
min_ = np.min(mins, axis=0)

# allocate fused, weights
# NB: flip xy here!
shape = np.ceil((np.max(maxs, axis=0) - np.min(mins, axis=0)) / 2.5e-7).astype(int)
fused = np.zeros(shape[::-1])
weights = np.zeros(shape[::-1])

# fuse overviews
for img, n, mi, ma in list(zip(imgs, names, coords - np.array(lens)/2, coords + np.array(lens)/2))[:]:
    if 'detail' in n:
        continue
        
    fused[tuple((slice(int(np.round(mii)), int(np.round(mai))) for mii, mai in zip(((mi- min_)/2.5e-7)[::-1], ((ma- min_)/2.5e-7)[::-1])))] += img.max(axis=0)
    weights[tuple((slice(int(np.round(mii)), int(np.round(mai))) for mii, mai in zip(((mi- min_)/2.5e-7)[::-1], ((ma- min_)/2.5e-7)[::-1])))] += 1

# normalize weights
fused[weights!=0] /= weights[weights!=0]

plt.figure(figsize=(20,20))
plt.imshow(fused, cmap='hot')

# boxes around details
for img, n, mi, ma, l in list(zip(imgs, names, coords - np.array(lens)/2, coords + np.array(lens)/2, lens))[:]:
    if 'detail' not in n:
        continue
        
    rec = Rectangle((mi-min_)/2.5e-7, *(np.array(l)/2.5e-7), color='limegreen', fill=None)
    plt.gca().add_artist(rec)
    
plt.axis('off')

# plt.savefig('/scratch/hoerl/auto_sir_dna_comp/fig.png', dpi=np.max(shape)/20)

### Individual tiles

In [None]:

out_base_path = '/scratch/hoerl/20230507_imr90_stitching'

outdir = Path(out_base_path) / os.path.relpath(h5file, basepath)
outdir = outdir.parent / outdir.stem
outdir.mkdir(parents=True, exist_ok=True)

ctr = count()
tile_config_lines = ['dim=3']
for img, psz, coord, name in zip(imgs, pszs, coords, names):

    # skip detail images
    if 'detail' in name:
        continue

    # save as imageJ-compatible stack
    outfile = outdir / f'tile_{zeropad(next(ctr))}.tif'
    imwrite(outfile, img.astype(np.uint16), imagej=True, metadata={'axes':'ZYX'})

    # make tile config for Grid/Collection Stitching
    tile_config_lines.append('{};;({},{},0.0)'.format(outfile.name, *(coord/psz)))

with (outdir / 'tile_config.txt').open('w') as tc_file:
    tc_file.write('\n'.join(tile_config_lines))
