In [9]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import glob
import os
import matplotlib.patches as mpatches


In [13]:
error_df = pd.read_csv(f'temporal_experiment_errors_all_epochs.csv').drop('Unnamed: 0',axis=1)

In [14]:
error_df['epoch'].unique()

array([120, 125, 130])

In [66]:
'''
RMSE 2 metrics - annual_mean, rx1d
'''


climo_metric_range = {'annual_mean':[0,2],'DJF_mean':[0,2],'JJA_mean':[0,2.3],'rx1d':[10,40]}
cc_metric_range = {'annual_mean':[3,12],'DJF_mean':[5,25],'JJA_mean':[5,15],'rx1d':[15,50]}
markers = ['*','s','o']

region = 'whole region'

x_ticks = {'1961-1980':160,'2015-2034':170,'2080-2099':180}

subplot_labels = [['(a)','(b)'],
                 ['(c)','(d)'],
                 ['(e)','(f)'],
                 ['(g)','(h)']]

for epoch in [130,125,120,'120-125-130-average']:
    
    if epoch == '120-125-130-average':
        epoch_df = error_df
    else:
        epoch_df = error_df[error_df['epoch']==epoch]

    for period in ['1985_2004','2080_2099','CC_signal']:

        fig,axs = plt.subplots(2,2,figsize=(15,7),sharex=True,layout='constrained')

        df = epoch_df[epoch_df['GCM']!='ACCESS-CM2']
        df = df[df['error type']=='RMSE']
        df = df[df['error period']==period]

        for i,metric in enumerate(['annual_mean','rx1d']):

            metric_df = df[df['metric']==metric]

            for j,framework in enumerate(['perfect','imperfect']):

                framework_df = metric_df[metric_df['framework']==framework]

                for ml_type in ['GAN','unet']:

                    ml_type_df = framework_df[framework_df['model_type']==ml_type]

                    gcm_dfs = []
                    for gcm  in ['EC-Earth3','NorESM2-MM']:
                        gcm_result = ml_type_df[ml_type_df['GCM']==gcm].groupby('sampling_n').mean(numeric_only=True)[region]
                        gcm_dfs.append(gcm_result)

                    mean = sum(gcm_dfs)/2

                    results_n = mean[~mean.index.isin(['1961-1980','2015-2034','2080-2099'])]
                    results_n.index = results_n.index.astype(int)
                    results_n = results_n.sort_index()

                    results_20 = mean[mean.index.isin(['1961-1980','2015-2034','2080-2099'])]

                    # PLOTTING
                    
                    # dashed line at 20-year random sampling
                    random_20_val = results_n[20]
                    axs[i][j].hlines(random_20_val,15,25,color='orange' if ml_type=='GAN' else 'blue',linestyle='--',alpha=0.6)
                    axs[i][j].hlines(random_20_val,150,190,color='orange' if ml_type=='GAN' else 'blue',linestyle='--',alpha=0.6)#,label=f'{ml_type} random 20-year sample')
                    
                    # n samples
                    x = [int(v) for v in results_n.index]
                    axs[i][j].plot(x,results_n,label=ml_type,color='orange' if ml_type=='GAN' else 'blue')

                    # 20-year samples
                    for k,index in enumerate(results_20.index):
                        x_val = x_ticks[index]
                        axs[i][j].scatter(x_val,results_20.loc[index],s=50, color='orange' if ml_type=='GAN' else 'blue')#,marker=markers[k],facecolors='none'
                        

                    # AXIS LABELS
                    if period == 'CC_signal':
                        axs[i][j].set_ylim(cc_metric_range[metric][0],cc_metric_range[metric][1])
                    else:
                        axs[i][j].set_ylim(climo_metric_range[metric][0],climo_metric_range[metric][1])

                    if j == 0:
                        axs[i][j].set_ylabel('RMSE (%)' if period == 'CC_signal' else 'RMSE (mm/day)',fontsize=20)

                    label = metric.split('_')[0] + ' ' + metric.split('_')[1] if metric != 'rx1d' else metric
                    axs[i][j].text(0.1, 0.85, subplot_labels[i][j] + ' ' + label, va='center', ha='left',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=18)
                    axs[i][j].axvline([150],color='black')
                    axs[i][j].tick_params(axis='y', labelsize=16)
                    axs[i][j].set_xlim(0,190)

                    if j == 1:
                        axs[i][j].set_yticks([])

                    if i == 0:
                        axs[i][j].set_title(framework,fontsize=20)

                    if i == 1:
                        # x axis label
                        axs[i][j].text(0.2, -0.25, 'sample size (years)', va='center', ha='left',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=20)
                        
                        # configure xticks and labels
                        positions = [5, 10, 20, 40, 60, 100, 140, 160, 170, 180]
                        labels = ['5', '10', '20', '40', '60', '100', '140', '1961-1980', '2015-2034', '2080-2099']

                        # Set all at once
                        axs[i][j].set_xticks(positions)
                        axs[i][j].set_xticklabels(labels, fontsize=16)

                        # Rotate only the "future period" labels
                        for pos, label in zip(positions, axs[i][j].get_xticklabels()):
                            if pos >= 160:
                                label.set_rotation(45)
                                label.set_ha('right')

        axs[0][0].legend(loc='upper right',fontsize=16)

        plt.savefig(f'RMSE_plots/two_GCM_average_RMSE_{period}_epoch_{epoch}_2_metrics.png',dpi=300)
        plt.close()

In [65]:
'''
RMSE 4 metrics - annual_mean, DJF_mean, JJA_mean, rx1d
'''


climo_metric_range = {'annual_mean':[0,2],'DJF_mean':[0,2],'JJA_mean':[0,2.3],'rx1d':[10,40]}
cc_metric_range = {'annual_mean':[3,12],'DJF_mean':[5,25],'JJA_mean':[5,15],'rx1d':[15,50]}
markers = ['*','s','o']

region = 'whole region'

x_ticks = {'1961-1980':160,'2015-2034':170,'2080-2099':180}

subplot_labels = [['(a)','(b)'],
                 ['(c)','(d)'],
                 ['(e)','(f)'],
                 ['(g)','(h)']]

for epoch in [130,125,120,'120-125-130-average']:
    
    if epoch == '120-125-130-average':
        epoch_df = error_df
    else:
        epoch_df = error_df[error_df['epoch']==epoch]

    for period in ['1985_2004','2080_2099','CC_signal']:

        fig,axs = plt.subplots(4,2,figsize=(15,14),sharex=True,layout='constrained')

        df = epoch_df[epoch_df['GCM']!='ACCESS-CM2']
        df = df[df['error type']=='RMSE']
        df = df[df['error period']==period]

        for i,metric in enumerate(['annual_mean','DJF_mean','JJA_mean','rx1d']):

            metric_df = df[df['metric']==metric]

            for j,framework in enumerate(['perfect','imperfect']):

                framework_df = metric_df[metric_df['framework']==framework]

                for ml_type in ['GAN','unet']:

                    ml_type_df = framework_df[framework_df['model_type']==ml_type]

                    gcm_dfs = []
                    for gcm  in ['EC-Earth3','NorESM2-MM']:
                        gcm_result = ml_type_df[ml_type_df['GCM']==gcm].groupby('sampling_n').mean(numeric_only=True)[region]
                        gcm_dfs.append(gcm_result)

                    mean = sum(gcm_dfs)/2

                    results_n = mean[~mean.index.isin(['1961-1980','2015-2034','2080-2099'])]
                    results_n.index = results_n.index.astype(int)
                    results_n = results_n.sort_index()

                    results_20 = mean[mean.index.isin(['1961-1980','2015-2034','2080-2099'])]

                    # PLOTTING
                    
                    # dashed line at 20-year random sampling
                    random_20_val = results_n[20]
                    axs[i][j].hlines(random_20_val,15,25,color='orange' if ml_type=='GAN' else 'blue',linestyle='--',alpha=0.6)
                    axs[i][j].hlines(random_20_val,150,190,color='orange' if ml_type=='GAN' else 'blue',linestyle='--',alpha=0.6)#,label=f'{ml_type} random 20-year sample')
                    
                    # random sampling
                    x = [int(v) for v in results_n.index]#[int(v) for v in results['sampling']]
                    axs[i][j].plot(x,results_n,label=ml_type,color='orange' if ml_type=='GAN' else 'blue')

                    # 20-year sampling
                    for k,index in enumerate(results_20.index):
                        x_val = x_ticks[index]
                        axs[i][j].scatter(x_val,results_20.loc[index],s=50, color='orange' if ml_type=='GAN' else 'blue')#,marker=markers[k],facecolors='none'

                    # AXIS LABELS
                    if period == 'CC_signal':
                        axs[i][j].set_ylim(cc_metric_range[metric][0],cc_metric_range[metric][1])
                    else:
                        axs[i][j].set_ylim(climo_metric_range[metric][0],climo_metric_range[metric][1])

                    if j == 0:
                        axs[i][j].set_ylabel('RMSE (%)' if period == 'CC_signal' else 'RMSE (mm/day)',fontsize=20)

                    label = metric.split('_')[0] + ' ' + metric.split('_')[1] if metric != 'rx1d' else metric
                    axs[i][j].text(0.1, 0.85, subplot_labels[i][j] + ' ' + label, va='center', ha='left',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=18)
                    axs[i][j].axvline([150],color='black')
                    axs[i][j].tick_params(axis='y', labelsize=16)
                    axs[i][j].set_xlim(0,190)

                    if j == 1:
                        axs[i][j].set_yticks([])

                    if i == 0:
                        axs[i][j].set_title(framework,fontsize=20)

                    if i == 3:
                        axs[i][j].text(0.2, -0.25, 'sample size (years)', va='center', ha='left',rotation='horizontal', transform=axs[i][j].transAxes,fontsize=20)
                        
                        positions = [5, 10, 20, 40, 60, 100, 140, 160, 170, 180]
                        labels = ['5', '10', '20', '40', '60', '100', '140', '1961-1980', '2015-2034', '2080-2099']

                        # Set all at once
                        axs[i][j].set_xticks(positions)
                        axs[i][j].set_xticklabels(labels, fontsize=16)

                        # Rotate only the "future period" labels
                        for pos, label in zip(positions, axs[i][j].get_xticklabels()):
                            if pos >= 160:
                                label.set_rotation(45)
                                label.set_ha('right')

        axs[0][0].legend(loc='upper right',fontsize=16)

        plt.savefig(f'RMSE_plots/two_GCM_average_RMSE_{period}_epoch_{epoch}_4_metrics.png',dpi=300)
        plt.close()