# DS ML Project ⟡ Flight Delay Prediction Challenge

## Synopsis

**TODO** Write this paragraph

This notebook represents our analysis of the [flight delay dataset for Tunisair](https://zindi.africa/competitions/flight-delay-prediction-challenge) from [Zindi](https://zindi.africa) 

At last: "our" refers to ...
- [greseberisha](https://github.com/greseberisha)
- [MoSeBaur](https://github.com/MoSeBaur)
- [kvn-dtrx](https://github.com/kvn-dtrx)

## Requirements

In [None]:
# Data Science
import pandas as pd

# Scientific Computation
import numpy as np

# Scikit-Learn Tools
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
# from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import r2_score
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Power predictive score
import ppscore as pps

import warnings

# Configs

Next, let us specify all required configurations to have them in one place.

In [None]:
# Level of the warnings module
WARNINGS_LEVEL = "ignore"
warnings.filterwarnings(WARNINGS_LEVEL)

# Path to train data
PATH_DATA_TRAIN = "./data/train.csv"

# Random seed
RSEED = 42

# Resolution when storing plots in files
DPI = 600

# Matplotlib style
PLT_STYLE = "seaborn"
try:
    plt.style.use(PLT_STYLE)
except:
    warnings.warn(f"Could not load matplotlib style {PLT_STYLE}", UserWarning)

# Whether to run long computations
RUN_LONG_COMPUTATIONS = False

## First Confrontation with the Data

We read the data from file into a pandas data frame and create a copy that will incorporate our manipulations.

In [None]:
df_0 = pd.read_csv(PATH_DATA_TRAIN)

df = df_0.copy()

The usual initial inspection commands:

In [None]:
print(df.head())
print(df.shape)
print(df.isnull().sum())
print(df.dtypes)

df.columns

Meaning of column names (according to <https://zindi.africa/competitions/flight-delay-prediction-challenge/data>):

Present in the data:

| Column | Description |
| --- | --- |
| ID | Unique identifier for the flight |
| DATOP | Date of flight |
| FLTID | Flight number |
| DEPSTN | Departure point (station/airport) |
| ARRSTN | Arrival point (station/airport) |
| STD | Scheduled Time of Departure |
| STA | Scheduled Time of Arrival |
| STATUS | Flight status (e.g., delayed, canceled) |
| AC | Aircraft code |
| target | Flight delay (in minutes) |


Not present in the data (although claimed on the referenced web page):

| Column | Description |
| --- | --- |
| ETD | Expected Time departure |
| ETA | Expected Time arrival |
| ATD | Actual Time of Departure |
| ATA | Actual Time of arrival |
| DELAY1 | Delay code 1 |
| DUR1 | Delay time 1 |
| DELAY2 | Delay code 2 |
| DUR2 | Delay time 2 |
| DELAY3 | Delay code 3 |
| DUR3 | Delay time 3 |
| DELAY4 | Delay code 4 |
| DUR4 | Delay time 4 |

### Status Column

In [None]:
# Sorry for "statuses" ...
statuses = df["STATUS"].unique()

print("All Statuses:")
for status in statuses:
    print(f"  Number of entries of {status}: {df[df['STATUS'] == status].shape[0]}")
    print(f"  Mean: {df[df['STATUS'] == status]['target'].mean()}")
    print(f"  Median: {df[df['STATUS'] == status]['target'].median()}")

for status in statuses:
    df[df["STATUS"] == status]["target"].hist(
        bins=50,
        log=False,
    )
    plt.title(status)
    plt.xlabel("Delay")
    plt.ylabel("Frequency")

    plt.savefig(f"./img/delay-to-sum-flight-on-status-eq-{status}_hist.png", dpi=DPI, bbox_inches="tight")
    plt.show()


| Code | Name              | Description                                                                 |
|------|-------------------|-----------------------------------------------------------------------------|
| ATA  | Actual Time Arrival| Flights that successfully landed at their destination.                     |
| DEP  | Departed          | Flights that departed but may not have completed their journey.             |
| RTR  | Returned          | Flights that took off but returned to the departure airport due to issues.  |
| SCH  | Scheduled         | Flights listed in the schedule, no delay data applicable.                   |
| DEL  | Canceled          | Flights that were canceled, treated as permanent delays.                    |

The interpretation of DEP remains a bit obscure ... in a first approximation, we drop it. Further, it is hard to measure the delay of a DEL flight (a possibility for regular flights would be to take the duration between the DEL flight and the next flight that indeed arrives plus the delay of that flight). But we also decide us simply for dropping.

In [None]:
df = df[~df["STATUS"].isin(["DEP", "DEL"])]

### Airport Columns

We introduce columns reducing the airports of departure and destination to its country, respectively.

In [None]:
airports = pd.read_csv("airports.csv")
airports = airports[["iata_code", "iso_country"]]
airports = airports.dropna()

In [None]:
# TODO Make this cell idempotent

df = df.merge(
    airports[["iata_code", "iso_country"]],
    left_on="DEPSTN",
    right_on="iata_code",
    how="left",
)
df.drop(columns="iata_code", inplace=True)
df.rename(columns={"iso_country": "country_dep"}, inplace=True)

df = df.merge(
    airports[["iata_code", "iso_country"]],
    left_on="ARRSTN",
    right_on="iata_code",
    how="left",
)
df.drop(columns="iata_code", inplace=True)
df.rename(columns={"iso_country": "country_arr"}, inplace=True)

df["country_arr"].shape

df.loc[df["DEPSTN"] == "SXF", "country_dep"] = "DE"
df.loc[df["ARRSTN"] == "SXF", "country_arr"] = "DE"

For converting the iso codes to continent codes, we use the functionality provided by the module `pycountry_convert`.

In [None]:
import pycountry_convert as pc


def iso_to_continent(iso):
    try:
        continent_code = pc.country_alpha2_to_continent_code(iso)
        return pc.convert_continent_code_to_continent_name(continent_code)
    except:
        return None


df["continent_dep"] = df["country_dep"].apply(iso_to_continent)
df["continent_arr"] = df["country_arr"].apply(iso_to_continent)

### Dating issues

The data set contains several columns with date semantics. Let us convert them to the appropriate dtype.

In [None]:
df["DATOP"] = pd.to_datetime(df["DATOP"], format="%Y-%m-%d")
df["STD"] = pd.to_datetime(df["STD"], format="%Y-%m-%d %H:%M:%S")
df["STA"] = pd.to_datetime(df["STA"], format="%Y-%m-%d %H.%M.%S")

Now, we can introduce a bunch of further useful date and time related columns:

In [None]:
df["DATOP_year"] = df["DATOP"].dt.year
df["DATOP_month"] = df["DATOP"].dt.month
df["DATOP_day"] = df["DATOP"].dt.dayofweek + 1

def map_hour_to_period(hour):
    if 6 <= hour < 12:
        return "morning"
    elif 12 <= hour < 18:
        return "day"
    elif 18 <= hour < 24:
        return "evening"
    else:
        return "night"


df["STD_hour"] = df["STD"].dt.hour
df["STD_period"] = df["STD_hour"].apply(map_hour_to_period)

df["flight_time"] = (df["STA"] - df["STD"]).dt.total_seconds() / 60

Which years are actually present?

In [None]:
DATOP_years = df["DATOP_year"].unique()
DATOP_years

So the data are from the years 2016, 2017, 2018. 

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(DATOP_years), figsize=(16, 5), sharey=True)

for idx, year in enumerate(DATOP_years):
    # Filter the DataFrame for the specific year
    df_year = df[df["DATOP_year"] == year]
    
    # Plot the histogram on the respective subplot
    axes[idx].hist(df_year["DATOP_month"], bins=range(1, 14), alpha=0.8, color="blue")
    axes[idx].set_title(f"Flight Distribution for {year}")
    axes[idx].set_xlabel("Month")
    axes[idx].set_xticks(range(1, 13))  # Set x-axis ticks for months
    axes[idx].set_ylabel("Number of Flights")

# Adjust layout
plt.tight_layout()

# Save the combined plot
plt.savefig("./img/month-to-sum-flight-by-year_hist.png", dpi=300, bbox_inches="tight")

# Display the plot
plt.show()

In each year, we find a suspicious months in which the sum of flights is significantly less than in the others. Looking at the provided test data set one sees that the majority of flights for the affected months can be found there (sic!) ... We drop these months completely: 

In [None]:
df = df[~((df["DATOP_month"] == 5) & (df["DATOP_year"] == 2016))]
df = df[~((df["DATOP_month"] == 2) & (df["DATOP_year"] == 2017))]
df = df[~((df["DATOP_month"] == 9) & (df["DATOP_year"] == 2018))]


df = df[df["DEPSTN"] != df["ARRSTN"]]


In [None]:
plt.figure(figsize=(8, 6))
plt.scatter(df["flight_time"], df["target"], color="blue")
plt.xlabel("Flight Time")
plt.ylabel("Delay")
plt.xlim(1,3000)
plt.ylim(1,3000)
plt.title("Scatter Plot of X Column vs Y Column")
plt.show()

In [None]:
for year in DATOP_years:
    plt.figure(figsize=(8, 4))
    df_year = df[df["DATOP_year"] == year]
    df_year.groupby("DATOP_month")["target"].mean().plot(
        kind="line",
        title=f"Monthly Average of Delay for {year}",
        xlabel="Month",
        ylabel="Average of Delay",
    )
    plt.savefig(f"./img/month-to-avg-delay-on-status-eq-{status}_line.png", dpi=DPI, bbox_inches="tight")
    plt.show()

## Final Look at the Brushed Data


In [None]:
df.head()


A scatterplot gives a feeling for single and pairwise distributions:

In [None]:
if RUN_LONG_COMPUTATIONS:
    sns.pairplot(df)

    plt.savefig("./img/each-vs-each-wrt-distribution_scatterplot.png", dpi=DPI, bbox_inches="tight")
    plt.show()

Inspecting the correlation matrix is never a bad idea:

In [None]:
correlation_matrix = df.corr()

# Plot the correlation matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
    correlation_matrix,
    annot=True,
    # cmap="Reds",
    fmt=".2f",
)
plt.title("Correlation Matrix")

plt.savefig("./img/each-vs-each-wrt-correlation_heatmap.png", dpi=DPI, bbox_inches="tight")
plt.show()

But as we have so many categorical features, it makes sense to compute the power predictive score (pps) matrix:

In [None]:
if RUN_LONG_COMPUTATIONS:

    cols = [
        col for col in df.columns
        if not col.startswith("DATOP_") and col not in ["ID"]
    ]
    
    df_tmp = df[cols]

    pp_scores = pps.matrix(df_tmp)[["x", "y", "ppscore"]].pivot(
        columns="x", index="y", values="ppscore"
    )

    pp_scores = pp_scores.round(2)

    plt.figure(figsize=(12, 8))

    sns.heatmap(
        pp_scores,
        vmin=0,
        vmax=1,
        # cmap="Reds",
        linewidths=0.5,
        annot=True,
    )

    plt.savefig("./img/each-vs-each-wrt-pp-score_heatmap.png", dpi=DPI, bbox_inches="tight")
    
    plt.plot()

## Some Considerations About Flight Delay

In [None]:
# die Verteilung der Verspätungen unter Abluegen und Ankuenften
# Durchschnittliche Verspätung pro Abflughafen
dep_delay = df.groupby("DEPSTN")["target"].mean().sort_values(ascending=False)

plt.figure(figsize=(12, 6))
dep_delay.plot(kind="bar", color="skyblue")
plt.title("Durchschnittliche Verspätung pro Abflughafen")
plt.xlabel("Abflughafen")
plt.ylabel("Durchschnittliche Verspätung (Minuten)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"./img/dep-to-avg-delay_hist.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
# Durchschnittliche Verspätung pro Ankunftsflughafen
arr_delay = df.groupby("ARRSTN")["target"].mean().sort_values(ascending=False)

plt.figure(figsize=(12, 6))
arr_delay.plot(kind="bar", color="salmon")
plt.title("Durchschnittliche Verspätung pro Ankunftsflughafen")
plt.xlabel("Ankunftsflughafen")
plt.ylabel("Durchschnittliche Verspätung (Minuten)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"./img/dest-to-avg-delay_hist.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(14, 6))
sns.boxplot(x="DEPSTN", y="target", data=df)
plt.title("Verteilung der Verspätungen pro Abflughafen")
plt.xticks(rotation=45)
plt.savefig(f"./img/airport-to-delay_boxplot.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
# Pivot-Tabelle für Heatmap
route_delay = df.pivot_table(
    index="ARRSTN", columns="DEPSTN", values="target", aggfunc="mean"
)

plt.figure(figsize=(12, 8))
sns.heatmap(route_delay, cmap="Reds", linewidths=0.5, annot=False)
plt.title("Durchschnittliche Verspätung zwischen Abflug- und Ankunftsflughäfen")
plt.xlabel("Ankunftsflughafen")
plt.ylabel("Abflughafen")
plt.tight_layout()
plt.savefig(f"./img/dep-vs-dest-wrt-avg-delay_heatmap.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
df[df["DEPSTN"] == "TUN"]["target"].hist(bins=30)
plt.title("Verspätungsverteilung – Abflughafen TUN")
plt.xlabel("Verspätung in Minuten")
plt.ylabel("Anzahl Flüge")
plt.savefig(f"./img/airport-to-delay-on-dest-eq-TUN_boxplot.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
pivot = df.pivot_table(
    index="DEPSTN", columns="ARRSTN", values="target", aggfunc="mean"
)
plt.figure(figsize=(12, 8))
sns.heatmap(pivot, annot=False, cmap="Reds")
plt.title("Durchschnittliche Verspätung je Flugroute (in Minuten)")
plt.xlabel("Ankunftsflughafen")
plt.ylabel("Abflughafen")
plt.savefig(f"./img/dep-vs-dest-wrt-delay_heatmap.png", dpi=DPI, bbox_inches="tight")
plt.show()

In [None]:
top_dep = df["DEPSTN"].value_counts().head(10)
avg_delay_dep = df.groupby("DEPSTN")["target"].mean().loc[top_dep.index]

summary = pd.DataFrame({"Fluganzahl": top_dep, "Ø Verspätung (Min.)": avg_delay_dep})
print(summary)

### Without Assignment

In [None]:
df["target"].hist(
    bins=100, 
    log=False,
)
plt.xlabel("Delay")
plt.ylabel("Frequency")
plt.xlim(1, 1000)

plt.savefig(f"./img/delay-to-sum-flight_hist.png", dpi=DPI, bbox_inches="tight")
plt.show()

# Baseline Model (a.k.a. A Feeble Try), Version I

**Hypothesis**: The flight delay can be predicted from the Aircraft Code.

In [None]:
df_encoded = pd.get_dummies(df, columns=["DATOP_day"], prefix="AC")

y = df.target
X = df_encoded.drop("target", axis=1)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=RSEED
)

cols = [col for col in df_encoded.columns if col.startswith("AC_")]

X_0 = X_train[cols]
y_0 = y_train
X_1 = X_test[cols]
y_1 = y_test

model = LinearRegression()

model.fit(X_0, y_0)

print("Coefficients:", model.coef_)
print("Intercept:", model.intercept_)

For analysing the prediction errors we plot for the test set the actual values against the predicted values:

In [None]:
y_pred = model.predict(X_1)

mse = mean_squared_error(y_1, y_pred)
rmse = np.sqrt(mse)
r2 = r2_score(y_1, y_pred)

print("Root Mean Squared Error:", rmse)
print("R-squared Score:", r2)

t = 1000

# Calculate residuals
residuals = y_test - y_pred

# Plot residuals
plt.figure(figsize=(8, 6))
sns.scatterplot(x=y_test, y=y_pred)
plt.plot([0, t], [0, t], color='red', linestyle='--')
plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.xlim(0, 3000)
plt.title("Residuals Plot")
plt.savefig(f"./img/actual-vs-predicted_dist.png", dpi=DPI, bbox_inches="tight")
plt.show()

## A Model Using CatBoost

As among the features there are many categorical ones, we choose a method that claims to be suited for such situations ...

In [None]:
from catboost import CatBoostRegressor

target_col = "target"
feature_cols = [
    "STATUS",
    "FLTID",
    "AC",
    "flight_time",
    "DEPSTN", 
    "ARRSTN",
    "DATOP_year", 
    "DATOP_month", 
    "DATOP_day",
    "STD",
]

cat_cols = [
    "STATUS",
    "FLTID",
    "AC",
    # "flight_time",
    "DEPSTN", 
    "ARRSTN",
    # "DATOP_year", 
    # "DATOP_month", 
    # "DATOP_day",
    # "STD",
]

y = df[target_col]
X = df[feature_cols]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RSEED
)

model = CatBoostRegressor(verbose=0)

# Parameter grid for GridSearchCV
param_grid = {
    # Number of boosting rounds (trees). 
    # Higher values allow the model to capture complex patterns but may increase risk of overfitting.
    "iterations": [300, 500],

    # Learning rate controls the step size in gradient descent.
    # Smaller values slow down learning but reduce the risk of overshooting the optimal solution.
    "learning_rate": [0.01, 0.05, 0.1],
    # "learning_rate": [0.05, 0.1],

    # Maximum depth of each decision tree.
    # Larger values allow the model to capture more intricate feature interactions but may increase overfitting.
    "depth": [8, 10, 12],

    # Specifies the categorical columns in the dataset.
    # CatBoost handles these features differently, such as using target statistics or one-hot encoding.
    "cat_features": [cat_cols],
}

grid_search = GridSearchCV(
    estimator=model,
    param_grid=param_grid,
    scoring="neg_mean_squared_error",
    cv=5, 
    verbose=0,
)

grid_search.fit(X, y)

best_params = grid_search.best_params_
print("Best parameters:", grid_search.best_params_)

best_model = grid_search.best_estimator_


In [None]:
y_pred = best_model.predict(X_test)

rmse = mean_squared_error(y_test, y_pred, squared=False)
r2 = r2_score(y_test, y_pred)
print(f"RMSE: {rmse}")
print(f"R_2: {r2}")

t = 1000

# Calculate residuals
residuals = y_test - y_pred

# Plot residuals
plt.figure(figsize=(8, 6))
sns.scatterplot(x=y_test, y=y_pred)
plt.plot([0, t], [0, t], color="red", linestyle="--")
plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.xlim(0, 3000)
plt.title("Residuals Plot")
plt.savefig(f"./img/actual-vs-predicted_dist.png", dpi=DPI, bbox_inches="tight")
plt.show()

## A Two-Level Estimator

A try to combine a classifier and a regressor in the following way:
- If the classifier predicts that the delay is below a certain threshold return zero.
- Otherwise, use the prediction of the regressor that is trained on a curated data set.

In [None]:
from catboost import CatBoostRegressor
from catboost import CatBoostClassifier

THRESH = 600

THRESHS = [75, 100, 125, 150, 600]

for THRESH in THRESHS:
    X_train, X_test, y_train, y_test = train_test_split(
        X, 
        y, 
        test_size=0.2, 
        random_state=RSEED,
        stratify = (y < THRESH).astype(int)
    )

    p_train = (y_train < THRESH).astype(int)

    classifier = CatBoostClassifier(
        verbose=0,
        iterations=500,
        learning_rate=0.05,
        depth=1,
        cat_features=cat_cols,
    )

    classifier.fit(X_train, p_train)

    # Caveat: This is a prediction on the training set!
    p_predict_0_ = classifier.predict(X_train)
    p_predict_ = p_predict_0_.astype(bool)

    X_0 = X_train[p_predict_]
    y_0 = y_train[p_predict_]

    regressor = CatBoostRegressor(
        verbose=0,
        iterations=500,
        learning_rate=0.05,
        depth=4,
        cat_features=cat_cols,
    )

    regressor.fit(X_0, y_0)

    p_pred = classifier.predict(X_test)
    y_pred = p_pred * regressor.predict(X_test)

    rmse = mean_squared_error(y_test, y_pred, squared=False)
    r2 = r2_score(y_test, y_pred)
    print(f"THRESH: {THRESH}")
    print(f"pmean: {np.mean(p_predict_0_)}")
    print(f"pmean: {np.mean(p_pred)}")
    print(f"RMSE: {rmse}")
    print(f"R_2: {r2}")

## MoSe Modelling

In [None]:
df2 = pd.get_dummies(
    df, columns=["DATOP_day"], prefix="day", drop_first=True, dtype=int
)
df2 = pd.get_dummies(
    df2, columns=["DATOP_year"], prefix="yr", drop_first=True, dtype=int
)
df2 = pd.get_dummies(
    df2, columns=["DATOP_month"], prefix="mon", drop_first=True, dtype=int
)
df2 = pd.get_dummies(df2, columns=["DEPSTN"], prefix="dep", drop_first=True, dtype=int)
df2 = pd.get_dummies(df2, columns=["ARRSTN"], prefix="arr", drop_first=True, dtype=int)
df2 = pd.get_dummies(df2, columns=["AC"], prefix="ac", drop_first=True, dtype=int)
df2 = pd.get_dummies(
    df2, columns=["STD_period"], prefix="std", drop_first=True, dtype=int
)
df2 = pd.get_dummies(
    df2, columns=["STA_period"], prefix="sta", drop_first=True, dtype=int
)

In [None]:
y = df2.target
X = df2.drop("target", axis=1)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=RSEED
)

In [None]:
prefixes = ["day_", "yr_", "mt_", "ac_", "dep_", "arr_", "std_", "sta_"]

# Collect columns that match those prefixes
feature_cols = [
    col for col in df2.columns if any(col.startswith(p) for p in prefixes)
] + ["flight_time"]

x0 = X_train[feature_cols]
x1 = X_test[feature_cols]


model = LinearRegression()
# model = KNeighborsRegressor(n_neighbors=5)

model.fit(x0, y_train)
y_pred_test = model.predict(x1)

print(np.sqrt(mean_squared_error(y_test, y_pred_test)))
print(r2_score(y_test, y_pred_test))

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score

# Define the model
model = RandomForestRegressor(random_state=42)

# Define the hyperparameter grid
param_grid = {
    "n_estimators": [100],
    "max_depth": [20, 30],
    "min_samples_split": [5, 10],
    "min_samples_leaf": [2, 5],
    "max_features": ["auto", "sqrt"],
}

# Create GridSearchCV
grid_search = GridSearchCV(
    estimator=model,
    param_grid=param_grid,
    cv=5,  # 5-fold cross-validation
    scoring="neg_mean_squared_error",  # or use 'r2' if you prefer
    n_jobs=-1,  # Use all available cores
    verbose=2,
)

# Fit the grid search to the data
grid_search.fit(x0, y_train)

# Best model
best_model = grid_search.best_estimator_

# Predict on test data
y_pred_test = best_model.predict(x1)

# Evaluate
print("Best Parameters:", grid_search.best_params_)
print("RMSE:", np.sqrt(mean_squared_error(y_test, y_pred_test)))
print("R²:", r2_score(y_test, y_pred_test))

In [None]:
from xgboost import XGBRegressor

# Define the XGBoost model
xgb_model = XGBRegressor(random_state=42, verbosity=0)

# Define the hyperparameter grid
param_grid = {
    'n_estimators': [100],
    'max_depth': [10, 30],
    'learning_rate': [0.01, 0.1],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

# Set up GridSearchCV
grid_search = GridSearchCV(
    estimator=xgb_model,
    param_grid=param_grid,
    cv=5,
    scoring='neg_mean_squared_error',
    n_jobs=-1,
    verbose=1
)

# Fit on training data
grid_search.fit(x0, y_train)

# Best model
best_model = grid_search.best_estimator_

# Predict on test data
y_pred_test = best_model.predict(x1)

# Evaluate
print("Best Parameters:", grid_search.best_params_)
print("RMSE:", np.sqrt(mean_squared_error(y_test, y_pred_test)))
print("R²:", r2_score(y_test, y_pred_test))