In [98]:
# All imports go here
import polars as pl
from time import time

In [99]:
import adbc_driver_postgresql.dbapi

In [100]:
user="postgres"
password="postgres"
db_name = "ny_taxi"
port=5433
hostname="localhost"
DATABASE_URL = f"postgresql://{user}:{password}@{hostname}:{port}/{db_name}"

conn = adbc_driver_postgresql.dbapi.connect(DATABASE_URL)

# Create a cursor instance
with conn.cursor() as cur:
    cur.execute("SELECT count(1) FROM yellow_taxi_data_polars;")
    print(cur.fetch_arrow_table())

pyarrow.Table
count: int64
----
count: [[0]]


In [101]:
# endpoint = "https://github.com/DataTalksClub/nyc-tlc-data/releases/download/yellow/yellow_tripdata_2019-01.csv.gz"
endpoint = "hw1/data/yellow_tripdata_2019-01.csv.gz"

In [102]:
def yellow_taxi_polars_schema() -> dict:
  return {
    "VendorID": pl.Int32,
    "tpep_pickup_datetime": pl.Datetime,
    "tpep_dropoff_datetime": pl.Datetime,
    "passenger_count": pl.Int8,
    "trip_distance": pl.Float64,
    "PULocationID": pl.Int32,
    "DOLocationID": pl.Int32,
    "RatecodeID": pl.Int8,
    "store_and_fwd_flag": pl.String,
    "payment_type": pl.Int8,
    "fare_amount": pl.Float64,
    "extra": pl.Float64,
    "mta_tax": pl.Float64,
    "improvement_surcharge": pl.Float64,
    "tip_amount": pl.Float64,
    "tolls_amount": pl.Float64,
    "total_amount": pl.Float64,
    "congestion_surcharge": pl.Float64,
  }

In [103]:
# df = pl.read_csv(endpoint, dtypes=yellow_taxi_polars_schema(), n_rows=100)

In [104]:
# df

In [105]:
def recreate_yellow_taxi_table(connection_string: str, table_name: str = "yellow_taxi_data_polars") -> None:
    """
    Creates a PostgreSQL table for yellow taxi data using ADBC engine.
    
    Args:
        connection_string: PostgreSQL connection string
        table_name: Name of the table to create
    """
    # Map Polars types to PostgreSQL types
    pg_type_mapping = {
        pl.Int32: "INTEGER",
        pl.Int8: "SMALLINT",
        pl.Float64: "DOUBLE PRECISION",
        pl.Datetime: "TIMESTAMP",
        pl.String: "VARCHAR"
    }
    
    # Get the schema definition
    schema = yellow_taxi_polars_schema()
    
    # Convert schema to PostgreSQL column definitions
    columns = []
    for column_name, polars_type in schema.items():
        quoted_column_name = f'"{column_name}"'
        pg_type = pg_type_mapping[polars_type]
        columns.append(f"{quoted_column_name} {pg_type}")
    
    # Drop the table if it already exists
    drop_table_sql = f"DROP TABLE IF EXISTS {table_name};"
    print(f"Drop table sql ==> {drop_table_sql}")
    # Create a cursor instance
    with conn.cursor() as cur:
        cur.execute(drop_table_sql)
        conn.commit()
    
    # Create the table definition SQL
    nl = ",\n        "
    create_table_sql = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            {nl.join(columns)}
        );
        """
    
    print(f"Create table sql ==> {create_table_sql}")
    
    with conn.cursor() as cur:
        cur.execute(create_table_sql)
        conn.commit()

In [106]:
recreate_yellow_taxi_table(DATABASE_URL)

Drop table sql ==> DROP TABLE IF EXISTS yellow_taxi_data_polars;
Create table sql ==> 
        CREATE TABLE IF NOT EXISTS yellow_taxi_data_polars (
            "VendorID" INTEGER,
        "tpep_pickup_datetime" TIMESTAMP,
        "tpep_dropoff_datetime" TIMESTAMP,
        "passenger_count" SMALLINT,
        "trip_distance" DOUBLE PRECISION,
        "PULocationID" INTEGER,
        "DOLocationID" INTEGER,
        "RatecodeID" SMALLINT,
        "store_and_fwd_flag" VARCHAR,
        "payment_type" SMALLINT,
        "fare_amount" DOUBLE PRECISION,
        "extra" DOUBLE PRECISION,
        "mta_tax" DOUBLE PRECISION,
        "improvement_surcharge" DOUBLE PRECISION,
        "tip_amount" DOUBLE PRECISION,
        "tolls_amount" DOUBLE PRECISION,
        "total_amount" DOUBLE PRECISION,
        "congestion_surcharge" DOUBLE PRECISION
        );
        


In [107]:
batch_reader = pl.read_csv_batched(
    source = endpoint,
    batch_size=100000,
    schema_overrides=yellow_taxi_polars_schema(),
    has_header=True,
    separator=",",
    try_parse_dates=True
)

batches = batch_reader.next_batches(n=100)

start_time = time()
for df in batches:
  t_start = time()
  df.write_database(
    table_name="yellow_taxi_data_polars",
    connection=DATABASE_URL,
    engine="adbc",
    if_table_exists="append",
  )
  t_end = time()
  print('inserted another chunk, took %.3f second' % (t_end - t_start))

end_time = time()
print(f"Final insertion took {end_time - start_time} seconds")

inserted another chunk, took 0.237 second
inserted another chunk, took 0.205 second
inserted another chunk, took 0.202 second
inserted another chunk, took 0.201 second
inserted another chunk, took 0.199 second
inserted another chunk, took 0.206 second
inserted another chunk, took 0.200 second
inserted another chunk, took 0.207 second
inserted another chunk, took 0.211 second
inserted another chunk, took 0.200 second
inserted another chunk, took 0.198 second
inserted another chunk, took 0.197 second
inserted another chunk, took 0.201 second
inserted another chunk, took 0.203 second
inserted another chunk, took 0.195 second
inserted another chunk, took 0.195 second
inserted another chunk, took 0.201 second
inserted another chunk, took 0.200 second
inserted another chunk, took 0.196 second
inserted another chunk, took 0.195 second
inserted another chunk, took 0.212 second
inserted another chunk, took 0.193 second
inserted another chunk, took 0.194 second
inserted another chunk, took 0.203

In [108]:
with conn.cursor() as cur:
    cur.execute("SELECT count(1) FROM yellow_taxi_data_polars;")
    print(cur.fetch_arrow_table())

pyarrow.Table
count: int64
----
count: [[4012011]]
