In [None]:
%matplotlib inline

%matplotlib inline
%load_ext autoreload
%autoreload 2

import ERFutils
import dask
import matplotlib.pyplot as plt

import xarray as xr
import pandas as pd
import numpy as np
import cftime
import dask
import xarrayutils
import cartopy.crs as ccrs
from xmip.preprocessing import combined_preprocessing
from xmip.preprocessing import replace_x_y_nominal_lat_lon
from xmip.drift_removal import replace_time
from xmip.postprocessing import concat_experiments
import xmip.drift_removal as xm_dr
import xmip as xm
import xesmf as xe
import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta
import cf_xarray as cfxr

import seaborn as sns
import matplotlib as mpl
import cmocean
import cmocean.cm as cmo
from matplotlib.gridspec import GridSpec

from sklearn.linear_model import LinearRegression

import copy
import os

dask.config.set(**{'array.slicing.split_large_chunks': True})

# Load data and diagnose patterns

In [None]:
model_set = ERFutils.model_set
A = ERFutils.A
ds_out = ERFutils.ds_out

plot = True
save = False

output_path = ERFutils.path_to_ERF_outputs

train_id = ['ssp585']
for train in train_id:
    print(f'Loading {train} data.')
    pattern = {}
    tas_CMIP_path = f'{output_path}tas/tas_CMIP_{train}_all_ds.nc4'
    temp_response = xr.open_dataset(tas_CMIP_path) 
    for m in model_set:
        print(f'\t Creating pattern for {m}.')
        
        if 'ssp' in train:
            global_temp = temp_response.sel(s = range(165,250)).sel(model = m).weighted(A).mean(dim = ['lat','lon']).tas.values
            stacked_response = temp_response.sel(s = range(165,250)).sel(model = m).stack(allpoints=['lat','lon'])
        else:
            global_temp = temp_response.sel(model = m).weighted(A).mean(dim = ['lat','lon']).tas.values
            stacked_response = temp_response.sel(model = m).stack(allpoints=['lat','lon'])

        # Have to create the patterns locally, stack data array
        N_latlong = len(stacked_response['allpoints'].values)

        # Convert to np arrays, xarray indexing is too slow
        pattern_stacked = np.zeros((1,N_latlong))
        stacked_response_np = stacked_response.tas.values

        # Solve for spatially resolved pattern
        for i in range(N_latlong):
            stacked_response_local = stacked_response_np[:,i]
            reg = LinearRegression(fit_intercept=False).fit(global_temp.reshape(-1,1), stacked_response_local.reshape(-1,1))
            pattern_stacked[0,i] = reg.coef_

        pattern[m] = xr.Dataset(coords={'lon': ('lon', temp_response.lon.values),
                            'lat': ('lat', temp_response.lat.values)})
        pattern[m] = pattern[m].stack(allpoints=['lat','lon'])
        pattern[m]['pattern'] = ('allpoints',pattern_stacked[0])
        pattern[m] = pattern[m].unstack('allpoints')

    pattern_ds = ERFutils.concat_multirun(pattern, 'model')
    
    #if plot:
        #ERFutils.plot_pattern(pattern_ds, save_fig = False)

    if save:
        pattern_ds.to_netcdf(f'{output_path}pattern2_{train}_all_ds.nc4')

In [None]:
model_set = ERFutils.model_set
A = ERFutils.A
ds_out = ERFutils.ds_out

plot = True
save = False

output_path = ERFutils.path_to_ERF_outputs

train_id = ['ssp126','ssp245','ssp370','ssp585']
for train in train_id:
    print(f'Loading {train} data.')
    tas_CMIP_path = f'{output_path}tas/tas_CMIP_{train}_all_ds.nc4'
    temp_response = xr.open_dataset(tas_CMIP_path) 
    
    if 'ssp' in train:
        global_temp = temp_response.sel(s = range(165,250)).mean(dim = 'model').weighted(A).mean(dim = ['lat','lon']).tas.values
        stacked_response = temp_response.sel(s = range(165,250)).mean(dim = 'model').stack(allpoints=['lat','lon'])
    else:
        global_temp = temp_response.mean(dim = 'model').weighted(A).mean(dim = ['lat','lon']).tas.values
        stacked_response = temp_response.mean(dim = 'model').stack(allpoints=['lat','lon'])

    # Have to create the patterns locally, stack data array
    N_latlong = len(stacked_response['allpoints'].values)

    # Convert to np arrays, xarray indexing is too slow
    pattern_stacked = np.zeros((1,N_latlong))
    stacked_response_np = stacked_response.tas.values

    # Solve for spatially resolved pattern
    for i in range(N_latlong):
        stacked_response_local = stacked_response_np[:,i]
        reg = LinearRegression(fit_intercept=False).fit(global_temp.reshape(-1,1), stacked_response_local.reshape(-1,1))
        pattern_stacked[0,i] = reg.coef_

    pattern = xr.Dataset(coords={'lon': ('lon', temp_response.lon.values),
                        'lat': ('lat', temp_response.lat.values)})
    pattern = pattern.stack(allpoints=['lat','lon'])
    pattern['pattern'] = ('allpoints',pattern_stacked[0])
    pattern = pattern.unstack('allpoints')
    
    if plot:
        ERFutils.plot_pattern(pattern, 'check','test',save_fig = False)

    if save:
        pattern.to_netcdf(f'{output_path}pattern2_{train}_all_ds.nc4')