In [None]:
import time
from functools import partial
from pathlib import Path
from zipfile import ZipFile

import dask.array as da
from dask.array.core import normalize_chunks
import numpy as np
from shapely.geometry import GeometryCollection, Point
from skimage.transform import AffineTransform

from fuse.fuse import fuse_func
from utils.download_sample import download_from_dropbox
from utils.metadata import extract_coordinates, normalize_coords_to_pixel
from utils.imutils import crop_black_border, load_image, transpose, remove_background
from utils.shapely_and_napari_utils import get_transformed_array_corners, numpy_shape_to_shapely
from utils.chunks import get_chunk_coordinates, get_rect_from_chunk_boundary, find_chunk_tile_intersections

In [3]:
import pandas as pd
from tqdm.auto import tqdm
import re 
import os
from natsort import natsorted
import enum
import napari
import warnings
warnings.filterwarnings("ignore")

# Loading own functions to sort files and extract metadata

In [4]:
@enum.unique
class Channels(enum.Enum):
    BRIGHTFIELD = 0
    GFP = 1
    RFP = 2
    IRFP = 3
    PHASE = 4
    WEIGHTS = 50
    MASK_IRFP = 96
    MASK_RFP = 97
    MASK_GFP = 98
    MASK = 99
    
def parse_filename(filename: os.PathLike) -> dict:
    """Parse an OctopusLite filename and retreive metadata from the file.

    Parameters
    ----------
    filename : PathLike
        The full path to a file to parse.

    Returns
    -------
    metadata : dict
        A dictionary containing the parsed metadata.
    """
    
    OCTOPUSLITE_FILEPATTERN =(
        "r(?P<row>[0-9]+)c(?P<column>[0-9]+)f(?P<fov>[0-9]+)p(?P<plane>[0-9]+)-ch(?P<channel>[0-9]+)"
        "sk(?P<time>[0-9]+)fk(?P<fk>[0-9]+)fl(?P<fl>[0-9]+)"
        )
    
    pth, filename = os.path.split(filename)
    params = re.match(OCTOPUSLITE_FILEPATTERN, filename)

    metadata = {
        "filename": filename,
        "channel": Channels(int(params.group("channel"))),
        "time": params.group("time"),
        "row": params.group("row"), 
        "column": params.group("column"), 
        "fov": params.group("fov"), 
        "plane": params.group("plane"), 
        "time": params.group("time"), 
        "fk": params.group("fk"), 
        "fl": params.group("fl")

    }

    return metadata

# Define parameters of mosaic compilation

In [None]:
### array that shows the location of each fov in the scan pattern
fov_scan_pattern = np.array(([2,3,4],
                             [7,6,5],
                             [8,1,9],))
fov_scan_pattern

In [None]:
overlap = 0.1
chunk_size = (108,108)

In [5]:
image_dir = Path("/mnt/DATA/sandbox/pierre_live_cell_data/outputs/Replication_IPSDM_GFP/Images/")

In [6]:
relevant_channels = set([parse_filename(fn)['channel'] for fn in os.listdir(image_dir)])

In [8]:
fns = os.listdir(image_dir)

In [9]:
fns[0]

'r03c06f08p02-ch1sk36fk1fl1.tiff'

In [10]:
parse_filename(fns[0])

{'filename': 'r03c06f08p02-ch1sk36fk1fl1.tiff',
 'channel': <Channels.GFP: 1>,
 'time': '36',
 'row': '03',
 'column': '06',
 'fov': '08',
 'plane': '02',
 'fk': '1',
 'fl': '1'}

In [None]:
%%time
#iterate over channels
for ch in tqdm(Channels, total = 2):#Channels:
    if ch in relevant_channels:
        print('Starting channel', ch.name)
        ### define empty z stack
        da_zt_stack = []
        ### iterate over time 
        for t in (range(1,76)):
            ### define empty z stack
            da_z_stack = []
            ### iterate over frames
            for p in range(1,4):
                ### get all files at that time point
                files = list(Path("/mnt/DATA/sandbox/pierre_live_cell_data/outputs/Replication_IPSDM_GFP/Images/").glob(f"r03c03f*p0{p}-ch{ch.value}sk{t}fk*"))
                ### sort using the frame
                files.sort(key=lambda f: int(parse_filename(f)["time"]))
                ### extract tile coordinates from fn 'f' (ie. fov number)
                coords = pd.DataFrame()
                for fn in files:
                    entry = pd.DataFrame([parse_filename(fn)])
                    coords = pd.concat([coords, entry], ignore_index=True)
                ## lazy hack to make the raster scan like the provided fov_scan_pattern
                for i, row in coords.iterrows():
                    (X, Y) = tuple(map(int, np.where(fov_scan_pattern == int(row['fov']))))
                    coords.at[i, 'X'], coords.at[i, 'Y'] = X, Y
                ### lazy hack to register the row/col number as x/y location shifted by the amount of pixels in each image
                coords['um/px'] = 1/(2160*(1-overlap))
                ### turn coords into a np array of transformation amount using the um/px ratio
                normalized_coords = normalize_coords_to_pixel(coords).to_numpy()
                ### apply transforms if required -- could this be a background removal???
                input_transforms = None#[remove_background] #None #[crop_black_border, ]#transpose]
                ### define a function to load a test image and get tile shape from it
                _load_image = partial(load_image, transforms=input_transforms)
                tile_shape=_load_image(str(files[0])).shape
                ### apply the transformation to each tile to correctly mosaic them
                transforms = [AffineTransform(translation=stage_coord).params for stage_coord in normalized_coords]
                tiles = [get_transformed_array_corners(tile_shape, transform) for transform in transforms]
                ### define the bounding boxes of the tiles and overall FOV to determine the dask output shape
                all_bboxes = np.vstack(tiles)
                all_min = all_bboxes.min(axis=0)
                all_max = all_bboxes.max(axis=0)
                stitched_shape=tuple(np.ceil(all_max-all_min).astype(int))
                ### if there is a discrepancy between the top left tile and the origin then shift
                shift_to_origin = AffineTransform(translation=-all_min)
                transforms_with_shift = [t @ shift_to_origin.params for t in transforms]
                shifted_tiles = [get_transformed_array_corners(tile_shape, t) for t in transforms_with_shift]
                tiles_shifted_shapely = [numpy_shape_to_shapely(s) for s in shifted_tiles]
                ### split data into pre-defined chunks             
                chunks = normalize_chunks(chunk_size,shape=tuple(stitched_shape))
                computed_shape = np.array(list(map(sum, chunks)))
                ### check that tiles shape is correct
                assert np.all(np.array(stitched_shape) == computed_shape)
                ## get chunk details and plot as shapes
                chunk_boundaries = list(get_chunk_coordinates(stitched_shape, chunk_size))
                chunk_shapes = list(map(get_rect_from_chunk_boundary, chunk_boundaries))
                chunks_shapely = [numpy_shape_to_shapely(c) for c in chunk_shapes]
                ### iterate over files in an individual frame and attach tile info and transform
                for tile_shifted_shapely, file, transform in zip(tiles_shifted_shapely, 
                                                                 files, 
                                                                 transforms_with_shift):
                    tile_shifted_shapely.fuse_info = {'file':file, 'transform':transform}
                ### iterate over chunks for a single image and attach info
                for chunk_shapely, chunk_boundary  in zip(chunks_shapely, chunk_boundaries):
                    chunk_shapely.fuse_info = {'chunk_boundary': chunk_boundary}
                ### find intersection of tiles and chunks
                chunk_tiles = find_chunk_tile_intersections(tiles_shifted_shapely, chunks_shapely)
                ### define a fuse function to load all tiles for particular chunk
                _fuse_func=partial(fuse_func, 
                                   imload_fn=_load_image,
                                   dtype=np.uint16) 
                ### use map_blocks to calculate the fused image chunk by chunk
                target_array = da.map_blocks(func=_fuse_func,
                                             chunks=chunks, 
                                             input_tile_info=chunk_tiles,
                                             dtype=np.uint16)
                ### append the mosaic for that particular frame to a list of mosaics
                da_z_stack.append(target_array)
            ### stack that mosaic in a time series
            da_z_stack = da.stack(da_z_stack, axis = 0)
            ### append the z series for one plane to the t-stack
            da_zt_stack.append(da_z_stack)
        ### stack the z planes together
        da_zt_stack = da.stack(da_zt_stack, axis = 0)
        ### save out as zarr
        da_zt_stack.to_zarr(f"data/tzxy_stack_ch{ch.value}_tile108.zarr", overwrite=True)
    else:
        print(f'Channel {ch.name} not found in image directory')

In [None]:
da_zt_stack

In [None]:
v = napari.Viewer()

In [None]:
da_stack_gfp = da.from_zarr(f"data/tzxy_stack_ch{1}.zarr")
da_stack_rfp = da.from_zarr(f"data/tzxy_stack_ch{2}.zarr")
v = napari.Viewer()
v.add_image(da_stack_gfp, name="gfp", contrast_limits = [0,2352], blending = 'additive', colormap= 'green')# colormap = 'g')
v.add_image(da_stack_rfp, name="rfp", contrast_limits = [103,164], blending = 'additive', colormap = 'magenta')

### Testing GPU capacity

In [None]:
da_stack_gfp = da.from_zarr(f"data/zt_stack_ch{1}.zarr")


In [None]:
da_stack_gfp_gpu = da_stack_gfp.map_blocks(cp.asarray)
da_stack_gfp_gpu

In [None]:
da_stack_gfp

#### Would then need a reason to run an image manipulation on this ``dask_stack_gfp_gpu`` 

# Conclusions from this

    1. 108 px squared is too small for this tiling exercise
    2. Needs to be on the scale of tens of tiles per image, not 100s. 
    3. TZYX is the correct OME order of axes
    4. TCZXY is the proper but I want separate stacks of images for each channel