In [4]:
import os
import h5py
import cv2
import imageio
import random
import numpy as np
import torch
import skimage
from libtiff import TIFF
from pathlib import Path
from torch import Tensor
from PIL import Image
from skimage import io, morphology, measure
%matplotlib inline
import matplotlib.pyplot as plt
from typing import Callable, Iterable, List, Set, Tuple
from skimage import segmentation
from skimage import transform
from skimage.color import rgba2rgb


def setup_seed(seed: int) -> None:
    '''
        Set random seed to make experiments repeatable
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True  # implement same config in cpu and gpu
    torch.backends.cudnn.benchmark = True


def count_mean_and_std(img_dir: Path, suffix='png') -> Tuple: # (mean, std, num) (0.495770570343616, 0.16637932545720774, 22) for ISBI 2012 online
    '''
        Calculate mean and std of data 
        :param image_path:  Address for images, noted that the value of images should be [0, 1]
        :return: mean, std and num of data
    '''
    assert img_dir.is_dir(), "The input is not a dir"
    mean, std, num = 0, 0, 0
    mean_rgb, std_rgb = np.array([0, 0, 0], dtype=np.float), np.array([0, 0, 0], dtype=np.float)
    is_rgb = False
    imgs_path = img_dir.glob("*."+suffix)
    
    for img_path in imgs_path:
        num += 1
        img = io.imread(str(img_path)) * 1.0 / 255
        assert np.max(np.unique(img)) <= 1, "The img value should lower than 1 when calculate mean and std"
        if len(img.shape) == 2:
            mean += np.mean(img)
            std += np.std(img)
        elif len(img.shape) == 3:
            is_rgb = True
            mean_rgb[0] += np.mean(img[:,:,0])
            mean_rgb[1] += np.mean(img[:,:,1])
            mean_rgb[2] += np.mean(img[:,:,2])
            std_rgb[0] += np.std(img[:,:,0])
            std_rgb[1] += np.std(img[:,:,1])
            std_rgb[2] += np.std(img[:,:,2])
    mean /= num
    std /= num
    mean_rgb = mean_rgb / num
    std_rgb = std_rgb / num
    if is_rgb:
        return mean_rgb, std_rgb, num
    else:
        return mean, std, num
    

def load_img(img_path: Path) -> np.ndarray or Image:
    '''
        Load images or npy files
        :param img_path:  Address for images or npy file
        :return: PIL image or numpy array
    '''
    if img_path.suffix == '.npy':
        img = np.load(img_path)
    else:
        img = io.imread(str(img_path))
        if np.amax(img) == 255 and len(np.unique(img)) == 2:
            img = img * 1.0 / 255
    return img


def prepare_data_iron(in_path: Path, out_path: Path, *args) -> None:
    '''
        Prepare iron data for training
        The data can be download from https://github.com/Keep-Passion/pure_iron_grain_data_sets.
        :param in_path:  Address for origin input folder
        :param out_path: Address for origin output folder
    '''
    
    file_h5 = h5py.File(str(Path(in_path, 'pure_iron_grain_data_sets.hdf5')), 'r')
    
    real_image = file_h5['real']['image']
    real_label = file_h5["real"]["label"]
    real_boundary = file_h5["real"]["boundary"]

    print(real_image.shape, ' ', real_label.shape, ' ', real_boundary.shape)

    for item in range(0, 296):
        name = str(item).zfill(3)
        image = real_image[:,:,item]
        label = real_boundary[:,:,item]
        label = morphology.dilation(label, morphology.square(3))

        #print(name, ' ', image.shape, ' ', label.shape, ' ', np.unique(label))

        cv2.imwrite(str(Path(out_path, 'images', name + '.png')), image)
        cv2.imwrite(str(Path(out_path, 'labels', name + '.png')), 255 - label)



def prepare_data_snemi3d(in_path: Path, out_path: Path, *args) -> None:
    '''
        Prepare snemi3d data for training
        The data can be download from https://zenodo.org/record/7142003.
        :param in_path:  Address for origin input folder
        :param out_path: Address for origin output folder
    '''
    imgs_inpath = Path(in_path, 'train-input.tif')
    masks_inpath = Path(in_path, 'train-labels.tif')
    imgs_tif = TIFF.open(imgs_inpath, mode='r')
    masks_tif = TIFF.open(masks_inpath, mode='r')
    
    imgs_outpath = Path(out_path, 'images')
    masks_outpath = Path(out_path, 'labels')
    idx = -1
    for img, label in zip( list(imgs_tif.iter_images()),  list(masks_tif.iter_images()) ):
        idx += 1
        mask = np.ones_like(img) * 255
        mask[segmentation.find_boundaries(label, mode='thick')] = 0
        mask[label == 0] = 0
        name = str(idx).zfill(3)
        cv2.imwrite(str(Path(out_path, 'images', name + '.png')), img)
        cv2.imwrite(str(Path(out_path, 'labels', name + '.png')), mask)
    

if __name__ == '__main__':
    dataset_name = "mass_road"
    is_online_challange = False
    cwd = os.getcwd()
    random.seed(2020)
    in_path  = Path(cwd, 'data', dataset_name, 'data_backup')
    out_path = Path(cwd, 'data', dataset_name, 'data_experiment')

    if dataset_name == "iron":
        prepare_data_iron(in_path, out_path)
    elif dataset_name == "snemi3d":
        prepare_data_snemi3d(in_path, out_path)
        
    img_dir = Path(out_path, "images")

    if dataset_name == 'mass_road':  # can be download from https://www.kaggle.com/datasets/balraj98/massachusetts-roads-dataset.
        img_dir = Path(out_path, 'train')
    print('img_dir is ', img_dir)
    
    # count
    # for Iron, 0.9410404628082503 0.12481161024777744 296
    # for snemi3d 0.5053152359607174 0.16954360899089577 100
    # for mass_road  [0.42946925 0.43247649 0.3961301 ] [0.22669363 0.21916084 0.22397161] 1108
    mean, std, num = count_mean_and_std(img_dir, suffix='tiff')
    print(mean, std, num)

img_dir is  f:\chuniliu\skeaw-experiments\skeaw_code\data\mass_road\data_experiment\train


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  mean_rgb, std_rgb = np.array([0, 0, 0], dtype=np.float), np.array([0, 0, 0], dtype=np.float)


[0.42946925 0.43247649 0.3961301 ] [0.22669363 0.21916084 0.22397161] 1108
