### Import Modules

In [None]:
import os
import pickle
import matplotlib.pyplot as plt
from skimage.io import imread
from keras.models import load_model
import keras_metrics
from skimage.util.montage import montage2d


from matplotlib import image
from PIL import Image
import numpy as np

from functions5 import dice_coef, dice_p_bce, batch_img_gen
HOME = os.path.expanduser("~")

### Load Network

In [None]:
model_path = HOME + "/new_project/Models/model15_full.h5"
model = load_model(model_path, custom_objects={'dice_p_bce': dice_p_bce, 'dice_coef':dice_coef, 'binary_precision':keras_metrics.precision(), 'binary_recall':keras_metrics.precision()})

### Load Data

In [None]:
filepath = HOME + '/new_project/data/pickles/tif_train.pkl'
with open(filepath, 'rb') as pkl:
    tif_train = pickle.load(pkl)

In [None]:
filepath = HOME + '/new_project/data/pickles/mask_train.pkl'
with open(filepath, 'rb') as pkl:
    mask_train = pickle.load(pkl)

### Predict Test Set

In [None]:
valid_gen = batch_img_gen(4, tif_train, mask_train)

In [None]:
batch_X, batch_y = next(valid_gen)
if batch_X.shape[0]>16:
    batch_X = batch_X[:16]
    batch_y = batch_y[:16]
    
print('x', batch_X.shape, batch_X.dtype, batch_X.min(), batch_X.max())
print('y', batch_y.shape, batch_y.dtype, batch_y.min(), batch_y.max())
pred_y = model.predict(batch_X)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (24, 8))
montage_rgb = lambda x: np.stack([montage2d(x[:, :, :, i]) for i in range(x.shape[3])], -1)
ax1.imshow(montage_rgb(batch_X))
ax2.set_title('Source Image')
ax2.imshow(montage2d(batch_y[:, :, :, 0]), cmap = 'Greys')
ax2.set_title('Ground Truth')
temp = montage2d(pred_y[:, :, :, 0])
ax3.imshow(temp>0.8, cmap = 'Greys')
ax3.set_title('Prediction')

In [None]:
save_path = HOME + '/new_project/images/predictions/'
#fig.savefig('model15_pred2.png', dpi=300)

In [None]:
pred_y = model.predict(batch_X)

mask1 = montage2d(batch_y[:, :, :, 0])
temp = montage2d(pred_y[:, :, :, 0])
mask2 = temp>0.8

save_path = HOME + '/new_project/images/overlays/'

plt.figure(figsize=(3,4))
joint = np.zeros((len(mask1),len(mask1),3))
joint[:,:,0] = mask1
joint[:,:,1] = mask2
plt.figure(figsize=(20,10))
plt.imshow(joint)
#plt.savefig(str(save_path) + 'model15_over2.png', transparent=True)

### Load Disaster Data

In [None]:
im1_path = HOME + '/new_project/data/disaster/processed/iran_pre_crop.jpg'
# im2_path = HOME + '/new_project/data/disaster/worldview-2-iran-missile-facility-destroyed.jpg'

# load image as pixel array
im1 = Image.open(im1_path)
# im2 = Image.open(im2_path)


data = image.imread(im1_path)
# data2 = image.imread(im2_path)

# summarize shape of the pixel array
print(data.dtype)
print(data.shape)

# display the array of pixels as an image
plt.imshow(data)
plt.show()

In [None]:
# plt.imshow(data2)
# plt.show()

In [None]:
# filepath = HOME + '/new_project/data/pickles/iran_post_zm.pkl'
# with open(filepath, 'rb') as pkl:
#     pre = pickle.load(pkl)

In [None]:
filepath = HOME + '/new_project/data/pickles/iran_post_zm.pkl'
with open(filepath, 'rb') as pkl:
    post = pickle.load(pkl)

### Process Disaster Data

In [None]:
a = np.asarray(im1)
# b = np.asarray(im2)

In [None]:
a.shape
post.shape

In [None]:
disaster_list = [a, post]

In [None]:
disaster_gen = batch_img_gen(2, disaster_list, mask_train)

### Predict Disaster Data

In [None]:
disaster_x, t_y = next(disaster_gen)
if disaster_x.shape[0]>16:
    disaster_x = disaster_x[:16]
    t_y = t_y[:16]

pred_y = model.predict(disaster_x)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (32, 16))
montage_rgb = lambda x: np.stack([montage2d(x[:, :, :, i]) for i in range(x.shape[3])], -1)
ax1.imshow(montage_rgb(disaster_x))
ax1.set_title('Source Image')
temp = montage2d(pred_y[:, :, :, 0])
ax2.imshow(temp>0.8, cmap = 'Greys')
ax2.set_title('Prediction')
#plt.savefig(str(save_path) + 'disaster.png', transparent=True)