# Calculate SPI and SPEI indicators from various input datasets
# TODO choose base period! Probably don't actually want to fit on everything. On the other hand, choosing post 1980 isn't very good either because change has already happened
## Datasets

CRU monthly rainfall and mean temperatures

ECMWF ERA Interim

- The base period from ERA Interim is rather short, so prefer CRU for now


In [None]:
%matplotlib inline

from pathlib import Path
from datetime import date

import numba
import rasterio
import numpy as np
import scipy as sp
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from scipy import stats
from tqdm import tnrange, tqdm_notebook
from numba import float64, int64, jit, prange


from precipitation_indices import indices, compute
import weather_ecmwf
import population_tools

from config import DATA_SRC, WEATHER_SRC, POP_DATA_SRC

## Apply SPI calculation to CRU monthly precipitation data

In [None]:
dd = xr.open_mfdataset(str(WEATHER_SRC / 'cru/cru_ts4.01.1901.2016.*.dat.nc'), 
                        chunks={'lat':30, 'lon':30},
                        drop_variables=['stn', 'tmn'], engine='scipy', lock=False)

In [None]:
def spi3_ufunc(data):
    return indices.spi_pearson(data, 3)


def spi6_ufunc(data):
    return indices.spi_gamma(data, 6)

In [None]:
spi3 = xr.apply_ufunc(spi3_ufunc, ppt.pre, input_core_dims=[['time']], output_core_dims=[['time']], 
                     vectorize=True, dask='parallelized', output_dtypes=[np.float])

In [None]:
spi3.to_netcdf(DATA_SRC / 'lancet' / 'spi3_1901_2016_cru.nc')

In [None]:
spi6 = xr.apply_ufunc(spi6_ufunc, ppt.pre, input_core_dims=[['time']], output_core_dims=[['time']], 
                     vectorize=True, dask='parallelized', output_dtypes=[np.float])

In [None]:
spi6.to_netcdf(DATA_SRC / 'lancet' / 'spi6_1901_2016_cru.nc')

## Apply SPI calculation to ECMWF monthly precipitation data

In [None]:
weather_ecmwf.weather_mfdataset(PPT_FOLDER)

# Calculate SPEI

## Apply SPEI calculation to CRU monthly precipitation data

In [None]:
# def spei6_ufunc(precips_mm, temperatures, lat):
#     return indices.spei_gamma(precips_mm, 6, temps_celsius=temperatures, data_start_year=1901, latitude_degrees=lat)

# -------------------------------------------------------
# Manually ufunc'ed versions of SPEI index calculations
# -------------------------------------------------------

@jit(nogil=True, parallel=True)
def spei6_ufunc(precips_mm, temperatures, lat):
    months_scale=6
    spei_out = np.empty_like(precips_mm)
    for i in prange(precips_mm.shape[0]):
        latitude_degrees = lat[i]
        for j in prange(precips_mm.shape[1]):
            p_mm = precips_mm[i,j,:]
            t_celsius = temperatures[i,j,:]
            spei_out[i, j,:] = indices.spei_gamma(p_mm, months_scale, temps_celsius=t_celsius, 
                                                  data_start_year=1901, latitude_degrees=latitude_degrees)
    return spei_out


@jit(nogil=True, parallel=True)
def spei3_ufunc(precips_mm, temperatures, lat):
    months_scale=3
    spei_out = np.empty_like(precips_mm)
    for i in prange(precips_mm.shape[0]):
        latitude_degrees = lat[i]
        for j in prange(precips_mm.shape[1]):
            p_mm = precips_mm[i,j,:]
            t_celsius = temperatures[i,j,:]
            spei_out[i, j,:] = indices.spei_gamma(p_mm, months_scale, temps_celsius=t_celsius, 
                                                  data_start_year=1901, latitude_degrees=latitude_degrees)
    return spei_out



In [None]:
spei6 = xr.apply_ufunc(spei6_ufunc, dd.pre, dd.tmp, dd.lat, 
                       input_core_dims=[['time'],['time'],[]],
                       output_core_dims=[['time']], 
                       dask='parallelized', 
                       output_dtypes=[np.float])

In [None]:
spei6.to_netcdf(DATA_SRC / 'lancet' / 'spei6_1901_2016_cru.nc')

In [None]:
spei3 = xr.apply_ufunc(spei3_ufunc, dd.pre, dd.tmp, dd.lat, 
                       input_core_dims=[['time'],['time'],[]],
                       output_core_dims=[['time']], 
                       dask='parallelized', 
                       output_dtypes=[np.float])

In [None]:
spei3.to_netcdf(DATA_SRC / 'lancet' / 'spei3_1901_2016_cru.nc')