In [1]:
from tensorflow.keras import models, optimizers, losses
from utils import *
from network import *
from metrics import *
from ipywidgets import interact
from matplotlib import pyplot as plt
import cv2

In [14]:
train_low, train_high, val_low, val_high, test_low = data_loader_v3('./interpolation/')
print("Train X's shape : ", train_low.shape)
print("Train Y's shape : ", train_high.shape)
print("Validation X's shape : ", val_low.shape)
print("Validation Y's shape : ", val_high.shape)
print("Test X's shape : ", test_low.shape)


Train X's shape :  (303, 320, 256, 3)
Train Y's shape :  (303, 320, 256, 12)
Validation X's shape :  (25, 320, 256, 3)
Validation Y's shape :  (25, 320, 256, 12)
Test X's shape :  (342, 320, 256, 3)


In [9]:
root_path = './checkpoint/3to12/dense_L1/'
model_path = root_path+'model.h5'

weights_list = sorted(os.listdir(root_path))
weights_list = [i for i in weights_list if '.h5' in i]
weights_list.pop(-1)
#print(weights_list)

'model.h5'

In [10]:
weights_list

['0100_20.99_20.34.h5',
 '0200_16.75_25.59.h5',
 '0300_15.61_19.31.h5',
 '0400_13.56_21.44.h5',
 '0500_12.86_23.15.h5',
 '0600_12.22_23.05.h5',
 '0700_11.77_19.94.h5',
 '0800_11.62_22.15.h5',
 '0900_11.25_19.48.h5',
 '1000_10.59_18.97.h5',
 '1100_10.75_18.97.h5',
 '1200_10.49_19.18.h5',
 '1300_10.09_21.77.h5',
 '1400_10.04_19.33.h5',
 '1500_9.90_19.83.h5']

In [11]:
A = models.load_model(model_path, compile=False)
A.load_weights(root_path+weights_list[-1])
A.compile(optimizer=optimizers.Adam(lr=0.0001, epsilon=1e-8),
          loss=[losses.mean_squared_error, losses.binary_crossentropy], 
          loss_weights=[50, 1])

Instructions for updating:
Colocations handled automatically by placer.


In [12]:
train_, _ = A.predict(train_low)
val_, _ = A.predict(val_low)
print(train_.shape)
print(val_.shape)

(303, 320, 256, 12)
(25, 320, 256, 12)


$$
MSE = \frac{1}{n} \sum^n{(y-\hat{y})^2}
$$

$$
RMSE = \sqrt{MSE} = \sqrt{\frac{1}{n} \sum^n{(y-\hat{y})^2}}
$$

$$
RMSPE = \sqrt{\frac{1}{n} \sum^n({\frac{y-\hat{y}}{y}})^2}
$$

$$
PSNR = 10\log_{10}(\frac{MAX}{\sqrt{MSE}}) = 20\log_{10}{MAX} - 10\log_{10}\sqrt{MSE}
$$

$$
SSIM = I(x,y)C(x,y)S(x,y)
$$

In [15]:
train_mse = MSE(train_high, train_)
train_rmse = RMSE(train_high, train_)
train_rmspe = RMSPE(train_high, train_)
train_psnr = PSNR(train_high, train_, 65535)
train_ssim = SSIM(train_high, train_, 65535)

print(train_mse.shape, train_mse.mean())
print(train_rmse.shape, train_rmse.mean())
print(train_rmspe.shape, train_rmspe.mean())
print(train_psnr.shape, train_psnr.mean())
print(train_ssim.shape, train_ssim.mean())

(303, 12) 6456.025211499268
(303, 12) 73.36827672959285
(303, 12) 6.420254516078681
(303, 12) 59.7958347332202
(303, 12) 0.9949318672606428


In [16]:
val_mse = MSE(val_high, val_)
val_rmse = RMSE(val_high, val_)
val_rmspe = RMSPE(val_high, val_)
val_psnr = PSNR(val_high, val_, 65535)
val_ssim = SSIM(val_high, val_, 65535)

print(val_mse.shape, val_mse.mean())
print(val_rmse.shape, val_rmse.mean())
print(val_rmspe.shape, val_rmspe.mean())
print(val_psnr.shape, val_psnr.mean())
print(val_ssim.shape, val_ssim.mean())

(25, 12) 4874.320013251561
(25, 12) 64.89813874813738
(25, 12) 6.201660870348653
(25, 12) 60.734977396240474
(25, 12) 0.9958427578510829


In [17]:
train_high.max()

1963.24402359128

In [18]:
train_.max()

2403.725

In [19]:
def plot_result(y_pred, y_true, mode=12):
    # img : [Slices, Height, Width, 12]
    b, h, w, _ = y_pred.shape
    
    if mode == 12:
        tmp_pred = y_pred[...,3:-3]
        tmp_pred = np.transpose(tmp_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])

        tmp_true = y_true[...,3:-3]
        tmp_true = np.transpose(tmp_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
    
    elif mode ==6:
        tmp_pred = np.transpose(y_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])
        
        tmp_true = np.transpose(y_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
        
        
    max_idx = len(tmp_pred)-1
    err_min = (tmp_true-tmp_pred).min()
    err_max = (tmp_true-tmp_pred).max()
    
    def plot(idx=0):
        plt.figure(figsize=(12, 6))
        plt.subplot(131)
        plt.title("y_pred")
        plt.imshow(tmp_pred[idx], cmap='gray', vmin=0, vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_pred[idx].max()))
        plt.subplot(132)
        plt.title("y_true")
        plt.imshow(tmp_true[idx], cmap='gray', vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_true[idx].max()))
        plt.subplot(133)
        plt.title("Error (true - pred)")
        plt.imshow(tmp_true[idx]-tmp_pred[idx], cmap='gray', vmin = err_min, vmax=err_max)
    interact(plot, idx=(0, max_idx, 1))

In [21]:
plot_result(train_, train_high, 12)

interactive(children=(IntSlider(value=0, description='idx', max=1817), Output()), _dom_classes=('widget-intera…

In [12]:
plot_result(train_, train_high)

interactive(children=(IntSlider(value=0, description='idx', max=1817), Output()), _dom_classes=('widget-intera…