# Compound Events
- https://github.com/e-baumer/standard_precip
- https://github.com/e-baumer/standard_precip/blob/master/standard_precip/base_sp.py
- Handle missing data to calculate spi (last paragraph) https://www.droughtmanagement.info/literature/WMO_standardized_precipitation_index_user_guide_en_2012.pdf#page=9


### SPI:
- Add options
    - freq = ['D', 'M']
    - scale = #
- If freq == 'M': get monthly pr files
- If freq == 'D': get daily pr files
- Concat 3 dfs of pr data:
    - [historical, ssp126]
    - [historical, ssp245]
    - [historical, ssp370]
- Calculate SPI for each model using "from standard_precip.spi import SPI"
- Extract results JJA for 2015-2100 for spi and tasmax
    - Filter by common model/column names
    - Concat spi and tasmax axis=0
    - Concat ssp's axis=1
    - Determine if compound
    - Output 1 df

### Problem:
- KACE-1-0-G currently skipped since division by zero error (~1000 missing)
- SPI calc does not handle nan's well: 1 row of missing pr results in at least 30 rows of missing spi

- Why are there missing values? Outputs below
    - KACE-1-0-G: does not have a value for 31st day of month
    - All models missing 37 values: no value for some 29th day of month
        - [INM-CM4-8, INM-CM5-0, NorESM2-MM, NorESM2-LM, GFDL-ESM4, GISS-E2-1-G, FGOALS-g3, BCC-CSM2-MR, CMCC-ESM2, CESM2]
     
- How to handle missing values? 
    - Check original code
    - Options: interpolation by time, multivariate interpolation (slow), back/front fill

In [1]:
%%time

from process import *
import os
import calendar
import pandas as pd
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from standard_precip.spi import SPI

warnings.filterwarnings('ignore')

plt.rcParams['figure.figsize'] = (15, 4)
plt.rcParams['figure.dpi'] = 300 # 600
plt.rcParams['font.size'] = 10
plt.rcParams['figure.titlesize'] = 15
plt.rcParams['axes.linewidth'] = 0.1
plt.rcParams['patch.linewidth'] = 0
plt.rcParams['grid.linewidth'] = 0.1

event = ['CWHE','CDHE'][1]
months = [6, 7, 8]
center='LARC'

freq='D'
scale=30

initialize(center, event, months, freq, scale)

def read_data(filename):
    df = pd.read_csv(filename).rename(columns={'Unnamed: 0': 'date'})
    df['date'] = pd.to_datetime(df['date'])
    return df
    
# Get filenames
pr_files = [file for file in get_files() if 'pr_' in file]
historical_file = next(file for file in pr_files if 'historical' in file)
if type(historical_file) != str:
    raise TypeError(f'{historical_file} should be one filename')
ssp_files = sorted([file for file in pr_files if 'ssp' in file])
    
# Combine historical and ssp dataframes
historical_df = read_data(historical_file)
spi_dfs = {}
for ssp_file in ssp_files:
    ssp_name = ssp_file.split('/')[-1].split('_')[2].split('.')[0]
    spi_dfs[ssp_name] = pd.concat([historical_df, read_data(ssp_file)], ignore_index=True)

# Calculate SPI

for ssp_name, df in spi_dfs.items():
    print(f'{ssp_name}:')
    spi_df = pd.DataFrame({'date': df['date']})
    columns = df.columns[1:]
    for col in columns:
        try:
            spi_df[f'{col}_spi'] = (SPI().calculate(df, 'date', col, freq=freq, scale=scale, 
                                           fit_type='lmom', dist_type='gam')
                                .filter(regex='_calculated_index$'))
        except Exception as e:
            print(f'Error calculating SPI for {col}: {e}\n')
    spi_dfs[ssp_name] = spi_dfs[ssp_name].merge(spi_df, on='date')

ssp126:
Error calculating SPI for KACE-1-0-G: division by zero

ssp245:
Error calculating SPI for KACE-1-0-G: division by zero

ssp370:
Error calculating SPI for KACE-1-0-G: division by zero

CPU times: user 35.7 s, sys: 1.6 s, total: 37.3 s
Wall time: 37.6 s


# Missing values

In [2]:
# Check for missing nulls
nan_count = spi_dfs[ssp_name].isnull().sum()
print(37*30, 1057*30, '\n')
display(nan_count[nan_count>(scale-1)])

1110 31710 



INM-CM4-8                  37
INM-CM5-0                  37
NorESM2-MM                 37
NorESM2-LM                 37
GFDL-ESM4                  37
GISS-E2-1-G                37
FGOALS-g3                  37
BCC-CSM2-MR                37
CMCC-ESM2                  37
KACE-1-0-G               1057
CESM2                      37
INM-CM4-8_scale_30       1139
INM-CM5-0_scale_30       1139
NorESM2-MM_scale_30      1139
NorESM2-LM_scale_30      1139
GFDL-ESM4_scale_30       1139
GISS-E2-1-G_scale_30     1139
FGOALS-g3_scale_30       1139
BCC-CSM2-MR_scale_30     1139
CMCC-ESM2_scale_30       1139
KACE-1-0-G_scale_30     31710
CESM2_scale_30           1139
INM-CM4-8_spi            1139
INM-CM5-0_spi            1139
NorESM2-MM_spi           1139
NorESM2-LM_spi           1139
GFDL-ESM4_spi            1139
GISS-E2-1-G_spi          1139
FGOALS-g3_spi            1139
BCC-CSM2-MR_spi          1139
CMCC-ESM2_spi            1139
CESM2_spi                1139
dtype: int64

In [3]:
# Same output/rows for all models that are missing 37 values
aa = spi_dfs[ssp_name][spi_dfs[ssp_name]['CMCC-ESM2'].isnull()].date
print(aa.dt.day.unique(), '\n')
display(aa)

[29] 



789     1952-02-29
2250    1956-02-29
3711    1960-02-29
5172    1964-02-29
6633    1968-02-29
8094    1972-02-29
9555    1976-02-29
11016   1980-02-29
12477   1984-02-29
13938   1988-02-29
15399   1992-02-29
16860   1996-02-29
18321   2000-02-29
19782   2004-02-29
21243   2008-02-29
22704   2012-02-29
24165   2016-02-29
25626   2020-02-29
27087   2024-02-29
28548   2028-02-29
30009   2032-02-29
31470   2036-02-29
32931   2040-02-29
34392   2044-02-29
35853   2048-02-29
37314   2052-02-29
38775   2056-02-29
40236   2060-02-29
41697   2064-02-29
43158   2068-02-29
44619   2072-02-29
46080   2076-02-29
47541   2080-02-29
49002   2084-02-29
50463   2088-02-29
51924   2092-02-29
53385   2096-02-29
Name: date, dtype: datetime64[ns]

In [4]:
# Missing 1057 values
aa = spi_dfs[ssp_name][spi_dfs[ssp_name]['KACE-1-0-G'].isnull()].date
print(aa.dt.day.unique(), '\n')
display(aa)

[31] 



30      1950-01-31
89      1950-03-31
150     1950-05-31
211     1950-07-31
242     1950-08-31
           ...    
54937   2100-05-31
54998   2100-07-31
55029   2100-08-31
55090   2100-10-31
55151   2100-12-31
Name: date, Length: 1057, dtype: datetime64[ns]

In [None]:
# for column in spi_dfs[ssp_name].columns[1:]:
#     spi_dfs[ssp_name][column].plot.kde()
#     plt.title(f'Precipitation')
# plt.xlabel('Models')
# plt.ylabel('Density')
# plt.show()

In [None]:
# import pandas as pd
# import numpy as np


# daily = pd.read_csv('../compound/LARC_pr_ssp126_daily.csv').rename(columns={'Unnamed: 0': 'date'})
# monthly = pd.read_csv('../compound/LARC_pr_ssp126_monthly_avg.csv').rename(columns={'Unnamed: 0': 'date'})

# col = 'INM-CM4-8'

# daily_ = SPI().calculate(daily, 'date', col, freq='D', scale=30, 
#                        fit_type='lmom', dist_type='gam').dropna()
# monthly_ = SPI().calculate(monthly, 'date', col, freq='M', scale=1, 
#                        fit_type='lmom', dist_type='gam')#.dropna()
# print('daily spi')
# display(daily_.head(10).iloc[:, [0,2]])
# print('\nmonthly spi')
# display(monthly_.head(10).iloc[:, [0,2]])

In [None]:
# daily_.iloc[2:30]#.mean()

In [None]:
# # 1 month SPI using Gamma and L-moments
# for col in df.columns[1:]:
#     print(col)
#     spi = SPI().calculate(df, 'date', col, freq='D', scale=12, 
#                                  fit_type='lmom', dist_type='gam')

In [None]:
# from standard_precip.utils import plot_index
# plot_index(spi, 'date', 'INM-CM4-8_calculated_index')


In [None]:
# %%time
# from process_base import *
# # import plot as plot
# import pandas as pd

# plotly = False

# CENTERS = ['AMES', 'GSFC', 'JPL', 'KSC', 'MSFC', 'MAF', 'GISS',
#            'LARC', 'SSC', 'GRC', 'WFF', 'JSC', 'WSTF', 'AFRC']

# event = ['CWHE','CDHE'][1]
# months = [6, 7, 8]
# center='LARC'

# initialize(center, event, months)
# threshold = setup_thresholds(event)[0]
# files = {key[0]: [f for f in sorted(get_files()) if any(k in f for k in key)] 
#              for key in [['historical', 'ssp245'], ['ssp']]}

In [None]:
# from process import *

# event = ['CWHE','CDHE'][1]
# months = [6, 7, 8]
# center='LARC'

# ssp_dfs, hist, comp, results, pr, tm = main(center, event, months)

In [None]:
# %%time
# results = defaultdict(lambda: pd.DataFrame())
# for f in files['historical']:
#     name, df = preprocess_file(f, True, threshold)
#     results[name] = pd.concat([results[name], df], axis=0)

# common_columns = set(results[next(iter(results))].columns)
# for name, df in results.items():
#     common_columns.intersection_update(df.columns)
# common_columns = list(common_columns)

In [None]:
# results['pr'][common_columns]

In [None]:
# pd.concat(results.values())

In [None]:
# p = 0.5
# aa = results['pr'].rolling(window=30).mean().dropna(how='all')
# aa.groupby(aa.index.strftime('%m-%d')).quantile(p)

In [None]:
# # Thresholds
# for t, df in hist.items():
#     print(t, '\n', df.mean(), '\n')