In [1]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import glob
import os
import matplotlib.patches as mpatches
from sklearn.metrics import mean_squared_error


ERROR 1: PROJ: proj_create_from_database: Open of /nesi/project/niwa00018/queenle/ml_env_v2/share/proj failed


In [2]:
static_ds = xr.open_dataset('/nesi/project/niwa00018/ML_downscaling_CCAM/training_GAN/ancil_fields/ERA5_eval_ccam_12km.198110_NZ_Invariant.nc')
land_mask = static_ds.sftlf

# directories
metric_dir = '/nesi/project/niwa00018/queenle/ml_emulator_experiment_application/ml_downscaled_output/metrics/'
output_dir = '/nesi/project/niwa00018/queenle/ml_emulator_experiment_application/ml_downscaled_output/'
CCAM_dir = '/nesi/project/niwa00018/ML_downscaling_CCAM/multi-variate-gan/inputs/'

# TEMPORAL TRAINING EXPERIMENTS
temporal_experiments = {'pr':['pr_ACCESS-CM2_5','pr_ACCESS-CM2_10','pr_ACCESS-CM2_20','pr_ACCESS-CM2_30','pr_ACCESS-CM2_40',
                              'pr_ACCESS-CM2_50','pr_ACCESS-CM2_60','pr_ACCESS-CM2_70','pr_ACCESS-CM2_80','pr_ACCESS-CM2_90',
                              'pr_ACCESS-CM2_100','pr_ACCESS-CM2_1961-1980','pr_ACCESS-CM2_2015-2034','pr_ACCESS-CM2_2080-2099']}
        

metrics = {'tasmax':['annual_mean','TXx'],
           'tasmin':['annual_mean','TXn'],
           'sfcwind':['annual_mean'],
           'pr':['annual_mean','rx1d','DJF_mean','JJA_mean']}

base_period = ['1985','2004']
future_period = ['2080','2099']

In [3]:
'''
HELPER FUNCTIONS
'''

def get_CC_signal(ds,var,metric):
    
    base = ds[f'{metric}_{base_period[0]}-{base_period[1]}']
    future = ds[f'{metric}_{future_period[0]}-{future_period[1]}']
    
    if var != 'pr':
        CC_signal = future - base
    else:
        CC_signal = ((future - base)/base) * 100
        
    return(CC_signal)
    
    

def get_subregion_rmse(ccam,emulator,region):
    if region == 'full':
        rmse = (mean_squared_error(ccam.values.flatten(),emulator.values.flatten()))**(1/2)
        
    elif region == 'land':
        ccam = ccam.where(land_mask>0).values.flatten()
        ccam = ccam[~np.isnan(ccam)]
        emulator = emulator.where(land_mask>0).values.flatten()
        emulator = emulator[~np.isnan(emulator)]
        
        rmse = (mean_squared_error(ccam,emulator))**(1/2)
        
    elif region == 'ocean':
        ccam = ccam.where(land_mask==0).values.flatten()
        ccam = ccam[~np.isnan(ccam)]
        emulator = emulator.where(land_mask==0).values.flatten()
        emulator = emulator[~np.isnan(emulator)]
        
        rmse = (mean_squared_error(ccam,emulator))**(1/2)
        
    return(rmse)


def write_to_dict(result_dict,full,land,ocean,gcm,emulator,framework,period_name,var,metric,error_type,gan_flag,sampling):
    
    result_dict['whole region'].append(full)
    result_dict['land'].append(land)
    result_dict['ocean'].append(ocean)
    result_dict['GCM'].append(gcm)
    result_dict['emulator'].append(emulator)
    result_dict['framework'].append(framework)
    result_dict['error period'].append(period_name)
    result_dict['var'].append(var)
    result_dict['metric'].append(metric)
    result_dict['error type'].append(error_type)
    result_dict['model_type'].append(gan_flag)
    result_dict['sampling_n'].append(sampling)


In [4]:

def add_errors(result_dict,gcm,error_type,emulator,CCAM_metrics,ml_metrics,var,metric,period,gan_flag,sampling):
    
    if period == 'CC_signal':
        CCAM_signal = get_CC_signal(CCAM_metrics,var,metric)
        period_name = period
    elif period == 'base':
        CCAM_signal = CCAM_metrics[f'{metric}_{base_period[0]}-{base_period[1]}']
        period_name = f'{base_period[0]}_{base_period[1]}'
    elif period == 'future':
        CCAM_signal = CCAM_metrics[f'{metric}_{future_period[0]}-{future_period[1]}']
        period_name = f'{future_period[0]}_{future_period[1]}'
        
    for framework in ['perfect','imperfect']:
        if period == 'CC_signal':
            ml_signal = get_CC_signal(ml_metrics[framework],var,metric)
        elif period == 'base':
            ml_signal = ml_metrics[framework][f'{metric}_{base_period[0]}-{base_period[1]}']
        elif period == 'future':
            ml_signal = ml_metrics[framework][f'{metric}_{future_period[0]}-{future_period[1]}']

        if error_type == 'RMSE':
            full = get_subregion_rmse(CCAM_signal,ml_signal,'full')
            land = get_subregion_rmse(CCAM_signal,ml_signal,'land')
            ocean = get_subregion_rmse(CCAM_signal,ml_signal,'ocean')
            
        else:
            if error_type == 'MAE':
                error = abs(ml_signal-CCAM_signal)
                
            if error_type == 'MAPE':
                error = abs(((ml_signal-CCAM_signal)/CCAM_signal)*100)
            
            full = error.mean().data.tolist()
            land = error.where(land_mask>0).mean().data.tolist()
            ocean = error.where(land_mask==0).mean().data.tolist()

        write_to_dict(result_dict,full,land,ocean,gcm,emulator,framework,period_name,var,metric,error_type,gan_flag,sampling)
        
        

In [5]:
'''
COMPUTE ERRORS FOR TEMPORAL EXPERIMENTS
'''

result_dict = {'GCM':[],'emulator':[],'framework':[],'error period':[],'var':[],'metric':[],'error type':[],\
               'whole region':[],'land':[],'ocean':[],'sampling_n':[],'model_type':[]}

epoch = 215

for i,gan_flag in enumerate(['GAN','unet']):
    
    for j,gcm in enumerate(['ACCESS-CM2','NorESM2-MM','EC-Earth3']):

        for var in ['pr']:
            
            for k,emulator in enumerate(temporal_experiments[var]):
                
                sampling = emulator.split('_')[-1]

                CCAM_metrics = xr.open_dataset(f'{metric_dir}/{gcm}/CCAM/{gcm}_{var}_metrics.nc')

                ml_metrics = {}
                for framework in ['perfect','imperfect']:
                    ml_metrics[framework] = xr.open_dataset(f'{metric_dir}/{gcm}/{emulator}/{gcm}_{framework}_{var}_metrics_{gan_flag}_epoch_{epoch}.nc')
                    
                for metric in metrics[var]:

                    for error_type in ['MAE','MAPE','RMSE']:
                        add_errors(result_dict,gcm,error_type,emulator,CCAM_metrics,ml_metrics,var,metric,'CC_signal',gan_flag,sampling)
                        add_errors(result_dict,gcm,error_type,emulator,CCAM_metrics,ml_metrics,var,metric,'base',gan_flag,sampling)
                        add_errors(result_dict,gcm,error_type,emulator,CCAM_metrics,ml_metrics,var,metric,'future',gan_flag,sampling)

            
error_df = pd.DataFrame.from_dict(result_dict)



In [7]:
error_df.to_csv(f'temporal_exeriment_errors_epoch_{epoch}.csv')
