In [1]:
# === Imports & basic setup ===

from pathlib import Path
import polars as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.grid'] = True

sns.set(style="whitegrid")


In [2]:
# === Data paths ===
# Nếu chạy trên môi trường khác, chỉnh lại DATA_DIR cho phù hợp.

PRICING_PATH   = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Aggregates\Aggregates_Processed\agg_pricing_distribution.parquet")
TIMELINE_PATH  = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Aggregates\Aggregates_Processed\agg_timeline_hourly.parquet")
DAILY_PATH     = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Aggregates\Aggregates_Processed\agg_executive_daily.csv")
NETWORK_PATH   = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Aggregates\Aggregates_Processed\agg_network_monthly.parquet")
path_2019 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2019_processed.parquet")
path_2020 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2020_processed.parquet")
path_2021 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2021_processed.parquet")
path_2022 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2022_processed.parquet")
path_2023 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2023_processed.parquet")
path_2024 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2024_processed.parquet")
path_2025 = Path(r"X:\Programming\Python\Projects\Data processing\TLC NYC datasets\HVFHV subsets 2019-2025 - Samples\tlc_sample_2025_processed.parquet")

for p in [PRICING_PATH, TIMELINE_PATH, DAILY_PATH, NETWORK_PATH, path_2019, path_2020, path_2021, path_2022, path_2023, path_2024, path_2025]:
    print(p.name, 'exists:', p.exists())

agg_pricing_distribution.parquet exists: True
agg_timeline_hourly.parquet exists: True
agg_executive_daily.csv exists: True
agg_network_monthly.parquet exists: True
tlc_sample_2019_processed.parquet exists: True
tlc_sample_2020_processed.parquet exists: True
tlc_sample_2021_processed.parquet exists: True
tlc_sample_2022_processed.parquet exists: True
tlc_sample_2023_processed.parquet exists: True
tlc_sample_2024_processed.parquet exists: True
tlc_sample_2025_processed.parquet exists: True


In [9]:
# ===== Plotly figure caching helper =====
import plotly.io as pio
from pathlib import Path

PLOTLY_SAVE_DIR = Path(r"X:\Programming\Python\[Y3S1] Year 3, Autumn semester\[Y3S1] Data preparation and Visualisation\Projects\Final term (hck)\TLC NYC filtered\V4 - Finalize\vu")
PLOTLY_SAVE_DIR.mkdir(exist_ok=True)

def save_or_load_plotly(fig_name, fig_builder, scale=8, height=600, width=1000):
    """
    fig_name: file base name (no extension)
    fig_builder: function that builds and returns the figure
    """
    json_path = PLOTLY_SAVE_DIR / f"{fig_name}.json"
    html_path = PLOTLY_SAVE_DIR / f"{fig_name}.html"
    image_path = PLOTLY_SAVE_DIR / f"{fig_name}.png"

    # If cached → load JSON (do not show)
    if json_path.exists():
        fig = pio.read_json(str(json_path))
        return fig
    
    # Else → build, save JSON + HTML
    fig = fig_builder()
    pio.write_json(fig, str(json_path))
    pio.write_html(fig, str(html_path))
    pio.write_image(fig, str(image_path), scale=scale, height=height, width=width)
    return fig


In [4]:
sample_2019 = pl.read_parquet(path_2019)
sample_2020 = pl.read_parquet(path_2020)
sample_2021 = pl.read_parquet(path_2021)
sample_2022 = pl.read_parquet(path_2022)
sample_2023 = pl.read_parquet(path_2023)
sample_2024 = pl.read_parquet(path_2024)
sample_2025 = pl.read_parquet(path_2025)

samples_raw = {
    2019: sample_2019,
    2020: sample_2020,
    2021: sample_2021,
    2022: sample_2022,
    2023: sample_2023,
    2024: sample_2024,
    2025: sample_2025,
}

print("=== Sample shapes by year ===")
for year, df in samples_raw.items():
    print(f"sample_{year}: {df.shape}")

print("\n=== Schema: sample_2024 (reference schema – all years share this) ===")
print(sample_2024.schema)

# --- Aggregate marts (if present) ---

pricing = timeline = daily = network = None

try:
    pricing = pl.read_parquet(PRICING_PATH)
    print("\n=== Schema: agg_pricing_distribution (Mart 3) ===")
    print(pricing.schema)
except Exception as e:
    print("\n[WARN] Could not load agg_pricing_distribution:", e)

try:
    timeline = pl.read_parquet(TIMELINE_PATH)
    print("\n=== Schema: agg_timeline_hourly (Mart 1) ===")
    print(timeline.schema)
except Exception as e:
    print("\n[WARN] Could not load agg_timeline_hourly:", e)

try:
    daily = pl.read_csv(DAILY_PATH)
    print("\n=== Schema: agg_executive_daily (Mart 4) ===")
    print(daily.schema)
except Exception as e:
    print("\n[WARN] Could not load agg_executive_daily:", e)

try:
    network = pl.read_parquet(NETWORK_PATH)
    print("\n=== Schema: agg_network_monthly (Mart 2) ===")
    print(network.schema)
except Exception as e:
    print("\n[WARN] Could not load agg_network_monthly:", e)

=== Sample shapes by year ===
sample_2019: (1566175, 70)
sample_2020: (999861, 70)
sample_2021: (1214300, 70)
sample_2022: (1473050, 70)
sample_2023: (1608613, 70)
sample_2024: (1725560, 70)
sample_2025: (1242682, 70)

=== Schema: sample_2024 (reference schema – all years share this) ===
Schema({'pickup_datetime': Datetime(time_unit='us', time_zone=None), 'dropoff_datetime': Datetime(time_unit='us', time_zone=None), 'PULocationID': Int32, 'DOLocationID': Int32, 'base_passenger_fare': Float32, 'tolls': Float32, 'bcf': Float32, 'sales_tax': Float32, 'congestion_surcharge': Float32, 'airport_fee': Float32, 'tips': Float32, 'driver_pay': Float32, 'shared_request_flag': UInt8, 'shared_match_flag': UInt8, 'access_a_ride_flag': UInt8, 'wav_request_flag': UInt8, 'wav_match_flag': UInt8, 'cbd_congestion_fee': Float32, 'pickup_borough': Categorical, 'pickup_zone': String, 'dropoff_borough': Categorical, 'dropoff_zone': String, 'trip_km': Float32, 'duration_seconds': Int64, 'pickup_hour': Int8, '

In [5]:
# === Plotly imports ===
import plotly.express as px
import plotly.graph_objects as go

# === Convert aggregate marts to pandas ===

pricing_dist = pricing.to_pandas() if pricing is not None else None
timeline_hourly = timeline.to_pandas() if timeline is not None else None
exec_daily = daily.to_pandas() if daily is not None else None
network_monthly = network.to_pandas() if network is not None else None

print("\n=== Converted to pandas (for Plotly) ===")
for name, df in [
    ("pricing_dist", pricing_dist),
    ("timeline_hourly", timeline_hourly),
    ("exec_daily", exec_daily),
    ("network_monthly", network_monthly),
]:
    print(f"{name}: {None if df is None else df.shape}")

# --- Fix date / time columns where needed ---

# agg_pricing_distribution: pickup_date is Date → to datetime, then derive year/month
if pricing_dist is not None:
    pricing_dist["pickup_date"] = pd.to_datetime(pricing_dist["pickup_date"])
    pricing_dist["pickup_year"] = pricing_dist["pickup_date"].dt.year
    pricing_dist["pickup_month"] = pricing_dist["pickup_date"].dt.month

# agg_executive_daily: pickup_date is String → to datetime
if exec_daily is not None:
    exec_daily["pickup_date"] = pd.to_datetime(exec_daily["pickup_date"])

# agg_timeline_hourly already has pickup_year, pickup_month, pickup_day, pickup_hour

# === Trip-level samples combined (tlc_all) ===

# samples_raw is your dict {year: polars_df}
tlc_all = pl.concat(list(samples_raw.values())).to_pandas()
print(f"\nCombined tlc_all shape: {tlc_all.shape}")

# tlc_all already has:
# pickup_datetime, pickup_year, base_passenger_fare, cost_per_km, driver_revenue_share,
# uber_take_rate_proxy, pay_per_hour, tipping_pct, weather_state, time_of_day_bin, trip_archetype, etc.

# Ensure pickup_datetime is datetime (pandas)
tlc_all["pickup_datetime"] = pd.to_datetime(tlc_all["pickup_datetime"])
tlc_all["pickup_date"] = tlc_all["pickup_datetime"].dt.date
tlc_all["pickup_date"] = pd.to_datetime(tlc_all["pickup_date"])
tlc_all["pickup_month_dt"] = tlc_all["pickup_datetime"].dt.to_period("M").dt.to_timestamp()



=== Converted to pandas (for Plotly) ===
pricing_dist: (52862, 13)
timeline_hourly: (408173, 17)
exec_daily: (2433, 9)
network_monthly: (4567992, 12)

Combined tlc_all shape: (9830241, 70)


In [17]:
# === PART 1 – Time-of-day Pricing Structure (Uber-Styled Heatmap with WCAG Contrast, 2021–2025 Only) ===

def _build_fig_t1():

    import plotly.graph_objects as go
    import plotly.io as pio
    import uber_style as ub
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap, to_hex

    # -------------------------------------------------------------------
    # Activate Uber template
    # -------------------------------------------------------------------
    pio.templates["uber"] = ub.uber_style_template
    pio.templates.default = "uber"


    # -------------------------------------------------------------------
    # WCAG Contrast Functions
    # -------------------------------------------------------------------
    def compute_luminance(hex_color):
        hex_color = hex_color.lstrip('#')
        r = int(hex_color[0:2], 16) / 255
        g = int(hex_color[2:4], 16) / 255
        b = int(hex_color[4:6], 16) / 255
        return 0.299*r + 0.587*g + 0.114*b

    def build_color_map(colorscale):
        stops = [stop for stop, _ in colorscale]
        colors = [color for _, color in colorscale]
        return LinearSegmentedColormap.from_list("uber_custom", list(zip(stops, colors)))

    def get_text_contrast_colors(z_values, colorscale, threshold=0.55):
        cmap = build_color_map(colorscale)
        z_min, z_max = np.nanmin(z_values), np.nanmax(z_values)
        z_norm = (z_values - z_min) / (z_max - z_min + 1e-9)

        text_colors = []
        for row in z_norm:
            row_colors = []
            for v in row:
                rgb = cmap(v)
                hex_color = to_hex(rgb)
                lum = compute_luminance(hex_color)
                row_colors.append("black" if lum > threshold else "white")
            text_colors.append(row_colors)

        return text_colors


    # -------------------------------------------------------------------
    # DATA PROCESSING (FILTER 2021–2025)
    # -------------------------------------------------------------------
    filtered = pricing_dist[pricing_dist["pickup_year"].between(2021, 2025)]

    time_hier = (
        filtered
        .groupby(["pickup_year", "time_of_day_bin"], as_index=False)
        .agg(median_fare=("median_fare", "median"))
    )

    # Enforce timeline order for time_of_day_bin
    time_order = ["late_night", "morning_rush", "midday", "evening", "evening_rush"]

    time_hier["time_of_day_bin"] = pd.Categorical(
        time_hier["time_of_day_bin"],
        categories=[t for t in time_order if t in time_hier["time_of_day_bin"].unique()],
        ordered=True
    )
    time_hier = time_hier.sort_values(["time_of_day_bin", "pickup_year"])

    # Pivot table
    heat_df = time_hier.pivot(
        index="time_of_day_bin",
        columns="pickup_year",
        values="median_fare"
    )

    z = heat_df.to_numpy(dtype=float)
    z_text = np.round(z, 2)

    colorscale = ub.uber_style_template["data"]["heatmap"][0]["colorscale"]
    text_colors = get_text_contrast_colors(z, colorscale)


    # -------------------------------------------------------------------
    # HEATMAP
    # -------------------------------------------------------------------
    fig_t1 = go.Figure(
        data=go.Heatmap(
            z=z,
            x=heat_df.columns.astype(str),
            y=heat_df.index,
            colorscale=colorscale,
            xgap=1,
            ygap=1,
            hovertemplate="<b>%{y} × %{x}</b><br>Median Fare: %{z}<extra></extra>",
            colorbar=dict(
                thickness=14,
                tickfont=dict(size=12, color="#333333"),
                title="Median Base Fare ($)"
            )
        )
    )


    # -------------------------------------------------------------------
    # Add text annotations
    # -------------------------------------------------------------------
    annotations = []
    for i, y_val in enumerate(heat_df.index):
        for j, x_val in enumerate(heat_df.columns.astype(str)):
            annotations.append(
                dict(
                    x=x_val,
                    y=y_val,
                    text=str(z_text[i, j]),
                    showarrow=False,
                    font=dict(size=12, color=text_colors[i][j]),
                    xanchor="center",
                    yanchor="middle"
                )
            )

    fig_t1.update_layout(annotations=annotations)


    # -------------------------------------------------------------------
    # Layout
    # -------------------------------------------------------------------
    fig_t1.update_layout(
        xaxis=dict(
            title=dict(
                    text="Year",
                    font=dict(size=16, color="#141414")
            ),
            side="bottom",
            ticklabelposition="outside"
        ),
        yaxis=dict(
            title=None,
            ticklabelposition="outside",
            autorange="reversed"
        ),
        height=520,
        margin=dict(t=110, l=160, r=40, b=90)
    )

    # Manual y-label
    fig_t1.add_annotation(
        text="Time-of-Day Segment",
        xref="paper",
        yref="paper",
        x=-0.18,
        y=0.5,
        showarrow=False,
        textangle=270,
        font=dict(size=16, color="#141414"),
        xanchor="center",
        yanchor="middle"
    )


    # -------------------------------------------------------------------
    # Apply Uber Branding
    # -------------------------------------------------------------------
    fig_t1 = ub.apply_uber_branding(
        fig_t1,
        title="Median Base Fare by Time-of-Day Segment × Year (2021–2025)",
        subtitle="<b>Key Insight: </b>Rush-hour segments remain the most expensive across all post-pandemic years",
        source="NYC TLC High Volume FHV (2021–2025)",
        footer_y=-0.22,
        logo_y=-0.23
    )

    return fig_t1


# ============================================================
# === NEW PART: SAVE OR LOAD FIGURE (JSON + HTML)
# ============================================================

fig_t1 = save_or_load_plotly(
    fig_name="t1_time_of_day_pricing_structure",
    fig_builder=_build_fig_t1,
    height=600,
    width=1000,
)

fig_t1.show()   # uncomment to display



In [11]:
# === A. DISTRIBUTION OF TRIP LENGTH BY TIME-OF-DAY ===

distance_summary = (
    tlc_all
    .groupby("time_of_day_bin", as_index=False)
    .agg(
        avg_trip_km    = ("trip_km", "mean"),
        median_trip_km = ("trip_km", "median"),
        p90_trip_km    = ("trip_km", lambda s: s.quantile(0.90)),
        trips          = ("trip_km", "size")
    )
    .sort_values("avg_trip_km")
)

distance_summary


Unnamed: 0,time_of_day_bin,avg_trip_km,median_trip_km,p90_trip_km,trips
1,evening_rush,6.546682,4.16819,14.966862,2281512
0,evening,6.952157,4.570525,15.578411,1576055
3,midday,7.040067,4.345218,16.527922,2720772
4,morning_rush,7.416492,4.618805,17.396965,1585048
2,late_night,8.063046,5.391289,18.282101,1666854


In [12]:
# === B. DISTRIBUTION OF COST PER KM BY TIME-OF-DAY ===

cpkm_summary = (
    tlc_all
    .groupby("time_of_day_bin", as_index=False)
    .agg(
        avg_cost_per_km    = ("cost_per_km", "mean"),
        median_cost_per_km = ("cost_per_km", "median"),
        p90_cost_per_km    = ("cost_per_km", lambda s: s.quantile(0.90)),
        trips              = ("cost_per_km", "size")
    )
    .sort_values("median_cost_per_km", ascending=False)
)

cpkm_summary


Unnamed: 0,time_of_day_bin,avg_cost_per_km,median_cost_per_km,p90_cost_per_km,trips
1,evening_rush,5.541411,4.49203,9.604839,2281512
3,midday,5.302975,4.371397,9.102817,2720772
4,morning_rush,4.94559,4.111065,8.44268,1585048
0,evening,5.022946,4.049812,8.769184,1576055
2,late_night,4.596572,3.654094,8.021787,1666854


In [13]:
# === C. EXPECTED FARE BASED ON GLOBAL AVERAGE COST/KM ===

# Cost/km trung bình toàn thành phố
global_cpkm = tlc_all["cost_per_km"].mean()

# Nếu giá 100% theo distance:
tlc_all["expected_fare_from_km"] = tlc_all["trip_km"] * global_cpkm

compare_expected = (
    tlc_all
    .groupby("time_of_day_bin", as_index=False)
    .agg(
        actual_fare   = ("base_passenger_fare", "median"),
        expected_fare = ("expected_fare_from_km", "median")
    )
)

compare_expected["pricing_gap"] = compare_expected["actual_fare"] - compare_expected["expected_fare"]

compare_expected


Unnamed: 0,time_of_day_bin,actual_fare,expected_fare,pricing_gap
0,evening,16.059999,23.47427,-7.41427
1,evening_rush,16.440001,21.407875,-4.967875
2,late_night,16.940001,27.689722,-10.749722
3,midday,16.42,22.317091,-5.897091
4,morning_rush,16.43,23.722239,-7.292238


In [14]:
time_order = ["late_night", "morning_rush", "midday", "evening", "evening_rush"]

# Create weather_seg (if not already created)
weather_seg = (
    pricing_dist
    .groupby(["weather_state", "time_of_day_bin"], as_index=False)
    .agg(
        median_fare=("median_fare", "median"),
        p90_fare=("p90_fare_surge_proxy", "median")
    )
)
weather_seg["time_of_day_bin"] = pd.Categorical(
    weather_seg["time_of_day_bin"],
    categories=[t for t in time_order if t in weather_seg["time_of_day_bin"].unique()],
    ordered=True
)
weather_seg = weather_seg.sort_values(["weather_state", "time_of_day_bin"])
# ←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←←

In [15]:
# === PART 2 – Heatmap 1: Median Fare by Weather × Time-of-Day (Uber HD BI Style, Improved) ===

def _build_fig_w1():

    import plotly.graph_objects as go
    import plotly.io as pio
    import uber_style as ub
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap, to_hex

    pio.templates["uber"] = ub.uber_style_template
    pio.templates.default = "uber"


    # -------------------------------------------------------------------
    # WCAG Text Contrast Functions
    # -------------------------------------------------------------------
    def compute_luminance(hex_color):
        hex_color = hex_color.lstrip('#')
        r = int(hex_color[0:2], 16) / 255
        g = int(hex_color[2:4], 16) / 255
        b = int(hex_color[4:6], 16) / 255
        return 0.299*r + 0.587*g + 0.114*b

    def build_color_map(colorscale):
        stops = [stop for stop, _ in colorscale]
        colors = [color for _, color in colorscale]
        return LinearSegmentedColormap.from_list("uber_custom", list(zip(stops, colors)))

    def get_text_contrast_colors(z_values, colorscale, threshold=0.55):
        cmap = build_color_map(colorscale)
        z_min, z_max = np.nanmin(z_values), np.nanmax(z_values)
        z_norm = (z_values - z_min) / (z_max - z_min + 1e-9)

        text_colors = []
        for row in z_norm:
            row_colors = []
            for v in row:
                rgb = cmap(v)
                hex_color = to_hex(rgb)
                lum = compute_luminance(hex_color)
                row_colors.append("black" if lum > threshold else "white")
            text_colors.append(row_colors)

        return text_colors


    # -------------------------------------------------------------------
    # Set desired weather state order (UX-friendly)
    # -------------------------------------------------------------------
    weather_order = ["clear_cloudy", "snowing", "snow_on_ground", "raining"]

    weather_seg["weather_state"] = pd.Categorical(
        weather_seg["weather_state"],
        categories=[w for w in weather_order if w in weather_seg["weather_state"].unique()],
        ordered=True
    )
    weather_seg_sorted = weather_seg.sort_values(["weather_state", "time_of_day_bin"])  


    # -------------------------------------------------------------------
    # Pivot table for median fare
    # -------------------------------------------------------------------
    heat_med = weather_seg_sorted.pivot(
        index="weather_state",
        columns="time_of_day_bin",
        values="median_fare"
    )

    z = heat_med.to_numpy(dtype=float)
    z_text = np.round(z, 2)

    colorscale = ub.uber_style_template["data"]["heatmap"][0]["colorscale"]
    text_colors = get_text_contrast_colors(z, colorscale)


    # -------------------------------------------------------------------
    # Base Heatmap
    # -------------------------------------------------------------------
    fig_w1 = go.Figure(
        data=go.Heatmap(
            z=z,
            x=heat_med.columns.astype(str),
            y=heat_med.index,
            colorscale=colorscale,
            xgap=1,
            ygap=1,
            hovertemplate="<b>%{y} × %{x}</b><br>Median Fare: %{z}<extra></extra>",
            colorbar=dict(
                thickness=14,
                tickfont=dict(size=12, color="#333333"),
                title="Median Base Fare ($)"
            ),
        )
    )


    # -------------------------------------------------------------------
    # Add text labels
    # -------------------------------------------------------------------
    annotations = []
    for i, y_val in enumerate(heat_med.index):
        for j, x_val in enumerate(heat_med.columns.astype(str)):
            annotations.append(
                dict(
                    x=x_val,
                    y=y_val,
                    text=str(z_text[i, j]),
                    showarrow=False,
                    font=dict(size=12, color=text_colors[i][j]),
                    xanchor="center",
                    yanchor="middle"
                )
            )
    fig_w1.update_layout(annotations=annotations)


    # -------------------------------------------------------------------
    # Layout fixes (y-title no truncation)
    # -------------------------------------------------------------------
    fig_w1.update_layout(
        xaxis=dict(
            title=dict(
                text="Time-of-Day Segment",
                font=dict(size=16, color="#141414")
            ),
            ticklabelposition="outside"
        ),
        yaxis=dict(
            title=None,
            ticklabelposition="outside",
            autorange="reversed",
            tickangle=0
        ),
        height=520,
        margin=dict(t=110, l=210, r=40, b=90)
    )

    # Manual y-axis title (move further left)
    fig_w1.add_annotation(
        text="Weather State",
        xref="paper",
        yref="paper",
        x=-0.22,
        y=0.5,
        showarrow=False,
        font=dict(size=16, color="#141414"),
        textangle=270,
        xanchor="center",
        yanchor="middle"
    )


    # -------------------------------------------------------------------
    # Uber Branding
    # -------------------------------------------------------------------
    fig_w1 = ub.apply_uber_branding(
        fig_w1,
        title="Median Base Fare by Weather State × Time-of-Day",
        subtitle="<b>Key Insight: </b>Weather premiums emerge mainly during rush hours (morning_rush & evening_rush) segments",
        source="NYC TLC High Volume FHV (2019–2025)",
        footer_y=-0.22,
        logo_y=-0.23
    )

    return fig_w1



# ============================================================
# === SAVE OR LOAD FIGURE (NO SHOW BY DEFAULT)
# ============================================================

fig_w1 = save_or_load_plotly(
    fig_name="w1_weather_tod_median_fare",
    fig_builder=_build_fig_w1,
    height=600,
    width=1000,
)

fig_w1.show()   # uncomment when needed


In [16]:
# === PART 2 – Heatmap 2: P90 Fare by Weather × Time-of-Day (Uber HD BI Style) ===

def _build_fig_w2():

    import plotly.graph_objects as go
    import plotly.io as pio
    import uber_style as ub
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap, to_hex

    # -------------------------------------------------------------------
    # WCAG Contrast Functions
    # -------------------------------------------------------------------
    def compute_luminance(hex_color):
        hex_color = hex_color.lstrip('#')
        r = int(hex_color[0:2], 16) / 255
        g = int(hex_color[2:4], 16) / 255
        b = int(hex_color[4:6], 16) / 255
        return 0.299*r + 0.587*g + 0.114*b

    def build_color_map(colorscale):
        stops = [stop for stop, _ in colorscale]
        colors = [color for _, color in colorscale]
        return LinearSegmentedColormap.from_list("uber_custom", list(zip(stops, colors)))

    def get_text_contrast_colors(z_values, colorscale, threshold=0.55):
        cmap = build_color_map(colorscale)
        z_min, z_max = np.nanmin(z_values), np.nanmax(z_values)
        z_norm = (z_values - z_min) / (z_max - z_min + 1e-9)
        text_colors = []
        for row in z_norm:
            row_colors = []
            for v in row:
                rgb = cmap(v)
                hex_color = to_hex(rgb)
                lum = compute_luminance(hex_color)
                row_colors.append("black" if lum > threshold else "white")
            text_colors.append(row_colors)
        return text_colors


    # -------------------------------------------------------------------
    # Heatmap Data (P90 Fare)
    # -------------------------------------------------------------------
    heat_p90 = weather_seg.pivot(
        index="weather_state",
        columns="time_of_day_bin",
        values="p90_fare"
    )

    z = heat_p90.to_numpy(dtype=float)
    z_text = np.round(z, 2)

    colorscale = ub.uber_style_template["data"]["heatmap"][0]["colorscale"]
    text_colors = get_text_contrast_colors(z, colorscale)


    # -------------------------------------------------------------------
    # Base Heatmap
    # -------------------------------------------------------------------
    fig_w2 = go.Figure(
        data=go.Heatmap(
            z=z,
            x=heat_p90.columns.astype(str),
            y=heat_p90.index,
            colorscale=colorscale,
            xgap=1,
            ygap=1,
            hovertemplate="<b>%{y} × %{x}</b><br>P90 Fare: %{z}<extra></extra>",
            colorbar=dict(
                thickness=14,
                tickfont=dict(size=12, color="#333333"),
                title="P90 Fare ($)"
            )
        )
    )


    # -------------------------------------------------------------------
    # WCAG-safe annotations
    # -------------------------------------------------------------------
    annotations = []
    for i, y_val in enumerate(heat_p90.index):
        for j, x_val in enumerate(heat_p90.columns.astype(str)):
            annotations.append(
                dict(
                    x=x_val,
                    y=y_val,
                    text=str(z_text[i, j]),
                    showarrow=False,
                    font=dict(size=12, color=text_colors[i][j]),
                    xanchor="center",
                    yanchor="middle"
                )
            )
    fig_w2.update_layout(annotations=annotations)


    # -------------------------------------------------------------------
    # Layout
    # -------------------------------------------------------------------
    fig_w2.update_layout(
        xaxis=dict(
            title=dict(
                text="Time-of-Day Segment",
                font=dict(size=16, color="#141414")
            ),
            ticklabelposition="outside"
        ),
        yaxis=dict(
            title=None,
            ticklabelposition="outside",
            autorange="reversed",
            tickangle=0
        ),
        height=520,
        margin=dict(t=110, l=180, r=40, b=90)
    )

    # Manual vertical y-title
    fig_w2.add_annotation(
        text="Weather State",
        xref="paper",
        yref="paper",
        x=-0.18,
        y=0.5,
        showarrow=False,
        textangle=270,
        font=dict(size=16, color="#141414"),
        xanchor="center",
        yanchor="middle"
    )


    # -------------------------------------------------------------------
    # Uber Branding
    # -------------------------------------------------------------------
    fig_w2 = ub.apply_uber_branding(
        fig_w2,
        title="P90 Base Fare (Surge Intensity) by Weather × Time-of-Day",
        subtitle="<b>Key Insight: </b>Evening rush × Rain has the strongest surge behaviour",
        source="NYC TLC High Volume FHV (2019–2025)",
        footer_y=-0.22,
        logo_y=-0.23
    )

    return fig_w2



# ============================================================
# === SAVE OR LOAD CACHED VERSION (JSON + HTML)
# ============================================================

fig_w2 = save_or_load_plotly(
    fig_name="w2_weather_tod_p90_fare",
    fig_builder=_build_fig_w2,
    height=600,
    width=1000,
)

fig_w2.show()   # uncomment to display
