# Data Science Project: Fuel Price Prediction Model

## Data Preparation

In [None]:
import duckdb
from pathlib import Path

# Base directory of the cloned tankerkoenig-data repo
BASE_DIR = Path(
    r"C:\Users\websi\OneDrive - UT Cloud\Semester\3. WS2025_26\DS500 Data Science Project (12 ECTS)\tankerkoenig_repo\tankerkoenig-data"
)

# DuckDB database file (will be created if it does not exist)
DB_PATH = BASE_DIR / "fuel_price_preparation.duckdb"

con = duckdb.connect(DB_PATH.as_posix())
con.execute("PRAGMA threads=8;")       # use multiple cores if available
con.execute("SELECT setseed(0.42)")  # fixed seed for reproducible random sampling

print(f"Connected to DuckDB at: {DB_PATH}")

In [None]:
# Globs for 2023/2024 prices and stations
prices_2023_glob = (BASE_DIR / "prices" / "2023" / "*" / "*-prices.csv").as_posix()
prices_2024_glob = (BASE_DIR / "prices" / "2024" / "*" / "*-prices.csv").as_posix()

stations_2023_glob = (BASE_DIR / "stations" / "2023" / "*" / "*-stations.csv").as_posix()
stations_2024_glob = (BASE_DIR / "stations" / "2024" / "*" / "*-stations.csv").as_posix()

print(prices_2023_glob)
print(prices_2024_glob)
print(stations_2023_glob)
print(stations_2024_glob)

In [None]:
# 2.1 Prices: read 2023 + 2024 into one table
con.execute("DROP TABLE IF EXISTS prices_raw;")
con.execute(
    """
    CREATE TABLE prices_raw AS
    SELECT * FROM read_csv_auto(?, SAMPLE_SIZE=-1)
    UNION ALL
    SELECT * FROM read_csv_auto(?, SAMPLE_SIZE=-1);
    """,
    [prices_2023_glob, prices_2024_glob],
)

# Quick sanity check
con.execute("SELECT COUNT(*) AS n_rows, MIN(date), MAX(date) FROM prices_raw;").df()

In [None]:
# 2.2 Stations: read 2023 + 2024 into one table, also keep filename
con.execute("DROP TABLE IF EXISTS stations_raw;")
con.execute(
    """
    CREATE TABLE stations_raw AS
    SELECT
        uuid,
        name,
        brand,
        street,
        house_number,
        city,
        latitude,
        longitude,
        first_active,
        openingtimes_json,
        filename
    FROM read_csv_auto(
            ?, 
            SAMPLE_SIZE=-1, 
            filename=true, 
            union_by_name=true
         )
    UNION ALL
    SELECT
        uuid,
        name,
        brand,
        street,
        house_number,
        city,
        latitude,
        longitude,
        first_active,
        openingtimes_json,
        filename
    FROM read_csv_auto(
            ?, 
            SAMPLE_SIZE=-1, 
            filename=true, 
            union_by_name=true
         );
    """,
    [stations_2023_glob, stations_2024_glob],
)

con.execute(
    "SELECT COUNT(*) AS n_rows, MIN(uuid) AS min_uuid, MAX(uuid) AS max_uuid FROM stations_raw;"
).df()

In [None]:
# 3.1 Pick latest snapshot per station within 2023â€“2024
con.execute("DROP TABLE IF EXISTS stations_snapshot;")
con.execute(
    """
    CREATE TABLE stations_snapshot AS
    WITH parsed AS (
        SELECT
            *,
            -- extract 'YYYY-MM-DD' from filename, e.g. '.../2023-01-01-stations.csv'
            CAST(regexp_extract(filename, '([0-9]{4}-[0-9]{2}-[0-9]{2})', 1) AS DATE) AS snapshot_date
        FROM stations_raw
    ),
    ranked AS (
        SELECT
            *,
            ROW_NUMBER() OVER (PARTITION BY uuid ORDER BY snapshot_date DESC) AS rn
        FROM parsed
    )
    SELECT
        uuid,
        name,
        brand,
        street,
        house_number,
        city,
        CAST(latitude AS DOUBLE)  AS latitude,
        CAST(longitude AS DOUBLE) AS longitude,
        first_active,
        openingtimes_json
    FROM ranked
    WHERE rn = 1;
    """
)

con.execute(
    "SELECT COUNT(*) AS n_rows, COUNT(DISTINCT uuid) AS n_uuids FROM stations_snapshot;"
).df()

In [None]:
# 3.2 Sample 500 random stations and compute brand groups (rare brands -> 'other')
con.execute("DROP TABLE IF EXISTS stations_sampled;")
con.execute(
    """
    CREATE TABLE stations_sampled AS
    WITH ranked AS (
        SELECT
            uuid,
            brand,
            city,
            row_number() OVER (ORDER BY random()) AS rn
        FROM stations_snapshot
    )
    SELECT
        uuid,
        brand,
        city
    FROM ranked
    WHERE rn <= 500;
    """
)

# Brand grouping: brands with <= 5 stations in the 500-sample are mapped to 'other'
con.execute("DROP TABLE IF EXISTS stations_sampled_grouped;")
con.execute(
    """
    CREATE TABLE stations_sampled_grouped AS
    WITH brand_counts AS (
        SELECT brand, COUNT(*) AS n
        FROM stations_sampled
        GROUP BY brand
    ),
    extended AS (
        SELECT
            s.uuid,
            s.city,
            CASE WHEN bc.n > 5 THEN s.brand ELSE 'other' END AS brand_group
        FROM stations_sampled s
        LEFT JOIN brand_counts bc USING (brand)
    )
    SELECT * FROM extended;
    """
)

con.execute(
    """
    SELECT brand_group, COUNT(*) AS n_stations
    FROM stations_sampled_grouped
    GROUP BY brand_group
    ORDER BY n_stations DESC;
    """
).df()

In [None]:
# 4.1 E5 price changes for the 500 sampled stations
con.execute("DROP TABLE IF EXISTS prices_sampled_e5;")
con.execute(
    """
    CREATE TABLE prices_sampled_e5 AS
    SELECT
        CAST(p.date AS TIMESTAMP) AS ts,
        p.station_uuid,
        CAST(p.e5 AS DOUBLE) AS price_e5
    FROM prices_raw p
    JOIN stations_sampled_grouped s
        ON s.uuid = p.station_uuid
    WHERE p.e5 IS NOT NULL
      AND p.e5 > 0;
    """
)

con.execute(
    """
    SELECT COUNT(*) AS n_rows, MIN(ts) AS min_ts, MAX(ts) AS max_ts
    FROM prices_sampled_e5;
    """
).df()

In [None]:
# 4.2 Round timestamps down to the nearest 30-minute grid cell
con.execute("DROP TABLE IF EXISTS prices_sampled_e5_rounded;")
con.execute(
    """
    CREATE TABLE prices_sampled_e5_rounded AS
    SELECT
        station_uuid,
        -- floor to 30-minute grid: 00 or 30
        date_trunc('hour', ts)
            + INTERVAL (CASE WHEN EXTRACT(MINUTE FROM ts) < 30 THEN 0 ELSE 30 END) MINUTE AS ts_30,
        AVG(price_e5) AS price_e5
    FROM prices_sampled_e5
    GROUP BY station_uuid, ts_30;
    """
)

con.execute(
    """
    SELECT COUNT(*) AS n_rows,
           MIN(ts_30) AS min_ts_30,
           MAX(ts_30) AS max_ts_30
    FROM prices_sampled_e5_rounded;
    """
).df()

In [None]:
con.execute("DROP TABLE IF EXISTS grid_sampled_e5;")
con.execute(
    """
    CREATE TABLE grid_sampled_e5 AS
    WITH station_span AS (
        SELECT
            station_uuid,
            MIN(ts_30) AS min_ts,
            MAX(ts_30) AS max_ts
        FROM prices_sampled_e5_rounded
        GROUP BY station_uuid
    ),
    grid AS (
        -- 30-minute grid per station from first to last observed change
        SELECT
            s.station_uuid,
            gs.ts_30
        FROM station_span s,
        generate_series(
            s.min_ts,
            s.max_ts,
            INTERVAL 30 MINUTE
        ) AS gs(ts_30)
    ),
    base AS (
        -- attach price events (may be NULL if no change in this grid cell)
        SELECT
            g.station_uuid,
            g.ts_30,
            pr.price_e5 AS price_event
        FROM grid g
        LEFT JOIN prices_sampled_e5_rounded pr
          ON pr.station_uuid = g.station_uuid
         AND pr.ts_30       = g.ts_30
    ),
    numbered AS (
        -- sequential index per station for forward-fill logic
        SELECT
            station_uuid,
            ts_30,
            price_event,
            ROW_NUMBER() OVER (
                PARTITION BY station_uuid
                ORDER BY ts_30
            ) AS k
        FROM base
    ),
    with_last_k AS (
        -- for each row, compute index of last row with a non-null price_event
        SELECT
            station_uuid,
            ts_30,
            k,
            MAX(
                CASE
                    WHEN price_event IS NOT NULL THEN k
                    ELSE NULL
                END
            ) OVER (
                PARTITION BY station_uuid
                ORDER BY k
                ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
            ) AS last_k
        FROM numbered
    ),
    ff AS (
        -- join back to get the forward-filled price
        SELECT
            n.station_uuid,
            n.ts_30,
            e.price_event AS price_e5
        FROM with_last_k n
        LEFT JOIN numbered e
          ON e.station_uuid = n.station_uuid
         AND e.k           = n.last_k
    )
    SELECT *
    FROM ff
    WHERE price_e5 IS NOT NULL;
    """
)

con.execute(
    """
    SELECT COUNT(*) AS n_rows,
           MIN(ts_30) AS min_ts_30,
           MAX(ts_30) AS max_ts_30
    FROM grid_sampled_e5;
    """
).df()

In [None]:
con.execute("DROP TABLE IF EXISTS grid_sampled_e5_prepared;")
con.execute(
    """
    CREATE TABLE grid_sampled_e5_prepared AS
    SELECT
        station_uuid,
        ts_30 AS ts_utc,  -- Keep the original for reference if needed
        
        -- 1. Convert UTC timestamp to Berlin Local Time (Wall Time)
        timezone('Europe/Berlin', ts_30) AS ts_local,
        
        -- 2. Use Local Time for the Date (so 01:00 AM stays on the correct day)
        CAST(timezone('Europe/Berlin', ts_30) AS DATE) AS d,
        
        -- 3. Calculate the Time Cell (0..47) using Local Hour
        (EXTRACT(HOUR FROM timezone('Europe/Berlin', ts_30)) * 2 
         + EXTRACT(MINUTE FROM timezone('Europe/Berlin', ts_30)) / 30) AS time_cell,
         
        price_e5 AS price
    FROM grid_sampled_e5;
    """
)

# Validation: Check if Rush Hour is stable
con.execute("""
    SELECT 
        time_cell, 
        AVG(price) as mean_price 
    FROM grid_sampled_e5_prepared 
    GROUP BY time_cell 
    ORDER BY mean_price DESC 
    LIMIT 5;
""").df()

con.execute(
    """
    SELECT
        MIN(d) AS min_date,
        MAX(d) AS max_date,
        MIN(time_cell) AS min_cell,
        MAX(time_cell) AS max_cell
    FROM grid_sampled_e5_prepared;
    """
).df()

In [18]:
# --- 4.3 Feature Engineering: History + Seasonality + Momentum ---

con.execute("DROP TABLE IF EXISTS features_sampled_e5;")
con.execute(
    """
    CREATE TABLE features_sampled_e5 AS
    WITH windows AS (
        SELECT
            g.station_uuid,
            g.d,
            g.time_cell,
            g.price AS price_today,
            
            -- RAW HISTORY
            LAG(g.price, 1) OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_1d,
            LAG(g.price, 2) OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_2d,
            LAG(g.price, 3) OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_3d,
            
            -- WEEKLY HISTORY
            LAG(g.price, 7)  OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_7d,
            LAG(g.price, 14) OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_14d,
            LAG(g.price, 21) OVER (PARTITION BY g.station_uuid, g.time_cell ORDER BY g.d) AS price_lag_21d
            
        FROM grid_sampled_e5_prepared g
    )
    SELECT
        w.station_uuid,
        w.d AS date,
        w.time_cell,
        w.price_today AS price,
        
        -- 1. NEW: Explicit Seasonality
        -- dayofweek returns 0 for Sunday, 6 for Saturday in DuckDB
        dayofweek(w.d) AS dow, 
        CASE WHEN dayofweek(w.d) IN (0, 6) THEN 1 ELSE 0 END AS is_weekend,
        
        -- 2. NEW: Price Momentum (Velocity)
        -- "Are we currently crashing or spiking compared to 2 days ago?"
        (w.price_lag_1d - w.price_lag_2d) AS mom_1d_2d,
        (w.price_lag_1d - w.price_lag_7d) AS mom_1d_7d,

        -- Standard Lags
        w.price_lag_1d,
        w.price_lag_2d,
        w.price_lag_3d,
        w.price_lag_7d,
        w.price_lag_14d,
        w.price_lag_21d,
        
        -- Brand
        sb.brand_group
        
    FROM windows w
    JOIN stations_sampled_grouped sb ON sb.uuid = w.station_uuid
    WHERE w.price_lag_1d IS NOT NULL
      AND w.price_lag_21d IS NOT NULL;
    """
)

# Check output
con.execute("SELECT COUNT(*) FROM features_sampled_e5").df()

Unnamed: 0,count_star()
0,14737608


In [19]:
sample_df = con.execute(
    """
    SELECT *
    FROM features_sampled_e5
    ORDER BY date, time_cell, station_uuid
    LIMIT 20;
    """
).df()

sample_df

Unnamed: 0,station_uuid,date,time_cell,price,dow,is_weekend,mom_1d_2d,mom_1d_7d,price_lag_1d,price_lag_2d,price_lag_3d,price_lag_7d,price_lag_14d,price_lag_21d,brand_group
0,0627cc2b-d880-4c21-89bd-37e983c02624,2023-01-21,46.0,2.249,6,1,0.0,-0.04,2.249,2.249,2.249,2.289,2.289,2.289,ARAL
1,7915e506-168d-4509-a012-7cfba23fbc5c,2023-01-21,46.0,2.249,6,1,0.0,0.01,2.249,2.249,2.249,2.239,2.239,2.239,ARAL
2,9289d2c9-8d32-44ea-bf74-432d36c02425,2023-01-21,46.0,1.829,6,1,-0.05,-0.03,1.769,1.819,1.819,1.799,1.799,1.759,other
3,0627cc2b-d880-4c21-89bd-37e983c02624,2023-01-21,47.0,2.249,6,1,0.0,-0.04,2.249,2.249,2.249,2.289,2.289,2.289,ARAL
4,1cf7a642-26bb-4dd4-a011-914025d0b7de,2023-01-21,47.0,1.869,6,1,0.0,0.0,1.849,1.849,1.829,1.849,1.789,1.799,TotalEnergies
5,2822fc54-4d03-4a97-a3ab-1b00e04c5f46,2023-01-21,47.0,1.779,6,1,0.02,0.01,1.769,1.749,1.749,1.759,1.729,1.729,AVIA
6,431170c0-24bf-47f3-85ca-46a2b307eefd,2023-01-21,47.0,1.779,6,1,0.01,0.03,1.779,1.769,1.769,1.749,1.759,1.759,TotalEnergies
7,4b7e7978-b9ac-4013-9c28-c6afbfd4cb2d,2023-01-21,47.0,1.749,6,1,0.0,0.0,1.749,1.749,1.769,1.749,1.719,1.749,other
8,59ca7896-2c19-4354-97a2-a0f1de18be4f,2023-01-21,47.0,1.809,6,1,0.0,-0.02,1.749,1.749,1.749,1.769,1.759,1.809,other
9,70726c80-fd30-45f8-87a6-8acf0c0000db,2023-01-21,47.0,1.809,6,1,0.01,0.04,1.829,1.819,1.779,1.789,1.819,1.799,TotalEnergies


In [20]:
# --- Export features to Parquet ---
output_dir = BASE_DIR / "derived"
output_dir.mkdir(exist_ok=True)

output_parquet = output_dir / "features_sampled_e5_2023_2024.parquet"

sql = f"""
COPY features_sampled_e5
TO '{output_parquet.as_posix()}'
(FORMAT PARQUET, COMPRESSION ZSTD);
"""

con.execute(sql)

print(f"Exported features to: {output_parquet}")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Exported features to: C:\Users\websi\OneDrive - UT Cloud\Semester\3. WS2025_26\DS500 Data Science Project (12 ECTS)\tankerkoenig_repo\tankerkoenig-data\derived\features_sampled_e5_2023_2024.parquet


## Model training

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

import lightgbm as lgb
from sklearn.metrics import mean_absolute_error, mean_squared_error

# Path to features parquet (no market / markup, no post_code)
BASE_DIR = Path(
    r"C:\Users\websi\OneDrive - UT Cloud\Semester\3. WS2025_26\DS500 Data Science Project (12 ECTS)\tankerkoenig_repo\tankerkoenig-data"
)

features_path = BASE_DIR / "derived" / "features_sampled_e5_2023_2024.parquet"

df = pd.read_parquet(features_path)
print(df.shape)
df.head()

In [None]:
# --- Data Type Optimization ---

# 1. FIX: Convert 'date' to Timestamp
df["date"] = pd.to_datetime(df["date"])
df["time_cell"] = df["time_cell"].astype("int16")

# 2. Ensure correct types for all history features
cols_to_fix = [
    "price_lag_1d", "price_lag_2d", "price_lag_3d",
    "price_lag_7d", "price_lag_14d", "price_lag_21d"
]

for c in cols_to_fix:
    df[c] = df[c].astype("float32")

# 3. Sort for safety
df = df.sort_values(["station_uuid", "date", "time_cell"]).reset_index(drop=True)

df.dtypes

In [None]:
# Define cut dates
train_end = pd.Timestamp("2024-06-30")
valid_end = pd.Timestamp("2024-09-30")

train_mask = df["date"] <= train_end
valid_mask = (df["date"] > train_end) & (df["date"] <= valid_end)
test_mask  = df["date"] > valid_end

print("Train rows:", train_mask.sum())
print("Valid rows:", valid_mask.sum())
print("Test rows :", test_mask.sum())

In [None]:
# --- Define target and features ---

target_col = "price"

feature_cols = [
    "time_cell",        # The "When"
    "brand_group",      # The "Who"
    
    # Recent History (Short term state)
    "price_lag_1d", 
    "price_lag_2d", 
    "price_lag_3d",
    
    # Seasonal History (Day-of-week state)
    "price_lag_7d",
    "price_lag_14d",
    "price_lag_21d",
]

cat_cols = ["brand_group"]

# --- Time-based train/validation/test split ---

train_end = pd.Timestamp("2024-06-30")
valid_end = pd.Timestamp("2024-09-30")

train_mask = df["date"] <= train_end
valid_mask = (df["date"] > train_end) & (df["date"] <= valid_end)
test_mask  = df["date"] > valid_end

print("Train rows:", train_mask.sum())
print("Valid rows:", valid_mask.sum())
print("Test rows :", test_mask.sum())

X_train = df.loc[train_mask, feature_cols].copy()
y_train = df.loc[train_mask, target_col].values

X_valid = df.loc[valid_mask, feature_cols].copy()
y_valid = df.loc[valid_mask, target_col].values

X_test = df.loc[test_mask, feature_cols].copy()
y_test = df.loc[test_mask, target_col].values

# Cast categoricals
for c in cat_cols:
    X_train[c] = X_train[c].astype("category")
    X_valid[c] = X_valid[c].astype("category")
    X_test[c] = X_test[c].astype("category")

X_train.head()

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np

def mae_rmse(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    return mae, rmse

# Baseline 1: yesterday same time (Lag 1d)
b1_pred = df.loc[test_mask, "price_lag_1d"].values
b1_mae, b1_rmse = mae_rmse(y_test, b1_pred)

# Baseline 2: week-ago same time (Lag 7d)
b2_pred = df.loc[test_mask, "price_lag_7d"].values
b2_mae, b2_rmse = mae_rmse(y_test, b2_pred)

# Baseline 3: simple average of last 3 days (Smoothing baseline)
b3_pred = (df.loc[test_mask, "price_lag_1d"].values + 
           df.loc[test_mask, "price_lag_2d"].values + 
           df.loc[test_mask, "price_lag_3d"].values) / 3.0
b3_mae, b3_rmse = mae_rmse(y_test, b3_pred)

print("Baseline metrics on TEST set")
print(f"Yesterday:   MAE={b1_mae:.4f}, RMSE={b1_rmse:.4f}")
print(f"Week-ago:    MAE={b2_mae:.4f}, RMSE={b2_rmse:.4f}")
print(f"Avg(3days):  MAE={b3_mae:.4f}, RMSE={b3_rmse:.4f}")

In [None]:
import itertools
import lightgbm as lgb

# --- 1. Create Datasets (Same as before) ---
train_data = lgb.Dataset(
    X_train,
    label=y_train,
    categorical_feature=cat_cols,
    free_raw_data=False,
)

valid_data = lgb.Dataset(
    X_valid,
    label=y_valid,
    categorical_feature=cat_cols,
    reference=train_data,
    free_raw_data=False,
)

# --- 2. Define the Grid Search ---
# We focus on complexity (num_leaves) and robustness (min_data_in_leaf)
param_grid = {
    "num_leaves": [31, 63, 127],
    "min_data_in_leaf": [100, 200, 500],
    "learning_rate": [0.05] 
}

# Generate all combinations
keys, values = zip(*param_grid.items())
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

print(f"Starting Grid Search with {len(combinations)} candidates...\n")

best_score = float("inf")
best_params = {}
best_model = None

# --- 3. Run the Loop ---
for i, config in enumerate(combinations):
    
    # Base parameters (fixed settings)
    current_params = {
        "objective": "mae",
        "metric": "mae",
        "feature_fraction": 0.8,
        "bagging_fraction": 0.8,
        "bagging_freq": 1,
        "max_bin": 255,
        "verbosity": -1,
        "seed": 42,
    }
    # Update with current grid values
    current_params.update(config)
    
    # Train
    # We use a smaller patience (20) for the grid search to speed it up
    callbacks = [
        lgb.early_stopping(stopping_rounds=20, verbose=False),
        lgb.log_evaluation(period=0) # Silence the training logs
    ]
    
    gbm_candidate = lgb.train(
        current_params,
        train_set=train_data,
        num_boost_round=1000, 
        valid_sets=[train_data, valid_data],
        valid_names=["train", "valid"],
        callbacks=callbacks
    )
    
    # Get the best score (MAE on validation set)
    score = gbm_candidate.best_score["valid"]["l1"]
    
    print(f"[{i+1}/{len(combinations)}] Params: {config} -> Val MAE: {score:.5f}")
    
    if score < best_score:
        best_score = score
        best_params = current_params
        best_model = gbm_candidate

print("\n" + "="*40)
print(f"WINNER: MAE = {best_score:.5f}")
print("Best Params:", best_params)
print("="*40)

# --- 4. Set the winner ---
# This ensures the next cells use the best model found
gbm = best_model

In [None]:
# Use the best iteration determined by early stopping
y_pred_test = gbm.predict(X_test, num_iteration=gbm.best_iteration)

lgb_mae, lgb_rmse = mae_rmse(y_test, y_pred_test)

print("LightGBM metrics on TEST set")
print(f"LightGBM:    MAE={lgb_mae:.4f}, RMSE={lgb_rmse:.4f}")

print("\nComparison:")
print(f"Yesterday:   MAE={b1_mae:.4f}, RMSE={b1_rmse:.4f}")
print(f"Week-ago:    MAE={b2_mae:.4f}, RMSE={b2_rmse:.4f}")
print(f"Avg(y,7d):   MAE={b3_mae:.4f}, RMSE={b3_rmse:.4f}")

In [None]:
# Feature importances
import matplotlib.pyplot as plt

importances = gbm.feature_importance(importance_type="gain")
for col, imp in sorted(zip(feature_cols, importances), key=lambda x: -x[1]):
    print(f"{col:15s} {imp:.1f}")

# Simple bar plot
plt.figure(figsize=(6, 4))
order = np.argsort(importances)
plt.barh(np.array(feature_cols)[order], importances[order])
plt.xlabel("Importance (gain)")
plt.tight_layout()
plt.show()

In [None]:
# Save model to disk for later use
model_path = BASE_DIR / "derived" / "lightgbm_sampled_e5_mae.txt"
gbm.save_model(model_path.as_posix())
print(f"Saved LightGBM model to: {model_path}")

In [None]:
valid_stations = con.execute("SELECT DISTINCT station_uuid FROM features_sampled_e5 LIMIT 5").df()
print(valid_stations)

In [None]:
def predict_price(station_uuid, query_time_str, con, gbm):
    """
    Robustly predicts fuel price.
    Fixes date comparison issues and enforces column ordering.
    """
    import pandas as pd
    import numpy as np
    
    # 1. Setup Query
    query_ts = pd.Timestamp(query_time_str)
    
    # Calculate Time Cell (0..47)
    time_cell = (query_ts.hour * 2) + (1 if query_ts.minute >= 30 else 0)
    
    # Calculate Lags
    target_date = query_ts.date()
    lag_days = [1, 2, 3, 7, 14, 21]
    dates_to_fetch = [target_date - pd.Timedelta(days=d) for d in lag_days]
    
    # Convert to strings for safe SQL and Python comparison
    date_strs = [d.isoformat() for d in dates_to_fetch]
    
    # 2. Fetch History
    # We cast 'd' to VARCHAR in SQL to ensure it comes back as a string in Pandas
    # This prevents the "datetime.date vs Timestamp" mismatch error completely.
    query = f"""
    WITH history AS (
        SELECT 
            CAST(d AS VARCHAR) as d_str, 
            price 
        FROM grid_sampled_e5_prepared 
        WHERE station_uuid = '{station_uuid}'
          AND time_cell = {time_cell}
          AND d IN (
              CAST('{date_strs[0]}' AS DATE),
              CAST('{date_strs[1]}' AS DATE),
              CAST('{date_strs[2]}' AS DATE),
              CAST('{date_strs[3]}' AS DATE),
              CAST('{date_strs[4]}' AS DATE),
              CAST('{date_strs[5]}' AS DATE)
          )
    ),
    brand_info AS (
        SELECT brand_group 
        FROM stations_sampled_grouped 
        WHERE uuid = '{station_uuid}'
    )
    SELECT 
        h.d_str, 
        h.price,
        b.brand_group
    FROM history h, brand_info b
    """
    
    history_df = con.execute(query).df()
    
    if history_df.empty:
        print(f"Warning: No history found for St: {station_uuid}, Cell: {time_cell}")
        # Fallback: return None or a default value
        return None

    # 3. Construct Features safely
    features = {
        "time_cell": time_cell,
        "brand_group": history_df["brand_group"].iloc[0]
    }
    
    # Helper: Match strings to strings
    def get_price(date_str):
        row = history_df[history_df["d_str"] == date_str]
        if not row.empty:
            return float(row["price"].values[0])
        return np.nan

    # Map prices
    features["price_lag_1d"] = get_price(date_strs[0])
    features["price_lag_2d"] = get_price(date_strs[1])
    features["price_lag_3d"] = get_price(date_strs[2])
    features["price_lag_7d"] = get_price(date_strs[3])
    features["price_lag_14d"] = get_price(date_strs[4])
    features["price_lag_21d"] = get_price(date_strs[5])
    
    # Debug Print: Check if we actually found data
    found_lags = sum(1 for k,v in features.items() if "price_lag" in k and not np.isnan(v))
    print(f"Debug: Found {found_lags}/6 historical prices for {query_time_str}")

    # 4. Create DataFrame and Enforce Order
    X_input = pd.DataFrame([features])
    
    # Cast types
    X_input["brand_group"] = X_input["brand_group"].astype("category")
    for col in X_input.columns:
        if "price_lag" in col:
            X_input[col] = X_input[col].astype("float32")
            
    # CRITICAL: Reorder columns to match exactly what LightGBM expects
    # gbm.feature_name() gives the list of features from training
    X_input = X_input[gbm.feature_name()]
    
    # 5. Predict
    prediction = gbm.predict(X_input)[0]
    return prediction

# --- TEST AGAIN ---
# (Make sure to use a valid UUID from your training set)
test_station = "78f22673-eee8-4577-81da-b6062309dda1" 
test_time_1 = "2024-10-01 17:00:00"
test_time_2 = "2024-10-01 18:00:00"

p1 = predict_price(test_station, test_time_1, con, gbm)
p2 = predict_price(test_station, test_time_2, con, gbm)

print(f"17:00 Price: {p1:.3f}")
print(f"18:00 Price: {p2:.3f}")