In [47]:
# Combine FARS and CRSS fault data to full census data

#   Build company-level crash summaries (overall + per-source) from the
#   census-with-FARS/CRSS merged inputs.
#   Also export the full census rows with metrics attached by USDOT.

import pandas as pd
import numpy as np
from pathlib import Path
from typing import Tuple, Optional, Union

def _read_parquet_required(path: Union[str, Path]) -> pd.DataFrame:
    """Read a Parquet file; raise a clear error if missing."""
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Missing required file: {p.resolve()}")
    return pd.read_parquet(p)

def _stdcols(df: pd.DataFrame) -> pd.DataFrame:
    """Uppercase + strip column names for consistency."""
    d = df.copy()
    d.columns = pd.Index([str(c).strip().upper() for c in d.columns])
    return d

def _clean_usdot_series(s: pd.Series) -> pd.Series:
    """Normalize USDOT IDs: string, strip, remove leading zeros, blank→NA."""
    s = s.astype("string").str.strip().str.replace(r"^\s*0+(?=\d)", "", regex=True)
    return s.mask(s.isin(["", "nan", "<NA>"]))

def _force_string_identifiers(df: pd.DataFrame) -> pd.DataFrame:
    """Keep identifier-like columns as strings for Parquet stability."""
    id_tokens = ("VIN", "PLATE", "MCARR_I2", "DOT_NUMBER", "USDOT")
    d = df.copy()
    for c in d.columns:
        if any(tok in c.upper() for tok in id_tokens):
            d[c] = d[c].astype("string")
    return d

# Standardization
def standardize_census_fields(df: pd.DataFrame) -> pd.DataFrame:
    """
    Standardize minimal schema for combining/summary:
      DOT_NUMBER, LEGAL_NAME, DBA_NAME, SOURCE, LIKELY_AT_FAULT, FAULT_SCORE
    """
    d = _stdcols(df)

    # Ensure DOT_NUMBER
    if "DOT_NUMBER" not in d.columns and "MCARR_I2" in d.columns:
        d["DOT_NUMBER"] = d["MCARR_I2"]
    if "DOT_NUMBER" not in d.columns:
        raise ValueError("DOT_NUMBER not found (and no MCARR_I2 to backfill).")
    d["DOT_NUMBER"] = _clean_usdot_series(d["DOT_NUMBER"])

    # Company names
    if "LEGAL_NAME" not in d.columns:
        for alt in ("LEGALNAME", "CARRIER_LEGAL_NAME"):
            if alt in d.columns:
                d["LEGAL_NAME"] = d[alt]
                break
    if "LEGAL_NAME" not in d.columns:
        d["LEGAL_NAME"] = pd.NA

    if "DBA_NAME" not in d.columns:
        for alt in ("DBANAME", "CARRIER_DBA_NAME"):
            if alt in d.columns:
                d["DBA_NAME"] = d[alt]
                break
    if "DBA_NAME" not in d.columns:
        d["DBA_NAME"] = pd.NA

    # Required scoring fields
    for req in ("SOURCE", "LIKELY_AT_FAULT", "FAULT_SCORE"):
        if req not in d.columns:
            raise ValueError(f"Missing required field: {req}")

    d["LIKELY_AT_FAULT"] = pd.to_numeric(d["LIKELY_AT_FAULT"], errors="coerce")
    d["FAULT_SCORE"]     = pd.to_numeric(d["FAULT_SCORE"], errors="coerce")

    return _force_string_identifiers(d)


def build_company_summary(
    fars_merged_path: Union[str, Path] = "census_with_fars.parquet",
    crss_merged_path: Union[str, Path] = "census_with_crss.parquet",
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Build company-level summaries including FARS/CRSS splits.
    Keeps all census carriers (even with 0 crashes).

    Returns
    -------
    per_source : per USDOT × source summary
    overall    : per USDOT summary with FARS/CRSS splits
    combined_rows : ALL rows (matched + census-only) for downstream exports
    """
    fars = standardize_census_fields(_read_parquet_required(fars_merged_path))
    crss = standardize_census_fields(_read_parquet_required(crss_merged_path))

    both = pd.concat([fars, crss], ignore_index=True)
    both = both[both["DOT_NUMBER"].notna()].copy()

    # Identify matched vs census-only rows
    both["MATCHED_FLAG"] = (both["SOURCE"].notna()).astype(int)

    # Fault flag only applies to matched rows; NaN for census-only so means ignore them
    both["AT_FAULT_FLAG"] = np.where(
        both["MATCHED_FLAG"] == 1,
        (both["LIKELY_AT_FAULT"] == 1).astype("float"),
        np.nan
    )

    # Per-source summary
    per_source = (
        both.groupby(["DOT_NUMBER","LEGAL_NAME","DBA_NAME","SOURCE"], dropna=False)
            .agg(
                total_crashes          = ("MATCHED_FLAG","sum"),
                total_at_fault_crashes = ("AT_FAULT_FLAG","sum"),
                pct_at_fault           = ("AT_FAULT_FLAG","mean"),
                mean_fault_score       = ("FAULT_SCORE","mean")
            )
            .reset_index()
    )
    per_source["total_at_fault_crashes"] = per_source["total_at_fault_crashes"].fillna(0).astype(int)

    # Overall summary across sources
    overall = (
        both.groupby(["DOT_NUMBER","LEGAL_NAME","DBA_NAME"], dropna=False)
            .agg(
                total_crashes          = ("MATCHED_FLAG","sum"),
                total_at_fault_crashes = ("AT_FAULT_FLAG","sum"),
                mean_fault_score       = ("FAULT_SCORE","mean"),
            )
            .reset_index()
    )
    overall["total_at_fault_crashes"] = overall["total_at_fault_crashes"].fillna(0).astype(int)

    # Merge in FARS / CRSS splits
    fars_split = per_source.loc[per_source["SOURCE"] == "FARS",
                                ["DOT_NUMBER","total_crashes","total_at_fault_crashes","pct_at_fault"]]
    fars_split.columns = ["DOT_NUMBER","fars_total","fars_at_fault","fars_pct_at_fault"]

    crss_split = per_source.loc[per_source["SOURCE"] == "CRSS",
                                ["DOT_NUMBER","total_crashes","total_at_fault_crashes","pct_at_fault"]]
    crss_split.columns = ["DOT_NUMBER","crss_total","crss_at_fault","crss_pct_at_fault"]

    overall = (
        overall
        .merge(fars_split, on="DOT_NUMBER", how="left")
        .merge(crss_split, on="DOT_NUMBER", how="left")
    )

    # Fill splits and compute overall pct
    for col in ["fars_total","fars_at_fault","crss_total","crss_at_fault"]:
        overall[col] = overall[col].fillna(0).astype(int)

    overall["pct_at_fault"] = np.where(
        overall["total_crashes"] > 0,
        overall["total_at_fault_crashes"] / overall["total_crashes"],
        np.nan
    )

    # Reorder columns
    column_order = [
        "DOT_NUMBER", "LEGAL_NAME", "DBA_NAME",
        "total_crashes", "total_at_fault_crashes", "pct_at_fault",
        "mean_fault_score",
        "fars_total", "fars_at_fault", "fars_pct_at_fault",
        "crss_total", "crss_at_fault", "crss_pct_at_fault",
    ]
    remaining = [c for c in overall.columns if c not in column_order]
    overall = overall[column_order + remaining]

    # Sort for readability
    overall = overall.sort_values(
        ["total_at_fault_crashes","total_crashes","pct_at_fault"],
        ascending=[False, False, False]
    ).reset_index(drop=True)

    return per_source, overall, both

# Merge with full census data
def make_full_census_with_metrics(
    combined_rows: pd.DataFrame,
    overall: pd.DataFrame
) -> pd.DataFrame:
    """
    Create a full census-row dataset and attach per-USDOT metrics from `overall`.
    Every original census row (including non-matched) receives the same metrics
    for its USDOT.
    """
    crash_cols = {"SOURCE", "LIKELY_AT_FAULT", "FAULT_SCORE", "MATCHED_FLAG", "AT_FAULT_FLAG"}
    census_cols = [c for c in combined_rows.columns if c not in crash_cols]
    census_backbone = combined_rows[census_cols].drop_duplicates().copy()

    metric_cols = [
        "DOT_NUMBER",
        "total_crashes", "total_at_fault_crashes", "pct_at_fault", "mean_fault_score",
        "fars_total", "fars_at_fault", "fars_pct_at_fault",
        "crss_total", "crss_at_fault", "crss_pct_at_fault",
    ]
    metrics = overall[metric_cols].copy()

    out = census_backbone.merge(metrics, on="DOT_NUMBER", how="left")

    # Fill integer counts; leave pct columns as NaN where there are no crashes
    for col in ["total_crashes","total_at_fault_crashes","fars_total","fars_at_fault","crss_total","crss_at_fault"]:
        if col in out.columns:
            out[col] = out[col].fillna(0).astype(int)

    return out

# Export summary
def export_company_summary(
    per_source: pd.DataFrame,
    overall: pd.DataFrame,
    combined_rows: Optional[pd.DataFrame] = None,
    out_dir: Union[str, Path] = "outputs",
    stem: str = "fars_crss_2020_2023",
    also_csv: bool = False,
    export_full_census_with_metrics: bool = False,
    full_census_path: Optional[Union[str, Path]] = None,
) -> None:
    """
    Write summary tables to Parquet (and optional CSV).
    If export_full_census_with_metrics=True, also writes a full census file
    with per-USDOT metrics attached.
    """
    out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)

    # Ensure IDs as strings
    per_source = _force_string_identifiers(per_source)
    overall    = _force_string_identifiers(overall)
    if combined_rows is not None:
        combined_rows = _force_string_identifiers(combined_rows)

    # Paths
    ps_path = out / f"{stem}_per_source.parquet"
    ov_path = out / f"{stem}_overall.parquet"
    cr_path = out / f"{stem}_combined_rows.parquet"
    fc_path = Path(full_census_path) if full_census_path else (out / f"{stem}_census_with_metrics.parquet")

    # Write Parquet
    per_source.to_parquet(ps_path, index=False)
    overall.to_parquet(ov_path, index=False)
    if combined_rows is not None:
        combined_rows.to_parquet(cr_path, index=False)

    # Full census with FARS/CRSS at fault merged
    if export_full_census_with_metrics:
        if combined_rows is None:
            raise ValueError("export_full_census_with_metrics=True requires `combined_rows`.")
        full_census = make_full_census_with_metrics(combined_rows, overall)
        full_census = _force_string_identifiers(full_census)
        full_census.to_parquet(fc_path, index=False)

    # Optional CSVs
    """
    if also_csv:
        per_source.to_csv(out / f"{stem}_per_source.csv", index=False)
        overall.to_csv(out / f"{stem}_overall.csv", index=False)
        if combined_rows is not None:
            combined_rows.to_csv(out / f"{stem}_combined_rows.csv", index=False)
        if export_full_census_with_metrics:
            full_census.to_csv(out / f"{stem}_census_with_metrics.csv", index=False)
    """
    
    # Confirm
    print("\nSaved:")
    print(" ", ps_path)
    print(" ", ov_path)
    if combined_rows is not None:
        print(" ", cr_path)
    if export_full_census_with_metrics:
        print(" ", fc_path)
    if also_csv:
        print(" (CSV copies saved as well)")


In [51]:
# Run code
per_source, overall, combined_rows = build_company_summary(
     fars_merged_path = "fars_census_merged.parquet",
     crss_merged_path = "crss_census_merged.parquet"
 )
export_company_summary(per_source, overall, combined_rows,
                        out_dir="outputs", stem="fars_crss_2020_2023", also_csv=False)

# Preview
display(overall.head(25))


Saved:
  outputs/fars_crss_2020_2023_per_source.parquet
  outputs/fars_crss_2020_2023_overall.parquet
  outputs/fars_crss_2020_2023_combined_rows.parquet


Unnamed: 0,DOT_NUMBER,LEGAL_NAME,DBA_NAME,total_crashes,total_at_fault_crashes,pct_at_fault,mean_fault_score,fars_total,fars_at_fault,fars_pct_at_fault,crss_total,crss_at_fault,crss_pct_at_fault
0,21800,UNITED PARCEL SERVICE INC,UPS,161,86,0.534161,1.52795,109,57,0.522936,52,29,0.557692
1,327574,PENSKE TRUCK LEASING CO LP,PENSKE TRUCK RENTAL,106,62,0.584906,1.745283,57,31,0.54386,49,31,0.632653
2,80806,J B HUNT TRANSPORT INC,J B HUNT,121,54,0.446281,0.975207,73,33,0.452055,48,21,0.4375
3,54283,SWIFT TRANSPORTATION CO OF ARIZONA LLC,,92,49,0.532609,1.402174,62,30,0.483871,30,19,0.633333
4,53467,WERNER ENTERPRISES INC,,62,31,0.5,1.145161,33,14,0.424242,29,17,0.586207
5,90849,OLD DOMINION FREIGHT LINE INC,,74,30,0.405405,0.797297,47,18,0.382979,27,12,0.444444
6,303024,US XPRESS INC,US XPRESS,56,29,0.517857,1.214286,29,13,0.448276,27,16,0.592593
7,63585,WAL-MART TRANSPORTATION LLC,,55,26,0.472727,0.672727,38,18,0.473684,17,8,0.470588
8,3706,NEW PRIME INC,PRIME INC,52,26,0.5,1.403846,24,9,0.375,28,17,0.607143
9,264184,SCHNEIDER NATIONAL CARRIERS INC,,44,24,0.545455,1.704545,24,12,0.5,20,12,0.6
