In [46]:
import matplotlib
import os
import glob
import cmocean
import numpy as np
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from matplotlib import ticker
from matplotlib.colors import BoundaryNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cpf
from cartopy.feature import NaturalEarthFeature, LAND, COASTLINE

In [47]:
experiment = "SIC_Attention_Res_UNet"
lead_times = [0, 2, 4, 6, 8]
#
path_output = "/lustre/storeB/users/cyrilp/COSI/Figures/Article/"
#
sizefont = 28
#
map_proj = cartopy.crs.Stereographic(central_latitude = 0.0, central_longitude = 0.0)
LAND_highres = cpf.NaturalEarthFeature("physical", "land", "50m", edgecolor = "face", facecolor = "dimgrey", linewidth = 0.1)
map_extent = (-180, 180, 60, 90)
#
levels_RMSE = np.linspace(0, 25, 26)
norm_RMSE = BoundaryNorm(levels_RMSE, 256)
levels_RMSE_improvement = np.linspace(-40, 40, 41)
norm_RMSE_improvement = BoundaryNorm(levels_RMSE_improvement, 256)
colormap_RMSE = "gnuplot"
colormap_RMSE_improvement = "RdYlBu"
#
Nobs_min = 50
difference_threshold = 0
#
dataset = "ML"
list_references = ["Calib", "TOPAZ", "TOPAZ_bias_corrected", "Persistence", "Anomaly_Persistence"]

In [48]:
def get_stats_RMSE():
    Stats = {}
    Sum_squared_errors = {}
    N = np.zeros((len(lead_times), 544, 544))
    list_var = ["ML", "TOPAZ", "TOPAZ_bias_corrected", "Persistence", "Anomaly_Persistence"]
    #
    for var in list_var:
        Sum_squared_errors[var] = np.zeros((len(lead_times), 544, 544)) 
        Stats["RMSE_" + var] = np.zeros((len(lead_times), 544, 544))
    #
    for lt, leadtime in enumerate(lead_times):
        path_data = "/lustre/storeB/project/copernicus/cosi/WP3/Data/Predictions/" + experiment + "/lead_time_" + str(leadtime) + "_days/netCDF/"
        dataset = sorted(glob.glob(path_data + "Predictions_*.nc"))
        for fi, filename in enumerate(dataset):
            nc = Dataset(filename, "r")
            #
            if lt == 0 and fi == 0:
                Stats["lat"] = nc.variables["lat"][:,:]
                Stats["lon"] = nc.variables["lon"][:,:]
            #
            TARGET_SIC = nc.variables["TARGET_AMSR2_SIC"][:,:]
            SICobs = nc.variables["SICobs_AMSR2_SIC"][:,:]
            TOPAZ_SIC = nc.variables["TOPAZ_SIC"][:,:]
            Predicted_SIC = nc.variables["Predicted_SIC"][:,:]
            TOPAZbiascor = nc.variables["TOPAZ_bias_corrected"][:,:]
            Anomaly_persistence = nc.variables["Anomaly_persistence_SIC"][:,:]
            nc.close()
            #
            idx = np.logical_and(np.isnan(TARGET_SIC) == False, TARGET_SIC > 0)
            N[lt,:,:][idx == True] = N[lt,:,:][idx == True] + 1
            #
            Sum_squared_errors["ML"][lt,:,:] = Sum_squared_errors["ML"][lt,:,:] + (Predicted_SIC - TARGET_SIC)**2
            Sum_squared_errors["TOPAZ"][lt,:,:] = Sum_squared_errors["TOPAZ"][lt,:,:] + (TOPAZ_SIC - TARGET_SIC)**2
            Sum_squared_errors["TOPAZ_bias_corrected"][lt,:,:] = Sum_squared_errors["TOPAZ_bias_corrected"][lt,:,:] + (TOPAZbiascor - TARGET_SIC)**2
            Sum_squared_errors["Persistence"][lt,:,:] = Sum_squared_errors["Persistence"][lt,:,:] + (SICobs - TARGET_SIC)**2
            Sum_squared_errors["Anomaly_Persistence"][lt,:,:] = Sum_squared_errors["Anomaly_Persistence"][lt,:,:] + (Anomaly_persistence - TARGET_SIC)**2
        #
        for var in list_var:
            Stats["RMSE_" + var][lt,:,:] = np.sqrt(Sum_squared_errors[var][lt,:,:] / N[lt,:,:])
            Stats["RMSE_" + var][lt,:,:][N[lt,:,:] < Nobs_min] = np.nan
    #
    return(Stats)

In [49]:
def make_maps_RMSE_and_RMSE_improvement(Stats, dataset, references, saving = True):
    #
    caption_ticks = ["a)", "b)", "c)", "d)", "e)", "f)", "g)", "h)", "i)", "j)"]
    lab_dataset = dataset.replace("_", " ").replace("ML", "Calibrated forecasts").replace("TOPAZ", "TOPAZ4")
    #
    fig, big_axes = plt.subplots(figsize = (30, 33) , nrows = len(references), ncols = 1, sharey = True) 
    plt.rc('xtick', labelsize = sizefont)
    plt.rc('ytick', labelsize = sizefont)
    #
    for row, big_ax in enumerate(big_axes):
        big_ax.axis('off')
        if row == 0:
            big_ax.set_title(("RMSE Calibrated forecasts").replace("TOPAZ", "TOPAZ4").replace("_", " "), fontsize = sizefont * 1.2, pad = 30, fontweight = "bold")
            for lt, leadtime in enumerate(lead_times):
                axs = fig.add_subplot(len(references), len(lead_times), lt + 1, projection = map_proj)
                axs.set_extent(map_extent, crs = cartopy.crs.PlateCarree())
                axs.add_feature(LAND_highres, zorder = 1)
                cs = axs.pcolormesh(Stats["lon"], Stats["lat"], Stats["RMSE_" + dataset][lt, 0:543, 0:543], transform = ccrs.PlateCarree(), norm = norm_RMSE, cmap = colormap_RMSE, zorder = 0, shading = "flat")
                if leadtime == 0:
                    axs.set_title("Lead time: " + str(leadtime + 1) + " day", fontsize = sizefont)
                else:
                    axs.set_title("Lead time: " + str(leadtime + 1) + " days", fontsize = sizefont) 
            #
            cbar_ax = fig.add_axes([0.91, 0.74, 0.02, 0.15])
            cbar = fig.colorbar(cs, cax = cbar_ax, ticks = levels_RMSE[1:-1], extend = "both")
            cbar.set_label("RMSE (%)", fontsize = sizefont)
            tick_locator = ticker.MaxNLocator(nbins = 8)
            cbar.locator = tick_locator
            cbar.update_ticks()
        else:
            big_ax.set_title(("RMSE improvement compared to " + references[row]).replace("TOPAZ", "TOPAZ4").replace("_", " "), fontsize = sizefont * 1.2, pad = 30, fontweight = "bold")
            for lt, leadtime in enumerate(lead_times):
                RMSE_improvement = 100 * (1 - Stats["RMSE_" + dataset][lt, 0:543, 0:543] / Stats["RMSE_" + references[row]][lt, 0:543, 0:543])
                axs = fig.add_subplot(len(references), len(lead_times), row * len(lead_times) + lt + 1, projection = map_proj)
                axs.set_extent(map_extent, crs = cartopy.crs.PlateCarree())
                axs.add_feature(LAND_highres, zorder = 1)
                cs = axs.pcolormesh(Stats["lon"], Stats["lat"], RMSE_improvement, transform = ccrs.PlateCarree(), norm = norm_RMSE_improvement, cmap = colormap_RMSE_improvement, zorder = 0, shading = "flat")
                if leadtime == 0:
                    axs.set_title("Lead time: " + str(leadtime + 1) + " day", fontsize = sizefont)
                else:
                    axs.set_title("Lead time: " + str(leadtime + 1) + " days", fontsize = sizefont) 
            #
            cbar_ax = fig.add_axes([0.91, 0.74 - row * 0.16, 0.02, 0.15])
            cbar = fig.colorbar(cs, cax = cbar_ax, ticks = levels_RMSE_improvement[1:-1], extend = "both")
            cbar.set_label("RMSE improvement (%)", fontsize = sizefont)
            tick_locator = ticker.MaxNLocator(nbins = 8)
            cbar.locator = tick_locator
            cbar.update_ticks()
    #
    if saving == True:
        plt.savefig(path_output + "Maps_RMSE_and_RMSE_improvement_2022.png", dpi = 400, bbox_inches = "tight", transparent = False)
        plt.close()
    else:
        plt.show()

In [50]:
Stats_RMSE = get_stats_RMSE()
make_maps_RMSE_and_RMSE_improvement(Stats_RMSE, dataset, list_references, saving = True)

  Stats["RMSE_" + var][lt,:,:] = np.sqrt(Sum_squared_errors[var][lt,:,:] / N[lt,:,:])
  Stats["RMSE_" + var][lt,:,:] = np.sqrt(Sum_squared_errors[var][lt,:,:] / N[lt,:,:])
