In [None]:
import numpy as np
import h5py
import cv2
import os

import matplotlib.pyplot as plt

In [None]:
from scipy.ndimage.measurements import label

def repair_masks(masks, keep_threshold=10):
    for layer in range(masks.shape[2]):
        img = masks[:, :, layer]
        labeled, n = label(img)
        if n > 1:
            areas = np.array([np.sum(labeled == i) for i in np.arange(n) + 1])
            idx = (np.arange(n) + 1)[areas < keep_threshold]
            img[np.isin(labeled, idx)] = 0
        masks[:, :, layer] = img
    return masks

In [None]:
def generate_tiles(img, size=512, overlap=100, shifts=1):
    height, width = img.shape[:2]
    origins = np.mgrid[0:height:(size // shifts - overlap),0:width:(size // shifts - overlap)].T.reshape(-1,2)
    imgs = []
    for tl in origins:
        tile = img[tl[0]:(tl[0] + size), tl[1]:(tl[1] + size)]
        if tile.shape[0] < size and tile.shape[1] == size:
            tile = img[(height - size):height, tl[1]:(tl[1] + size)]
        elif tile.shape[0] == size and tile.shape[1] < size:
            tile = img[tl[0]:(tl[0] + size), (width - size):width]
        elif tile.shape[0] < size and tile.shape[1] < size:
            tile = img[(height - size):height, (width - size):width]
        imgs.append(tile)
    return np.stack(imgs, axis=2)

def stitch_tiles(tiles, target_shape, size=512, overlap=100, shifts=1, flatten=False):
    height, width = target_shape[:2]
    origins = np.mgrid[0:height:(size // shifts - overlap),0:width:(size // shifts - overlap)].T.reshape(-1,2)
    img = np.zeros((height, width, *tiles.shape[2:]), dtype=np.uint8)
    for idx, tl in enumerate(origins):
        if tl[0] + size > height:
            tl[0] = height - size
        if tl[1] + size > width:
            tl[1] = width - size
        img[tl[0]:(tl[0] + size), tl[1]:(tl[1] + size), idx] = tiles[:, :, idx]
    if flatten:
        img = img.sum(axis=2) > 0
    return img

In [None]:
def get_annotations(img, masks, class_names=None, size=512, overlap=100, shifts=1):

    img_tiles = generate_tiles(img, size, overlap, shifts)
    mask_tiles = generate_tiles(masks, size, overlap, shifts)

    non_empty = mask_tiles.sum(axis=(0, 1)) > 0

    img_tiles = img_tiles[:, :, non_empty.sum(axis=1) > 0, :]
    mask_tiles = mask_tiles[:, :, non_empty.sum(axis=1) > 0, :]

    annotations = []
    pixel_sums = []
    for idx in range(img_tiles.shape[2]):
        layers = mask_tiles[:, :, idx].copy()
        layer_class_names = class_names[layers.sum(axis=(0, 1)) > 0]
        layers = layers[:, :, layers.sum(axis=(0, 1)) > 0]        
        pixel_sum = img_tiles[:, :, idx].sum()
        if pixel_sum not in pixel_sums and np.all(layers.sum(axis=(0, 1)) > 900):
            pixel_sums.append(pixel_sum)
            annotations.append((img_tiles[:, :, idx], layers, layer_class_names))
            
    return annotations

def plot_annotations(annotations):
    for img_tile, tile_layers, class_names in annotations:
        fig, axes = plt.subplots(1, 2, figsize=(5, 5))
        axes[0].imshow(img_tile)
        axes[1].imshow(tile_layers.sum(axis=2))
        print(*class_names)
        plt.show()

In [None]:
directory = '/home/jordanlab/Documents/'
with h5py.File(os.path.join(directory, 'multi_brevis_512.h5'), 'a') as new_annotation:
    new_annotation.create_dataset('annotations', shape=(0, ), maxshape=(None,), dtype='|S400')
    with h5py.File(os.path.join(directory, 'multi_brevis.h5'), 'r') as annotation:
        for image in annotation['image_tiles'][:].astype(np.str): # list(annotation.keys())[1:]: # version huy / very old version
            
#             img = cv2.imread(os.path.join(directory, image))[:, :, ::-1] # very old version

            img = annotation[image]['image'][:] # version huy
            masks = repair_masks(annotation[image]['mask'][:]) # version huy
            class_names = annotation[image]['classes'][:].astype(np.str) # version huy
#             class_names = None # very old version
            
#             masks = [] # very old version
#             for mask in list(annotation[image].keys()): # very old version
#                 masks.append(annotation[image][mask]) # very old version
#             masks = np.stack(masks, axis=2) # very old version

            image = image.replace('\\', '/')

            annotations = get_annotations(img, masks, class_names=class_names, size=1024, overlap=100, shifts=3)
            for idx, (img_tile, tile_layers, class_names) in enumerate(annotations):
                path = os.path.join(image, str(idx))
                ann = new_annotation.create_group(path)
                ann.create_dataset('image', data=img_tile)
                ann.create_dataset('mask', data=tile_layers, compression="gzip", compression_opts=9)
#                 class_names = np.array(['multifasciatus'] * tile_layers.shape[2]).astype(np.bytes_) # very old version
                ann.create_dataset('class_names', data=class_names.astype(np.bytes_))
                new_annotation['annotations'].resize((new_annotation['annotations'].shape[0] + 1,))
                new_annotation['annotations'][-1] = np.bytes_(path)
            print(image)

In [None]:
directory = '/home/jordanlab/Documents/'
# with h5py.File(os.path.join(directory, 'filename.h5'), 'a') as new_annotation:
#     new_annotation.create_dataset('annotations', shape=(0, ), maxshape=(None,), dtype='|S400')
with h5py.File(os.path.join(directory, 'multi_brevis.h5'), 'r') as annotation:
    for image in annotation['image_tiles'][:].astype(np.str):

        img = annotation[image]['image'][:] # version huy
        masks = repair_masks(annotation[image]['mask'][:]) # version huy
        class_names = annotation[image]['classes'][:].astype(np.str) # version huy

        path = image.replace('\\', '/')
        path = path.replace('/', '-').replace(':', '').replace(' ', '_')[:-2]
#             ann = new_annotation.create_group(path)
#             ann.create_dataset('image', data=img)
#             ann.create_dataset('mask', data=masks, compression="gzip", compression_opts=9)
#             ann.create_dataset('class_names', data=class_names.astype(np.bytes_))
#             new_annotation['annotations'].resize((new_annotation['annotations'].shape[0] + 1,))
#             new_annotation['annotations'][-1] = np.bytes_(path)
        cv2.imwrite(os.path.join('/media/jordanlab/S12/huy', path), img)
        print(os.path.join('/media/jordanlab/S12/huy', path))

In [None]:
mask_tiles = generate_tiles(masks)
stitched_masks = stitch_tiles(mask_tiles, masks.shape)

In [None]:
stitched_masks = stitched_masks.sum(axis=(2, 3)) > 0
fig, ax = plt.subplots(figsize=(10, 10))
plt.imshow(stitched_masks)