In [None]:
import os
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import utilities as u
import subpixellayer as sl # Is important for the subpixel layer

from keras.models import load_model
from matplotlib.gridspec import GridSpec
from matplotlib_scalebar.scalebar import ScaleBar

# Ignore Tensorflow Warnings and other tensorflow options
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

In [None]:
# Model directory / Network architecture (e.g. DataAugmented, OriginalFukami,...)
model_dir = "Subpixel"

# Additional run info
sample_num = 11400

variables = ["Hs", "Tm02", "Dir"]
convert = True # Is Direction data to be shifted? 
nan_thresholds = [0.01, 0.5, 0.05] # Under which value should be regarded as NaN?

# Save the figure?
save = False

# Load models into list
models = []
LRs = []
HRs = []

for var in variables:
    # Load models
    fdir = "Models/{}".format(model_dir)
    fmodel = "Model_Inp_{}.hdf5".format(var)
    fmodel = os.path.join(fdir, fmodel)
    models.append(load_model(fmodel))

    # Load LR - Data and reference
    fn_LR = "Data/LR/{}/BaskCoast_{}_{}.npy"
    fn_LR = fn_LR.format(var, var.upper(), sample_num)
    LRs.append(np.load(fn_LR))

    fn_HR = "Data/HR/{}/BaskCoast_{}_{}.npy"
    fn_HR = fn_HR.format(var, var.upper(), sample_num)
    HRs.append(np.load(fn_HR))

In [None]:
predictions = []
maes = []
for i, LR in enumerate(LRs):
    # Check if direction variable should be shifted
    if variables[i] == "Dir" and convert:
        # Convert NaNs to zero for prediction with the neural network
        LR_inp = np.nan_to_num((LR + 105)%360)
        # Predict with the neural network and set negligible values back to NaN
        prediction = models[i].predict(np.reshape(LR_inp, (1, 10, 10, 1)))
        # Squeeze unnecessary axes for plotting
        prediction = np.squeeze(prediction)
        prediction[prediction < nan_thresholds[i]] = np.nan
        prediction = (prediction + 255)%360
    else:
        LR_inp = np.nan_to_num(LR)
        prediction = models[i].predict(np.reshape(LR_inp, (1, 10, 10, 1)))
        # Squeeze unnecessary axes for plotting
        prediction = np.squeeze(prediction)
        prediction[prediction < nan_thresholds[i]] = np.nan
    
    # Upsample LR image for the plot 
    LRs[i] = u.NNUpsampling(LR, factor=16)
    
    # Store all predictions in one list
    predictions.append(prediction)
    
    # Compute MAE per image
    maes.append(np.nanmean((HRs[i] - prediction)))
    
data = [LRs, predictions, HRs]
diffs = [HRs[i] - predictions[i] for i in range(3)]

# Plot, every variable in a column

In [None]:
# Define the ranges of longitude and latitude to set up the coordinate system
minlon = -1.617
maxlon = -1.518

minlat = 43.46
maxlat = 43.53

# Format coordinates in a pretty way
x = ["-1.60°E", "", "-1.56°E", "", "-1.52°E"]
y = ["43.46°N", "", "43.49°N", "", "43.52°N"]

# Get the position of the ticks
x_loc = np.arange(30, 160, 30)
y_loc = np.arange(10, 160, 30)

# Plotting constants for ticks
CLABELSIZE = 17
CTICKLABELSIZE = 15
LABELSIZE = 18
TICKLABELSIZE = 15
LENGTH = 5.0
WIDTH= 1.0

# Set colorbar ranges
vmin = [np.nanmin(predictions[ncol]) for ncol in range(3)]
vmax = [np.nanmax(predictions[ncol]) for ncol in range(3)]
# Problem with colorbar range for mean wave direction
vmin[1] = 7.5
vmin[2] = 280

# Colorbar Properties
CB_BOTTOM = 0.32
CB_WIDTH = 0.2
CB_HEIGHT = 0.01
CB_LEFT = [0.15, 0.41, 0.67]

# Scale the colorbars by which factor
SHRINK = 0.7

# Various labels
clabels = [r"$H_S$ [m]", r"$T_{m02}$ [s]", "$Dir$ [°]"]
ylabels = ["Low-Resolution Input", "Neural Network Prediction", 
           "High-Resolution Reference"]
titles = [r"Significant Wave Height $H_S$", r"Mean Wave Period $T_{m02}$",
          r"Mean Wave Direction $Dir$"]

width = 18

fig = plt.figure(figsize=(width, width*4/3))

gs = GridSpec(4, 3, hspace=0.4,  height_ratios=[1,1,1,1.2])

subg = gs[0:3,:].subgridspec(3, 3, hspace=0.02)



# Reconstructions
for nrow in range(3):
    for ncol in range(3):
        
        ax = fig.add_subplot(subg[nrow, ncol])
        
        # LR input
        ax.imshow(data[nrow][ncol], cmap="jet", origin="lower",
                            vmin=vmin[ncol], vmax=vmax[ncol])
        # Neural Network Prediction
        pos = ax.imshow(data[nrow][ncol], cmap="jet", origin="lower",
                                 vmin=vmin[ncol], vmax=vmax[ncol])
        # Ground Truth
        ax.imshow(data[nrow][ncol], cmap="jet", origin="lower",
                        vmin=vmin[ncol], vmax=vmax[ncol])
        
        
        ax.tick_params(axis='both', left=False, right=True,
                                    bottom=True, labelleft=False,
                                    labelright=False, labelbottom=False, 
                                    labelsize=TICKLABELSIZE, length=LENGTH,
                                    width=WIDTH)
        
        ax.set_xticks(x_loc)
        ax.set_yticks(y_loc)
        
        plt.subplots_adjust(wspace=0)
        
        # Add MAE text
        if nrow == 1:
            ax.text(99, 8, 'MAE: {:.1e}'.format(maes[ncol]),
                    bbox={'facecolor': 'white', 'pad': 10},
                    fontdict={'size': 15})
        
        # Add Colorbars
        if nrow == 2:
            
            cbar_ax0 = fig.add_axes([CB_LEFT[ncol], CB_BOTTOM, CB_WIDTH, CB_HEIGHT])
            cb = fig.colorbar(pos, ax=ax, cax=cbar_ax0, orientation="horizontal",
                              shrink=SHRINK)
            cb.set_label(label=clabels[ncol], size=CLABELSIZE)
            cbar_ax0.tick_params(labelsize=CTICKLABELSIZE) 
       
        # Add ylabels only in first column and titles only in first row
        if ncol == 0:
            ax.set_ylabel(ylabels[nrow], size=LABELSIZE)
        if nrow == 0:
            ax.set_title(titles[ncol], size=LABELSIZE)
        
        
        # Add tick labels on outermost rows and columns
        if nrow == 2:
            ax.tick_params(axis="x", labelbottom=True)
            ax.set_xticks(x_loc, x)
        if ncol == 2:
            ax.tick_params(axis="y", labelright=True)
            ax.set_yticks(y_loc, y)

            

# Difference Maps
for ncol in range(3):
    
    ax = fig.add_subplot(gs[3, ncol])
    
    # Add a scalebar
    scalebar = ScaleBar(0.05, "km", location="lower right", fixed_value=2,
                       scale_loc="top", font_properties={"size": 15})
    ax.add_artist(scalebar)
      
    # Get rid of erroneous values in Dir
    if ncol == 2:
        diffs[ncol][diffs[ncol] > 50] = np.nan

    # Difference Map
    pos = ax.imshow(diffs[ncol], cmap="jet", origin="lower")
    

    cbar_ax1 = fig.add_axes([CB_LEFT[ncol], 0.09, CB_WIDTH, CB_HEIGHT])
    cb = fig.colorbar(pos, ax=ax, cax=cbar_ax1, orientation="horizontal",
                      label=clabels[ncol], shrink=SHRINK)
    cb.set_label(label=clabels[ncol], size=CLABELSIZE)
    cbar_ax1.tick_params(labelsize=CTICKLABELSIZE) 
    
    
    ax.tick_params(axis='both', left=False, right=True,
                                    bottom=True, labelleft=False,
                                    labelright=False, labelbottom=False, 
                                    labelsize=TICKLABELSIZE, length=LENGTH,
                                    width=WIDTH)
    
    ax.set_xticks(x_loc)
    ax.set_yticks(y_loc)


    # Add ylabels only in first column and titles only in first row
    if ncol == 0:
        ax.set_ylabel("Difference of Reference & NN", size=LABELSIZE)


    # Add tick labels on outermost rows and columns
    ax.tick_params(axis="x", labelbottom=True)
    ax.set_xticks(x_loc, x)
    if ncol == 2:
        ax.tick_params(axis="y", labelright=True)
        ax.set_yticks(y_loc, y)

              

# Folder and filename if plot is to be saved
fdir = "Plots"
fname = "MainPlot_sample_{}_var_{}_model_{}.png"
fname = fname.format(sample_num, var, model_dir)

if save:
    if not os.path.isdir(fdir):
        os.makedirs(fdir)
    plt.savefig(os.path.join(fdir, fname), facecolor="white", bbox_inches="tight")
                         
plt.show()