In [None]:
import pandas as pd
import geopandas as gpd
import os
import contextily as cx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch
import gma.extend.mapplottools as mpt
import matplotlib as mpl

def join_dict(fi, dir,jdict):
    if 'citibike' in dir:
        jdict['k1'] = 'name'
        if 'ends' in fi:
            jdict['k2'] = 'end_station_name'
            jdict['flag'] = 1
        else:
            jdict['k2'] = 'start_station_name'
            jdict['flag'] = 0
    if 'TLC' in dir:
        jdict['k1'] = 'location_i'
        if 'Dropoffs' in fi:
            jdict['k2'] = 'DOLocationID'
            jdict['flag'] = 1
        else:
            jdict['k2'] = 'PULocationID'
            jdict['flag'] = 0
        
    if 'Turnstile' in dir:
        jdict['k1'] = 'sta_line'
        jdict['k2'] = 'station_line'
        if 'Entries' in fi:
            jdict['flag'] = 0
        else:
            jdict['flag'] = 1

    return jdict

    
def att_join(files, dir, shp, ts, gdfs):
    
    for fi in files:#this files are for a hourly-grouped fig
        jdict = {}
        jdict = join_dict(fi, dir,jdict)
        
        t = fi[4:6]
        
        k1 = jdict['k1']
        k2 = jdict['k2']
        flag = jdict['flag']
        
        dfi = pd.read_csv(dir+fi)
        gdfi = gpd.GeoDataFrame(dfi.merge(shp, left_on=k2, right_on=k1))
        
        gdfs[t,flag] = gdfi
        
    return gdfs

def polygon_centroid(gdf,ci):
    df = pd.DataFrame()
    # df['az'+ci] = gdf['az'+ci]
    df['r'+ci] = gdf['r'+ci]
    
    df["centroid"] = gdf.centroid
    df["lat"] = gdf.centroid.map(lambda p: p.x)
    df["lng"] = gdf.centroid.map(lambda p: p.y)
    
    pts = gpd.GeoDataFrame(df,geometry=gpd.points_from_xy(df.lat, df.lng),crs = gdf.crs)
    pts = pts.drop(columns=['centroid','lat', 'lng'])
    return pts
    
# def plot_pts(pts,ax,col):
#     return

 
def geo_plot_(files, dir, shp,ts,evt):
    title=''
    h=10
    if 'citibike' in dir:
        title='citibike'
        h=8
    elif 'of Taxi' in files[0]:
        title='Taxi'
    elif 'HVFHV' in files[0]:
        title='HVFHV'
    # elif 'Green' in files[0]:
    #     title='Green'
    # elif 'Yellow' in files[0]:
    #     title='Yellow'
    elif "Turnstile" in dir:
        title='Metro'
    
    aoi_bounds = shp.geometry.total_bounds
    xmin, ymin, xmax, ymax = aoi_bounds
    pad = 5000
    
    if evt=='Aug':
        c = ['_210821','_210822','_210823']
        
    if evt=='Sep':
        c = ['_210901','_210902']
        
    if evt=='Jul':
        c = ['_210708','_210709']
        
    if evt=='Oct':
        c = ['_211025','_211026']

    gdfs = {}
    gdfs = att_join(files, dir, shp, ts, gdfs)
    
    
    color_dic = {'Regular':'#FFDAB9', 
                 'Potential Irregular (+)':'#FA8072', 
                 'Potential Irregular (-)':'#B2DFEE', 
                 'Irregular (+)':'#8B0000',
                 'Irregular (-)':'#3A5FCD', 
                 # 'Highly Irregular (+)':'#800000',
                 # 'Highly Irregular (-)':'#000080',
                 'Null':'#E5E5E5'
                }
    pmarks = []
    for k,v in color_dic.items():
        pmarks.append(Patch(facecolor=v, label=k))
    
    
    for ci in c:
        col='c'+ci
        fig, axes = plt.subplots(nrows=2, ncols=len(ts),figsize=((xmax-xmin+10000)/(ymax-ymin+10000)*10*3,h),
                                 sharex=True, sharey=True)
        
        for t in ts:
            i = int(int(t)/4)#x loc for a subplot
            for f in [0,1]:#f=0 means original data, f=1 means destination data
                ax = axes[f,i]
                gdfi = gdfs[t,f]
                gdfi = gdfi.to_crs('EPSG:3857')
                
                gdfi[col] = gdfi[col].fillna('Null')
                gdfi['colors'] = gdfi[col].map(color_dic)
                handles, _ = ax.get_legend_handles_labels()
                
                # gdfi["centroid"] = gdfi.centroid
                geom = gdfi.geom_type[0]
                if geom=='Polygon':
                    gdfi.plot(ax = ax, column=col, color = gdfi['colors'],
                          legend=True,legend_kwds={'handles':[*handles,*pmarks], 'loc':'lower right'},alpha=0.8)
                    pts = polygon_centroid(gdfi,ci)
                    norm = mpl.colors.CenteredNorm()#DivergingNorm(vmin=pts['r'+ci].min(), vcenter=0, vmax=pts['r'+ci].max())
                    pts.plot(ax = ax, column='r'+ci, legend=True,markersize=abs(pts['r'+ci])*50, 
                             cmap='bwr',norm=norm,alpha=0.6)# column='r'+ci, markersize=abs(pts['r'+ci])*100+3,
                
                elif geom=='Point':
                    gdfi.plot(ax = ax, column=col, color = gdfi['colors'],markersize=abs(gdfi['r'+ci])*50,
                          legend=True,legend_kwds={'handles':[*handles,*pmarks], 'loc':'lower right'},alpha=0.8)
                    
                ax.set_xlim(xmin-pad, xmax+pad)
                ax.set_ylim(ymin-pad, ymax+pad)
                
                # handles, _ = ax.get_legend_handles_labels()
                # ax.legend(handles=[*handles,*pmarks], loc='lower right')
                if f==0:
                    text = title+'_Origin'+ci+'_T'+t
                else:
                    text = title+'_Destination'+ci+'_T'+t
                
                    
                cx.add_basemap(ax,source=cx.providers.CartoDB.Positron)
                
                mpt.AddScaleBar(ax=ax, LOC=(0.8,0.03),FontSize=8, BarWidth=1.2)
                mpt.AddCompass(ax=ax, LOC=(0.1,0.9), SCA=0.03, FontSize=8)
                
        # patch_col = axes[0,0].collections[0]
        # cb = fig.colorbar(patch_col, ax=axes, shrink=0.5)
        fig.legend(handles=[*handles,*pmarks], loc='lower center',ncol=8)
        plt.savefig(dir+text+".jpg",bbox_inches='tight',dpi=100)
        plt.show()
        plt.close('all')

        
def geo_plot(files, dir, shp,ts,evt):
    title=''
    if 'citibike' in dir:
        title='citibike'
    elif 'of Taxi' in files[0]:
        title='Taxi'
    elif 'HVFHV' in files[0]:
        title='FHV'
    # elif 'Green' in files[0]:
    #     title='Green'
    # elif 'Yellow' in files[0]:
    #     title='Yellow'
    elif "Turnstile" in dir:
        title='Subway'
    
    aoi_bounds = shp.geometry.total_bounds
    xmin, ymin, xmax, ymax = aoi_bounds
    pad = 6000
    
    if evt=='Aug':
        c = ['_210821','_210822','_210823']
        
    if evt=='Sep':
        c = ['_210901','_210902']
        
    if evt=='Jul':
        c = ['_210708','_210709']
        
    if evt=='Oct':
        c = ['_211025','_211026']

    gdfs = {}
    gdfs = att_join(files, dir, shp, ts, gdfs)
    
    
    color_dic = {'Regular':'#FFEFD5', 
                 'Potential Irregular (+)':'#F08080', 
                 'Potential Irregular (-)':'#87CEFA', 
                 'Irregular (+)':'#FF0000',
                 'Irregular (-)':'#3A5FCD', 
                 # 'Highly Irregular (+)':'#800000',
                 # 'Highly Irregular (-)':'#000080',
                 'Null':'#E5E5E5'
                }
    pmarks = []
    for k,v in color_dic.items():
        pmarks.append(Patch(facecolor=v, label=k))
    
    
    for ci in c:
        col='c'+ci
        fig, axes = plt.subplots(nrows=2, ncols=len(ts),figsize=((xmax-xmin+10000)/(ymax-ymin+10000)*10*3,10),
                                 sharex=True, sharey=True)
        
        for t in ts:
            i = int(int(t)/4)#x loc for a subplot
            for f in [0,1]:#f=0 means original data, f=1 means destination data
                ax = axes[f,i]
                gdfi = gdfs[t,f]
                gdfi = gdfi.to_crs('EPSG:3857')
                
                gdfi[col] = gdfi[col].fillna('Null')
                gdfi['colors'] = gdfi[col].map(color_dic)
                handles, _ = ax.get_legend_handles_labels()
                
                # gdfi["centroid"] = gdfi.centroid
                geom = gdfi.geom_type[0]
                if geom=='Polygon':
                    gdfi.plot(ax = ax,  color = '#FFEFD5',legend=False,alpha=0.7)
                    pts = polygon_centroid(gdfi,ci)
                    # norm = mpl.colors.CenteredNorm()#DivergingNorm(vmin=pts['r'+ci].min(), vcenter=0, vmax=pts['r'+ci].max())
                    pts.plot(ax = ax, column=col,color = gdfi['colors'], legend=True,markersize=abs(pts['r'+ci])**2+10, 
                            legend_kwds={'handles':[*handles,*pmarks], 'loc':'lower right'},alpha=0.7)# column='r'+ci, markersize=abs(pts['r'+ci])*100+3,
                
                elif geom=='Point':
                    gdfi.plot(ax = ax, column=col, color = gdfi['colors'],markersize=abs(gdfi['r'+ci])**2+10,
                          legend=True,legend_kwds={'handles':[*handles,*pmarks], 'loc':'lower right'},alpha=0.7)
                    
                ax.set_xlim([xmin-pad, xmax+pad])
                ax.set_ylim([ymin-pad, ymax+pad])
                
                # handles, _ = ax.get_legend_handles_labels()
                # ax.legend(handles=[*handles,*pmarks], loc='lower right')
                if f==0:
                    text = title+'_Origin'+ci+'_T'+t
                else:
                    text = title+'_Destination'+ci+'_T'+t
                if 'citibike' in text:   
                    ax.set_title(text)
                else: 
                    ax.set_title(text,fontsize=15)
                cx.add_basemap(ax,source=cx.providers.CartoDB.Positron)
                if 'citibike' in dir:
                    mpt.AddScaleBar(ax=ax, LOC=(0.55,0.1),FontSize=10, BarWidth=1.2)
                else:    
                    mpt.AddScaleBar(ax=ax, LOC=(0.78,0.05),FontSize=10, BarWidth=1.2)
                
                mpt.AddCompass(ax=ax, LOC=(0.1,0.9), SCA=0.03, FontSize=10)
                
        # patch_col = axes[0,0].collections[0]
        # cb = fig.colorbar(patch_col, ax=axes, shrink=0.5)
        fig.legend(handles=[*handles,*pmarks], loc='lower center',ncol=8,fontsize=16)
        plt.savefig(dir+text+".jpg",bbox_inches='tight',dpi=300)
        plt.show()
        plt.close('all')        
    
def run(dir,files,G_dir,events,ts):
    for evt in events:
        efiles = [i for i in files if evt in i]
        
        if 'TLC' in dir:
            shp = gpd.read_file(G_dir+r'NYC Taxi Zones\NYC_taxi_zones.shp').to_crs(3857)

            # FHV = [i for i in efiles if ' FHV ' in i]
            HVFHV = [i for i in efiles if 'HVFHV' in i]
            Taxi = [i for i in efiles if 'of Taxi' in i]
            
            # Green = [i for i in efiles if 'Green' in i]
            # Yellow = [i for i in efiles if 'Yellow' in i]
            # for fis in [FHV, HVFHV, Green, Yellow]:
            for fis in [Taxi, HVFHV]:
                geo_plot(fis,dir, shp,ts,evt)
        
        if 'Turnstile' in dir:
            shp = gpd.read_file(G_dir+r'Subway\Subway_and_Path_stations.shp').to_crs(3857)

            stations = [i for i in efiles if "Station's" in i]
            geo_plot(stations,dir, shp,ts,evt)

        if 'citibike' in dir:
            shp = gpd.read_file(G_dir+r'Bicycle\citibike_stations.shp').to_crs(3857)

            geo_plot(efiles,dir, shp,ts,evt)

    

if __name__=='__main__':
    
    main_dir = r""
    dirs = [r'\TLC\result\without holiday/',r'\Turnstile\result\without holiday/',r'\citibike\result\without holiday/']
    G_dir = r'\Geo_shps/'
    
    
    ts = ['00','04','08','12','16','20']
    events = ['Aug','Sep','Jul','Oct']
    for dir in dirs:
        dir = main_dir + dir
        files = [i for i in os.listdir(dir) if os.path.splitext(i)[1] == '.csv' and i[4:6] in ts]
        run(dir,files,G_dir,events,ts)
            
