In [1]:
from tensorflow.keras import models, optimizers, losses
from utils import *
from network import *
from metrics import *
from ipywidgets import interact
from matplotlib import pyplot as plt
import cv2

In [2]:
train_low, train_high, val_low, val_high, test_low = data_loader_v2('./interpolation/')
print("Train X's shape : ", train_low.shape)
print("Train Y's shape : ", train_high.shape)
print("Validation X's shape : ", val_low.shape)
print("Validation Y's shape : ", val_high.shape)
print("Test X's shape : ", test_low.shape)

Train X's shape :  (303, 320, 256, 3)
Train Y's shape :  (303, 320, 256, 6)
Validation X's shape :  (25, 320, 256, 3)
Validation Y's shape :  (25, 320, 256, 6)
Test X's shape :  (328, 320, 256, 3)


In [3]:
root_path = './checkpoint/3to6/dense/'
model_path = root_path+'model.h5'

weights_list = sorted(os.listdir(root_path))

['0100_128559.47.h5', '0200_74224.55.h5', '0300_57165.20.h5', '0400_49975.56.h5', '0500_42044.17.h5', '0600_35648.90.h5', '0700_33564.13.h5', '0800_26772.70.h5', '0900_30249.87.h5', '1000_22447.03.h5', '1100_20460.09.h5', '1200_21774.93.h5', '1300_21052.38.h5', '1400_18297.91.h5', '1500_18631.46.h5']


In [4]:
A = models.load_model(root_path+weights_list[-1], compile=False)
# A = models.load_model(model_path, compile=False)
# A.load_weights(root_path+weights_list[5])
A.compile(optimizer=optimizers.Adam(lr=0.0001, epsilon=1e-8),
          loss=[losses.mean_squared_error, losses.binary_crossentropy], 
          loss_weights=[50, 1])

Instructions for updating:
Colocations handled automatically by placer.


In [5]:
train_, _ = A.predict(train_low)
val_, _ = A.predict(val_low)
print(train_.shape)
print(val_.shape)

(303, 320, 256, 6)
(25, 320, 256, 6)


$$
MSE = \frac{1}{n} \sum^n{(y-\hat{y})^2}
$$

$$
RMSE = \sqrt{MSE} = \sqrt{\frac{1}{n} \sum^n{(y-\hat{y})^2}}
$$

$$
RMSPE = \sqrt{\frac{1}{n} \sum^n({\frac{y-\hat{y}}{y}})^2}
$$

$$
PSNR = 10\log_{10}(\frac{MAX}{\sqrt{MSE}}) = 20\log_{10}{MAX} - 10\log_{10}\sqrt{MSE}
$$

$$
SSIM = I(x,y)C(x,y)S(x,y)
$$

In [6]:
train_mse = MSE(train_high, train_)
train_rmse = RMSE(train_high, train_)
train_rmspe = RMSPE(train_high, train_)
train_psnr = PSNR(train_high, train_, 65535)
train_ssim = SSIM(train_high, train_, 65535)

print(train_mse.shape, train_mse.mean())
print(train_rmse.shape, train_rmse.mean())
print(train_rmspe.shape, train_rmspe.mean())
print(train_psnr.shape, train_psnr.mean())
print(train_ssim.shape, train_ssim.mean())

(303, 6) 1264.6795439560594
(303, 6) 32.636808728096646
(303, 6) 1.474638255093552
(303, 6) 66.77439599780611
(303, 6) 0.9990183627100298


In [7]:
val_mse = MSE(val_high, val_)
val_rmse = RMSE(val_high, val_)
val_rmspe = RMSPE(val_high, val_)
val_psnr = PSNR(val_high, val_, 65535)
val_ssim = SSIM(val_high, val_, 65535)

print(val_mse.shape, val_mse.mean())
print(val_rmse.shape, val_rmse.mean())
print(val_rmspe.shape, val_rmspe.mean())
print(val_psnr.shape, val_psnr.mean())
print(val_ssim.shape, val_ssim.mean())

(25, 6) 1742.5762055426292
(25, 6) 39.38365392583606
(25, 6) 1.8855316345371715
(25, 6) 64.94743179710387
(25, 6) 0.9985600257896747


In [8]:
def plot_result(y_pred, y_true, mode=12):
    # img : [Slices, Height, Width, 12]
    b, h, w, _ = y_pred.shape
    
    if mode == 12:
        tmp_pred = y_pred[...,3:-3]
        tmp_pred = np.transpose(tmp_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])

        tmp_true = y_true[...,3:-3]
        tmp_true = np.transpose(tmp_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
    
    elif mode ==6:
        tmp_pred = np.transpose(y_pred, [0, 3, 1, 2])
        tmp_pred = np.reshape(tmp_pred, [b*6, h, w])
        
        tmp_true = np.transpose(y_true, [0, 3, 1, 2])
        tmp_true = np.reshape(tmp_true, [b*6, h, w])
        
        
    max_idx = len(tmp_pred)-1
    err_min = (tmp_true-tmp_pred).min()
    err_max = (tmp_true-tmp_pred).max()
    
    def plot(idx=0):
        plt.figure(figsize=(12, 6))
        plt.subplot(131)
        plt.title("y_pred")
        plt.imshow(tmp_pred[idx], cmap='gray', vmin=0, vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_pred[idx].max()))
        plt.subplot(132)
        plt.title("y_true")
        plt.imshow(tmp_true[idx], cmap='gray', vmax=tmp_true.max())
        plt.xlabel("%.2f"%(tmp_true[idx].max()))
        plt.subplot(133)
        plt.title("Error (true - pred)")
        plt.imshow(tmp_true[idx]-tmp_pred[idx], cmap='gray', vmin = err_min, vmax=err_max)
    interact(plot, idx=(0, max_idx, 1))

In [9]:
plot_result(val_, val_high, 6)

interactive(children=(IntSlider(value=0, description='idx', max=149), Output()), _dom_classes=('widget-interac…

In [11]:
plot_result(train_, train_high, 6)

interactive(children=(IntSlider(value=0, description='idx', max=1817), Output()), _dom_classes=('widget-intera…