In [2]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec

ERROR 1: PROJ: proj_create_from_database: Open of /nesi/project/niwa00018/queenle/ml_env_v2/share/proj failed


In [3]:
# directories
metric_dir = '/nesi/project/niwa00018/queenle/ML_emulator_temporal_sampling_experiments/inference/output/metrics'
output_dir = '/nesi/project/niwa00018/queenle/ML_emulator_temporal_sampling_experiments/inference/output'

plot_dir = '/nesi/project/niwa00018/queenle/ML_emulator_temporal_sampling_experiments/plotting/plots/maps'
final_figure_dir = '/nesi/project/niwa00018/queenle/ML_emulator_temporal_sampling_experiments/plotting/final_figures'

ccam_ds = xr.open_dataset('/nesi/project/niwa00018/ML_downscaling_CCAM/multi-variate-gan/inputs/target_fields/target_fields_hist_ssp370_concat.nc')

In [4]:
def create_figure(fig_width=14, fig_height=16):

    fig = plt.figure(figsize=(fig_width, fig_height))
    
    # 5 rows, 3 columns
    gs = gridspec.GridSpec(5, 4, wspace=0.1, hspace=0.05)

    axes = []

    # First two columns: 5 subplots each
    for row in range(5):
        for col in range(2):  # columns 0 and 1
            ax = fig.add_subplot(gs[row, col], projection=ccrs.PlateCarree(central_longitude=180))
            ax.set_extent([164.86, 183.9646, -51.21329, -32.86], crs=ccrs.PlateCarree())
            ax.coastlines(resolution='10m')
            axes.append(ax)

    # Third column: one subplot centered vertically across rows 1–3 (rows 1:4 in 0-based indexing)
    ax_right = fig.add_subplot(gs[1:4, 2:], projection=ccrs.PlateCarree(central_longitude=180))
    ax_right.set_extent([164.86, 183.9646, -51.21329, -32.86], crs=ccrs.PlateCarree())
    ax_right.coastlines(resolution='10m')
    axes.append(ax_right)

    return fig, [axes[:2],axes[2:4],axes[4:6],axes[6:8],axes[8:10],axes[10]]

#fig, axs = create_figure()
#plt.show()


In [114]:
'''
CLIMOS
'''

def plot_map(ax,gcm,sampling_n,framework,ml_type,epoch,metric,label,cmap):
    
    ds = xr.open_dataset(f'{metric_dir}/{gcm}/pr_ACCESS-CM2_{sampling_n}/{gcm}_{framework}_pr_metrics_{ml_type}_epoch_{epoch}.nc')

    da = ds[f'{metric}_{period}']

    im = da.plot(ax=axs[i][j],vmin=0,vmax=v_val,add_colorbar=False,cmap=cmap,transform=ccrs.PlateCarree())

    max_val = da.max().values.tolist()
    ax.text(0.5,0.1, f'{label} max = {max_val:.0f}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=20)

    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('')
    
    return(im)

def plot_CCAM(ax,gcm,metric,period,label,cmap):
    
    ccam_metrics = xr.open_dataset(f'{metric_dir}/{gcm}/CCAM/{gcm}_pr_metrics.nc')
    ccam_da = ccam_metrics[f'{metric}_{period}']
    
    im = ccam_da.plot(ax=ax,vmin=0,vmax=v_val,add_colorbar=False,cmap=cmap,transform=ccrs.PlateCarree())

    max_val = ccam_da.max().values.tolist()
    ax.text(0.5,0.1, f'{label} max = {max_val:.0f}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=20,color='black')

    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('CCAM',fontsize=26)

colors = ["#ffffff", "#1f78b4", "#ffff00"]#, "#6a0dad"]  # white, blue, orange, purple
#colors = ["#ffffff", "#1f78b4", "#ffff00", "#6a0dad"]  # white, blue, orange, purple

cmap = mcolors.LinearSegmentedColormap.from_list("white_blue_orange_purple", colors)

capitalized = {'GAN':'GAN','unet':'U-Net','rx1d':'Rx1d','annual_mean':'Annual Mean',\
               'DJF_mean':'DJF Mean','JJA_mean':'JJA Mean','total_max':'Total Max',\
               'perfect':'Perfect','imperfect':'Imperfect'}

cmaps = {'rx1d':cmap,'total_max':cmap,'annual_mean':'BrBG','DJF_mean':'BrBG','JJA_mean':'BrBG'}
v_val_dict = {'rx1d':150,'total_max':600,'annual_mean':20,'DJF_mean':20,'JJA_mean':20}

epoch = 230
extent = [164.86, 183.9646, -51.21329, -32.86]
framework = 'imperfect'

labels = [['(a)','(b)'],
          ['(c)','(d)'],
          ['(e)','(f)'],
          ['(g)','(h)'],
          ['(i)','(j)'],
          '(k)']

for metric in ['total_max','rx1d','annual_mean','DJF_mean','JJA_mean']:
    
    v_val = v_val_dict[metric]
    cmap = cmaps[metric]
    
    for gcm in ['EC-Earth3','NorESM2-MM']:
        for period in ['1985-2004','2080-2099']:
            for framework in ['perfect','imperfect']:
                
                fig, axs = create_figure()

                # plot CCAM 
                plot_CCAM(axs[-1],gcm,metric,period,labels[-1],cmap)

                # ML MAPS
                axs = np.array(axs[:-1])
                label_subset = labels[:-1]
                for i,sampling_n in enumerate(['5','20','50','100','140']):
                    for j,ml_type in enumerate(['unet','GAN']):

                        im = plot_map(axs[i][j],gcm,sampling_n,framework,ml_type,epoch,metric,label_subset[i][j],cmap)

                        if j == 0:
                            axs[i][j].text(-0.15, 0.5, sampling_n, va='center', ha='center',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=26)

                        if i == 0:
                            axs[i][j].text(0.5, 1.1, capitalized[ml_type], va='center', ha='center',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=24)


                fig.text(0.05,0.5,'Stratified Random Sampling (years)',fontsize=26, va='center', ha='center',rotation='vertical')
                cbar_ax = fig.add_axes([0.2, 0.06, 0.6, 0.015])  # [left, bottom, width, height]
                cb = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')

                cb.ax.tick_params(labelsize=24)
                cb.set_label('mm/day', fontsize=26)

                plt.savefig(f'{plot_dir}/{gcm}_{framework}_{metric}_{period}_epoch_{epoch}.png', bbox_inches='tight')
                #plt.savefig(f'{final_figure_dir}/Figure3.png',dpi=600, bbox_inches='tight')#, bbox_extra_artists=(legend,title_artist))
                plt.close()
                

In [13]:
'''
CC Signals
'''

def plot_map(ax,gcm,sampling_n,framework,ml_type,epoch,metric,label,cmap):
    
    ds = xr.open_dataset(f'{metric_dir}/{gcm}/pr_ACCESS-CM2_{sampling_n}/{gcm}_{framework}_pr_metrics_{ml_type}_epoch_{epoch}.nc')

    da = get_CC_signal(ds,metric)
    
    im = da.plot(ax=axs[i][j],vmin=-v_val,vmax=v_val,add_colorbar=False,cmap=cmap,transform=ccrs.PlateCarree())

    ax.text(0.5,0.1, f'{label}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=20)

    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('')
    
    return(im)

def get_CC_signal(ds,metric):
    
    base_period = ds[f'{metric}_1985-2004']
    future_period = ds[f'{metric}_2080-2099']
    
    CC_signal = ((future_period - base_period)/base_period) * 100
        
    return(CC_signal)
    
    
def plot_CCAM(ax,gcm,metric,label,cmap):
    
    ccam_metrics = xr.open_dataset(f'{metric_dir}/{gcm}/CCAM/{gcm}_pr_metrics.nc')
    ccam_da = get_CC_signal(ccam_metrics,metric)
    
    im = ccam_da.plot(ax=ax,vmin=-v_val,vmax=v_val,add_colorbar=False,cmap=cmap,transform=ccrs.PlateCarree())

    ax.text(0.5,0.1, f'{label}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=20,color='black')

    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('CCAM',fontsize=26)

capitalized = {'GAN':'GAN','unet':'U-Net','rx1d':'Rx1d','annual_mean':'Annual Mean',\
               'DJF_mean':'DJF Mean','JJA_mean':'JJA Mean','total_max':'Total Max',\
               'perfect':'Perfect','imperfect':'Imperfect'}

v_val_dict = {'annual_mean':50,'DJF_mean':50,'JJA_mean':40,'rx1d':70,'total_max':150}

epoch = 230
extent = [164.86, 183.9646, -51.21329, -32.86]

labels = [['(a)','(b)'],
          ['(c)','(d)'],
          ['(e)','(f)'],
          ['(g)','(h)'],
          ['(i)','(j)'],
          '(k)']

for metric in ['annual_mean','DJF_mean','JJA_mean','total_max','rx1d',]:
    
    v_val = v_val_dict[metric]
    
    for gcm in ['EC-Earth3','NorESM2-MM']:
        for framework in ['perfect','imperfect']:

            fig, axs = create_figure()

            # plot CCAM 
            plot_CCAM(axs[-1],gcm,metric,labels[-1],'BrBG')

            # ML MAPS
            axs = np.array(axs[:-1])
            label_subset = labels[:-1]
            for i,sampling_n in enumerate(['5','20','50','100','140']):
                for j,ml_type in enumerate(['unet','GAN']):

                    im = plot_map(axs[i][j],gcm,sampling_n,framework,ml_type,epoch,metric,label_subset[i][j],'BrBG')

                    if j == 0:
                        axs[i][j].text(-0.15, 0.5, sampling_n, va='center', ha='center',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=26)

                    if i == 0:
                        axs[i][j].text(0.5, 1.1, capitalized[ml_type], va='center', ha='center',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=24)


            fig.text(0.05,0.5,'Stratified Random Sampling (years)',fontsize=26, va='center', ha='center',rotation='vertical')
            cbar_ax = fig.add_axes([0.2, 0.06, 0.6, 0.015])  # [left, bottom, width, height]
            cb = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')

            cb.ax.tick_params(labelsize=24)
            cb.set_label('%', fontsize=26)

            plt.savefig(f'{plot_dir}/{gcm}_{framework}_{metric}_CC-Signal_epoch_{epoch}.png', bbox_inches='tight')
            plt.close()
                
                

In [3]:
'''
EVENTS
'''

boundaries2 = [0, 5,12.5, 15, 20, 25,30, 35, 40, 50, 60, 70, 80, 100, 125, 150, 200, 250]
colors2 = [[0.000, 0.000, 0.000, 0.000], [0.875, 0.875, 0.875, 0.784],\
          [0.761, 0.761, 0.761, 1.000], [0.639, 0.886, 0.871, 1.000], [0.388, 0.773, 0.616, 1.000],\
          [0.000, 0.392, 0.392, 0.588], [0.000, 0.576, 0.576, 0.667], [0.000, 0.792, 0.792, 0.745],\
          [0.000, 0.855, 0.855, 0.863], [0.212, 1.000, 1.000, 1.000], [0.953, 0.855, 0.992, 1.000],\
          [0.918, 0.765, 0.992, 1.000], [0.918, 0.612, 1.000, 1.000], [0.878, 0.431, 1.000, 1.000],\
          [0.886, 0.349, 1.000, 1.000], [0.651, 0.004, 0.788, 1.000], [0.357, 0.008, 0.431, 1.000],\
          [0.180, 0.000, 0.224, 1.000]]
#reviated for clarity
 
# Create the colormap using ListedColormap
cmap = mcolors.ListedColormap(colors2)
norm = mcolors.BoundaryNorm(boundaries2, cmap.N)


def plot_CCAM(gcm,event):
    
    fig,ax = plt.subplots(figsize=(8,6), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=175)})
    
    ccam_da = ccam_ds.sel(GCM=gcm,time=event)['pr']
    ccam_da = ccam_da*86400 # convert from flux to mm/day
    im = ccam_da.squeeze().plot.contourf(ax=ax,add_colorbar=False,cmap=cmap,norm=norm,extend='max',transform=ccrs.PlateCarree())#cmap='viridis',vmin=0,vmax=200
    max_val = ccam_da.max().values.tolist()
    ax.text(0.8,0.2, f'MAX:\n{max_val:.0f}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=16)

    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('CCAM',fontsize=20)

    cb = fig.colorbar(im,shrink=0.8)
    cb.ax.tick_params(labelsize=16)
    cb.set_label('pr (mm/day)', fontsize=20)

    plt.savefig(f'maps/CCAM_{gcm}_{event}.png')
    plt.close()
    
def plot_NIWA_REMS(gcm,event,framework,ml_type):
        
    fig,ax = plt.subplots(figsize=(8,6), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=175)})

    if framework == 'perfect':
        da = xr.open_dataarray(f'/nesi/project/niwa00018/queenle/NIWA-REMS_inference/output/{gcm}/NIWA-REMS_v110425_pr/{gcm}_pr_ssp370_{framework}_framework_{ml_type}.nc')
    else:
        if event == '1987-03-26':
            da = xr.open_dataarray(f'/nesi/project/niwa00018/queenle/NIWA-REMS_inference/output/{gcm}/NIWA-REMS_v110425_pr/{gcm}_pr_historical_{framework}_framework_{ml_type}.nc')
        elif event == '2095-01-20':
            da = xr.open_dataarray(f'/nesi/project/niwa00018/queenle/NIWA-REMS_inference/output/{gcm}/NIWA-REMS_v110425_pr/{gcm}_pr_ssp370_{framework}_framework_{ml_type}.nc')

    da = da.sel(time=event).clip(0,1500)
    im = da.squeeze().plot.contourf(ax=ax,add_colorbar=False,cmap=cmap,norm=norm,extend='max',transform=ccrs.PlateCarree())#,cmap='viridis',vmin=0,vmax=200

    max_val = da.max().values.tolist()
    ax.text(0.8,0.2, f'MAX:\n{max_val:.0f}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=16)
    
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title(f'NIWA-REMS ({ml_type})',fontsize=20)
    
    plt.savefig(f'maps/NIWA-REMS_{gcm}_{framework}_{event}_{ml_type}.png')
    plt.close()
    
def plot_ML(ax,gcm,sampling_n,framework,ml_type,epoch,event):
    
    if framework == 'perfect':
        da = xr.open_dataarray(f'{output_dir}/{gcm}/pr_ACCESS-CM2_{sampling_n}/{gcm}_pr_ssp370_{framework}_framework_{ml_type}_epoch_{epoch}.nc')
    else:
        if event == '1987-03-26':
            da = xr.open_dataarray(f'{output_dir}/{gcm}/pr_ACCESS-CM2_{sampling_n}/{gcm}_pr_historical_{framework}_framework_{ml_type}_epoch_{epoch}.nc')
        elif event == '2095-01-20':
            da = xr.open_dataarray(f'{output_dir}/{gcm}/pr_ACCESS-CM2_{sampling_n}/{gcm}_pr_ssp370_{framework}_framework_{ml_type}_epoch_{epoch}.nc')

    da = da.sel(time=event).clip(0,1500)
    im = da.squeeze().plot.contourf(ax=axs[i][j],add_colorbar=False,cmap=cmap,norm=norm,extend='max',transform=ccrs.PlateCarree())#,cmap='viridis',vmin=0,vmax=200

    max_val = da.max().values.tolist()
    ax.text(0.8,0.2, f'MAX:\n{max_val:.0f}', va='center', ha='center',rotation='horizontal', transform=ax.transAxes,fontsize=16)
    
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m')
    ax.set_title('')
    
    return(im)

gcm = 'NorESM2-MM'
epoch = 230
extent = [164, 184, -52, -32]
event1 = '1987-03-26'
event2 = '2095-01-20'

for event in [event1,event2]:
    # plot CCAM 
    #plot_CCAM(gcm,event)

    # plot ML results
    for framework in ['perfect','imperfect']:
        fig,axs = plt.subplots(2,5,figsize=(12,5), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=175)},layout='constrained')

        for i,ml_type in enumerate(['GAN','unet']):
            # plot NIWA-REMS
            #plot_NIWA_REMS(gcm,event,framework,ml_type)

            for j,sampling_n in enumerate(['5','20','50','100','140']):

                im = plot_ML(axs[i][j],gcm,sampling_n,framework,ml_type,epoch,event)

                if i == 0:
                    axs[i][j].set_title(sampling_n,fontsize=20)

                if j == 0:
                    axs[i][j].text(-0.1, 0.5, ml_type, va='center', ha='center',rotation='vertical', transform=axs[i][j].transAxes,fontsize=20)

        cb = fig.colorbar(im,ax=axs[:, :],shrink=0.8)
        cb.ax.tick_params(labelsize=16)
        cb.set_label('pr (mm/day)', fontsize=20)

        plt.savefig(f'{plot_dir}/{gcm}_{framework}_{event}_epoch_{epoch}.png')
        plt.close()