In [36]:
import xarray as xr
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
import sys

In [37]:
def clip_to_common_dates(emulator_da,ccam_da):
    
    # Convert time coords to 'YYYY-MM-DD' strings
    dates1_str = pd.to_datetime(emulator_da.time.values).strftime('%Y-%m-%d')
    dates2_str = [t.isoformat()[:10] for t in ccam_da.time.values]

    # Find common dates
    common_dates = set(dates1_str).intersection(dates2_str)

    # Create masks
    mask1 = [d.strftime('%Y-%m-%d') in common_dates for d in pd.to_datetime(emulator_da.time.values)]
    mask2 = [t.isoformat()[:10] in common_dates for t in ccam_da.time.values]

    # Apply masks
    da1_clipped = emulator_da.sel(time=emulator_da.time.values[mask1])
    da2_clipped = ccam_da.sel(time=ccam_da.time.values[mask2])

    return da1_clipped, da2_clipped


def get_CCAM_ds(rmse_years):
    
    print('getting CCAM_ds')
    
    CCAM_downscaled_ds = xr.open_dataset(CCAM_dir + 'target_fields/target_fields_hist_ssp370_concat.nc')
    CCAM_downscaled_ds = CCAM_downscaled_ds.where(CCAM_downscaled_ds.time.dt.year.isin(rmse_years),drop=True)

    return(CCAM_downscaled_ds)


def calc_rmse(ccam,emulator,region):
    
    if region == 'full':
        rmse = (mean_squared_error(ccam.values.flatten(),emulator.values.flatten()))**(1/2)
        
    elif region == 'land':
        ccam = ccam.where(land_mask>0).values.flatten()
        ccam = ccam[~np.isnan(ccam)]
        emulator = emulator.where(land_mask>0).values.flatten()
        emulator = emulator[~np.isnan(emulator)]
        
        rmse = (mean_squared_error(ccam,emulator))**(1/2)
        
    elif region == 'ocean':
        ccam = ccam.where(land_mask==0).values.flatten()
        ccam = ccam[~np.isnan(ccam)]
        emulator = emulator.where(land_mask==0).values.flatten()
        emulator = emulator[~np.isnan(emulator)]
        
        rmse = (mean_squared_error(ccam,emulator))**(1/2)
        
    return(rmse)

def get_rmse(GCM,var,emulator,gan_flag,framework,CCAM_downscaled_da,years):
        
    # Emulator - imperfect
    if framework == 'imperfect':
        imperfect_hist = xr.open_dataset(f'{output_dir}/{GCM}/{emulator}/{GCM}_{var}_historical_imperfect_framework_{gan_flag}.nc')[var]
        imperfect_future = xr.open_dataset(f'{output_dir}/{GCM}/{emulator}/{GCM}_{var}_ssp370_imperfect_framework_{gan_flag}.nc')[var]
        emulator_da = xr.concat([imperfect_hist,imperfect_future],dim='time')

    # Emulator - perfect
    elif framework == 'perfect':
        emulator_da = xr.open_dataset(f'{output_dir}/{GCM}/{emulator}/{GCM}_{var}_ssp370_perfect_framework_{gan_flag}.nc')[var]

    if var == 'tasmin' or var == 'tasmax':
        emulator_da = emulator_da-272.15
        
    print(type(emulator_da.time.values[0]))
        
    emulator_da = emulator_da.where(emulator_da.time.dt.year.isin(years),drop=True)

    #common_emulator,common_ccam = clip_to_common_dates(emulator_da,CCAM_downscaled_da)
    
    print(emulator_da.shape,CCAM_downscaled_da.shape)
    
    full,land,ocean = calc_rmse(CCAM_downscaled_da,emulator_da,'full'),calc_rmse(CCAM_downscaled_da,emulator_da,'land'),calc_rmse(CCAM_downscaled_da,emulator_da,'ocean')
    
    return(full,land,ocean)


In [21]:
'''
MAIN
'''

static_ds = xr.open_dataset('/nesi/project/niwa00018/ML_downscaling_CCAM/training_GAN/ancil_fields/ERA5_eval_ccam_12km.198110_NZ_Invariant.nc')
land_mask = static_ds.sftlf

# directories
output_dir = '/nesi/project/niwa00018/queenle/ml_emulator_experiment_application/ml_downscaled_output/'
CCAM_dir = '/nesi/project/niwa00018/ML_downscaling_CCAM/multi-variate-gan/inputs/'

# TEMPORAL TRAINING EXPERIMENTS
temporal_experiments = {'pr':['pr_ACCESS-CM2_5','pr_ACCESS-CM2_10','pr_ACCESS-CM2_20','pr_ACCESS-CM2_30','pr_ACCESS-CM2_40','pr_ACCESS-CM2_50','pr_ACCESS-CM2_60','pr_ACCESS-CM2_70','pr_ACCESS-CM2_80','pr_ACCESS-CM2_90']}#,'pr_ACCESS-CM2_100']}
        
rmse_years = [1965,1998,2034,2065,2097]
years_str = '/'.join([str(val) for val in rmse_years])

#CCAM_downscaled_ds = get_CCAM_ds(rmse_years)
with open('rmse.csv', 'w') as file:
    
    file.write('GCM,emulator,framework,error period,var,metric,whole region,land,ocean,model_type\n')
    
    for i,gan_flag in enumerate(['GAN','unet']):
        print(f'-{i+1}/2')
        for j,gcm in enumerate(['ACCESS-CM2','NorESM2-MM','EC-Earth3']):
            print(f'\t-{j+1}/3')

            result_dict = {}


            for var in ['pr']:#'tasmax','tasmin','sfcwind']:

                # CCAM dynamical downscaled output
                CCAM_downscaled_da = CCAM_downscaled_ds.sel(GCM=gcm)[var]
                if var == 'pr':
                    CCAM_downscaled_da = CCAM_downscaled_da*86400 # convert from flux to mm/day
                if var == 'tasmin' or var == 'tasmax':
                    CCAM_downscaled_da = CCAM_downscaled_da-272.15

                for k,emulator in enumerate(temporal_experiments[var]):
                    print(f'\t\t-{k+1}/{len(temporal_experiments[var])}')

                    for framework in ['perfect','imperfect']:

                        full,land,ocean = get_rmse(gcm,var,emulator,gan_flag,framework,CCAM_downscaled_da,rmse_years)
                        
                        print('writing line')
                        file.write(f'{gcm},{emulator},{framework},{years_str},{var},rmse,{str(full)},{str(land)},{str(ocean)},{gan_flag}\n')
                                   

-1/2
	-1/3
		-1/10
<class 'numpy.datetime64'>


TypeError: cannot unpack non-iterable NoneType object