In [None]:
import os, h5py, itertools
import imageio as io
import numpy as np
from matplotlib import pyplot as plt
from skimage import transform as sp

from tensorflow.keras import models

In [None]:
trained_model = models.load_model('resources/nsdf/workflow/seg_msd_50_2_ep100')

In [None]:
def cut_image_into_tiles(img, main_size=380, border_offset=10, step=None):
    true_step = main_size #+ border_offset #(Uncomment the second section of line when trying to average out the border offset)
    img = np.pad(img,((border_offset,border_offset),(border_offset,border_offset)),'edge')    
    x, y = img.shape[:2]
    x_left = np.arange(border_offset, img.shape[0] - main_size + border_offset + 1, true_step)
    x_left = x_left[0:-1] # Easy fix when 200 x 200 image with 100 pixel border offset is used (last element is not required)
    y_top = np.arange(border_offset, img.shape[1] - main_size + border_offset + 1, true_step)
    y_top = y_top[0:-1]
    image_list = []
    for x, y in itertools.product(x_left, y_top):
        tile_i = img[x-border_offset:x+main_size+border_offset, y-border_offset:y+main_size+border_offset]
        image_list.append(tile_i)
    return np.array(image_list)  # without modifying x_left and y_top from list to ndarray gives error because of matrix mismatch


def reassemble_tiled_image_noavg(img_stack, shape, main_size, border_offset):
    result = np.zeros([shape[0],shape[1]])
    true_step = main_size  
    x_left = np.arange(0, result.shape[0] - main_size + border_offset + 1, true_step)
    y_top = np.arange(0, result.shape[1] - main_size + border_offset + 1, true_step)
    count = 0
    # We just add the 
    for x, y in itertools.product(x_left, y_top):
        result[x:x+main_size,y:y+main_size] = img_stack[count,border_offset:border_offset+main_size,border_offset:border_offset+main_size]
        count = count + 1    
    return result

def normalize_crop_reshape_image(fpath):
    crop = slice(250, 2250), slice(250, 2250)
    img = io.imread(fpath)[crop].astype(float)
    vmin, vmax = np.percentile(img, (0.01, 99.9))
    img = np.clip(img - vmin, 0, vmax-vmin) / (vmax - vmin)
    return img

def process_image(src, dst, fname, main_size, border_offset):
    """
    Read an image from os.path.join(src, fname), apply
    CNN, and save to os.path.join(dst, fname)
    """
    img = normalize_crop_reshape_image(os.path.join(src, fname))
    img_tiles = cut_image_into_tiles(img, main_size, border_offset) # Mainsize + 2*Border Size should be 400 (Trained CNN model only takes 400 x 400 images)
    img_segment_tile = []
    for y in range(0,img_tiles.shape[0]):
        img_seg = img_tiles[y,:] 
        img_seg = np.expand_dims(img_seg, (0, 3))
        img_segment = trained_model.predict(img_seg)
        img_segment_tile.append(np.squeeze(img_segment))    
    img_segment_tile = np.array(img_segment_tile)
    #img_segment_stitch = reassemble_tiled_image(img_segment_tile, [img.shape[0],img.shape[1]], main_size, border_offset)
    img_segment_stitch = reassemble_tiled_image_noavg(img_segment_tile, [img.shape[0],img.shape[1]], main_size, border_offset)
    img_norm = (img_segment_stitch*255).astype('uint8')
    outpath = os.path.join(dst, fname)
    try:
        io.imsave(outpath,img_norm) #sp.resize(np.squeeze(img_norm),[2000,2000]))
    except FileNotFoundError:
        os.makedirs(dst)
        io.imsave(outpath,img_norm) #sp.resize(np.squeeze(img_norm),[2000,2000]))
    return img_norm

def process_all_images_in_folder(src, dst, main_size, border_offset):
    file_list = os.listdir(src)
    for file in file_list:
        print(file)
        process_image(src, dst, file, main_size, border_offset)

### Sample ID: fly_scan_id_112517

### Running the CNN Model for on one single Slice

In [None]:
# Where the reconstructed slices are saved:
src = '/home/kancr/ondemand/CNN_Model_Test/Reconstructed_Data/'
# Where to save the processed images:
dst = '/home/kancr/ondemand/CNN_Model_Test/'

In [None]:
%%time
fname = 'recon_cgls_tv_01000.tiff'
input_img = normalize_crop_reshape_image(os.path.join(src, fname ))
img_segment = process_image(src, dst, fname, 200, 100)

In [None]:
%matplotlib inline
f, ((ax1, ax2)) = plt.subplots(2, 1, figsize=([20, 20]))

mappable = ax1.imshow(np.squeeze(input_img), cmap = 'gray')
f.colorbar(mappable, ax=ax1)

mappable = ax2.imshow(np.squeeze(img_segment), cmap = 'gray')
f.colorbar(mappable, ax=ax2)

### Checking Small Sections in a given slice to check the performance of the CNN Model

In [None]:
x_left = np.arange(0, 2000 - 200 + 100 + 1, 200)
print(x_left)

In [None]:
# Where the reconstructed slices are saved:
src = '/home/kancr/ondemand/CNN_Model_Test/Reconstructed_Data/'
# Where to save the processed images:
dst = '/home/kancr/ondemand/CNN_Model_Test/'#'TiledImages_SegmentedData/'
fname = 'recon_cgls_tv_00801.tiff'
input_img = normalize_crop_reshape_image(os.path.join(src, fname))
img_section = input_img[1000:1400,1200:1600]
img_section = np.expand_dims(img_section, (0, 3))
img_segment_section = trained_model.predict(img_section)

In [None]:
%matplotlib inline
f, ((ax1)) = plt.subplots(1, 1, figsize=([20, 20]))

mappable = ax1.imshow(np.squeeze(input_img), cmap = 'gray')
f.colorbar(mappable, ax=ax1)

In [None]:
%matplotlib inline
f, ((ax1, ax2)) = plt.subplots(2, 1, figsize=([20, 20]))

mappable = ax1.imshow(np.squeeze(img_section), cmap = 'gray')
f.colorbar(mappable, ax=ax1)

mappable = ax2.imshow(np.squeeze(img_segment_section), cmap = 'gray')
f.colorbar(mappable, ax=ax2)

### Run the CNN Model Model for all the Reconstruction Slice

In [None]:
# Where the reconstructed slices are saved:
src = '/home/kancr/ondemand/CNN_Model_Test/Reconstructed_Data/'
# Where to save the processed images:
dst = '/home/kancr/ondemand/CNN_Model_Test/TiledImages_SegmentedData/'

In [None]:
# Double Check you are passing to appropriate destination
# process_all_images_in_folder(src, dst)

### Sample ID: fly_scan_id_112509

In [None]:
# Where the reconstructed slices are saved:
src = '/home/kancr/ondemand/CNN_Model_Test/fly_scan_id_112509/'
# Where to save the processed images:
dst = '/home/kancr/ondemand/CNN_Model_Test/fly_scan_id_112509/Segment_Data'
fname = 'recon_cgls_tv_01202.tiff'
input_img = normalize_crop_reshape_image(os.path.join(src, fname))
img_segment = process_image(src, dst, fname)

f, ((ax1, ax2)) = plt.subplots(2, 1, figsize=([20, 20]))
mappable = ax1.imshow(np.squeeze(input_img), cmap = 'gray')
f.colorbar(mappable, ax=ax1)

mappable = ax2.imshow(np.squeeze(img_segment), cmap = 'gray')
f.colorbar(mappable, ax=ax2)

### Checking the performance for small section in this sample

In [None]:
input_img = normalize_crop_reshape_image(os.path.join(src, fname))
img_section = input_img[600:1000,900:1300]
img_section = np.expand_dims(img_section, (0, 3))
img_segment_section = trained_model.predict(img_section)

In [None]:
f, ((ax1, ax2)) = plt.subplots(2, 1, figsize=([20, 20]))
mappable = ax1.imshow(np.squeeze(img_section), cmap = 'gray')
f.colorbar(mappable, ax=ax1)
mappable = ax2.imshow(np.squeeze(img_segment_section), cmap = 'gray')
f.colorbar(mappable, ax=ax2)

# Scratch

In [None]:
%matplotlib inline
f, ((ax1, ax2)) = plt.subplots(1, 2, figsize=([20, 20]))

mappable = ax1.imshow(np.squeeze(input_img))
f.colorbar(mappable, ax=ax1)

mappable = ax2.imshow(img_segment)
f.colorbar(mappable, ax=ax2)

In [None]:
%matplotlib inline
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
mappable = ax1.imshow(np.squeeze(input_img))
f.colorbar(mappable, ax=ax1)
# tmp = 255*np.squeeze(img_segment)
# tmp01 = tmp<=200
# tmp02 = tmp>25
# tmp03 = np.bitwise_xor(tmp01,tmp02)
mappable = ax2.imshow(dst1)
f.colorbar(mappable, ax=ax2)
plt.show()

In [None]:
#Averaging Filter
import cv2
av_filter = 1/400 * np.ones([20,20], dtype='float32')

# # ddepth = -1, means destination image has depth same as input image
dst1 = cv2.filter2D(img_segment, -1, av_filter)
# cv2.imwrite('2_av_fil.jpg', dst1)

In [None]:
Row = np.arange(0, tmp.shape[0] - 390 + 1, 390)
Col = np.arange(0, tmp.shape[0] - 390 + 1, 390)
print(Row.shape)

In [None]:
# Design a Mask to Correct the Overlap Region
Mask = np.ones([2000,2000])
Row_Ind = np.arange(0, Mask.shape[0] - 380 + 1, 390)
print(Row_Ind)


In [None]:

f, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
mappable = ax1.imshow(tmp)
f.colorbar(mappable, ax=ax1)
x_left = np.arange(0, tmp.shape[0] - 390 + 1, 390)
tmp[x_left[4]:x_left[4]+9,:] = 255
mappable = ax2.imshow(Mask)
f.colorbar(mappable, ax=ax2)
plt.show()