## Modelling

In this Notebook we can start modelling, with some data from our DB.

- To do this we can connect with our local DB using the `duckdb` library
- When a connection has been made we can start retrieving data from our DB.


### Setup


In [3]:
import duckdb
import polars as pl
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mlflow
import mlflow.xgboost

In [4]:
%load_ext sql
conn = duckdb.connect(database="../dsp-dagster/data_systems_project.duckdb")
%sql conn --alias duckdb
%sql SHOW ALL TABLES; # shows all available tables

The sql extension is already loaded. To reload it, use:
  %reload_ext sql


Unnamed: 0,database,schema,name,column_names,column_types,temporary
0,data_systems_project,joined,incidents_buurten,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[BIGINT, DATE, TIME, TIME, TIME, DOUBLE, VARCH...",False
1,data_systems_project,joined,incidents_wijken,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[BIGINT, DATE, TIME, TIME, TIME, DOUBLE, VARCH...",False
2,data_systems_project,public,bag_panden,"[geometry, identificatie, rdf_seealso, bouwjaa...","[BLOB, VARCHAR, VARCHAR, BIGINT, VARCHAR, VARC...",False
3,data_systems_project,public,cbs_buurten,"[geometry, buurtcode, buurtnaam, wijkcode, gem...","[BLOB, VARCHAR, VARCHAR, VARCHAR, VARCHAR, VAR...",False
4,data_systems_project,public,cbs_wijken,"[geometry, wijkcode, wijknaam, gemeentecode, g...","[BLOB, VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIG...",False
5,data_systems_project,public,fire_stations_and_vehicles,"[Fire_Station, Vehicle, Vehicle_Type]","[VARCHAR, VARCHAR, VARCHAR]",False
6,data_systems_project,public,service_areas,"[H_Verzorgingsgebied_ID, Verzorgingsgebied, LA...","[BIGINT, VARCHAR, DOUBLE, DOUBLE, VARCHAR]",False
7,data_systems_project,public,storm_deployments,"[Deployment_ID, Incident_ID, Vehicle_Type, Veh...","[BIGINT, BIGINT, VARCHAR, VARCHAR, VARCHAR, VA...",False
8,data_systems_project,public,storm_incidents,"[Incident_ID, Date, Incident_Starttime, Incide...","[BIGINT, DATE, TIME, TIME, TIME, BIGINT, VARCH...",False


In [5]:
## We can use SQL magic to retrieve data from our DB like so:
# %sql res << SELECT * FROM joined.deployment_incident_vehicles_weather
# res

In [9]:
# Or the more Pythonic way:

# Here we retrieve a table where KNMI weather data and Fire Department data is combined
df = conn.execute(
    """
    SELECT * FROM joined.incidents_buurten """
).pl()

# Close the database connection
conn.close()

In [10]:
df.head()

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,geometry,index_right,buurtcode,buurtnaam,wijkcode,gemeentecode,gemeentenaam,indelingswijzigingWijkenEnBuurten,water,meestVoorkomendePostcode,dekkingspercentage,omgevingsadressendichtheid,stedelijkheidAdressenPerKm2,bevolkingsdichtheidInwonersPerKm2,aantalInwoners,mannen,vrouwen,percentagePersonen0Tot15Jaar,percentagePersonen15Tot25Jaar,percentagePersonen25Tot45Jaar,percentagePersonen45Tot65Jaar,percentagePersonen65JaarEnOuder,percentageOngehuwd,percentageGehuwd,percentageGescheid,percentageVerweduwd,aantalHuishoudens,percentageEenpersoonshuishoudens,percentageHuishoudensZonderKinderen,percentageHuishoudensMetKinderen,gemiddeldeHuishoudsgrootte,percentageWesterseMigratieachtergrond,percentageNietWesterseMigratieachtergrond,percentageUitMarokko,percentageUitNederlandseAntillenEnAruba,percentageUitSuriname,percentageUitTurkije,percentageOverigeNietwestersemigratieachtergrond,oppervlakteTotaalInHa,oppervlakteLandInHa,oppervlakteWaterInHa,jrstatcode,jaar
i64,date,time,time,time,f64,str,str,str,f64,f64,i64,i64,i64,i64,i64,i64,str,f64,str,str,str,str,str,f64,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64
511,2018-09-07,08:25:18,10:00:56,01:35:38,2.0,"""Amstelveen""","""Amstelveen""","""Tree""",4.838685,52.281552,8,10,1,25,0,35,"""POINT (4.83868…",497.0,"""BU03621104""","""Legmeer""","""WK036211""","""GM0362""","""Amstelveen""",1.0,"""NEE""","""1185""",1.0,2041.0,2.0,505.0,220.0,115.0,100.0,4.0,16.0,69.0,8.0,3.0,82.0,13.0,5.0,1.0,170.0,80.0,14.0,6.0,1.3,10.0,28.0,3.0,2.0,5.0,1.0,17.0,44.0,43.0,1.0,"""2022BU03621104…",2022.0
738,2018-09-10,16:46:38,18:00:55,01:14:17,2.0,"""Victor""","""Amsterdam""","""Tree""",4.930968,52.359724,16,18,1,46,0,14,"""POINT (4.93096…",365.0,"""BU03632902""","""Oostpoort""","""WK036329""","""GM0363""","""Amsterdam""",1.0,"""NEE""","""1093""",1.0,8727.0,1.0,9084.0,2020.0,995.0,1020.0,17.0,9.0,33.0,25.0,17.0,61.0,26.0,9.0,4.0,950.0,47.0,24.0,29.0,1.9,17.0,31.0,6.0,1.0,8.0,3.0,12.0,23.0,22.0,1.0,"""2022BU03632902…",2022.0
739,2018-09-10,17:11:20,18:01:00,00:49:40,2.0,"""Dirk""","""Amsterdam""","""Building""",4.898818,52.357071,17,18,0,11,1,49,"""POINT (4.89881…",548.0,"""BU03632404""","""Sarphatiparkbu…","""WK036324""","""GM0363""","""Amsterdam""",1.0,"""NEE""","""1073""",4.0,11144.0,1.0,21998.0,4990.0,2470.0,2520.0,6.0,15.0,49.0,20.0,10.0,77.0,15.0,6.0,1.0,3375.0,65.0,25.0,10.0,1.5,30.0,20.0,2.0,1.0,3.0,2.0,12.0,23.0,23.0,0.0,"""2022BU03632404…",2022.0
1335,2018-09-18,19:17:33,20:41:29,01:23:56,2.0,"""Zebra""","""Amsterdam""","""Fence, Road si…",4.960809,52.392155,19,20,1,17,41,23,"""POINT (4.96080…",259.0,"""BU03636804""","""Markengouw Zui…","""WK036368""","""GM0363""","""Amsterdam""",1.0,"""NEE""","""1024""",1.0,1755.0,2.0,10854.0,2315.0,1165.0,1150.0,28.0,14.0,30.0,21.0,7.0,63.0,28.0,7.0,2.0,845.0,35.0,14.0,50.0,2.6,10.0,65.0,16.0,2.0,9.0,16.0,23.0,23.0,21.0,1.0,"""2022BU03636804…",2022.0
1493,2018-09-21,06:59:05,08:21:25,01:22:20,2.0,"""Amstelveen""","""Amstelveen""","""Tree""",4.879741,52.301365,6,8,1,59,21,22,"""POINT (4.87974…",598.0,"""BU03620603""","""Boekenbuurt""","""WK036206""","""GM0362""","""Amstelveen""",1.0,"""NEE""","""1183""",1.0,2796.0,1.0,5864.0,2940.0,1375.0,1565.0,12.0,8.0,28.0,22.0,30.0,46.0,37.0,9.0,7.0,1560.0,49.0,28.0,23.0,1.8,20.0,28.0,1.0,1.0,3.0,2.0,20.0,52.0,50.0,2.0,"""2022BU03620603…",2022.0


In [None]:
def plot_feature_importances(model, feature_names, top_n=20, title="Feature Importances"):
    """
    Plots the top n feature importances in a horizontal bar chart.

    :param model: The trained model
    :param feature_names: List of feature names
    :param top_n: Number of top features to display
    :param title: Title of the plot
    """
    # Extract feature importances
    importances = model.feature_importances_

    # Create a DataFrame and sort it based on importances
    feature_importance_df = pd.DataFrame(
        {"Feature": feature_names, "Importance": importances}
    )
    feature_importance_df = feature_importance_df.sort_values(
        by="Importance", ascending=False
    ).head(top_n)

    # Plotting
    plt.figure(figsize=(10, 6))
    sns.barplot(
        data=feature_importance_df, y="Feature", x="Importance", palette="viridis"
    )
    plt.title(title)
    plt.xlabel("Relative Importance")
    plt.ylabel("Feature")
    plt.tight_layout()

    return plt

### XGBoost


In [None]:
# Select only the relevant columns
weather_cols = [
    "Dd",
    "Fh",
    "Ff",
    "Fx",
    "T",
    "T10n",
    "Td",
    "Sq",
    "Q",
    "Dr",
    "Rh",
    "P",
    "Vv",
    "N",
    "U",
    "Ww",
    "Ix",
    "M",
    "R",
    "S",
    "O",
    "Y",
]
group_cols = ["Date", "Hour", "Service_Area", "Damage_Type"] + weather_cols

# Aggregate data
agg_df = (
    df.groupby(group_cols)
    .agg(Incident_Count=pl.count("Incident_ID"))
    .sort(["Date", "Hour"])
)


# Drop Date and Hour columns if not needed
agg_df = agg_df.drop(["Date", "Hour"])

# Encode categorical variables using one-hot encoding
agg_df = agg_df.to_dummies(columns=["Service_Area", "Damage_Type"])

# Splitting the features and target variable
y = agg_df["Incident_Count"]
X = agg_df.drop("Incident_Count")

# Convert to Pandas DataFrame for compatibility with scikit-learn
X_pd = X.to_pandas()
y_pd = y.to_pandas()

# Split the data
X_train, X_test, y_train, y_test = train_test_split(
    X_pd, y_pd, test_size=0.2, random_state=42
)

# Train XGBoost model
model = xgb.XGBRegressor(
    objective="count:poisson"
)  # Using Poisson regression for count data

model.fit(X_train, y_train)

# Make predictions and calculate metrics
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)

# Plot feature importances
feature_importances = model.feature_importances_


print(feature_importances)
plot_feature_importances(model, X_train.columns, top_n=20)


# Set the MLflow tracking URI
mlflow.set_tracking_uri("http://dsp-mlflow:5001")

# Start an MLFlow run
with mlflow.start_run(run_name="Incident Prediction Model"):
    # Log model
    mlflow.xgboost.log_model(model, "xgboost-model")

    # Log parameters
    mlflow.log_params(model.get_params())

    # Log metrics
    mlflow.log_metric("MSE", mse)
    mlflow.log_metric("RMSE", rmse)
    mlflow.log_metric("MAE", mae)