https://data.ny.gov/Transportation/MTA-Subway-Hourly-Ridership-Beginning-February-202/wujg-7c2s/about_data

In [None]:
import polars as pl
import re

In [None]:
def add_dec_prec(row):
    latitude = row["latitude"]
    longitude = row["longitude"]
    # Assuming latitude and longitude are strings; extract the precision part
    lat_precision = len(latitude.split(".")[-1]) if "." in latitude else 0
    long_precision = len(longitude.split(".")[-1]) if "." in longitude else 0
    # Return the total precision as the sum of both
    return lat_precision + long_precision

In [None]:
def clean_column_name(name):
    return re.sub(r"[ \-&]", "_", name).replace("___", "_").replace("__","_").lower()

In [None]:
#Ridership table
columns_to_keep = ["transit_timestamp", "station_complex_id", "fare_class_category", "ridership"]
original_ridership = pl.read_parquet("data/hist.parquet", columns=columns_to_keep, low_memory=True)

ridership_wide = original_ridership.with_columns(
    [pl.col("transit_timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %I:%M:%S %p"),
     pl.col("ridership").cast(pl.Int16)]
).pivot(
    index=["transit_timestamp", "station_complex_id"],
    columns="fare_class_category",
    values="ridership",
    aggregate_function="sum",
    sort_columns=True
).sort(
    ["transit_timestamp", "station_complex_id"], descending=[False, False]
).fill_null(0)

metrocard_columns = [col for col in ridership_wide.columns if "Metrocard" in col]
omny_columns = [col for col in ridership_wide.columns if "OMNY" in col]
ridership_columns = [col for col in ridership_wide.columns if "Metrocard" in col or "OMNY" in col]

ridership = ridership_wide.with_columns(
    total_metrocard_ridership=pl.sum_horizontal(col for col in metrocard_columns),
    total_omny_ridership=pl.sum_horizontal(col for col in omny_columns),
    total_ridership=pl.sum_horizontal(col for col in ridership_columns),
)

rename_mapping = {col: clean_column_name(col) for col in ridership.columns}
ridership = ridership.rename(rename_mapping)

#Since we got rid of the shuttle and TRAM lines, we filter them out here too.
ridership = ridership.filter(~pl.col("station_complex_id").str.contains("TRAM")).filter(~pl.col("station_complex_id").str.contains("141"))
ridership

In [None]:
#Subset of the data for the other tables in the schema
df = pl.read_parquet("data/hist.parquet", n_rows=30_000_000, low_memory=True)

In [None]:
#Stations table
stations = df.select(["station_complex_id", "station_complex", "borough", "latitude", "longitude"]).unique()

df_with_precision = stations.with_columns(
    [(pl.struct(["latitude", "longitude"]).map_batches(
        lambda batch: batch.map_elements(add_dec_prec, return_dtype=pl.Int64)
    )).alias("total_precision"),
     pl.col("latitude").cast(pl.Float64),
     pl.col("longitude").cast(pl.Float64)]
)
df_with_precision = df_with_precision.sort(['station_complex_id', 'total_precision'], descending=[False, True]).unique(subset=["station_complex_id"])

stations = df_with_precision.select(["station_complex_id", "station_complex", "borough", "latitude", "longitude"]).unique().sort("station_complex_id")

stations = stations.with_columns(
    pl.col("station_complex")
    .str.replace_all(r"\,S", "")
    .str.replace_all(r"\(110 St\)", "- 110 St")
    .str.replace_all(r"\/Botanic Garden \(S\)", "")
    .str.strip_chars()
    ).filter(~pl.col("station_complex_id").str.contains("TRAM")
    ).filter(~pl.col("station_complex").str.contains(r"\(S\)")
    ).sort("station_complex_id")
stations

In [None]:
print(stations["station_complex"].to_list())

In [None]:
stations_clean = stations.with_columns(pl.col("station_complex").str.replace_all(r"\([^)]*\)", "").str.strip_chars())
stations_clean

In [None]:
#Routes table
station_list = stations["station_complex"].to_list()
regex_pattern = r"\(([^)]+)\)"

unique_train_lines = set()

for station in station_list:
    matches = re.findall(regex_pattern, station)
    if matches:
        for line in matches[0].split(','):
            unique_train_lines.add(line.strip())


routes = pl.DataFrame({
    "route_name": sorted(list(unique_train_lines))
})
routes

In [None]:
#Station_routes table
station_routes = stations.with_columns(
    pl.col("station_complex")
    .str.extract_all(r"\((.*?)\)")
    .map_elements(lambda groups: ','.join(groups), return_dtype=str)
    .str.replace_all(r"\(", "")
    .str.replace_all(r"\)", "")
    .str.split(",")
    .alias("route_list")
    )

# Step 3: Explode the list into separate rows
stations_exploded = station_routes.explode("route_list")

# Step 4: Select and rename columns to fit the SQL schema, remove duplicates
station_routes = stations_exploded.select([
    pl.col("station_complex_id"),
    pl.col("station_complex")
    .str.replace_all(r"\([^)]*\)", "").str.strip_chars().alias("station_complex_unclean"),
    pl.col("route_list").alias("route_name")
]).unique()

station_routes = station_routes.sort("station_complex_id")
station_routes