In [None]:
import h5py
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from imageio import imread
from skimage.color import rgb2hed, hed2rgb, rgb2gray
from skimage.util import img_as_ubyte
from random import Random, shuffle

## Create HDF5

In [None]:
DATA_DIR = 'data/'
FILENAME = 'data_tf.hdf5'

In [None]:
TRAIN_DIR = DATA_DIR + 'train/'
TRAIN_FILES = os.listdir(TRAIN_DIR)
len(TRAIN_FILES)

In [None]:
TEST_DIR = DATA_DIR + 'test/'
TEST_FILES = os.listdir(TEST_DIR)
len(TEST_FILES)

In [None]:
IMG_SIZE = (96, 96, 4)
RAND_SEED = 333

In [None]:
df = pd.read_csv(DATA_DIR + 'train_labels.csv')
TRAIN_IDS = [ df[df['label'] == 0]['id'].values.tolist(), df[df['label'] == 1]['id'].values.tolist() ]

In [None]:
TEST_IDS = [os.path.splitext(p)[0] for p in TEST_FILES]

In [None]:
# The HED color deconv, rgb2hed(), produces values outside of 0 and 1 for some images.
# Workaround - use hed2rgb() with each HED channel separately and then run rgb2gray()
# on that conversion result
def rgb2ghed(img_rgb: np.ndarray) -> np.ndarray:
    assert img_rgb.dtype == 'uint8'
    assert img_rgb.shape == IMG_SIZE[:2] + (3,)
    img_hed = rgb2hed(img_rgb)
    img_z = np.zeros(img_rgb.shape[:2])
    img_h = img_as_ubyte(rgb2gray(hed2rgb(np.stack((img_hed[:, :, 0], img_z, img_z), axis=-1))))
    img_e = img_as_ubyte(rgb2gray(hed2rgb(np.stack((img_z, img_hed[:, :, 1], img_z), axis=-1))))
    img_d = img_as_ubyte(rgb2gray(hed2rgb(np.stack((img_z, img_z, img_hed[:, :, 2]), axis=-1))))
    img_g = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    img_ghed = np.stack((img_g, img_h, img_e, img_d), axis=-1)
    return img_ghed

In [None]:
def create_train_dset_2(f: h5py.File, name: str, m: int):
    dset_x = f.create_dataset('x_' + name, (m,) + IMG_SIZE, np.uint8)
    dset_y = f.create_dataset('y_' + name, (m,), np.uint8)
    ids = list(zip(TRAIN_IDS[0] + TRAIN_IDS[1],
                   [0] * len(TRAIN_IDS[0]) +  [1] * len(TRAIN_IDS[1]))
              )
    Random(RAND_SEED).shuffle(ids)
    for i, id_lbl in enumerate(ids[:m]):
        img_rgb = imread(TRAIN_DIR + id_lbl[0] + '.tif')
        img_ghed = rgb2ghed(img_rgb)
        dset_x[i, ...] = img_ghed[:, :, :IMG_SIZE[2]]
        dset_y[i] = id_lbl[1]
    return dset_x, dset_y

In [None]:
def create_train_dset_01(f: h5py.File, name: str, label: int, m: int):
    dset = f.create_dataset('x_' + name, (m,) + IMG_SIZE, np.uint8)
    ids = TRAIN_IDS[label]
    Random(RAND_SEED).shuffle(ids)
    for i, id in enumerate(ids[:m]):
        img_rgb = imread(TRAIN_DIR + id + '.tif')
        img_ghed = rgb2ghed(img_rgb)
        dset[i, ...] = img_ghed[:, :, :IMG_SIZE[2]]
    return dset

In [None]:
def create_test_dsets(f: h5py.File, name: str, m: int):
    dset = f.create_dataset('x_' + name, (m,) + IMG_SIZE, np.uint8)
    ids = TEST_IDS[:]
    Random(RAND_SEED).shuffle(ids)
    for i, id in enumerate(ids[:m]):
        img_rgb = imread(TEST_DIR + id + '.tif')
        img_ghed = rgb2ghed(img_rgb)
        dset[i, ...] = img_ghed[:, :, :IMG_SIZE[2]]
    return dset

In [None]:
f = h5py.File(DATA_DIR + FILENAME, 'w')

In [None]:
# dset_x, dset_y = create_train_dset_2(f, 'train', 10)
dset_x, dset_y = create_train_dset_2(f, 'train', len(TRAIN_IDS[0]) + len(TRAIN_IDS[1]))
dset_x.shape, dset_y.shape

In [None]:
# dset = create_train_dset_01(f, 'train0', 0, 10)
dset = create_train_dset_01(f, 'train0', 0, len(TRAIN_IDS[0]))
dset.shape

In [None]:
# dset = create_train_dset_01(f, 'train1', 1, 10)
dset = create_train_dset_01(f, 'train1', 1, len(TRAIN_IDS[1]))
dset.shape

In [None]:
f.close()

In [None]:
assert False

## Check / Inspect HDF5

In [None]:
f = h5py.File(DATA_DIR + FILENAME, 'r')

In [None]:
f.keys()

In [None]:
f['x_train'].shape, f['x_train0'].shape, f['x_train1'].shape

In [None]:
f['x_train'].dtype, f['x_train0'].dtype, f['x_train1'].dtype

In [None]:
f['y_train'].shape, f['y_train'].dtype

In [None]:
f['y_train'][:10]

In [None]:
def show_ghed(img):
    fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=(6,6))
    ax11.set_title('G')
    ax12.set_title('H')
    ax21.set_title('E')
    ax22.set_title('D')
    ax11.imshow(img[:, :, 0], cmap='gray')
    ax12.imshow(img[:, :, 1], cmap='gray')
    ax21.imshow(img[:, :, 2], cmap='gray')
    ax22.imshow(img[:, :, 3], cmap='gray')
    fig.tight_layout()

In [None]:
dset = f['x_train']
img = dset[0]
show_ghed(img)
img.shape, img.dtype

In [None]:
dset0 = f['x_train0']
img = dset0[0]
show_ghed(img)
img.shape, img.dtype

In [None]:
dset1 = f['x_train1']
img = dset1[0]
show_ghed(img)
img.shape, img.dtype

In [None]:
# assert False

In [None]:
f.close()