In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import random
import shutil
import glob

import matplotlib.pyplot as plt
import scipy.io
import numpy as np

from tqdm import tqdm

import skimage.io
import skimage.segmentation
import skimage.morphology

import sys
__file__ = 'segm.ipynb'
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import utils.dirtools  # utils package should has __init__.py in it
import utils.augmentation
import utils.model_builder
import utils.data_provider
import utils.metrics
import utils.objectives
import utils.evaluation

from config import config_vars

In [None]:
# setup global variables

config_vars["root_directory"] = 'testExp/'
experiment_name = 'cellSegm'

config_vars = utils.dirtools.setup_working_directories(config_vars)
config_vars = utils.dirtools.setup_experiment(config_vars, experiment_name)


In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

def empty_dir(folder):
    print('empty directory: ', folder)
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)
            
            
def remove(pred_label, cb):
    # how to use filter to constrain not removing big nucleus
    out = pred_label.copy()
    out_f = out.ravel()  # flatten out
    component_sizes = np.bincount(out_f)
    cb_ = cb
    for i in range(out.shape[0]):
        for j in range(out.shape[1]):
            for c in cb_:
                if out[i, j] == c:
                    out[i, j] = 0
    
    return out

def remove_border(pred_label):
    # find the classes of the partial cells which located on the border
    # traverse the border
    cb = set()
    for i in range(512):
        if pred_label[i, 0] != 0:
            cb.add(pred_label[i, 0])
        if pred_label[0, i] != 0:
            cb.add(pred_label[0, i])
        if pred_label[i, 511] != 0:
            cb.add(pred_label[i, 511])
        if pred_label[511, i] != 0:
            cb.add(pred_label[511, i])
    
    # change these classes to 0 as background
    # maybe I should add a constrain about not filtering the big cell
    return remove(pred_label, cb)

img_count = 0 
def preprocess_folder2(folder_name):
    img_count = 0
    
    dest_folder = folder_name + "/raw_images/"
    os.makedirs(dest_folder, exist_ok=True)
    
    fd_list = sorted(os.listdir(folder_name))
    for fd in fd_list:
        if fd != 'raw_images' and fd!= 'experiments':
            print(fd)
            image_list = sorted(os.listdir(folder_name + "/" + fd))

            for img in image_list:
                # img is simply copy to raw_images dir
                shutil.copy2(folder_name + "/" + fd + "/" + img, dest_folder + "{:04d}".format(img_count) + ".tif")
                img_count += 1
          
    print("Total image: ", img_count)

In [None]:
preprocess_folder2(config_vars["root_directory"])

In [None]:

config_vars["mat_out_dir"] = config_vars["root_directory"] + '/experiments/cellSegm/out/mat/'
experiment_name = 'cellSegm'

config_vars = utils.dirtools.setup_working_directories(config_vars)
config_vars = utils.dirtools.setup_experiment(config_vars, experiment_name)

### PREPROCESS

In [None]:
file_list = os.listdir(config_vars["raw_images_dir"])
img = skimage.io.imread(config_vars["raw_images_dir"] + file_list[-1])

figure, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(img)
ax[1].hist(img.flatten(), bins=100)[2]

In [None]:
os.makedirs(config_vars["normalized_images_dir"], exist_ok=True)
os.makedirs(config_vars["mat_out_dir"], exist_ok=True)
#os.makedirs(config_vars["boundary_labels_dir"], exist_ok=True)

In [None]:
# normalize images

if config_vars["transform_images_to_PNG"]:
    filelist = sorted(os.listdir(config_vars["raw_images_dir"]))
    # run over all raw images
    for filename in tqdm(filelist):
        # load image and its annotation
        orig_img = skimage.io.imread(config_vars["raw_images_dir"] + filename)       
        # normalize to [0,1]
        percentile = 99.9
        high = np.percentile(orig_img, percentile)  # maximum for all image
        low = np.percentile(orig_img, 100-percentile)

        img = np.minimum(high, orig_img)
        img = np.maximum(low, img)

        # gives float64, thus cast to 8 bit later
        img = (img - low) / (high - low) 
        img = skimage.img_as_ubyte(img) 
             
        skimage.io.imsave(config_vars["normalized_images_dir"] + filename[:-3] + 'png', img)    
else:
    config_vars["normalized_images_dir"] = config_vars["raw_images_dir"]

### PREDICT

In [None]:
file_list = os.listdir(config_vars["normalized_images_dir"])
image_list = [x for x in file_list if x.endswith("png")]
len(image_list)

In [None]:
def create_image_lists(dir_raw_images):
    file_list = os.listdir(dir_raw_images)
    image_list = [x for x in file_list if x.endswith("png")]
    image_list = sorted(image_list)

    image_list_train_aug = []
    image_list_test = []

    image_list_train = []
    image_list_validation = image_list

    return image_list_train, image_list_test, image_list_validation, image_list_train_aug

In [None]:
[list_training, list_test, list_validation, list_training_aug] = create_image_lists(
    config_vars["normalized_images_dir"]
)

utils.dirtools.write_path_files(config_vars["path_files_training"], list_training)
utils.dirtools.write_path_files(config_vars["path_files_test"], list_test)
utils.dirtools.write_path_files(config_vars["path_files_validation"], list_validation)

In [None]:
config_vars['model_file'] = 'model_PAO1.hdf5'

In [None]:
data_partitions = utils.dirtools.read_data_partitions(config_vars, load_augmented=False)

In [None]:
image_names = [os.path.join(config_vars["normalized_images_dir"], f) for f in data_partitions["validation"]]
imagebuffer = skimage.io.imread_collection(image_names)
images = imagebuffer.concatenate()

dim1, dim2 = images.shape[1], images.shape[2]
images = images.reshape((-1, dim1, dim2, 1))
# preprocess (assuming images are encoded as 8-bits in the preprocessing step)
images = images / 255

### build model and load weights
#model = utils.model_builder.get_model(dim1, dim2, output_channel=1, activation="sigmoid")
model = utils.model_builder.get_model(dim1, dim2, output_channel=3, activation=None)

model.load_weights(config_vars["model_file"])

predictions = model.predict(images, batch_size=1)

In [None]:
empty_dir(config_vars["probmap_out_dir"])
empty_dir(config_vars["labels_out_dir"])

In [None]:
# boundary to segmentation v-1.0

for i in range(len(images)):
    filename = imagebuffer.files[i]
    filename = os.path.basename(filename)
    filenamewoext = os.path.splitext(filename)[0]
    print(filename)
    
    probmap = predictions[i].squeeze()
   
    skimage.io.imsave(config_vars["probmap_out_dir"] + filename, probmap)
    # binary mask dilation
    config_vars["boundary_boost_factor"] = 0.4
    pred = utils.metrics.probmap_to_pred(probmap, config_vars["boundary_boost_factor"])
    label = utils.metrics.pred_to_label(pred, config_vars["cell_min_size"])
    # min threshold of cell being recognized
    label = skimage.morphology.remove_small_objects(label, min_size=10)
    label = remove_border(label)
    skimage.io.imsave(config_vars["labels_out_dir"] + filename, label)
    
    scipy.io.savemat(config_vars["mat_out_dir"] + filenamewoext + '.mat', mdict={'label': label},appendmat=True)
    
    if (i < 10):
        plt.imshow(probmap)
        plt.show()
        plt.imshow(pred)
        plt.show()
        plt.imshow(label)
        plt.show() 

In [None]:
# segmentation to boundary
config_vars["raw_annotations_dir"] = config_vars["labels_out_dir"] 
#config_vars["raw_annotations_dir"] = '0809_PAK_cellCount/1K5/experiments/cellRecog/out/segm/' 
filelist = sorted(os.listdir(config_vars["raw_annotations_dir"]))
filelist = [x for x in filelist if x.endswith('png')]
total_objects = 0

# run over all raw images
for filename in tqdm(filelist):
    # GET ANNOTATION
    annot = skimage.io.imread(config_vars["raw_annotations_dir"] + filename)
    
    # label the annotations nicely to prepare for future filtering operation
    annot = skimage.morphology.label(annot)
    total_objects += len(np.unique(annot)) - 1
    each_objects = len(np.unique(annot)) - 1
    print("#objects: ",each_objects)
    
print("Total objects: ",total_objects)