In [3]:
import pyarrow
import numpy as np
import pandas as pd
import glob
import pyarrow.dataset as ds
import pyarrow.compute as pc


In [None]:
from pathlib import Path
import re


BASE_DIR = Path("/d/hpc/projects/FRI/bigdata/data/Taxi")

RE = re.compile(r"fhvhv_tripdata[_-]?(\d{4})(?:[_-]?(\d{2}))?", re.IGNORECASE)
min_year, min_month = 2019, 2
def find_parquet_files(base_dir: Path) -> list:
    # files = list(base_dir.rglob("green_tripdata*.parquet"))
    files = list(base_dir.rglob("fhvhv_tripdata*.parquet"))
    selected = []
    for p in files:
        m = RE.search(p.name)
        if not m:
            continue
        year = int(m.group(1))
        month = int(m.group(2)) if m.group(2) else 1  # treat missing month as January
        if (year, month) >= (min_year, min_month):
            selected.append((year, month, str(p)))
    # sort by year, month
    selected.sort(key=lambda x: (x[0], x[1]))
    # return just paths (strings)
    return [p for _, __, p in selected]

# usage
files = find_parquet_files(BASE_DIR)
print(f"Found {len(files)} files (from {min_year}-{min_month} onward).")
for f in files:
    print(f)

Found 72 files (from 2019-2 onward).
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-02.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-03.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-04.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-05.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-06.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-07.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-08.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-09.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-10.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-11.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-12.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2020-01.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2020-02.parquet
/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2020-03.parqu

In [None]:
#!/usr/bin/env python3
"""
PyArrow streaming processor that supports:
 - green/yellow taxi schemas (tpep_ / lpep_ names)
 - hvfhs / fhvhv schemas (pickup_datetime, request_datetime, on_scene_datetime, etc.)
 - keeps airport_fee as string, computes year from canonical pickup_datetime,
   writes parquet partitioned by year=YYYY

Behavior:
 - Lowercases incoming column names
 - Maps a wide set of pickup/dropoff candidate names into a canonical 'pickup_datetime'
   and 'dropoff_datetime' used for parsing and year computation
 - Builds a unified schema (union of expected fields) and emits that schema for outputs,
   filling missing columns with nulls when necessary
"""
import os
import glob
import uuid
import logging
from collections import defaultdict
from typing import List, Dict, Optional

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyarrow.parquet as pq

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

subset_files = files

OUTPUT_DIR = "/d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv"

# Union of fields across green/yellow + hvfhs/fhvhv. 
EXPECTED_COLUMNS = [
    # green/yellow taxi typical fields
    'vendorid', 'tpep_pickup_datetime', 'tpep_dropoff_datetime',
    'passenger_count', 'trip_distance', 'ratecodeid', 'store_and_fwd_flag',
    'pulocationid', 'dolocationid', 'payment_type', 'fare_amount', 'extra',
    'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge',
    'total_amount', 'congestion_surcharge', 'airport_fee',
    # hvfhs / fhvhv / fhv-style fields

    'hvfhs_license_num', 'dispatching_base_num', 'originating_base_num',
    'request_datetime', 'on_scene_datetime', 'pickup_datetime', 'dropoff_datetime',
    'trip_miles', 'trip_time', 'base_passenger_fare', 'tolls', 'bcf',
    'sales_tax', 'tips', 'driver_pay',
    'shared_request_flag', 'shared_match_flag', 'access_a_ride_flag',
    'wav_request_flag', 'wav_match_flag'
]


PICKUP_CANDIDATES = [
    'tpep_pickup_datetime', 'lpep_pickup_datetime', 'pickup_datetime',
    'pickup_time', 'pickup_ts', 'pickup', 'request_datetime', 'on_scene_datetime'
]

DROPOFF_CANDIDATES = [
    'tpep_dropoff_datetime', 'lpep_dropoff_datetime', 'dropoff_datetime',
    'dropoff_time', 'dropoff_ts', 'dropoff'
]


TIMESTAMP_STR_FORMATS = [
    "%Y-%m-%d %H:%M:%S",
    "%Y-%m-%d %H:%M:%S.%f",
    "%Y-%m-%dT%H:%M:%S",
    "%Y-%m-%dT%H:%M:%S.%f",
]

BATCH_SIZE = 100_000
ROW_GROUP_ROWS = 2_000_000

# Canonical names used internally
TARGET_PICKUP = 'pickup_datetime'
TARGET_DROPOFF = 'dropoff_datetime'
# --------------------------------------------------------------------------


def discover_files(paths_or_globs: List[str]) -> List[str]:
    files = []
    for p in paths_or_globs:
        if any(ch in p for ch in ["*", "?"]):
            files.extend(sorted(glob.glob(p)))
        else:
            if os.path.isdir(p):
                files.extend(sorted(glob.glob(os.path.join(p, "*.parquet"))))
            elif os.path.isfile(p):
                files.append(p)
            else:
                logger.warning("Path not found (skipping): %s", p)
    if not files:
        raise FileNotFoundError("No parquet files found for the given patterns/paths.")
    return sorted(files)


def read_file_schema_lowered(path: str) -> pa.Schema:
    s = pq.read_schema(path)
    fields = [pa.field(f.name.lower(), f.type, nullable=f.nullable) for f in s]
    return pa.schema(fields)


def _normalize_name_to_canonical(name: str) -> str:
    n = name.lower()
    if n in PICKUP_CANDIDATES or ('pickup' in n and ('date' in n or 'time' in n)):
        return TARGET_PICKUP
    if n in DROPOFF_CANDIDATES or ('dropoff' in n and ('date' in n or 'time' in n)):
        return TARGET_DROPOFF
    return n


def infer_unified_schema(paths: List[str], expected_columns: List[str]) -> pa.Schema:
    observed_types: Dict[str, pa.DataType] = {}
    for p in paths:
        try:
            sch = read_file_schema_lowered(p)
        except Exception as e:
            logger.warning("Could not read schema for %s: %s", p, e)
            continue
        for f in sch:
            canonical = _normalize_name_to_canonical(f.name)
            if canonical not in observed_types:
                observed_types[canonical] = f.type

    # ensure expected columns are present in observed_types (fallback to string)
    for col in expected_columns:
        nc = _normalize_name_to_canonical(col.lower())
        if nc not in observed_types:
            observed_types[nc] = pa.string()

    
    seen = set()
    ordered_expected = []
    for c in expected_columns:
        nc = _normalize_name_to_canonical(c.lower())
        if nc not in seen:
            ordered_expected.append(nc)
            seen.add(nc)
    remaining = [k for k in sorted(observed_types.keys()) if k not in seen]
    final_order = ordered_expected + remaining

    fields = []
    for col in final_order:
        t = observed_types.get(col, pa.string())
        if col == TARGET_PICKUP or col == TARGET_DROPOFF:
            t = pa.timestamp('us')
        if col == 'airport_fee':
            t = pa.string()
        fields.append(pa.field(col, t, nullable=True))
    fields.append(pa.field('year', pa.int32(), nullable=True))
    return pa.schema(fields)


def detect_pickup_name_from_schema(schema: pa.Schema) -> Optional[str]:
    names = {n.lower(): n for n in schema.names}
    for cand in PICKUP_CANDIDATES:
        if cand in names:
            return names[cand]
    for n in names:
        if 'pickup' in n and ('date' in n or 'time' in n):
            return names[n]
    return None


def parse_timestamp_array(arr: pa.Array) -> pa.Array:
    if pa.types.is_timestamp(arr.type):
        return pc.cast(arr, pa.timestamp("us"))
    if pa.types.is_integer(arr.type):
        try:
            maxv = int(pc.max(arr).as_py() or 0)
        except Exception:
            maxv = 0
        if maxv > 10 ** 14:
            return pc.cast(arr, pa.timestamp("us"))
        if maxv > 10 ** 11:
            return pc.multiply(pc.cast(arr, pa.int64()), pa.scalar(1000, pa.int64())).cast(pa.timestamp("us"))
        return pc.multiply(pc.cast(arr, pa.int64()), pa.scalar(1_000_000, pa.int64())).cast(pa.timestamp("us"))
    if pa.types.is_string(arr.type):
        last_err = None
        for fmt in TIMESTAMP_STR_FORMATS:
            try:
                parsed = pc.strptime(arr, format=fmt, unit="us")
                return parsed
            except Exception as e:
                last_err = e
        try:
            return pc.cast(arr, pa.timestamp("us"))
        except Exception:
            raise ValueError(f"Failed to parse timestamp string column. Last error: {last_err}")
    raise TypeError(f"Unsupported timestamp column type: {arr.type}")


def ensure_table_has_expected(table: pa.Table, target_schema: pa.Schema) -> pa.Table:
    existing = {name.lower(): table.column(i) for i, name in enumerate(table.schema.names)}
    rows = table.num_rows

    # map pickup candidates to canonical
    if TARGET_PICKUP not in existing:
        for cand in PICKUP_CANDIDATES:
            if cand in existing:
                existing[TARGET_PICKUP] = existing[cand]
                logger.info(f"Mapped {cand} -> {TARGET_PICKUP}")
                break
        if TARGET_PICKUP not in existing:
            for name in list(existing.keys()):
                if ('pickup' in name) and ('date' in name or 'time' in name):
                    existing[TARGET_PICKUP] = existing[name]
                    logger.info(f"Mapped {name} -> {TARGET_PICKUP} (heuristic)")
                    break

    if TARGET_DROPOFF not in existing:
        for cand in DROPOFF_CANDIDATES:
            if cand in existing:
                existing[TARGET_DROPOFF] = existing[cand]
                logger.info(f"Mapped {cand} -> {TARGET_DROPOFF}")
                break
        if TARGET_DROPOFF not in existing:
            for name in list(existing.keys()):
                if ('dropoff' in name) and ('date' in name or 'time' in name):
                    existing[TARGET_DROPOFF] = existing[name]
                    logger.info(f"Mapped {name} -> {TARGET_DROPOFF} (heuristic)")
                    break

    arrays = []
    names = []
    for f in target_schema:
        if f.name == 'year':
            continue
        if f.name in existing:
            arr = existing[f.name]
            try:
                if not arr.type.equals(f.type):
                    arr = pc.cast(arr, f.type)
            except Exception:
                arr = pa.array([None] * rows, type=f.type)
        else:
            arr = pa.array([None] * rows, type=f.type)
        arrays.append(arr)
        names.append(f.name)
    return pa.Table.from_arrays(arrays, names=names)


def write_partition_table(table: pa.Table, out_dir: str, year: int, row_group_rows: int) -> str:
    part_dir = os.path.join(out_dir, f"year={year}")
    os.makedirs(part_dir, exist_ok=True)
    fname = f"part-{uuid.uuid4().hex}.parquet"
    out_path = os.path.join(part_dir, fname)
    if 'airport_fee' in table.schema.names:
        idx = table.schema.get_field_index('airport_fee')
        if idx >= 0 and not pa.types.is_string(table.column(idx).type):
            try:
                table = table.set_column(idx, 'airport_fee', pc.cast(table.column(idx), pa.string()))
            except Exception:
                table = table.set_column(idx, 'airport_fee',
                                        pa.array([str(x) if x is not None else None for x in table.column(idx)],
                                                 type=pa.string()))

    if 'year' in table.schema.names:
        yidx = table.schema.get_field_index('year')
        ycol = table.column(yidx)
        if not pa.types.is_int32(ycol.type):
            try:
                ycol = pc.cast(ycol, pa.int32())
            except Exception:
                ycol = pa.array([None if v.is_null else int(v.as_py()) for v in ycol], type=pa.int32())
            table = table.set_column(yidx, 'year', ycol)

    new_cols = []
    for i, f in enumerate(table.schema):
        col = table.column(i)
        if pa.types.is_dictionary(col.type):
            vals = [None if v.is_null else v.as_py() for v in col]
            target_type = col.type.value_type
            col = pa.array(vals, type=target_type)
        new_cols.append(col)
    if any(not table.column(i).equals(new_cols[i]) for i in range(len(new_cols))):
        table = pa.Table.from_arrays(new_cols, names=table.schema.names)

    pq.write_table(table, out_path, row_group_size=row_group_rows, use_dictionary=False)

    
    # pq.write_table(table, out_path, row_group_size=row_group_rows)
    logger.info("Wrote %d rows -> %s (year=%s)", table.num_rows, out_path, year)
    return out_path


def process_files_pyarrow(file_paths: List[str], expected_cols: List[str], out_dir: str,
                          batch_size: int = BATCH_SIZE, row_group_rows: int = ROW_GROUP_ROWS):
    files = discover_files(file_paths)
    logger.info("Found %d files", len(files))

    unified_schema = infer_unified_schema(files, expected_cols)
    logger.info("Unified schema prepared with %d fields (incl. year)", len(unified_schema))
    print("Unified schema:")
    print(unified_schema)

    # find pickup candidate name for diagnostics (optional)
    pickup_field_original = None
    for f in files:
        try:
            s = read_file_schema_lowered(f)
            val = detect_pickup_name_from_schema(s)
            if val:
                pickup_field_original = val
                break
        except Exception:
            continue
    if pickup_field_original:
        logger.info("Detected pickup field (lowercased): %s", pickup_field_original)
    else:
        logger.warning("No pickup datetime column found among files (will create null pickup column).")

    buffers = defaultdict(list)
    buffered_rows = defaultdict(int)

    for file in files:
        logger.info("Reading file: %s", file)
        ds_obj = ds.dataset(file, format="parquet")
        scanner = ds_obj.scanner(columns=None, batch_size=batch_size)

        for rb in scanner.to_batches():
            tbl = pa.Table.from_batches([rb])
            tbl = pa.Table.from_arrays([tbl.column(i) for i in range(tbl.num_columns)],
                                       names=[n.lower() for n in tbl.schema.names])

            tbl = ensure_table_has_expected(tbl, unified_schema)

            pickup_idx = tbl.schema.get_field_index(TARGET_PICKUP)
            if pickup_idx >= 0:
                pickup_arr = tbl.column(pickup_idx)
                try:
                    ts = parse_timestamp_array(pickup_arr)
                    tbl = tbl.set_column(pickup_idx, TARGET_PICKUP, ts)
                except Exception as e:
                    logger.exception("Failed to parse pickup timestamp in file=%s : %s", file, e)
                    raise
            else:
                ts = pa.array([None] * tbl.num_rows, type=pa.timestamp('us'))
                tbl = tbl.add_column(tbl.num_columns, TARGET_PICKUP, ts)

            try:
                year_arr = pc.year(tbl.column(tbl.schema.get_field_index(TARGET_PICKUP)))
            except Exception:
                year_arr = pa.array([None] * tbl.num_rows, type=pa.int32())

            if 'year' in tbl.schema.names:
                yidx = tbl.schema.get_field_index('year')
                tbl = tbl.set_column(yidx, 'year', year_arr)
            else:
                tbl = tbl.add_column(tbl.num_columns, 'year', year_arr)

            # coerce airport_fee to string
            if 'airport_fee' in tbl.schema.names:
                af_idx = tbl.schema.get_field_index('airport_fee')
                try:
                    if not pa.types.is_string(tbl.column(af_idx).type):
                        tbl = tbl.set_column(af_idx, 'airport_fee', pc.cast(tbl.column(af_idx), pa.string()))
                except Exception:
                    tbl = tbl.set_column(af_idx, 'airport_fee',
                                         pa.array([str(x) if x is not None else None for x in tbl.column(af_idx)],
                                                  type=pa.string()))

            try:
                yrs = pc.unique(tbl.column(tbl.schema.get_field_index('year'))).to_pylist()
            except Exception:
                yrs = [None]

            for y in yrs:
                if y is None:
                    logger.warning("Found NULL year values - skipping these rows")
                    continue

                mask = pc.equal(tbl.column(tbl.schema.get_field_index('year')), pa.scalar(y, pa.int32()))
                part = tbl.filter(mask)
                if part.num_rows == 0:
                    continue

                # final_cols = []
                # final_names = []
                # for f in unified_schema:
                #     if f.name == 'year':
                #         continue
                #     idx = part.schema.get_field_index(f.name)
                #     if idx >= 0:
                #         final_cols.append(part.column(idx))
                #         final_names.append(f.name)
                #     else:
                #         final_cols.append(pa.array([None] * part.num_rows, type=f.type))
                #         final_names.append(f.name)
                # final_cols.append(part.column(part.schema.get_field_index('year')))
                # final_names.append('year')
                # final_tbl = pa.Table.from_arrays(final_cols, names=final_names)

                final_cols = []
                final_names = []
                # Build arrays cast to the unified schema types
                for f in unified_schema:
                    if f.name == 'year':
                        continue
                    # get array or create null array
                    idx = part.schema.get_field_index(f.name)
                    if idx >= 0:
                        arr = part.column(idx)
                    else:
                        arr = pa.array([None] * part.num_rows, type=f.type)

                    try:
                        if not arr.type.equals(f.type):
                            arr = pc.cast(arr, f.type)
                    except Exception:
                        # fallback: materialize python values and rebuild as desired type
                        vals = [None if v.is_null else v.as_py() for v in arr]
                        arr = pa.array(vals, type=f.type)
                    final_cols.append(arr)
                    final_names.append(f.name)


                yidx = part.schema.get_field_index('year')
                year_arr = part.column(yidx)
                try:
                    if not pa.types.is_int32(year_arr.type):
                        year_arr = pc.cast(year_arr, pa.int32())
                except Exception:
                    vals = [None if v.is_null else int(v.as_py()) for v in year_arr]
                    year_arr = pa.array(vals, type=pa.int32())

                final_cols.append(year_arr)
                final_names.append('year')

                final_tbl = pa.Table.from_arrays(final_cols, names=final_names)

                

                buffers[y].append(final_tbl)
                buffered_rows[y] += final_tbl.num_rows

                if buffered_rows[y] >= row_group_rows:
                    merged = pa.concat_tables(buffers[y], promote=True)
                    write_partition_table(merged, out_dir, y, row_group_rows)
                    buffers[y].clear()
                    buffered_rows[y] = 0

    for y, tbls in buffers.items():
        if not tbls:
            continue
        merged = pa.concat_tables(tbls, promote=True)
        write_partition_table(merged, out_dir, y, row_group_rows)

    logger.info("All done. Output written under: %s", out_dir)


if __name__ == "__main__":
    process_files_pyarrow(subset_files, EXPECTED_COLUMNS, OUTPUT_DIR,
                          batch_size=BATCH_SIZE, row_group_rows=ROW_GROUP_ROWS)


INFO:__main__:Found 72 files
INFO:__main__:Unified schema prepared with 36 fields (incl. year)
INFO:__main__:Detected pickup field (lowercased): pickup_datetime
INFO:__main__:Reading file: /d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-02.parquet


Unified schema:
vendorid: string
pickup_datetime: timestamp[us]
dropoff_datetime: timestamp[us]
passenger_count: string
trip_distance: string
ratecodeid: string
store_and_fwd_flag: string
pulocationid: int64
dolocationid: int64
payment_type: string
fare_amount: string
extra: string
mta_tax: string
tip_amount: string
tolls_amount: string
improvement_surcharge: string
total_amount: string
congestion_surcharge: double
airport_fee: string
hvfhs_license_num: string
dispatching_base_num: string
originating_base_num: string
trip_miles: double
trip_time: int64
base_passenger_fare: double
tolls: double
bcf: double
sales_tax: double
tips: double
driver_pay: double
shared_request_flag: string
shared_match_flag: string
access_a_ride_flag: string
wav_request_flag: string
wav_match_flag: null
year: int32


  process_files_pyarrow(subset_files, EXPECTED_COLUMNS, OUTPUT_DIR,
INFO:__main__:Wrote 2000000 rows -> /d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv/year=2019/part-ab3c6af86d2a42f8a160a6dd9a305f24.parquet (year=2019)
INFO:__main__:Wrote 2000000 rows -> /d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv/year=2019/part-8dd0744e875e41e28668c09809004eeb.parquet (year=2019)
INFO:__main__:Wrote 2000000 rows -> /d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv/year=2019/part-64cd93f29bc84e3ba8271d1379287527.parquet (year=2019)
INFO:__main__:Wrote 2000000 rows -> /d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv/year=2019/part-3ca54e6766a74393bc5ee5a41b7c0cc1.parquet (year=2019)
INFO:__main__:Wrote 2000000 rows -> /d/hpc/projects/FRI/bigdata/students/in7357/optimized_parquet_pyarrow_fhvhv/year=2019/part-8660236a71484e5d8279b969adb4a35f.parquet (year=2019)
INFO:__main__:Wrote 2000000 rows

In [None]:
import pyarrow.parquet as pq


p = "/d/hpc/projects/FRI/bigdata/data/Taxi/fhvhv_tripdata_2019-11.parquet"  # pick a file
print("Parquet schema (original):")
print(pq.read_schema(p))

tbl = pq.read_table(p, columns=None, use_pandas_metadata=False)
print("Lowercased columns:")
n_rows = tbl.num_rows
print(f"Rows: {n_rows}")
print([n.lower() for n in tbl.schema.names])

for cand in PICKUP_CANDIDATES + ['lpep_pickup_datetime','pickup_datetime']:
    if cand in [n.lower() for n in tbl.schema.names]:
        arr = tbl.column([n.lower() for n in tbl.schema.names].index(cand))
        print(f"Found candidate column: {cand} — type: {arr.type}")
        # print first 5 non-null values (safe)
        nonnulls = [x for x in arr.to_pylist() if x is not None][:5]
        print(" first 5 non-null values:", nonnulls)
        print(" first 5 non-null values:", nonnulls)

In [1]:
import dask.dataframe as dd

In [None]:
df_test = 