# NYC taxi data for demand forecasting

### 1. Data reading 

In [1]:
import os
import subprocess

DATA_DIR = "data"
if not os.path.exists(DATA_DIR):
    os.mkdir(DATA_DIR)


for month in range(1, 13):
    if not os.path.exists(f"{DATA_DIR}/{month}.parquet"):
        subprocess.run(
            [
                "wget",
                f"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-{month:02}.parquet",
                "-O",
                f"{DATA_DIR}/{month}.parquet",
            ]
        )


total_size = sum(
    os.path.getsize(f"{DATA_DIR}/{month}.parquet") for month in range(1, 13)
)  # bytes
total_size_mb = total_size // (1024 * 1024)
print(f"Total dataset size: {total_size_mb} MB")

Total dataset size: 660 MB


In [2]:
if not os.path.exists(f"{DATA_DIR}/taxi_zone_lookup.csv"):
    subprocess.run(
        [
            "wget",
            "https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv",
            "-O",
            f"{DATA_DIR}/taxi_zone_lookup.csv",
        ]
    )

In [4]:
import polars as pl

In [7]:
dfs = []
for month in range(1, 13):
    df = pl.scan_parquet(f"{DATA_DIR}/{month}.parquet")
    df = df.with_columns(
        pl.col("tpep_pickup_datetime").dt.cast_time_unit("ms"),
        pl.col("tpep_dropoff_datetime").dt.cast_time_unit("ms"),
    )

    df = df.filter(
        (pl.col("tpep_pickup_datetime").dt.year() >= 2024)
        & (pl.col("tpep_dropoff_datetime").dt.date() <= pl.date(2025, 1, 1))
    )

    df = df.with_columns(pl.col("passenger_count").cast(pl.UInt8))
    df = df.with_columns(pl.col("PULocationID").cast(pl.UInt16))
    df = df.with_columns(pl.col("DOLocationID").cast(pl.UInt16))
    df = df.with_columns(pl.col("RatecodeID").cast(pl.UInt8))
    df = df.with_columns(pl.col("VendorID").cast(pl.UInt8))
    df = df.with_columns(pl.col("payment_type").cast(pl.UInt8))

    dfs.append(df)

df_2024 = pl.concat(dfs)

In [8]:
df_2024.collect()

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,Airport_fee
u8,datetime[ms],datetime[ms],u8,f64,u8,str,u16,u16,u8,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,2024-01-01 00:57:55,2024-01-01 01:17:43,1,1.72,1,"""N""",186,79,2,17.7,1.0,0.5,0.0,0.0,1.0,22.7,2.5,0.0
1,2024-01-01 00:03:00,2024-01-01 00:09:36,1,1.8,1,"""N""",140,236,1,10.0,3.5,0.5,3.75,0.0,1.0,18.75,2.5,0.0
1,2024-01-01 00:17:06,2024-01-01 00:35:01,1,4.7,1,"""N""",236,79,1,23.3,3.5,0.5,3.0,0.0,1.0,31.3,2.5,0.0
1,2024-01-01 00:36:38,2024-01-01 00:44:56,1,1.4,1,"""N""",79,211,1,10.0,3.5,0.5,2.0,0.0,1.0,17.0,2.5,0.0
1,2024-01-01 00:46:51,2024-01-01 00:52:57,1,0.8,1,"""N""",211,148,1,7.9,3.5,0.5,3.2,0.0,1.0,16.1,2.5,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2,2024-12-31 23:32:00,2024-12-31 23:56:00,,10.71,,,16,7,0,-4.13,0.0,0.5,0.0,0.0,1.0,-2.63,,
2,2024-12-31 23:05:00,2024-12-31 23:18:00,,4.56,,,252,16,0,-1.68,0.0,0.5,0.0,0.0,1.0,-0.18,,
2,2024-12-31 23:03:16,2024-12-31 23:28:35,,3.94,,,181,255,0,4.46,0.0,0.5,5.19,0.0,1.0,11.15,,
1,2024-12-31 23:15:33,2024-12-31 23:36:29,,4.2,,,165,61,0,27.07,0.0,0.5,0.0,0.0,1.0,28.57,,


In [9]:
df_taxi_zones = pl.scan_csv("data/taxi_zone_lookup.csv")
df_taxi_zones.collect()

LocationID,Borough,Zone,service_zone
i64,str,str,str
1,"""EWR""","""Newark Airport""","""EWR"""
2,"""Queens""","""Jamaica Bay""","""Boro Zone"""
3,"""Bronx""","""Allerton/Pelham Gardens""","""Boro Zone"""
4,"""Manhattan""","""Alphabet City""","""Yellow Zone"""
5,"""Staten Island""","""Arden Heights""","""Boro Zone"""
…,…,…,…
261,"""Manhattan""","""World Trade Center""","""Yellow Zone"""
262,"""Manhattan""","""Yorkville East""","""Yellow Zone"""
263,"""Manhattan""","""Yorkville West""","""Yellow Zone"""
264,"""Unknown""","""N/A""","""N/A"""


### 2. Data cleaning and filtering

In [12]:
trip_time = (pl.col("tpep_dropoff_datetime") - pl.col("tpep_pickup_datetime")).alias(
    "trip_time"
)

money_cols = [
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "congestion_surcharge",
    "Airport_fee",
]

In [15]:
df_2024 = (
    df_2024.with_columns(pl.col("passenger_count").fill_null(1))
    .filter(pl.col("passenger_count") > 0)
    .with_columns(pl.col("passenger_count").clip(upper_bound=6))
    .filter(trip_time.dt.total_minutes() <= 120)
    .with_columns([pl.col(col).abs() for col in money_cols])
    .filter([pl.col(col) < 1000 for col in money_cols])
    .filter(
        pl.col("RatecodeID").is_not_null(),
        pl.col("RatecodeID").is_in([1, 2, 3, 4, 5, 6]),
        pl.col("RatecodeID") != 99,
        pl.col("VendorID").is_not_null(),
        pl.col("VendorID").is_in([1, 2, 6, 7]),
    )
)

### 3. Data transformation 