# NYC taxi data for demand forecasting

### 1. Data reading 

In [1]:
import os
import subprocess

DATA_DIR = "data"
if not os.path.exists(DATA_DIR):
    os.mkdir(DATA_DIR)


for month in range(1, 13):
    if not os.path.exists(f"{DATA_DIR}/{month}.parquet"):
        subprocess.run(
            [
                "wget",
                f"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-{month:02}.parquet",
                "-O",
                f"{DATA_DIR}/{month}.parquet",
            ]
        )


total_size = sum(
    os.path.getsize(f"{DATA_DIR}/{month}.parquet") for month in range(1, 13)
)  # bytes
total_size_mb = total_size // (1024 * 1024)
print(f"Total dataset size: {total_size_mb} MB")

Total dataset size: 660 MB


In [2]:
if not os.path.exists(f"{DATA_DIR}/taxi_zone_lookup.csv"):
    subprocess.run(
        [
            "wget",
            "https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv",
            "-O",
            f"{DATA_DIR}/taxi_zone_lookup.csv",
        ]
    )

In [3]:
import polars as pl

In [4]:
dfs = []
for month in range(1, 13):
    df = pl.scan_parquet(f"{DATA_DIR}/{month}.parquet")
    df = df.with_columns(
        pl.col("tpep_pickup_datetime").dt.cast_time_unit("ms"),
        pl.col("tpep_dropoff_datetime").dt.cast_time_unit("ms"),
    )

    df = df.filter(
        (pl.col("tpep_pickup_datetime").dt.year() >= 2024)
        & (pl.col("tpep_dropoff_datetime").dt.date() <= pl.date(2025, 1, 1))
    )

    df = df.with_columns(pl.col("passenger_count").cast(pl.UInt8))
    df = df.with_columns(pl.col("PULocationID").cast(pl.UInt16))
    df = df.with_columns(pl.col("DOLocationID").cast(pl.UInt16))
    df = df.with_columns(pl.col("RatecodeID").cast(pl.UInt8))
    df = df.with_columns(pl.col("VendorID").cast(pl.UInt8))
    df = df.with_columns(pl.col("payment_type").cast(pl.UInt8))

    dfs.append(df)

df_2024 = pl.concat(dfs)

In [5]:
# df_2024.collect()

In [6]:
df_taxi_zones = pl.scan_csv("data/taxi_zone_lookup.csv")

### 2. Data cleaning and filtering

In [7]:
trip_time = (pl.col("tpep_dropoff_datetime") - pl.col("tpep_pickup_datetime")).alias(
    "trip_time"
)

money_cols = [
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "congestion_surcharge",
    "Airport_fee",
]

In [8]:
df_2024 = (
    df_2024.with_columns(pl.col("passenger_count").fill_null(1))
    .filter(pl.col("passenger_count") > 0)
    .with_columns(pl.col("passenger_count").clip(upper_bound=6))
    .filter(trip_time.dt.total_minutes() <= 120)
    .with_columns([pl.col(col).abs() for col in money_cols])
    .filter([pl.col(col) < 1000 for col in money_cols])
    .filter(
        pl.col("RatecodeID").is_not_null(),
        pl.col("RatecodeID").is_in([1, 2, 3, 4, 5, 6]),
        pl.col("RatecodeID") != 99,
        pl.col("VendorID").is_not_null(),
        pl.col("VendorID").is_in([1, 2, 6, 7]),
    )
)

### 3. Data transformation 

In [9]:
other_payment_type = [0, 3, 4, 5, 6]  # 1 and 2 are card and cash respectively
weekdays = [1, 2, 3, 4, 5]
taxi_zones = df_taxi_zones.select(["LocationID", "Borough"])


df_2024 = (
    df_2024.with_columns(
        pl.when(pl.col("payment_type").is_in(other_payment_type))
        .then(pl.lit(0))
        .otherwise(pl.col("payment_type"))
        .alias("payment_type")
    )
    .join(taxi_zones, left_on="PULocationID", right_on="LocationID")
    .rename({"Borough": "PU_Borough"})
    .join(taxi_zones, left_on="DOLocationID", right_on="LocationID")
    .rename({"Borough": "DO_Borough"})
    .drop(["PULocationID", "DOLocationID"])
    .with_columns((pl.col("Airport_fee").fill_null(0) > 0).alias("is_airport_ride"))
    .with_columns(
        pl.when(
            pl.col("tpep_pickup_datetime").dt.weekday().is_in(weekdays)
            & (
                (
                    (
                        pl.col("tpep_pickup_datetime").dt.time()
                        >= pl.time(hour=6, minute=30)
                    )
                    & (
                        pl.col("tpep_pickup_datetime").dt.time()
                        <= pl.time(hour=9, minute=30)
                    )
                )
                | (
                    (
                        pl.col("tpep_pickup_datetime").dt.time()
                        >= pl.time(hour=15, minute=30)
                    )
                    & (
                        pl.col("tpep_pickup_datetime").dt.time()
                        <= pl.time(hour=20, minute=0)
                    )
                )
            )
        )
        .then(pl.lit(True))
        .otherwise(pl.lit(False))
        .alias("is_rush_hour")
    )
)

### 4. Feature extraction 

In [10]:
df_2024_eager = df_2024.collect()

In [11]:
print(f"Type: {type(df_2024_eager)}")
df_2024_eager.head()

Type: <class 'polars.dataframe.frame.DataFrame'>


VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,Airport_fee,PU_Borough,DO_Borough,is_airport_ride,is_rush_hour
u8,datetime[ms],datetime[ms],u8,f64,u8,str,u8,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,str,bool,bool
2,2024-01-01 00:57:55,2024-01-01 01:17:43,1,1.72,1,"""N""",2,17.7,1.0,0.5,0.0,0.0,1.0,22.7,2.5,0.0,"""Manhattan""","""Manhattan""",False,False
1,2024-01-01 00:03:00,2024-01-01 00:09:36,1,1.8,1,"""N""",1,10.0,3.5,0.5,3.75,0.0,1.0,18.75,2.5,0.0,"""Manhattan""","""Manhattan""",False,False
1,2024-01-01 00:17:06,2024-01-01 00:35:01,1,4.7,1,"""N""",1,23.3,3.5,0.5,3.0,0.0,1.0,31.3,2.5,0.0,"""Manhattan""","""Manhattan""",False,False
1,2024-01-01 00:36:38,2024-01-01 00:44:56,1,1.4,1,"""N""",1,10.0,3.5,0.5,2.0,0.0,1.0,17.0,2.5,0.0,"""Manhattan""","""Manhattan""",False,False
1,2024-01-01 00:46:51,2024-01-01 00:52:57,1,0.8,1,"""N""",1,7.9,3.5,0.5,3.2,0.0,1.0,16.1,2.5,0.0,"""Manhattan""","""Manhattan""",False,False


In [12]:
final_dataset = (
    df_2024_eager.to_dummies(
        ["PULocation_Borough", "DOLocationID_Borough"]  # payment_type is okay
    )
    .with_columns(
        # add integer variables for counting daily events
        pl.col("passenger_count")
        .count()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("total_number_of_rides"),
        pl.col("is_airport_ride")
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("number_of_airport_rides"),
        pl.col("is_rush_hour")
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("number_of_rush_hour_rides"),
        # add features aggregating daily rides information
        pl.col("fare_amount")
        .mean()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("avg_fare_amount"),
        pl.col("trip_distance")
        .median()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("median_distance"),
        pl.col("total_amount")
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("sum_of_total_amounts"),
        pl.col("total_amount")
        .filter(pl.col("payment_type") == 1)
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("sum_of_total_amounts_card"),
        pl.col("total_amount")
        .filter(pl.col("payment_type") == 2)
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("sum_of_total_amounts_cash"),
        pl.col("total_amount")
        .filter(pl.col("payment_type") == 0)
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("sum_of_total_amounts_other"),
        pl.col("congestion_surcharge")
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("total_congestion_surcharge"),
        pl.col("passenger_count")
        .sum()
        .over(pl.col("tpep_pickup_datetime").dt.date())
        .alias("total_passenger_count"),
        # add time features
        pl.col("tpep_pickup_datetime").dt.quarter().alias("quarter"),
        pl.col("tpep_pickup_datetime").dt.month().alias("month"),
        pl.col("tpep_pickup_datetime").dt.day().alias("day_of_month"),
        pl.col("tpep_pickup_datetime").dt.weekday().alias("day_of_week"),
        pl.col("tpep_pickup_datetime").dt.weekday().is_in([6, 7]).alias("is_weekend"),
        pl.col("tpep_pickup_datetime").dt.date().alias("date"),
    )
    .drop(
        [
            "VendorID",
            "RatecodeID",
            "tpep_pickup_datetime",
            "tpep_dropoff_datetime",
            "store_and_fwd_flag",
            "extra",
            "mta_tax",
            "tip_amount",
            "tolls_amount",
            "improvement_surcharge",
            "congestion_surcharge",
            "Airport_fee",
        ]
    )
)

In [13]:
final_dataset

passenger_count,trip_distance,payment_type,fare_amount,total_amount,PU_Borough,DO_Borough,is_airport_ride,is_rush_hour,total_number_of_rides,number_of_airport_rides,number_of_rush_hour_rides,avg_fare_amount,median_distance,sum_of_total_amounts,sum_of_total_amounts_card,sum_of_total_amounts_cash,sum_of_total_amounts_other,total_congestion_surcharge,total_passenger_count,quarter,month,day_of_month,day_of_week,is_weekend,date
u8,f64,u8,f64,f64,str,str,bool,bool,u32,u32,u32,f64,f64,f64,f64,f64,f64,f64,i64,i8,i8,i8,i8,bool,date
1,1.72,2,17.7,22.7,"""Manhattan""","""Manhattan""",false,false,69541,8352,19360,22.184388,2.09,2.1793e6,1.7389e6,384662.67,55778.86,156123.25,109013,1,1,1,1,false,2024-01-01
1,1.8,1,10.0,18.75,"""Manhattan""","""Manhattan""",false,false,69541,8352,19360,22.184388,2.09,2.1793e6,1.7389e6,384662.67,55778.86,156123.25,109013,1,1,1,1,false,2024-01-01
1,4.7,1,23.3,31.3,"""Manhattan""","""Manhattan""",false,false,69541,8352,19360,22.184388,2.09,2.1793e6,1.7389e6,384662.67,55778.86,156123.25,109013,1,1,1,1,false,2024-01-01
1,1.4,1,10.0,17.0,"""Manhattan""","""Manhattan""",false,false,69541,8352,19360,22.184388,2.09,2.1793e6,1.7389e6,384662.67,55778.86,156123.25,109013,1,1,1,1,false,2024-01-01
1,0.8,1,7.9,16.1,"""Manhattan""","""Manhattan""",false,false,69541,8352,19360,22.184388,2.09,2.1793e6,1.7389e6,384662.67,55778.86,156123.25,109013,1,1,1,1,false,2024-01-01
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1,1.9,1,12.1,18.98,"""Manhattan""","""Manhattan""",false,false,76959,5463,27117,18.422413,1.61,2.0855e6,1.6223e6,359093.87,104116.0,176115.0,113355,4,12,31,2,false,2024-12-31
1,3.88,2,19.1,24.1,"""Manhattan""","""Queens""",false,false,76959,5463,27117,18.422413,1.61,2.0855e6,1.6223e6,359093.87,104116.0,176115.0,113355,4,12,31,2,false,2024-12-31
1,5.53,1,29.6,36.6,"""Manhattan""","""Brooklyn""",false,false,76959,5463,27117,18.422413,1.61,2.0855e6,1.6223e6,359093.87,104116.0,176115.0,113355,4,12,31,2,false,2024-12-31
1,0.89,1,9.3,16.44,"""Manhattan""","""Manhattan""",false,false,76959,5463,27117,18.422413,1.61,2.0855e6,1.6223e6,359093.87,104116.0,176115.0,113355,4,12,31,2,false,2024-12-31


In [14]:
final_dataset.write_parquet("data/dataset.parquet")

### 5. Data analysis 