In [None]:
%load_ext autoreload
%autoreload 1

%aimport Seasonal_Outliers
import importlib

importlib.reload(Seasonal_Outliers)

import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np

import Ipynb_importer


from numba import jit,cuda

from statsmodels.tsa.seasonal import seasonal_decompose
from scipy import stats
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from prophet import Prophet
import logging
logger = logging.getLogger('cmdstanpy')
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)

@jit(target_backend='cuda')  
def read(dir_,file):

    df = pd.read_csv(dir_+'cleaning/'+file,dtype={'SCP': str,'LINENAME':str})
    df['DATE_TIME'] = pd.to_datetime(df['DATE_TIME'], errors='coerce')
    # df['LINENAME'] = df['LINENAME'].apply(lambda x:''.join(sorted(x)))
    
    df.sort_values(by=['STATION','LINENAME','SCP','DATE_TIME'],inplace=True)
    
    df['DATE'] = df['DATE_TIME'].dt.date
    dt_range = pd.to_datetime(pd.date_range(start='07/01/2021', end='10/31/2021')).date
    
    df = df[df["DATE"].isin(dt_range)]
    print(df)
    df = Data_Cleaning_metro.clean(df)

    df['Day of week'] = df['DATE_TIME'].dt.dayofweek
    dt_wkends = df[df['Day of week']>4]['DATE'].unique()
    holiday = pd.to_datetime(pd.Series(['2021-07-04','2021-07-05','2021-09-06','2021-09-16','2021-10-11'])).dt.date.unique()

    dt_non_work = []
    dt_non_work = np.append(dt_wkends.tolist(),holiday)#Labor Day
    
    return df,dt_non_work


def vis(idir,outliers,names,col,title):
    
    dt_range = pd.to_datetime(pd.date_range(start='07/01/2021', end='10/31/2021')).date
    df = pd.DataFrame(columns=names,index=dt_range)
    df = df.fillna(0).astype('float')
    for idx in df.index:
        vals = outliers[outliers["DATE"]==idx][[col]].values
        for val in vals:
            # df.loc[idx][val] = 1
            df.loc[idx][val] = outliers.loc[(outliers['DATE']==idx)&(outliers[col].astype(str)==str(val[0]))]['norm_resid']
            
            
    if "Line's" in title:
        figsize=(16, 8)
    else:
        figsize=(36, 12)
        
    
    fig, ax = plt.subplots(figsize=figsize)
    ax = sns.heatmap(df, yticklabels = df.index,cbar_kws={"pad": 0.05},cmap='bwr',center=0,zorder=1)#,zorder=2)#square=True,

    for i in range(df.shape[1]+1):
        ax.axvline(i, color='white', lw=1.5,zorder=2)
        
    x1 = 0
    x2 = len(df.columns)

    ylabels = ax.get_yticklabels()
    
    ys = [ y.get_position()[1] for y in ylabels if str(y.get_text()) in ["2021-09-01","2021-09-02"]]
    y1 = ys[0]-0.5
    y2 = ys[1]+0.5
    print(ys)
    # ax.fill_between([x1,x2],ida_window[0],ida_window[1], color='lightskyblue', alpha=0.4)
    ax.fill_between([x1,x2],y1,y2, alpha=0.5, color='none',edgecolor='blue',label='Ida',zorder=3)
    
    ys2 = [ y.get_position()[1] for y in ylabels if str(y.get_text()) in ["2021-08-21","2021-08-23"]]
    y12 = ys2[0]-0.5
    y22 = ys2[1]+0.5
    print(ys2)
    ax.fill_between([x1,x2],y12,y22, alpha=0.8, color='none', edgecolor='gold', label='Henri',zorder=3)
    
    ys3 = [ y.get_position()[1] for y in ylabels if str(y.get_text()) in ["2021-07-08","2021-07-09"]]
    y13 = ys3[0]-0.5
    y23 = ys3[1]+0.5
    print(ys3)
    ax.fill_between([x1,x2],y13,y23, alpha=0.5, color='none', edgecolor='green', label='Elsa',zorder=3)
    
    ys4 = [ y.get_position()[1] for y in ylabels if str(y.get_text()) in ['2021-10-25','2021-10-26']]
    y14 = ys4[0]-0.5
    y24 = ys4[1]+0.5
    print(ys4)
    ax.fill_between([x1,x2],y14,y24, alpha=0.5, color='none', edgecolor='deeppink', label="Nor'eastern",zorder=3)

      

    for i, label in enumerate(ylabels):
        if i % 2 == 0:
            label.set_visible(True)
        else:
            label.set_visible(False)
            
    if "Line's" in title:
        ax.tick_params(axis='y', labelsize=7)

        
    # ax.set_title(title)
    ax.legend(ncol=4,loc='upper right',fontsize=18)
    #
    fig = plt.gcf()
    cax = fig.axes[-1]
    plt.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.9, wspace=0.2, hspace=0.2)
    cax.set_position([.796, .2, .03, .6]) # 
    plt.savefig(idir+'result/'+title.replace(':','_').replace('/',' ')+'.png',format='png',bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)
    plt.close('all')

    

    
def find_outliers(idir,df,key, title):
    if "Station's" in title:
        df['station_line'] = df['STATION'].astype(str) + '_Line ' + df['LINENAME']
        col = 'station_line'#["STATION","LINENAME"]

    elif "Line's" in title:
        # df = df.assign(LINENAME=df['LINENAME'].apply(list)).explode('LINENAME')
        df = df.assign(LINENAME=df['LINENAME'].str.split('-')).explode('LINENAME')
        col = "LINENAME"
        
    names = df[col].unique()
    tmpp = pd.DataFrame(df.groupby([col,"DATE"])[key].sum()).reset_index()
    gps = tmpp.groupby([col])
    
    outs = pd.DataFrame()
    stas = pd.DataFrame()
    for nm,df2 in gps:
        sta_df = df2.set_index(["DATE"])[[key]]
        if sta_df[key].mean()<72:
            continue
        tmp,forecast = Seasonal_Outliers.seasonal_de_day(sta_df,key,plot=False)
        if len(tmp)>0:
            out = pd.DataFrame(sta_df.loc[tmp][key])
            out['DATE'] = out.index
            
            out.reset_index(drop=True,inplace=True)
            out[col] = str(nm)
            

            forecast[col] = str(nm)
            outs = pd.concat([outs,out[["DATE",col]]],ignore_index=True)
            stas = pd.concat([stas,forecast],ignore_index=True)
    
    
    sst = stas[['ds','norm_resid',col]]
    print(sst,outs)
    sst['ds'] = pd.to_datetime(sst['ds'], errors='coerce')
    sst['DATE'] = sst['ds'].dt.date
    outs = pd.merge(outs,sst,on=['DATE',col])
    print(outs)

    vis(idir,outs,names,col, title)
    stas.to_csv(idir+'result/'+title.replace(':','_').replace('/',' ')+'.csv',index=False)

    
    
def run(idir):
    


    keys_dict = {
                 "Station's Entries":['ENTRIES_D'],
                 "Station's Exits":['EXITS_D'],
                 "Line's Entries":['ENTRIES_D'],
                 "Line's Exits":['EXITS_D'],
                }
    
    for k,v in keys_dict.items():
        title = _title + k
        find_outliers(idir,df,v[0], title)
        
        

    
    



In [None]:
input_dir = ""
file = 'diffed_resampled_uniqued_202107-10_Turnstile.csv'
_title = 'Daily Stationarity of '
df,dt_non_work = read(input_dir,file)

In [None]:
%load_ext autoreload
%autoreload 1
%aimport Seasonal_Outliers
import importlib
importlib.reload(Seasonal_Outliers)

run(input_dir)