In [12]:
import pandas as pd
import os
import numpy as np

from tableone import TableOne
import plotly.express as px
import plotly.graph_objects as go


from data.constants import DATA_FOLDER

In [13]:
point_panel_in = os.path.join(DATA_FOLDER, "final", "point_panel.parquet")
line_panel_in = os.path.join(DATA_FOLDER, "final", "line_panel.parquet")
tract_panel_in = os.path.join(DATA_FOLDER, "final", "tract_panel.parquet")
comm_panel_in = os.path.join(DATA_FOLDER, "final", "comm_panel.parquet")

In [14]:
point_panel = pd.read_parquet(point_panel_in)
line_panel = pd.read_parquet(line_panel_in)
tract_panel = pd.read_parquet(tract_panel_in)
comm_panel = pd.read_parquet(comm_panel_in)

In [15]:
comm_panel = comm_panel[comm_panel.date < "2024-08-19"]
tract_panel = tract_panel[tract_panel.date < "2024-08-19"]
line_panel = line_panel[line_panel.date < "2024-08-19"]

# Baseline descriptive stats:

## Space-like

In [16]:
def spacelike_stats(df):
    spacelike = (df
             .drop(columns=['date','DNC','is_weekend','dotw','rides'])
             .drop_duplicates()
             .groupby(['transit'])
             .agg({"id":"nunique", "UCMP":"sum",
                   'lat':["min","max","mean"],
                   'long':["min","max","mean"]})).T
    spacelike.index=['n units.', 'serves DNC',
                    'lat: min','lat: max','lat: mean',
                    'lon: min','lon: max','lon: mean']
    return spacelike.round(2)

In [17]:
plot_data = spacelike_stats(tract_panel)

fig = go.Figure(data=[go.Table(
    header=dict(values=['transit'] + list(plot_data.columns),
                align='right'),
    cells=dict(values=[plot_data.index, plot_data.bike, plot_data.train, plot_data.uber],
               align='right'))
])
fig.update_layout(width=600)
# fig.write_json("../../../reports/replication/baseline-table.json")
# fig.show()

In [18]:
def timelike_stats(df):
    df = (df[~df['is_weekend']]
             .drop(columns=['DNC','is_weekend','dotw'])
             .drop_duplicates()
             .groupby(['transit'])
             .agg({'date': 'nunique',
                   'rides': ['min','max','mean']})).T
    df.index = [
        "n. weekdays", "rides: min", "rides: max", "rides: mean"
    ]
    return df.round(2)

In [19]:
plot_data = pd.concat([spacelike_stats(tract_panel), timelike_stats(tract_panel)])
plot_data
fig = go.Figure(data=[go.Table(
    header=dict(values=['transit'] + list(plot_data.columns),
                align='right'),
    cells=dict(values=[plot_data.index, plot_data.bike, plot_data.train, plot_data.uber],
               align='right'))
])
fig.update_layout(width=700)
fig.write_json("../../../reports/replication/baseline-table.json")
!cp ../../../reports/replication/baseline-table.json ../../../../eric-mc2-cv/static/json
fig.show()

# Balance

In [20]:
def fix_count_rows(tone, count_row):
    row_labels = tone.tableone.index.get_level_values(0)
    is_count = row_labels.str.contains(count_row)
    new_labels =  np.where(is_count,
                            row_labels.str.replace(', mean (SD)',''),
                            row_labels)
    tone.tableone.index = pd.MultiIndex.from_tuples(zip(new_labels, 
                                            tone.tableone.index.get_level_values(1).values))
    return tone

def balance_table(df, unit_name, unit_abbr):
    space_data = df.drop_duplicates(['transit','id']).copy()
    space_table = TableOne(space_data,
         columns=['transit','lat','long'],
         groupby='UCMP',
         rename={'UCMP':'Near DNC', 'n':unit_name},
         pval=True,
         missing=False,
         overall=False)
    
    space_table = fix_count_rows(space_table, unit_name)

    time_data = df.drop_duplicates(['transit','id','UCMP','date','rides'])
    time_table = TableOne(time_data,
                      columns=['rides'],
                      groupby='UCMP',
                      rename={'UCMP':'Near DNC',
                              'n':f'{unit_abbr}-days',
                              'rides': 'daily rides'},
                      pval=True,
                      missing=False,
                      overall=False)
    
    time_table = fix_count_rows(time_table, unit_abbr)

    
    def apply_order(x, keys):
        orders = np.zeros_like(x)
        for i, key in enumerate(keys):
            orders += i * x.str.contains(key)
        return orders

    key_order = [unit_name, unit_abbr, 'rides', 'transit','lat','lon']
    order_func = lambda x: apply_order(x, key_order)
    bal_table = pd.concat([space_table.tableone, time_table.tableone])
    bal_table = bal_table.sort_index(key=order_func)
    
    # Hack to concat TableOne tables.
    space_table.tableone = bal_table
    return space_table

In [21]:
tract_balance = balance_table(tract_panel, 'tracts', 'tract')
comm_balance = balance_table(comm_panel, 'community areas', 'CA')
line_balance = balance_table(line_panel, 'routes', 'route')
print(tract_balance.tabulate(tablefmt='simple'))
print(comm_balance.tabulate(tablefmt='simple'))
print(line_balance.tabulate(tablefmt='simple'))

                               0               1               P-Value
----------------------  -----  --------------  --------------  ---------
tracts                         1898            91
tract-days                     113663          6950
daily rides, mean (SD)         358.3 (1548.9)  756.3 (1699.2)  <0.001
transit, n (%)          bike   600 (31.6)      37 (40.7)       <0.001
                        train  96 (5.1)        14 (15.4)
                        uber   1202 (63.3)     40 (44.0)
lat, mean (SD)                 -0.1 (1.2)      -0.1 (0.2)      0.208
long, mean (SD)                -0.3 (1.4)      0.5 (0.3)       <0.001
                               0                1                P-Value
----------------------  -----  ---------------  ---------------  ---------
community areas                173              23
CA-days                        12683            1749
daily rides, mean (SD)         3779.0 (9628.5)  6389.9 (9083.1)  <0.001
transit, n (%)          bike   69 (39

In [22]:
with open("../../../reports/replication/tract-balance.md","w") as f:
    f.write(tract_balance.tabulate(headers=['Not Near DNC', 'Near DNC', 'P-Value'],tablefmt="github"))
with open("../../../reports/replication/comm-balance.md","w") as f:
    f.write(comm_balance.tabulate(headers=['Not Near DNC', 'Near DNC', 'P-Value'],tablefmt="github"))
with open("../../../reports/replication/line-balance.md","w") as f:
    f.write(line_balance.tabulate(headers=['Not Near DNC', 'Near DNC', 'P-Value'],tablefmt="github"))
!cp ../../../reports/replication/tract-balance.md ../../../../eric-mc2-cv/static/uploads/
!cp ../../../reports/replication/comm-balance.md ../../../../eric-mc2-cv/static/uploads/
!cp ../../../reports/replication/line-balance.md ../../../../eric-mc2-cv/static/uploads/