In [None]:
%load_ext autoreload
%autoreload 1

%aimport Seasonal_Outliers
import importlib

importlib.reload(Seasonal_Outliers)

import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np

import Ipynb_importer


from numba import jit,cuda

from statsmodels.tsa.seasonal import seasonal_decompose
from scipy import stats
import seaborn as sns

from prophet import Prophet
import logging
logger = logging.getLogger('cmdstanpy')
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)

@jit(target_backend='cuda')  
def read(dir_,file):

    df = pd.read_csv(dir_+'cleaning/'+file,dtype={'SCP': str,'LINENAME':str})
    df['DATE_TIME'] = pd.to_datetime(df['DATE_TIME'], errors='coerce')
    # df['LINENAME'] = df['LINENAME'].apply(lambda x:''.join(sorted(x)))
    
    df.sort_values(by=['STATION','LINENAME','SCP','DATE_TIME'],inplace=True)
    
    df['DATE'] = df['DATE_TIME'].dt.date
    dt_range = pd.to_datetime(pd.date_range(start='07/01/2021', end='10/31/2021')).date
    
    df = df[df["DATE"].isin(dt_range)]
    print(df.head())


    df['Day of week'] = df['DATE_TIME'].dt.dayofweek
    dt_wkends = df[df['Day of week']>4]['DATE'].unique()
    holiday = pd.to_datetime(pd.Series(['2021-07-04','2021-07-05','2021-09-06','2021-09-16','2021-10-11'])).dt.date.unique()

    dt_non_work = []
    dt_non_work = np.append(dt_wkends.tolist(),holiday)#Labor Day
    
    return df,dt_non_work


def vis(idir,outliers,names,col,title):
    start = pd.to_datetime('07/01/2021 00:00:00')
    end = pd.to_datetime('10/31/2021 23:59:59')
    dt_range = pd.to_datetime(pd.date_range(start=start, end=end,freq='4H'),format='%Y-%m-%d %H:%M:%S')
    df = pd.DataFrame(index=names,columns=dt_range)
    df = df.fillna(0).astype('float')
    for idx in df.index:#stations
        vals = outliers[outliers[col]==idx][["DATE_TIME"]].values#datetimes
        for val in vals:
            df.loc[idx][val] = outliers.loc[(outliers['DATE_TIME']==val[0])&(outliers[col].astype(str)==str(idx))]['norm_resid']
            # print(df.loc[idx][val])
    # print(df)
    if "Line's" in title:
        figsize=(50, 8)
    else:
        figsize=(70, 30)
    fig, ax = plt.subplots(figsize=figsize)
    # my_colors=['whitesmoke','orangered']
    ax = sns.heatmap(df, cmap='bwr', cbar_kws={"pad": 0.05},center=0, xticklabels = df.columns.strftime('%Y-%m-%d,%H'),zorder=1)#square=True,

    for i in range(df.shape[0]+1):
        ax.axhline(i, color='white', lw=1.5,zorder=2)
    y1 = 0
    y2 = len(df)

    # ida_window = pd.to_datetime(['2021-08-26 00','2021-09-04 00'])
    # ida_nyc_window = pd.to_datetime(['2021-09-01 00','2021-09-02 00'])
    xlabels = ax.get_xticklabels()
    print(xlabels[0].get_text())
    print(xlabels[0])

    xs = [ x.get_position()[0] for x in xlabels if str(x.get_text()) in ["2021-09-01,00","2021-09-03,00"]]
    x1 = xs[0]-0.5
    x2 = xs[1]+0.5
    print(xs)
    # ax.fill_between([x1,x2],ida_window[0],ida_window[1], color='lightskyblue', alpha=0.4)
    ax.fill_between([x1,x2],y1,y2,  alpha=0.5, color='none',edgecolor='blue', label='Ida', zorder=3,linewidth=3)
    
    xs2 = [ x.get_position()[0] for x in xlabels if str(x.get_text()) in ["2021-08-21,00","2021-08-24,00"]]
    x12 = xs2[0]-0.5
    x22 = xs2[1]+0.5
    print(xs2)
    ax.fill_between([x12,x22],y1,y2,alpha=0.8, color='none', edgecolor='gold', zorder=3,label='Henri',linewidth=3)
    
    xs3 = [ x.get_position()[0] for x in xlabels if str(x.get_text()) in ["2021-07-08,00","2021-07-10,00"]]
    x13 = xs3[0]-0.5
    x23 = xs3[1]+0.5
    print(xs3)
    ax.fill_between([x13,x23],y1,y2, alpha=0.5, color='none', edgecolor='green', zorder=3,label='Elsa',linewidth=3)
    
    xs4 = [ x.get_position()[0] for x in xlabels if str(x.get_text()) in ["2021-10-25,00","2021-10-27,00"]]
    x14 = xs4[0]-0.5
    x24 = xs4[1]+0.5
    print(xs4)
    ax.fill_between([x14,x24],y1,y2, alpha=0.5, color='none', edgecolor='deeppink', zorder=3,label="Nor'eastern",linewidth=3)

    for i, label in enumerate(xlabels):
        if i % 2 == 0:
            label.set_visible(True)
        else:
            label.set_visible(False)
    if "Line's" in title:
        ax.tick_params(axis='x', labelsize=7)    
    ax.set_title(title)
    ax.legend(ncol=4,loc='upper right',fontsize=40)
    # 
    fig = plt.gcf()
    cax = fig.axes[-1]
    plt.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.9, wspace=0.2, hspace=0.2)
    cax.set_position([.796, .2, .03, .6]) # 
    plt.savefig(idir+'result/'+title.replace(':','_').replace('/',' ')+'.png',format='png',bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)
    plt.close('all')



    
def find_outliers(idir,df,key, title):
    if "Station's" in title:
        df['station_line'] = df['STATION'].astype(str) + '_Line ' + df['LINENAME']
        col = 'station_line'#["STATION","LINENAME"]
    elif "Line's" in title:
        df = df.assign(LINENAME=df['LINENAME'].str.split('-')).explode('LINENAME')
        col = "LINENAME"
        
    tmpp = pd.DataFrame(df.groupby([col,"DATE_TIME"])[key].sum()).reset_index()
    gps = tmpp.groupby([col])
    names = df[col].unique()
    outs = pd.DataFrame()
    stas = pd.DataFrame()
    for nm,df2 in gps:
        sta_df = df2.set_index(["DATE_TIME"])[[key]]
        if sta_df[key].mean()<12:
            continue
        tmp,forecast = Seasonal_Outliers.seasonal_de_hour(sta_df,key,plot=False)
        if len(tmp)>0:
            out = pd.DataFrame(sta_df.loc[tmp][key])
            out['DATE_TIME'] = out.index
            
            out.reset_index(drop=True,inplace=True)

            out[col] = str(nm)
            forecast[col] = str(nm)

            outs = pd.concat([outs,out[["DATE_TIME",col]]],ignore_index=True)
            stas = pd.concat([stas,forecast],ignore_index=True)
    
        
    sst = stas[['ds','norm_resid',col]]
    print(sst,outs)
    sst['ds'] = pd.to_datetime(sst['ds'], errors='coerce')
    sst['DATE_TIME'] = sst['ds']
    outs = pd.merge(outs,sst,on=['DATE_TIME',col])
    print(outs)
         
    vis(idir,outs,names,col, title)
    stas.to_csv(idir+'result/'+title.replace(':','_').replace('/',' ')+'.csv',index=False)

    
    
def run(idir):
    
    file = 'diffed_resampled_uniqued_202107-10_Turnstile.csv'
    
    _title = 'Hourly Stationarity of '
    df,dt_non_work = read(idir,file)
   
    keys_dict = {
                 "Station's Entries":['ENTRIES_D'],
                 "Station's Exits":['EXITS_D'],
                 "Line's Entries":['ENTRIES_D'],
                 "Line's Exits":['EXITS_D'],
                }
    
    for k,v in keys_dict.items():
        title = _title + k
        find_outliers(idir,df,v[0], title)
        
        

    
    
    


if __name__=='__main__':
    
    input_dir = ""
    run(input_dir)
