In [1]:
from tensorflow.keras import models, optimizers, losses
from utils import *
from loss import *
from network import *
from metrics import *
from ipywidgets import interact
from matplotlib import pyplot as plt
import cv2

In [None]:
def plot_result(y_pred, y_true, mode=12):
    # img : [Slices, Height, Width, 12]
    b, h, w, _ = y_pred.shape
    
    if mode == 12:
        tmp_pred = y_pred[...,3:-3]
        tmp_pred = np.transpose(tmp_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])

        tmp_true = y_true[...,3:-3]
        tmp_true = np.transpose(tmp_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
    
    elif mode ==6:
        tmp_pred = np.transpose(y_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])
        
        tmp_true = np.transpose(y_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
        
        
    max_idx = len(tmp_pred)-1
    err = tmp_true-tmp_pred
    err_min = (tmp_true-tmp_pred).min()
    err_max = (tmp_true-tmp_pred).max()
    
    def plot(idx=0):
        plt.figure(figsize=(12, 6))
        plt.subplot(131)
        plt.title("y_pred")
        plt.imshow(tmp_pred[idx], cmap='gray', vmin=0, vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_pred[idx].max()))
        plt.subplot(132)
        plt.title("y_true")
        plt.imshow(tmp_true[idx], cmap='gray', vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_true[idx].max()))
        plt.subplot(133)
        plt.title("Error (true - pred)")
        plt.imshow(err[idx], cmap='gray', vmin = err_min, vmax=err_max)
        #plt.xlabel("%.2f ~ %.2f"%(err[idx].min(), err[idx].max()))
    interact(plot, idx=(0, max_idx, 1))

In [2]:
train_low, train_high, val_low, val_high, test_low = data_loader_v3('./interpolation/')
print("Train X's shape : ", train_low.shape)
print("Train Y's shape : ", train_high.shape)
print("Validation X's shape : ", val_low.shape)
print("Validation Y's shape : ", val_high.shape)
print("Test X's shape : ", test_low.shape)


Train X's shape :  (303, 320, 256, 3)
Train Y's shape :  (303, 320, 256, 12)
Validation X's shape :  (25, 320, 256, 3)
Validation Y's shape :  (25, 320, 256, 12)
Test X's shape :  (320, 320, 256, 3)


In [15]:
root_path = './checkpoint/3to12/dense_L1/'
model_path = root_path+'model.h5'

weights_list = sorted(os.listdir(root_path))
weights_list = [i for i in weights_list if '.h5' in i]
weights_list.pop(-1)
#print(weights_list)

'model.h5'

In [16]:
A = models.load_model(model_path, compile=False)
A.load_weights(root_path+weights_list[-1])
A.compile(optimizer=optimizers.Adam(lr=0.0001, epsilon=1e-8),
          loss=[Custom_MSE, losses.binary_crossentropy], 
          loss_weights=[50, 1])

In [17]:
train_, _ = A.predict(train_low)
val_, _ = A.predict(val_low)
print(train_.shape)
print(val_.shape)

(303, 320, 256, 12)
(25, 320, 256, 12)


$$
MSE = \frac{1}{n} \sum^n{(y-\hat{y})^2}
$$

$$
RMSE = \sqrt{MSE} = \sqrt{\frac{1}{n} \sum^n{(y-\hat{y})^2}}
$$

$$
RMSPE = \sqrt{\frac{1}{n} \sum^n({\frac{y-\hat{y}}{y}})^2}
$$

$$
PSNR = 10\log_{10}(\frac{MAX}{\sqrt{MSE}}) = 20\log_{10}{MAX} - 10\log_{10}\sqrt{MSE}
$$

$$
SSIM = I(x,y)C(x,y)S(x,y)
$$

In [18]:
train_mae = MAE(train_high[...,3:-3], train_[...,3:-3])
train_mse = MSE(train_high[...,3:-3], train_[...,3:-3])
train_rmse = RMSE(train_high[...,3:-3], train_[...,3:-3])
train_rmspe = RMSPE(train_high[...,3:-3], train_[...,3:-3])
train_psnr = PSNR(train_high[...,3:-3], train_[...,3:-3], 65535)
train_ssim = SSIM(train_high[...,3:-3], train_[...,3:-3], 65535)

print(train_mae.shape, train_mae.mean())
print(train_mse.shape, train_mse.mean())
print(train_rmse.shape, train_rmse.mean())
print(train_rmspe.shape, train_rmspe.mean())
print(train_psnr.shape, train_psnr.mean())
print(train_ssim.shape, train_ssim.mean())

(303, 6) 33.98999570365423
(303, 6) 5320.14926935723
(303, 6) 69.56844863046261
(303, 6) 3.1230199652967565
(303, 6) 59.911784310505936
(303, 6) 0.9964254193247921


In [19]:
val_mae = MAE(val_high[...,3:-3], val_[...,3:-3])
val_mse = MSE(val_high[...,3:-3], val_[...,3:-3])
val_rmse = RMSE(val_high[...,3:-3], val_[...,3:-3])
val_rmspe = RMSPE(val_high[...,3:-3], val_[...,3:-3])
val_psnr = PSNR(val_high[...,3:-3], val_[...,3:-3], 65535)
val_ssim = SSIM(val_high[...,3:-3], val_[...,3:-3], 65535)

print(val_mae.shape, val_mae.mean())
print(val_mse.shape, val_mse.mean())
print(val_rmse.shape, val_rmse.mean())
print(val_rmspe.shape, val_rmspe.mean())
print(val_psnr.shape, val_psnr.mean())
print(val_ssim.shape, val_ssim.mean())

(25, 6) 31.278478686788404
(25, 6) 4156.903864137075
(25, 6) 62.17177322807019
(25, 6) 3.4606491075809718
(25, 6) 60.7756308484864
(25, 6) 0.9969909700011051


In [28]:
print(train_mae.argmin(), train_mae.min())
print(train_mse.argmin(), train_mse.min())
print(train_rmse.argmin(), train_rmse.min())

print(val_mae.argmin(), val_mae.min())
print(val_mse.argmin(), val_mse.min())
print(val_rmse.argmin(), val_rmse.min())

1190 12.855120694154085
1190 704.5527041233876
1190 26.543411689596113
93 18.806085026464462
92 1304.9834736414628
92 36.124554995756874


In [27]:
train_mae.shape

(303, 6)

In [26]:
val_mae.shape

(25, 6)

In [21]:
plot_result(train_, train_high, 12)

interactive(children=(IntSlider(value=0, description='idx', max=1817), Output()), _dom_classes=('widget-intera…

In [22]:
plot_result(val_, val_high, 12)

interactive(children=(IntSlider(value=0, description='idx', max=149), Output()), _dom_classes=('widget-interac…