# This notebook can be used to generate data for UNet

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
from copy import deepcopy
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from scipy import ndimage
from typing import Tuple, Union, List, Any
from tqdm import tqdm
import sys
from pickle import dump, load
import logging
from time import time

from dask_jobqueue import PBSCluster
from dask.distributed import Client, get_worker

In [3]:
from biomedical_image_segmentation.elastic_deform import custom_2d_elastic_deform
from biomedical_image_segmentation.utils import insert_grid, split, create_dir, empty_dir
from biomedical_image_segmentation.data.generator import generate
from biomedical_image_segmentation.data.validator import validate

In [4]:
PROJECT_PATH = os.getcwd().replace("notebooks","")
LOG_PATH = os.path.join(PROJECT_PATH, "logs")
DATA_PATH = os.path.join(PROJECT_PATH, "data")
SAMPLES_PATH = os.path.join(DATA_PATH, "samples")
TRAIN_SAMPLES_PATH = os.path.join(SAMPLES_PATH, "train")
MASKS_PATH = os.path.join(DATA_PATH, "masks")
TRAIN_MASKS_PATH = os.path.join(MASKS_PATH, "train")

AUGMENTED_DATA_PATH = os.path.join(DATA_PATH, "augmented")
VALID_DATA_PATH = os.path.join(DATA_PATH, "valid")
TEST_DATA_PATH = os.path.join(DATA_PATH, "test")

SCRIPT_NAME = "data-preparation"
LOG_FILE_NAME = os.path.join(LOG_PATH, f"{SCRIPT_NAME}_{int(time())}.log")

RANDOM_STATE = 40,
NUM_ELASTIC_DEFORMS = 20
ALPHA_AFFINE = (.01, .2)
SIGMA = 10.
ALPHA = 1.
ADJUSTMENT_PIXEL_RANGE = (5, 100)
ADJUSTED_PIXEL = 0
SPLIT_RATIO = (2/3, 1/6, 1/6)

In [5]:
create_dir(LOG_PATH, True)

if os.path.basename(LOG_FILE_NAME) in os.listdir(LOG_PATH): os.remove(LOG_FILE_NAME)
    
logging.basicConfig(
    filename=LOG_FILE_NAME,
    format='%(asctime)s %(message)s',
    datefmt='%m/%d/%Y %I:%M:%S %p',
    encoding='utf-8', 
    level=logging.INFO)

In [6]:
logging.info("Creating directory to store train dataset")
create_dir(AUGMENTED_DATA_PATH)
logging.info("Creating directory to store valid dataset")
create_dir(VALID_DATA_PATH)
logging.info("Creating directory to store test dataset")
create_dir(TEST_DATA_PATH)

In [7]:
logging.info(f"Removing content of {AUGMENTED_DATA_PATH}")
empty_dir(AUGMENTED_DATA_PATH, True)

logging.info(f"Removing content of {VALID_DATA_PATH}")
empty_dir(VALID_DATA_PATH, True)

logging.info(f"Removing content of {TEST_DATA_PATH}")
empty_dir(TEST_DATA_PATH, True)

## Prepare valid and test dataset before train set

In [8]:
train_ids, valid_ids, test_ids = split(
    os.listdir(TRAIN_SAMPLES_PATH), 
    ratio=SPLIT_RATIO, 
    seed=RANDOM_STATE)

logging.info(f"train ids: {train_ids}")
logging.info(f"valid ids: {valid_ids}")
logging.info(f"test ids: {test_ids}")

In [9]:
logging.info("Preparing valid dataset")
dataset = []
for i, valid_id in enumerate(valid_ids):
    img =  cv2.imread(os.path.join(TRAIN_SAMPLES_PATH, valid_id))
    mask = cv2.imread(os.path.join(TRAIN_MASKS_PATH, valid_id))
    dataset.append((img[:,:,0], mask[:,:,0]))

logging.info("Saving valid dataset")
with open(f"{os.path.join(VALID_DATA_PATH, 'valid.pickle')}", "wb") as f:
    dump(dataset, f)
    
    
logging.info("Preparing test dataset")
dataset = []
for i, test_id in enumerate(test_ids):
    img =  cv2.imread(os.path.join(TRAIN_SAMPLES_PATH, test_id))
    mask = cv2.imread(os.path.join(TRAIN_MASKS_PATH, test_id))
    dataset.append((img[:,:,0], mask[:,:,0]))

logging.info("Saving test dataset")
with open(f"{os.path.join(TEST_DATA_PATH, 'test.pickle')}", "wb") as f:
    dump(dataset, f)

## Sequential

In [10]:
# validation = True
# threshold_validation = -1.

# for i, train_id in tqdm(enumerate(train_ids), leave=False):
    
#     img =  cv2.imread(os.path.join(TRAIN_SAMPLES_PATH, train_id))
#     mask = cv2.imread(os.path.join(TRAIN_MASKS_PATH, train_id))
    
#     dataset = generate(
#         img[:,:,0],
#         mask[:,:,0],
#         num_elastic_deforms=NUM_ELASTIC_DEFORMS,
#         alpha_affine=ALPHA_AFFINE,
#         sigma=SIGMA,
#         alpha=ALPHA,
#         adjustment_pixel_range=ADJUSTMENT_PIXEL_RANGE, 
#         adjusted_pixel=ADJUSTED_PIXEL)
    
#     if (validation and np.random.uniform() > threshold_validation): validate(dataset)
    
        
#     with open(f"{os.path.join(AUGMENTED_DATA_PATH, train_id.replace('.tif',''))}.pickle", "wb") as f:
#         dump(dataset, f)
    
#     del dataset
    
# logging.info("Saved train dataset")

## Distributed: using HPC

In [11]:
def distributed_data_generation(
    train_id, 
    img_load_path,
    mask_load_path,
    save_path,
    num_elastic_deforms,
    alpha_affine,
    sigma,
    alpha,
    adjustment_pixel_range, 
    adjusted_pixel) -> bool:
    
    """
    Augment data in distributed environment.
            
    Returns
    -------
    status: boolean
        Status of job, True if success otherwise False
    """
    
    start = time()
    
    status = False
    
    img =  cv2.imread(os.path.join(img_load_path, train_id))
    mask = cv2.imread(os.path.join(mask_load_path, train_id))
    
    dataset = generate(
        img[:,:,0],
        mask[:,:,0],
        num_elastic_deforms=num_elastic_deforms,
        alpha_affine=alpha_affine,
        sigma=sigma,
        alpha=alpha,
        adjustment_pixel_range=adjustment_pixel_range, 
        adjusted_pixel=adjusted_pixel)
    
    try:
        validate(dataset)
        
        with open(f"{os.path.join(save_path, train_id.replace('.tif',''))}.pickle", "wb") as f:
            dump(dataset, f)
        status = True
    except Exception as e:
        time()
        get_worker().log_event("error", {"error": e, "status": status})
        
    get_worker().log_event("runtimes", {"time elapsed": f"{time()-start: .5f} seconds"})
    
    return status

In [12]:
N_WORKERS = 4 # number of parallel threads 
CORES = 16 # cores_per_worker = cores/ n_workers
MEMORY = "16GB" # memory_per_worker = memory/ n_workers

logging.info("Initializing cluster")
cluster = PBSCluster(
    n_workers=N_WORKERS,
    cores=CORES,
    memory=MEMORY)
logging.info(f"Dashboard link: {cluster.dashboard_link}")
cluster

In [13]:
client = Client(cluster)

In [14]:
%%time
logging.info("Warming up cluster")
def test(a, b):
    return a + b
results_future = {}
for i in range(10):
    results_future[i] = client.submit(test, a=10+i, b=10+i)
results = client.gather(results_future)

CPU times: user 446 ms, sys: 98.6 ms, total: 544 ms
Wall time: 26.4 s


In [15]:
results_future = {}
for train_id in tqdm(train_ids, leave=False):
    results_future[train_id] = client.submit(
        distributed_data_generation,
        train_id=train_id, 
        img_load_path=TRAIN_SAMPLES_PATH,
        mask_load_path=TRAIN_MASKS_PATH,
        save_path=AUGMENTED_DATA_PATH,
        num_elastic_deforms=NUM_ELASTIC_DEFORMS,
        alpha_affine=ALPHA_AFFINE,
        sigma=SIGMA,
        alpha=ALPHA,
        adjustment_pixel_range=ADJUSTMENT_PIXEL_RANGE, 
        adjusted_pixel=ADJUSTED_PIXEL)    

                                      

In [16]:
pd.DataFrame([results_future[i].status for i in results_future.keys()]).value_counts()

finished    20
dtype: int64