In [None]:
import os
import gzip
import gc
import pprint as pp
import numpy as np
import pickle
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.pyplot import imshow, figure, bar, hist, show

from skimage import io, exposure
from skimage.measure import compare_ssim as ssim
from skimage.transform import resize
from skimage.util import img_as_ubyte
from skimage.filters import sobel
from sklearn.cluster import KMeans

from PIL import Image
import numpy as np
from shutil import copyfile
from multiprocessing import Pool, TimeoutError

In [None]:
kernel_size = 24
half_kernel = int(kernel_size/2)

In [None]:
DebugImage = False

PATH_TO_TRAIN = "../../data/input/train"
PATH_TO_SAVE_NORMALIZED_IMAGES = "../../data/segmentation.normalization"
PATH_TO_INTERMEDIATE = "../../data/segmentation.pickle"

In [None]:
images_to_process = []

for current_directory in os.listdir(PATH_TO_TRAIN):
    current_image = {}
    current_image["id"] = current_directory
    current_image["images"] = []
    current_image["masks"] = []
    for current_image_file in os.listdir(os.path.join(PATH_TO_TRAIN,current_directory,"images")):
        current_image["images"].append(os.path.join(PATH_TO_TRAIN,current_directory,"images", current_image_file))
    
    for current_image_mask in os.listdir(os.path.join(PATH_TO_TRAIN,current_directory,"masks")):
        current_image["masks"].append(os.path.join(PATH_TO_TRAIN,current_directory,"masks", current_image_mask))
    images_to_process.append(current_image)


In [None]:
max_width = 0
max_height = 0
size = 0

for current_image in images_to_process:
    image_file = current_image['images'][0]
    file_png = io.imread(image_file, as_grey=True) 
    current_shape = file_png.shape
    
    if current_shape[0] > max_width :
        max_width = current_shape[0]
    if current_shape[1] > max_height :
        max_height = current_shape[1]
        
print('max_width:' , max_width, ', max_height:', max_height)

if max_width > max_height:
    size = max_width
else:
    size = max_height
    
size = int((int(size / 256) * 256) / 4)
print('image final size is ', size)    

In [None]:
def calculate_avg_mask(image):
    image_sizes = []
    for mask in current_image['masks']:
        cur_img = io.imread(mask, as_grey=True)
        cur_img = resize(cur_img , (size,size))

        cur_img = img_as_ubyte(cur_img)
        cur_img[cur_img > 0] = 255
      
        image_sizes.append(cur_img.sum())
    return np.array(image_sizes).mean()

In [None]:
mask_avg = []
for current_image in images_to_process:
    current_image["mask_avg_size"] = calculate_avg_mask(current_image)
    mask_avg.append(current_image["mask_avg_size"])
    


In [None]:
mask_avg_numpy = np.array(mask_avg)
mask_avg_numpy = mask_avg_numpy.reshape(-1, 1)
kmeans = KMeans(n_clusters=5, random_state=0).fit(mask_avg_numpy)
for current_image in images_to_process:
    current_image["cluster"] = str(kmeans.predict(current_image["mask_avg_size"])[0])

In [None]:
def is_part_of_nuclei(current_x, current_y, masked_img):
    return 1 if masked_img[current_x ,current_y ] > 0 else 0

def get_cluster(image_id):
    for image in images_to_process:
        if image["id"] == image_id:
            return image["cluster"]
    raise ValueError("Cluster not found:{}".format(image_id))

In [None]:
def create_pickling(file_name, current_rotation, original_img, masked_img):
    generated_slices = []
    image_id = '{}.{}'.format(current_rotation,file_name)
    first = True 
    original_shape = original_img.shape
    original_img = np.pad(original_img,((half_kernel,half_kernel), (half_kernel,half_kernel)),'constant')
    masked_img = np.pad(masked_img,((half_kernel,half_kernel), (half_kernel,half_kernel)),'constant')
    
    for current_y in range(half_kernel, original_img.shape[1] - half_kernel):
        for current_x in range(half_kernel, original_img.shape[0] - half_kernel):
            correct_current_x = current_x - half_kernel
            correct_current_y = current_y - half_kernel
            current_x_hex = hex(correct_current_x).replace('0x','').zfill(3) 
            current_y_hex = hex(correct_current_y).replace('0x','').zfill(3) 
            is_nuclei = is_part_of_nuclei(current_x, current_y, masked_img)
            
            current_slice = original_img[current_x : current_x + kernel_size, current_y : current_y + kernel_size]
            
            if not (current_slice.shape[0] == kernel_size and current_slice.shape[1] == kernel_size):
                continue
            
            current_slice = np.copy(current_slice)
            current_slice = current_slice.astype('float16')
            current_slice = (current_slice / 128) - 1
            
            generated_slices.append( 
                {
                    "current_x" : correct_current_x,
                    "current_y" : correct_current_y,
                    "current_x_hex" : current_x_hex, 
                    "current_y_hex" : current_y_hex, 
                    "sum" : current_slice.sum(),
                    "is_nuclei" : is_nuclei,
                    "slice" : current_slice,
                    "augmented" : 0
                } )

    image_to_pickle = {
        'shape' : original_img.shape,
        'original_shape' : original_shape,
        'slices' : generated_slices,
        'id' : image_id
    }
    with gzip.open(os.path.join(PATH_TO_INTERMEDIATE,get_cluster(image_id.split(".")[1]), image_id) + '.pickle' , 'wb') as handle:
        pickle.dump(image_to_pickle, handle, protocol=pickle.HIGHEST_PROTOCOL)    
    last_pickled_image = image_to_pickle
    del image_to_pickle
    gc.collect()

In [None]:
def generate_augmentation(file_name, image, mask):
    for current_rotation in range(4):
        image_id = file_name.split(".")[0]
        current_image = np.copy(image)
        current_mask = np.copy(mask)
        if current_rotation > 0:
            for current_angle in range(current_rotation):
                current_image = np.rot90(current_image)
                current_mask = np.rot90(current_mask)
        io.imsave(os.path.join(PATH_TO_SAVE_NORMALIZED_IMAGES,get_cluster(image_id),'{}.{}'.format(current_rotation,file_name)),current_image)
        io.imsave(os.path.join(PATH_TO_SAVE_NORMALIZED_IMAGES,get_cluster(image_id),'{}.{}'.format(current_rotation,file_name)).replace('.png','.mask.png'),current_mask)
        create_pickling(file_name, current_rotation,image,mask)
        current_image = np.bitwise_and(current_image,current_mask)
        io.imsave(os.path.join(PATH_TO_SAVE_NORMALIZED_IMAGES,get_cluster(image_id),'{}.{}'.format(current_rotation,file_name)).replace('.png','.mask.values.png'),current_image)


In [None]:
%matplotlib inline

total_files = len(images_to_process)
counter = 0

def process_image(current_image):
    global counter
    image_file = current_image['images'][0]
    image_name = image_file.split("/")[-1:][0]
    print("Processing:{}% - {}".format(int((counter / total_files) * 100), image_name))
    
    file_png = io.imread(image_file, as_grey=True) 
    file_png = resize(file_png , (size,size))

    if DebugImage:
        file_png = file_png[1:20,1:20]
       
    file_png = img_as_ubyte(file_png)
    #print(file_png.mean())
    if file_png.mean() > 64: 
        #print("Inverting the color of the image")
        file_png = np.invert(file_png)

    array1 = np.asarray(file_png)
    min_value = array1.min()
    max_value = array1.max()
    current_range = max_value - min_value
    
    #Original Image
    if DebugImage:
        print(array1.shape, min_value,max_value, current_range)
        figure(figsize=(8, 6), dpi=80)
        imshow(array1, cmap='gray',vmin=0,vmax=255)

    #Normalized Image
    file_png = file_png - min_value
    file_png = file_png * (255 / current_range )
    array1 = np.asarray(file_png)
    min_value = array1.min()
    max_value = array1.max()
    current_range = max_value - min_value
    if DebugImage:
        print(array1.shape, min_value,max_value, current_range)
        figure(figsize=(8, 6), dpi=80)
        imshow(array1, cmap='gray',vmin=0,vmax=255)

    #Surface Plot
    if DebugImage:
        fig = figure(figsize=(8, 6), dpi=80)
        ax = fig.gca(projection='3d')
        X, Y = np.meshgrid(np.arange(0, array1.shape[0]) ,  np.arange(0,array1.shape[1]))
        Z = np.array(array1[X, Y])
        surf = ax.plot_surface(X , Y, Z)
        show()
        
    array1 = array1.astype('uint8')

    current_mean = array1.mean()
    array1[file_png < (current_mean / 2)] = 0
    
    masked_img = None
    first = True
    for mask in current_image['masks']:
        cur_img = io.imread(mask, as_grey=True)
        cur_img = resize(cur_img , (size,size))

        cur_img = img_as_ubyte(cur_img)
        cur_img[cur_img > 0] = 255
        
        cur_img = sobel(cur_img)
        
        cur_img = img_as_ubyte(cur_img)
        cur_img[cur_img > 0] = 255
        
        if first:
            masked_img = cur_img
        else:
            masked_img = np.bitwise_or(masked_img,cur_img)
        first = False 
    
    print("Valores:{} img={}-{} mask={}-{} ".format(array1.shape, array1.min(), array1.max(), masked_img.min(), masked_img.max()))
    generate_augmentation(image_name,array1,masked_img)
    counter+=1

pool = Pool(processes=32)    
pool.map(process_image,images_to_process)
    
