In [None]:
import os, sys
import cv2
import numpy as np
import uuid
import tensorflow as tf
from skimage.io import imread, imsave, imshow
from PIL import Image, ImageTk
import matplotlib.pyplot as plt
from imutils import paths
import itertools
import json
from pprint import pprint

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from core.imageprep import dir_checker, random_crop, crop_generator, random_crop_batch
from core.models import UNet

from tensorflow.keras.models import load_model

from core.imageprep import create_crop_idx, crop_to_patch, construct_from_patch
from core.train_predict import stack_predict
from tqdm.notebook import trange


get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')

 ## Load Training Dataset

In [None]:
# load image
print("Load Images...")
# on mac
# path = "/Volumes/LaCie_DataStorage/PerlmutterData/"

# on Window PC 
path = os.path.join('D:', 'PerlmutterData')

# experiment
exp_name = 'dl_seg_project_raw'
# trianing timestamp
imginput_timestamp = '2019_12_06_17_06'
model_training_timestamp = '2019_12_06_17_14'
print('Training timestamp: {}'.format(model_training_timestamp))

# input img path
imginput = os.path.join(exp_name, 'data_crop', imginput_timestamp)
imgpath = os.path.join(path, imginput)
print('Input Images Path: {}'.format(imgpath))

# model path
modelfd = 'model'
modelfn = 'model_' + model_training_timestamp + '.h5'
path_model = os.path.join(path, modelfd, modelfn)
print('Model Path: {}'.format(path_model))

# raw path
rawfd = 'raw'
path_raw = os.path.join(path, rawfd)
print('Raw Path: {}'.format(path_raw))

# prediction path
pred_path = os.path.join(path, exp_name)
dir_checker('pred_img', pred_path)

## Parameter

In [None]:
# load parameter
parsfd = 'pars'
parsfn = 'pars_' + model_training_timestamp + '.json'
path_pars = os.path.join(path, parsfd, parsfn)

with open(path_pars) as json_file:
    pars = json.load(json_file)

In [None]:
pprint(pars)

In [None]:
label = pars['inputclass']
IMG_HEIGHT = pars['IMG_HEIGHT']
IMG_WIDTH = pars['IMG_WIDTH']

## Predict from Testing Dataset

In [None]:
# get dataset
rawfdlist = os.listdir(path_raw)
print(rawfdlist)

In [None]:
rawimglist = {}
for folder in rawfdlist:
    print(folder)
    rawimglist[folder] = list(paths.list_images(os.path.join(path_raw, folder, 'Aligned')))

In [None]:
pprint(rawimglist)

## Define Prediction Set 

In [None]:
samplesize = 10
rawimglist_small = {}
for idx, item in rawimglist.items():
    rawimglist_small[idx] = rawimglist[idx][:samplesize]
pprint(rawimglist_small)

In [None]:
img = imread(rawimglist[rawfdlist[0]][500])
plt.figure(figsize = (20,20))
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
model = load_model(path_model)
print(model)

In [None]:
'''
from tensorflow.keras import Sequential
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, concatenate, Conv2D, Conv2DTranspose, Dropout, Flatten, Dense, Activation, Layer, Reshape, Permute, Lambda
from tensorflow.keras.layers import Conv3D, MaxPool3D, ZeroPadding3D
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D, ZeroPadding2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam, Adadelta
from tensorflow.keras import backend as K

inputs = Input((None, None, 1))
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
print ("conv1 shape:",conv1.shape)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
print ("conv1 shape:",conv1.shape)

pool1 = MaxPool2D(pool_size=(2, 2))(conv1)
print ("pool1 shape:",pool1.shape)

conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
print ("conv2 shape:",conv2.shape)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
print ("conv2 shape:",conv2.shape)
pool2 = MaxPool2D(pool_size=(2, 2))(conv2)
print ("pool2 shape:",pool2.shape)

conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
print ("conv3 shape:",conv3.shape)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
print ("conv3 shape:",conv3.shape)
pool3 = MaxPool2D(pool_size=(2, 2))(conv3)
print ("pool3 shape:",pool3.shape)

conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPool2D(pool_size=(2, 2))(drop4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6])

conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3, up7])

conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2, up8])

conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1, up9])

conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)

output = Conv2D(1, 1, activation = 'sigmoid')(conv9)

output_shape = Model(inputs , conv9).output_shape

newmodel = Model(inputs, output)

newmodel.compile(loss="binary_crossentropy", lr=1e-5, metrics=['accuracy'])
newmodel.summary()
'''

In [None]:
'''
newmodel.load_weights(path_model)
'''

In [None]:
'''
# folder name
fdnm_small = 'batch_01'
dir_checker(fdnm_small, os.path.join(pred_path, 'pred_img'))
img_path_small = os.path.join(pred_path, 'pred_img', fdnm_small)

# create folder list
for folder in rawfdlist:
    dir_checker(folder, img_path_small)
'''

In [None]:
'''
from core.imageprep import create_crop_idx, crop_to_patch
from core.train_predict import stack_predict_v2

for idx in trange(len(rawfdlist)):
    
    folder = rawfdlist[idx]
    
    pred_input_imgs =  rawimglist_small[folder]
    pred_output_path = os.path.join(pred_path, 'pred_img', fdnm_small, folder)
    
    img = imread(rawimglist_small[folder][0])
    # cropidx = create_crop_idx(img.shape, (IMG_HEIGHT, IMG_WIDTH), overlap_fac = 0.1)
    # print(cropidx)
    
    stack_predict_v2(
                input_imgpath = pred_input_imgs, 
                output_imgpath = pred_output_path, 
                # cropidx = cropidx, 
                model = newmodel, 
                rescale = 1./255.,
                # patch_size = (IMG_HEIGHT, IMG_WIDTH), 
                predict_threshold = 0.5)
'''

## Prediction Small Dataset with tiles
### Create Folder 

In [None]:

# folder name
fdnm_small = 'batch_01'
dir_checker(fdnm_small, os.path.join(pred_path, 'pred_img'))
img_path_small = os.path.join(pred_path, 'pred_img', fdnm_small)

# create folder list
for folder in rawfdlist:
    dir_checker(folder, img_path_small)


### Tiling Prediction with Stack Input
- Crop image into patched by a given overlap factor
- Export a cropping index
- Construct patches back into a image 

In [None]:

from core.imageprep import create_crop_idx, crop_to_patch

for idx in trange(len(rawfdlist)):
    
    folder = rawfdlist[idx]
    
    pred_input_imgs =  rawimglist_small[folder]
    pred_output_path = os.path.join(pred_path, 'pred_img', fdnm_small, folder)
    
    img = imread(rawimglist_small[folder][0])
    cropidx = create_crop_idx(img.shape, (IMG_HEIGHT, IMG_WIDTH), overlap_fac = 0.1)
    # print(cropidx)
    
    stack_predict(
                input_imgpath = pred_input_imgs, 
                output_imgpath = pred_output_path, 
                cropidx = cropidx, 
                model = model, 
                rescale = 1./255.,
                patch_size = (IMG_HEIGHT, IMG_WIDTH), 
                predict_threshold = 0.5)


## Prediction Small Dataset with tiles (whole stack)
### Create Folder 

In [None]:
'''
# folder name
fdnm_whole = 'batch_02'
dir_checker(fdnm_whole, os.path.join(pred_path, 'pred_img'))
img_path_whole = os.path.join(pred_path, 'pred_img', fdnm_whole)

# create folder list
for folder in rawfdlist:
    dir_checker(folder, img_path_whole)
'''

### Tiling Prediction with Stack Input
- Crop image into patched by a given overlap factor
- Export a cropping index
- Construct patches back into a image 

In [None]:
'''
from core.imageprep import create_crop_idx, crop_to_patch

for idx in trange(len(rawfdlist)):
    
    folder = rawfdlist[idx]
    
    pred_input_imgs =  rawimglist[folder]
    pred_output_path = os.path.join(pred_path, 'pred_img', fdnm_whole, folder)
    
    img = imread(rawimglist_small[folder][0])
    cropidx = create_crop_idx(img.shape, (IMG_HEIGHT, IMG_WIDTH), overlap_fac = 0.1)
    # print(cropidx)
    
    stack_predict(
                input_imgpath = pred_input_imgs, 
                output_imgpath = pred_output_path, 
                cropidx = cropidx, 
                model = model, 
                rescale = 1./255.,
                patch_size = (IMG_HEIGHT, IMG_WIDTH), 
                predict_threshold = 0.5)
'''