### calculate the correlation between the 6 region's time-series and some GCM derived fields 

In [1]:
%matplotlib inline 
from matplotlib import pyplot as plt

In [2]:
import sys
import pathlib
from datetime import datetime
from dateutil.relativedelta import relativedelta
from calendar import month_abbr, month_name

In [3]:
import numpy as np 
import pandas as pd
import xarray as xr
import cartopy.crs as ccrs 

### parameters 

In [4]:
provider = 'CDS'
GCM = 'ECMWF'
varname = 'SST'
varin = 'sst'
target_var = 'RAIN'

In [5]:
HOME = pathlib.Path.home() 
CWD = pathlib.Path.cwd() 

In [6]:
fig_path = CWD.parent.joinpath(f'figures/{target_var}')

In [7]:
if not(fig_path.exists()): 
    fig_path.mkdir(parents=True)

In [8]:
sys.path.append(str(HOME.joinpath("research/Smart_Ideas/code"))) 

In [10]:
from ml4seas.utils import * 
from ml4seas.GCM import shift_dset_time

In [11]:
dpath_gcm = pathlib.Path(f'/media/nicolasf/END19101/data/GCMs/processed/{provider}/{GCM}/{varname}')

In [12]:
lfiles_gcm = list(dpath_gcm.glob(f"{provider}_{GCM}_{varname}_*_seasonal_anomalies_1981_2010_clim.nc")) 

In [13]:
lfiles_gcm.sort()

In [14]:
len(lfiles_gcm)

468

In [15]:
dset_gcm = xr.open_mfdataset(lfiles_gcm, concat_dim='time', combine='nested')

In [16]:
dset_gcm

Unnamed: 0,Array,Chunk
Bytes,0 B,0 B
Shape,"(468, 0, 25, 181, 360)","(1, 0, 25, 181, 360)"
Count,1872 Tasks,468 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 0 B 0 B Shape (468, 0, 25, 181, 360) (1, 0, 25, 181, 360) Count 1872 Tasks 468 Chunks Type float32 numpy.ndarray",,

Unnamed: 0,Array,Chunk
Bytes,0 B,0 B
Shape,"(468, 0, 25, 181, 360)","(1, 0, 25, 181, 360)"
Count,1872 Tasks,468 Chunks
Type,float32,numpy.ndarray


### calculate the ensemble mean for now 

In [None]:
dset_gcm = dset_gcm.mean('member')

In [None]:
dset_gcm

In [None]:
dset_gcm['time'] = (('time'), dset_gcm.init_time.to_index() + pd.offsets.MonthEnd(0))

In [None]:
dset_gcm.time

In [None]:
dset_gcm = dset_gcm.drop(["init_time","month"])

In [None]:
dset_gcm

### reads the regional time-series 

In [None]:
dpath_regions = pathlib.Path(HOME.joinpath(f"research/Smart_Ideas/outputs/targets/NZ_regions/NZ_6_regions/{target_var}")) 

In [None]:
list_regions = ['NNI','WNI','ENI','NSI','WSI','ESI']

In [None]:
df = []
for region in list_regions: 
    data = pd.read_csv(dpath_regions.joinpath(f"{region}/TS_NZ_region_{region}_{target_var}_3_quantiles_anoms.csv"), index_col=0, parse_dates=True)
    data = data.loc[:,['anomalies']]
    data.columns = [region]
    df.append(data) 

In [None]:
df = pd.concat(df, axis=1)

In [None]:
df = (df - df.mean(0)) / df.std(0)

### casts the dataframe into an xarray dataset 

In [None]:
df_xr = df.to_xarray()

**the steps are 2,3,4,5**

+ 2 = lead times 0,1,2  
+ 3 = lead times 1,2,3 ***    
+ 4 = lead times 2,3,4   
+ 5 = lead times 3,4,5  

In [None]:
dset_gcm

In [None]:
for region in list_regions:
    
    for step in [5, 4, 3, 2]: 

        dset = dset_gcm.copy()

        dset = dset.sel(step=step)

        dset['time'] = dset['time'].to_index().shift(periods = (step - 1), freq='M')

        dset, df = xr.align(dset, df_xr)

        R = xr.corr(df[region], dset[varin], dim='time') 

        f, ax = plt.subplots(figsize=(10, 12), subplot_kw={'projection':ccrs.PlateCarree(central_longitude=180)})

        R.plot.contourf(transform=ccrs.PlateCarree(), levels=np.arange(-1, 1.1, 0.1), \
                                    cbar_kwargs={'orientation':'horizontal', 'pad':0.05, 'label':'R'})

        R.plot.contour(transform=ccrs.PlateCarree(), levels=[0.5], linewidths=2, colors='r', linestyles='-')
        R.plot.contour(transform=ccrs.PlateCarree(), levels=[-0.5], linewidths=2, colors='b', linestyles='-')

        ax.coastlines(resolution='50m') 

        gl = ax.gridlines(draw_labels=True, linestyle=':', xlocs=np.arange(-180, 180, 40), crs=ccrs.PlateCarree())

        gl.top_labels = False
        gl.right_labels = False

        gl.xlabel_style = {'size': 15, 'color': 'gray'}
        gl.ylabel_style = {'size': 15, 'color': 'gray'}

        ax.set_title(f"correlation field between {region} {target_var} and {GCM} {varname}\nleadtime (Months) = {step}", fontsize=15, color='gray')

        fpath = fig_path.joinpath(f"{region}/{varname}")
        
        if not fpath.exists():
            fpath.mkdir(parents=True)
        
        f.savefig(fpath.joinpath(f"R_{region}_{target_var}_{varname}_step_{step}_{GCM}.png"), dpi=200, bbox_inches='tight', facecolor='w')

        dset.close()

        df.close() 
        
        plt.close(f)

### Now loop over the regions and months 

In [None]:
for region in list_regions:
    
    for month in range(1, 13): 

        f, axes = plt.subplots(nrows=4, figsize=(8, 15), subplot_kw={'projection':ccrs.PlateCarree(central_longitude=180)}) 

        axes = axes.ravel() 

        for i, step in enumerate([5, 4, 3, 2]): 

            dset = dset_gcm.copy()

            dset = dset.sel(step=step)

            dset['time'] = dset['time'].to_index().shift(periods = (step - 1), freq='M')

            # select month 
            dset = dset.sel(time = (dset.time.dt.month == month)) 

            df = df_xr.sel(time = (df_xr.time.dt.month == month))

            dset, df = xr.align(dset, df_xr)

            R = xr.corr(df[region], dset[varin], dim='time') 

            R.plot.contourf(transform=ccrs.PlateCarree(), levels=np.arange(-1, 1.1, 0.1), \
                                        cbar_kwargs={'orientation':'vertical', 'pad':0.01, 'label':'R'}, ax=axes[i])

            R.plot.contour(transform=ccrs.PlateCarree(), levels=[0.5], linewidths=2, colors='r', linestyles='-', ax=axes[i])
            R.plot.contour(transform=ccrs.PlateCarree(), levels=[-0.5], linewidths=2, colors='b', linestyles='-', ax=axes[i])

            axes[i].set_title(f"{region} {target_var} and {GCM} {varname}\nSeason ending: {month_name[month]}, leadtime (Months): {step}", fontsize=8, color='k')

            axes[i].coastlines(resolution='50m')
            
            dset.close() 
            
            df.close() 
            
            R.close()

        fpath = fig_path.joinpath(f"{region}/{varname}")
        
        if not fpath.exists():
            fpath.mkdir(parents=True)
            
        f.savefig(fpath.joinpath(f"R_{region}_{target_var}_{varname}_month_{str(month).zfill(2)}_{GCM}.png"), dpi=200, bbox_inches='tight', facecolor='w')

        plt.close(f)

In [None]:
fig_path