In [1]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import glob
import os
import matplotlib.patches as mpatches


ERROR 1: PROJ: proj_create_from_database: Open of /nesi/project/niwa00018/queenle/ml_env_v2/share/proj failed


In [3]:
font = {
    'family': 'sans-serif',    # Clean sans-serif font family
    'sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],  # fallback fonts
    'weight': 'normal',        # normal or 'bold'
    'size': 16                 # font size in points
}

In [4]:
static_ds = xr.open_dataset('/nesi/project/niwa00018/ML_downscaling_CCAM/training_GAN/ancil_fields/ERA5_eval_ccam_12km.198110_NZ_Invariant.nc')
land_mask = static_ds.sftlf

# metric directory
metric_dir = '/nesi/project/niwa00018/queenle/NIWA-REMS_inference/output/metrics/'

# plot output directory
output_dir = '/nesi/project/niwa00018/queenle/NIWA-REMS_inference/plotting/maps/'

emulators = {'pr':'NIWA-REMS_v110425_pr','tasmax':'NIWA-REMS_tasmax_v050425'}

In [5]:
def get_CC_signal(ds,var,metric,base_period,future_period):
    
    base_period = ds[f'{metric}_{base_period[0]}-{base_period[1]}']
    future_period = ds[f'{metric}_{future_period[0]}-{future_period[1]}']
    
    if var != 'pr':
        CC_signal = future_period - base_period
    else:
        CC_signal = ((future_period - base_period)/base_period) * 100
        
    return(CC_signal)
    


In [6]:
def create_figure():
    
    fig = plt.figure(figsize=(7, 9))
    axs = []
    
    nrows, ncols = 3, 3

    # Width of each subplot, step is smaller → overlap
    w, h = 0.38, 0.30      # axes size
    dx, dy = 0.33, 0.31    # horizontal and vertical step (dx < w = horizontal overlap)

    for i in range(nrows):
        row_axes = []
        for j in range(ncols):
            left = j * dx
            bottom = 1.0 - (i + 1) * dy
            ax = fig.add_axes([left, bottom, w, h],
                              projection=ccrs.PlateCarree())

            # Set NZ extent and clean up axis
            ax.set_extent([165, 180, -48, -33], crs=ccrs.PlateCarree())
            ax.coastlines(resolution='10m')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_facecolor('none')
            ax.set_frame_on(False)
            for spine in ax.spines.values():
                spine.set_visible(False)

            row_axes.append(ax)
        axs.append(row_axes)
        
    return fig, np.array(axs)


In [7]:
config = pd.read_csv('plotting_config.txt')

def plot(gcm,metric,var,metric_das,base_period,future_period,emulator,printed_error='default'):

    fig,axs = create_figure()
    
    clim_vmin,clim_vmax,cc_vmin,cc_vmax,clim_err_vmin,clim_err_vmax,cc_err_vmin,cc_err_vmax,clim_cmap,err_type,unit = config[(config['variable'] == var) & (config['metric']==metric)].values[0][3:]

    '''
    CCAM in first row
    '''
    i = 0
    for j,period in enumerate([base_period,future_period,'CC signal']):

        if period != 'CC signal':
            ccam_da = metric_das['CCAM'][f'{metric}_{period[0]}-{period[1]}']
            ccam_clim_im = ccam_da.plot(ax=axs[i][j],vmin=clim_vmin,vmax=clim_vmax,add_colorbar=False,transform = ccrs.PlateCarree())
        else:
            ccam_cc = get_CC_signal(metric_das['CCAM'],var,metric,base_period,future_period)
            ccam_cc_im = ccam_cc.plot(ax=axs[i][j],vmin=cc_vmin,vmax=cc_vmax,cmap='BrBG' if var == 'pr' else 'YlOrRd',add_colorbar=False,transform = ccrs.PlateCarree())
    
    '''
    ML results rows 2+3
    '''
    for i,framework in enumerate([f'{gcm}_perfect',f'{gcm}_imperfect']):
        i += 1
        for j,period in enumerate([base_period,future_period,'CC signal']):

            if period != 'CC signal':
                ccam_da = metric_das['CCAM'][f'{metric}_{period[0]}-{period[1]}']
                ml_da = metric_das[framework][f'{metric}_{period[0]}-{period[1]}']
                
                if err_type == 'mape':
                    error = ((ml_da - ccam_da)/ccam_da) * 100
                elif err_type == 'mae':
                    error = ml_da - ccam_da
                    
                if printed_error == 'default':
                    mean = abs(error).mean(['lat','lon'])
                    axs[i][j].text(0.16, 0.58, f'{err_type}\n {mean:.2f}',transform=axs[i][j].transAxes,ha='left', va='bottom',fontsize=20)#,bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.6))

                elif printed_error == 'rmse':
                    rmse = np.sqrt(((ml_da - ccam_da)**2).mean({"lat","lon"}))
                    axs[i][j].text(0.16, 0.58, f'RMSE\n {rmse:.2f}',transform=axs[i][j].transAxes,ha='left', va='bottom',fontsize=20)#,bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.6))
                    
                ml_error_im = error.plot(ax=axs[i][j],vmin=clim_err_vmin,vmax=clim_err_vmax,cmap='RdBu_r',add_colorbar=False,transform = ccrs.PlateCarree())

            else:
                ccam_cc = get_CC_signal(metric_das['CCAM'],var,metric,base_period,future_period)
                ml_cc = get_CC_signal(metric_das[framework],var,metric,base_period,future_period)

                cc_error = ml_cc - ccam_cc
                
                if printed_error == 'default':
                    mean = abs(cc_error).mean(['lat','lon'])
                    axs[i][j].text(0.16, 0.58, f'{err_type}\n {mean:.2f}',transform=axs[i][j].transAxes,ha='left', va='bottom',fontsize=20)#,bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.6))

                elif printed_error == 'rmse':
                    rmse = np.sqrt(((ml_cc - ccam_cc)**2).mean({"lat","lon"}))
                    axs[i][j].text(0.16, 0.58, f'RMSE\n {rmse:.2f}',transform=axs[i][j].transAxes,ha='left', va='bottom',fontsize=20)#,bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.6))
                    
                cc_error_im = cc_error.plot(ax=axs[i][j],vmin=cc_err_vmin,vmax=cc_err_vmax,cmap='RdBu_r',add_colorbar=False,transform = ccrs.PlateCarree())


    for ax in axs.flatten():
        ax.coastlines('10m')
        ax.set_extent([166,179,-48,-34])
        ax.set_axis_off() 
        ax.set_title('')
        
    for i,source in enumerate(['CCAM','PERFECT','IMPERFECT']):
        axs[i][0].text(-0.1, 0.5, source, transform=axs[i][0].transAxes,rotation='vertical', va='center', ha='right',fontsize=22)
        
    for i,period in enumerate(['1985-2004','2080-2099','CC Signal']):
        axs[0][i].text(0.5, 1.05, period, transform=axs[0][i].transAxes,ha='center', va='bottom', fontsize=22)

    for i, row in enumerate(axs):
        for j, ax in enumerate(row):
            pos = ax.get_position()
            ax.set_position([
                pos.x0 - 0.5,  # Shift left for overlap
                pos.y0,
                pos.width,
                pos.height
            ])
            
    cb1 = plt.colorbar(ccam_clim_im,ax=axs[0,0:2], shrink=0.85, label=unit)
    cb2 = plt.colorbar(ml_error_im,ax=axs[1:,0:2],shrink=0.45,label=f'Bias ({unit})' if err_type == 'mae' else 'Bias (%)')
    cb2.ax.yaxis.set_label_position('left')
    cb3 = plt.colorbar(ccam_cc_im,ax=axs[:1,2], shrink=0.85, label=unit if err_type == 'mae' else '(%)')
    cb4 = plt.colorbar(cc_error_im,ax=axs[1:,2],shrink=0.45,label=f'Bias ({unit})' if err_type == 'mae' else 'Bias (%)')
    
    for cb in [cb1,cb2,cb3,cb4]:
        cb.ax.yaxis.label.set_size(20)
        cb.ax.tick_params(labelsize=16)

    #plt.suptitle(f'Emulator: {emulator}')
    
    if not os.path.exists(f'{output_dir}/{gcm}/{emulator}/'):
        os.makedirs(f'{output_dir}/{gcm}/{emulator}/')
            
    plt.savefig(f'{output_dir}/{gcm}/{emulator}/{gcm}_{var}_{metric}_{printed_error}_errors.png',dpi=300, bbox_inches='tight')
    
    #fig.subplots_adjust(wspace=-0.5)
    
    plt.close()


In [8]:
base_period = ['1985','2004']
future_period = ['2080','2099']

metrics = {'tasmax':['annual_mean','TXx','DJF_mean'],
           'tasmin':['annual_mean','TXn'],
           'sfcwind':['annual_mean'],
           'pr':['annual_mean','rx1d','DJF_mean','JJA_mean']}

for gcm in ['ACCESS-CM2','NorESM2-MM','EC-Earth3']:
    print(gcm)
    for var in ['pr','tasmax']:#'pr'
        print(f'\t-{var}')
        
        metric_das = {}
        
        CCAM_metrics = xr.open_dataset(f'{metric_dir}/{gcm}/CCAM/{gcm}_{var}_metrics.nc').sel(lat=slice(-48,-34),lon=slice(166,179))
        CCAM_metrics = CCAM_metrics.where(land_mask)
        
        metric_das['CCAM'] = CCAM_metrics
        
        emulator = emulators[var]
            
        print(f'\t\t-{emulator}')

        for ml_training,framework in [['perfect','perfect'],['perfect','imperfect']]:#,['imperfect','imperfect']]:
            ds = xr.open_dataset(f'{metric_dir}/{gcm}/{emulator}/{gcm}_{framework}_{var}_metrics_GAN.nc').sel(lat=slice(-48,-34),lon=slice(166,179))
            metric_das[f'{gcm}_{framework}'] = ds.where(land_mask)

        for metric in metrics[var]:
            print(f'\t\t\t-{metric}')

            for printed_err in ['default','rmse']:
                plot(gcm,metric,var,metric_das,base_period,future_period,emulator,printed_err)


ACCESS-CM2
	-pr
		-NIWA-REMS_v110425_pr
			-annual_mean
			-rx1d
			-DJF_mean
			-JJA_mean
	-tasmax
		-NIWA-REMS_tasmax_v050425
			-annual_mean
			-TXx
			-DJF_mean
NorESM2-MM
	-pr
		-NIWA-REMS_v110425_pr
			-annual_mean
			-rx1d
			-DJF_mean
			-JJA_mean
	-tasmax
		-NIWA-REMS_tasmax_v050425
			-annual_mean
			-TXx
			-DJF_mean
EC-Earth3
	-pr
		-NIWA-REMS_v110425_pr
			-annual_mean
			-rx1d
			-DJF_mean
			-JJA_mean
	-tasmax
		-NIWA-REMS_tasmax_v050425
			-annual_mean
			-TXx
			-DJF_mean


In [None]:
'''
OLD
'''


config = pd.read_csv('plotting_config.txt')

def plot(gcm,metric,var,metric_das,base_period,future_period,emulator):

    fig,axs = plt.subplots(3,3,figsize=(12,8),sharex=True,sharey=True,layout='constrained',subplot_kw={'projection': ccrs.PlateCarree(central_longitude=171.77)})

    clim_vmin,clim_vmax,cc_vmin,cc_vmax,clim_err_vmin,clim_err_vmax,cc_err_vmin,cc_err_vmax,clim_cmap,err_type,unit = config[(config['GCM'] == gcm) & (config['variable'] == var) & (config['metric']==metric)].values[0][4:]


    for i,period in enumerate([base_period,future_period]):

        ccam_da = metric_das['CCAM'][f'{metric}_{period[0]}-{period[1]}']
        ccam_clim_im = ccam_da.plot(ax=axs[i][0],vmin=clim_vmin,vmax=clim_vmax,add_colorbar=False,transform = ccrs.PlateCarree())
        if i == 0:
            axs[i][0].set_title(f'CCAM\n{period[0]}-{period[1]}')
        else:
            axs[i][0].set_title(f'{period[0]}-{period[1]}')

        for j,da in enumerate([f'{gcm}_perfect',f'{gcm}_imperfect']):#,'imperfect_imperfect']):

            ml_da = metric_das[da][f'{metric}_{period[0]}-{period[1]}']

            if var == 'pr':
                error = ((ml_da - ccam_da)/ccam_da) * 100
            else:
                error = ml_da - ccam_da

            ml_error_im = error.plot(ax=axs[i][j+1],vmin=clim_err_vmin,vmax=clim_err_vmax,cmap='RdBu_r',add_colorbar=False,transform = ccrs.PlateCarree())

            if i == 0:
                axs[i][j+1].set_title(da)
            else:
                axs[i][j+1].set_title('')

    ccam_cc = get_CC_signal(metric_das['CCAM'],var,metric,base_period,future_period)
    ccam_cc_im = ccam_cc.plot(ax=axs[2][0],vmin=cc_vmin,vmax=cc_vmax,cmap='RdBu_r',add_colorbar=False,transform = ccrs.PlateCarree())
    axs[2][0].set_title('CC signal')

    for i,da in enumerate([f'{gcm}_perfect',f'{gcm}_imperfect']):#,'imperfect_imperfect']):

        ml_cc = get_CC_signal(metric_das[da],var,metric,base_period,future_period)

        cc_error = ml_cc - ccam_cc
        cc_error_im = cc_error.plot(ax=axs[2][i+1],vmin=cc_err_vmin,vmax=cc_err_vmax,cmap='RdBu_r',add_colorbar=False,transform = ccrs.PlateCarree())

        axs[2][i+1].set_title('')

    for ax in axs.flatten():
        ax.coastlines('10m')

    plt.colorbar(ccam_clim_im,ax=axs[0:2,0:1], shrink=0.85, label=var + ' ' + metric + ' ' + unit)
    plt.colorbar(ml_error_im,ax=axs[0:2,1:4],shrink=0.85,label=err_type)
    plt.colorbar(ccam_cc_im,ax=axs[2,0:1], shrink=0.85, label=err_type)
    plt.colorbar(cc_error_im,ax=axs[2,1:4],shrink=0.85,label='mae')
    
    plt.suptitle(f'Emulator: {emulator}')
    
    if not os.path.exists(f'{output_dir}/{gcm}/{emulator}/'):
        os.makedirs(f'{output_dir}/{gcm}/{emulator}/')
            
    #plt.savefig(f'{output_dir}/{gcm}/{emulator}/{var}_{metric}_errors.png',dpi=300)
    
    plt.show()
    plt.close()
