In [None]:
%load_ext autoreload
%autoreload 2
from dask.distributed import Client
import dask.array as da
import xarray as xr
import numpy as np
import glob
import scipy.stats as stats
import os
from datetime import datetime, timedelta
import pandas as pd
import sys
import eva_dask
from eva_import return_values, model_fitting, extremes

In [None]:
# Initial Inputs
fn_input = glob.glob("<Paths to input files>")

# Estimating of annual maxima
dn_maxima = '<Directory to save extremes from individual models>'
fn_maxima = '<Full path to output extremes file>'
concat_dimension= 'year' # Name of dimension to concatenate
models_to_omit = ['UKESM1-0-LL', 'EC-Earth3','EC-Earth3-Veg']

# Fitting of GEV
fn_gev = '<Path to output model parameters file'
chunks_for_gev = [-1, 25, 25]

# Return Values
fn_rv = '<Path to output return values file>'
input_rl = [20, 25, 30, 32,35]
input_rp = [1,2,5,10,50,100,200]

In [None]:
# Connect Dask client
client = Client("tcp://127.0.0.1:<Dask Cluster Port>")
client.upload_file('<Path to eva_dask.py>')

In [None]:
maxima_files = glob.glob(os.path.join(dn_maxima, '*'))
for ff in maxima_files:
    os.remove(ff)

In [None]:
os.remove(fn_maxima)

In [None]:
os.remove(fn_gev)

In [None]:
os.remove(fn_rv)

### 1. CALCULATE ANNUAL MAXIMA

In [None]:
%%time
n_files = len(fn_input)
dataset_list = [ xr.open_zarr(fn) for fn in fn_input ]

for dd in range(n_files):
    
    dataset = dataset_list[dd]
    
    start_year = int(fn_input[dd][-9:-5])
    start_date = datetime(start_year, 1, 1)
    end_date   = start_date + timedelta(days = dataset.dims['time'] - 1)
    time = pd.date_range(start_date, end_date, freq='1D')
    
    dataset['time'] = time
    
    annual_maxima = extremes.annual_maxima(dataset)
    
    filebase = os.path.basename(fn_input[dd])
    annual_maxima.to_netcdf( os.path.join(dn_maxima, filebase+'.nc') )

### 2. CONCATENATE ANNUAL MAXIMA

In [None]:
# Get filenames of maxima datasets to concatenate
files_to_concat = glob.glob(os.path.join(dn_maxima, '*'))

# Open all of the datasets lazily into a list
datasets = [ xr.open_dataset(fn, chunks={}) for fn in files_to_concat ]

# Generate the names of each model to preserve information in concatenated dataset
omit_idx = []
for dd in np.arange( len( datasets ) ):
    model_name = os.path.basename(files_to_concat[dd])[:-26]
    if model_name in models_to_omit:
        omit_idx.append(dd)
    string_list = [model_name for ii in range(datasets[dd].dims[concat_dimension])]
    datasets[dd]['model_name'] = (['year'], string_list)
    
# Remove any omitted models
if len(omit_idx) > 0:
    for ii in omit_idx[::-1]:
        print('Removed model by name {0}'.format(datasets[ii].model_name[0].values))
        del datasets[ii]
    
# Concatenate using xarray and rechunk to have all of concat dimension in one chunk (probably time)
datasets_concat = xr.concat(datasets, dim=concat_dimension)
datasets_concat = datasets_concat.chunk({concat_dimension:-1})

# Write to file
datasets_concat.to_netcdf( fn_maxima )

### 3. FIT GEV TO EXTREMES 

In [None]:
dataset = xr.open_dataset(fn_maxima, chunks={'lat':25, 'lon':25})
data = dataset.Twb.load().data
data = da.from_array(data, chunks=chunks_for_gev)

In [None]:
%%time

mapped = model_fitting.fit_gev_model(data, 'genextreme', 5, 100)
mapped = mapped.compute()

In [None]:
# Write GEV fit to file
ds = xr.Dataset(coords = dict( 
                    lat = (['lat'], dataset.lat.values),
                    lon = (['lon'], dataset.lon.values),
                    param = (['param'], ['shape','loc','scale']) ),
                data_vars=dict( 
                    parameters = (['param','lat','lon'], mapped[:3]),
                    ks_pvalue = (['lat','lon'], mapped[3])))
ds.to_netcdf(fn_gev)

### 4. Calculate Return Values

In [None]:
ds_params = xr.open_dataset(fn_gev)
params = da.from_array( ds_params['parameters'].data )
params = params.rechunk([-1, 50, 50])

In [None]:
%%time
rl = return_values.return_levels_from_fit(params, stats.genextreme,
                                          input_rp).compute()
rp = return_values.return_periods_from_fit(params, stats.genextreme,
                                           input_rl).compute()

# Extract pvalues array to put into output array
pvalues = ds_params.ks_pvalue.values

In [None]:
# Write to xarray Dataset
ds = xr.Dataset(coords = dict( 
                    lat = (['lat'], ds_params.lat.values),
                    lon = (['lon'], ds_params.lon.values),
                    return_level = (['return_level'], input_rl),
                    return_period = (['return_period'], input_rp)),
                data_vars=dict( calculated_rl = (['return_period', 'lat', 'lon'], rl),
                    calculated_rp = (['return_level', 'lat', 'lon'], rp),
                    pvalues       = (['lat', 'lon'], pvalues)))
ds.to_netcdf(fn_rv)

In [None]:
#client.close()