## ETL on `OpenPowerlifting` Data

In [2]:
from IPython.display import display, Markdown
import polars as pl
from datetime import datetime as dt

# read configs
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))
from steps import conf

### Loading Data

In [20]:
landing_df = pl.read_parquet(conf.landing_s3_http)

In [22]:
df = pl.read_parquet(conf.raw_s3_http)
print(df.shape)
df.head(5)

(1645568, 44)


name,sex,event,equipment,age,age_class,birth_year_class,division,bodyweight_kg,weight_class_kg,squat1_kg,squat2_kg,squat3_kg,squat4_kg,best3_squat_kg,bench1_kg,bench2_kg,bench3_kg,bench4_kg,best3_bench_kg,deadlift1_kg,deadlift2_kg,deadlift3_kg,deadlift4_kg,best3_deadlift_kg,total_kg,place,dots,wilks,glossbrenner,goodlift,tested,country,state,federation,parent_federation,date,meet_country,meet_state,meet_town,meet_name,year_of_birth,primary_key,origin_country
str,str,str,str,f64,str,str,str,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,f64,f64,f64,f64,str,str,str,str,str,date,str,str,str,str,i32,str,str
"""Alona Vladi""","""F""","""SBD""","""Raw""",33.0,"""24-34""","""24-39""","""O""",58.3,"""60""",75.0,80.0,-90.0,,80.0,50.0,55.0,60.0,,60.0,95.0,105.0,107.5,,107.5,247.5,1,279.44,282.18,249.42,57.1,"""Yes""","""Russia""",,"""GFP""",,2019-05-11,"""Russia""",,"""Bryansk""","""Open Tournamen…",1986,"""alona-vladi-F-…","""Russia"""
"""Galina Solovya…","""F""","""SBD""","""Raw""",43.0,"""40-44""","""40-49""","""M1""",73.1,"""75""",95.0,100.0,105.0,,105.0,62.5,67.5,-72.5,,67.5,100.0,110.0,-120.0,,110.0,282.5,1,278.95,272.99,240.35,56.76,"""Yes""","""Russia""",,"""GFP""",,2019-05-11,"""Russia""",,"""Bryansk""","""Open Tournamen…",1976,"""galina-solovya…","""Russia"""
"""Daniil Voronin…","""M""","""SBD""","""Raw""",15.5,"""16-17""","""14-18""","""T""",67.4,"""75""",85.0,90.0,100.0,,100.0,55.0,62.5,-65.0,,62.5,90.0,100.0,105.0,,105.0,267.5,1,206.4,206.49,200.45,41.24,"""Yes""","""Russia""",,"""GFP""",,2019-05-11,"""Russia""",,"""Bryansk""","""Open Tournamen…",2004,"""daniil-voronin…","""Russia"""
"""Aleksey Krasov…","""M""","""SBD""","""Raw""",35.0,"""35-39""","""24-39""","""O""",66.65,"""75""",125.0,132.0,137.5,,137.5,115.0,122.5,-127.5,,122.5,150.0,165.0,170.0,,170.0,430.0,1,334.49,334.94,325.32,66.68,"""Yes""","""Russia""",,"""GFP""",,2019-05-11,"""Russia""",,"""Bryansk""","""Open Tournamen…",1984,"""aleksey-krasov…","""Russia"""
"""Margarita Ples…","""M""","""SBD""","""Raw""",26.5,"""24-34""","""24-39""","""O""",72.45,"""75""",80.0,85.0,90.0,,90.0,40.0,50.0,-60.0,,50.0,112.5,120.0,125.0,,125.0,265.0,1,194.46,193.55,187.29,39.34,"""Yes""","""Russia""",,"""GFP""",,2019-05-11,"""Russia""",,"""Bryansk""","""Open Tournamen…",1993,"""margarita-ples…","""Russia"""


In [23]:
renamed_df = df.select(conf.base_columns).rename(conf.base_renamed_columns)
print(renamed_df.shape)
renamed_df.head(5)

(1645568, 24)


date,name,sex,place,age,age_class,bodyweight,event,meet_country,equipment,squat,bench,deadlift,total,wilks,tested,federation,meet_name,country,state,parent_federation,origin_country,primary_key,year_of_birth
date,str,str,i64,f64,str,f64,str,str,str,f64,f64,f64,f64,f64,str,str,str,str,str,str,str,str,i32
2019-05-11,"""Alona Vladi""","""F""",1,33.0,"""24-34""",58.3,"""SBD""","""Russia""","""Raw""",80.0,60.0,107.5,247.5,282.18,"""Yes""","""GFP""","""Open Tournamen…","""Russia""",,,"""Russia""","""alona-vladi-F-…",1986
2019-05-11,"""Galina Solovya…","""F""",1,43.0,"""40-44""",73.1,"""SBD""","""Russia""","""Raw""",105.0,67.5,110.0,282.5,272.99,"""Yes""","""GFP""","""Open Tournamen…","""Russia""",,,"""Russia""","""galina-solovya…",1976
2019-05-11,"""Daniil Voronin…","""M""",1,15.5,"""16-17""",67.4,"""SBD""","""Russia""","""Raw""",100.0,62.5,105.0,267.5,206.49,"""Yes""","""GFP""","""Open Tournamen…","""Russia""",,,"""Russia""","""daniil-voronin…",2004
2019-05-11,"""Aleksey Krasov…","""M""",1,35.0,"""35-39""",66.65,"""SBD""","""Russia""","""Raw""",137.5,122.5,170.0,430.0,334.94,"""Yes""","""GFP""","""Open Tournamen…","""Russia""",,,"""Russia""","""aleksey-krasov…",1984
2019-05-11,"""Margarita Ples…","""M""",1,26.5,"""24-34""",72.45,"""SBD""","""Russia""","""Raw""",90.0,50.0,125.0,265.0,193.55,"""Yes""","""GFP""","""Open Tournamen…","""Russia""",,,"""Russia""","""margarita-ples…",1993


In [24]:
renamed_df.filter(pl.col("name") == "Joshua Luu")

date,name,sex,place,age,age_class,bodyweight,event,meet_country,equipment,squat,bench,deadlift,total,wilks,tested,federation,meet_name,country,state,parent_federation,origin_country,primary_key,year_of_birth
date,str,str,i64,f64,str,f64,str,str,str,f64,f64,f64,f64,f64,str,str,str,str,str,str,str,str,i32
2017-12-09,"""Joshua Luu""","""M""",1,20.0,"""20-23""",58.8,"""SBD""","""Singapore""","""Raw""",183.0,107.5,241.0,531.5,461.83,"""Yes""","""OceaniaPF""","""Oceania Champi…","""Australia""",,"""IPF""","""Singapore""","""joshua-luu-M-1…",1997
2022-07-16,"""Joshua Luu""","""M""",1,24.5,"""24-34""",74.5,"""SBD""","""Australia""","""Raw""",205.0,127.5,262.5,595.0,425.96,"""Yes""","""USAPL""","""Gold Coast Cla…","""Australia""",,,"""Australia""","""joshua-luu-M-1…",1998
2023-09-23,"""Joshua Luu""","""M""",2,25.5,"""24-34""",74.1,"""SBD""","""Australia""","""Raw""",225.0,127.5,265.0,617.5,443.75,"""Yes""","""USAPL""","""Australian Jun…","""Australia""",,,"""Australia""","""joshua-luu-M-1…",1998
2023-04-22,"""Joshua Luu""","""M""",1,25.5,"""24-34""",74.15,"""SBD""","""Australia""","""Raw""",220.0,130.0,270.0,620.0,445.34,"""Yes""","""USAPL""","""Bens Army Serv…","""Australia""",,,"""Australia""","""joshua-luu-M-1…",1998
2019-10-03,"""Joshua Luu""","""M""",1,22.0,"""20-23""",68.75,"""SBD""","""Canada""","""Raw""",200.0,122.0,245.0,567.0,430.87,"""Yes""","""WP""","""World Champion…","""Australia""",,"""WP""","""Singapore""","""joshua-luu-M-1…",1997
2018-10-13,"""Joshua Luu""","""M""",4,21.0,"""20-23""",68.25,"""SBD""","""Australia""","""Raw""",192.5,117.5,232.5,542.5,414.62,"""Yes""","""PA""","""Australian Nat…","""Australia""",,"""WP""","""Singapore""","""joshua-luu-M-1…",1997
2019-04-26,"""Joshua Luu""","""M""",1,21.5,"""20-23""",68.65,"""SBD""","""Australia""","""Raw""",193.0,116.0,226.0,535.0,407.01,"""Yes""","""PA""","""Australian Jun…","""Australia""",,"""WP""","""Australia""","""joshua-luu-M-1…",1998
2017-05-21,"""Joshua Luu""","""M""",2,19.5,"""20-23""",58.5,"""SBD""","""Australia""","""Raw""",180.0,100.0,210.0,490.0,427.81,"""Yes""","""PA""","""Obsidian Showd…","""Australia""",,"""IPF""","""Australia""","""joshua-luu-M-1…",1998
2017-03-19,"""Joshua Luu""","""M""",2,19.5,"""20-23""",58.7,"""SBD""","""Australia""","""Raw""",172.5,85.0,220.0,477.5,415.57,"""Yes""","""PA""","""Obsidian Showd…","""Australia""",,"""IPF""","""Australia""","""joshua-luu-M-1…",1998
2018-12-08,"""Joshua Luu""","""M""",1,21.0,"""20-23""",68.75,"""SBD""","""Australia""","""Raw""",185.0,115.0,235.0,535.0,406.55,"""Yes""","""PA""","""THJE Strength …","""Australia""",,"""WP""","""Singapore""","""joshua-luu-M-1…",1997


In [15]:
# IPF Weight class only for PA after 2018-01-01
DAYS_IN_YEAR = 365.25
# USAPL is it's own federation

cleansed_df = renamed_df.filter((pl.col("event") == "SBD") & (pl.col("tested") == "Yes") & (pl.col("equipment") == "Raw")).with_columns(pl.col("federation").str.to_lowercase())
print(cleansed_df.shape)

# Create feature engineered columns
time_since_last_comp_df = cleansed_df.with_columns((pl.col("date") - pl.col("date").shift(-1)).over("primary_key").alias("time_since_last_comp")).with_columns(pl.col("time_since_last_comp").dt.days())

## Feature engineering
# have to filter out the time_since_last_comp since there might be data entry handling error
fe_df = time_since_last_comp_df.with_columns(
    (pl.col("time_since_last_comp") / DAYS_IN_YEAR).alias("years_since_last_comp"),
    (pl.col("meet_country") == pl.col("origin_country")).alias("is_origin_country"),
    pl.col("date").apply(lambda x: x.toordinal()).alias("date_as_ordinal"),
    pl.col("name").cumcount().over("primary_key").alias("cumulative_comps"),
    pl.when(pl.col("meet_name").str.contains("national"))
    .then("national")
    .otherwise(pl.when(pl.col("meet_name").str.contains("International|World|Commonwealth")).then("international").otherwise("local"))
    .alias("meet_type"),
)

fe_df = fe_df.with_columns(
    ((pl.col("squat") - pl.col("squat").shift(1)) / pl.col("years_since_last_comp")).over("primary_key").alias("squat_progress"),
    ((pl.col("bench") - pl.col("bench").shift(1)) / pl.col("years_since_last_comp")).over("primary_key").alias("bench_progress"),
    ((pl.col("deadlift") - pl.col("deadlift").shift(1)) / pl.col("years_since_last_comp")).over("primary_key").alias("deadlift_progress"),
    ((pl.col("total") - pl.col("total").shift(1)) / pl.col("years_since_last_comp")).over("primary_key").alias("total_progress"),
    ((pl.col("wilks") - pl.col("wilks").shift(1)) / pl.col("years_since_last_comp")).over("primary_key").alias("wilks_progress"),
)

# have to filter out the time_since_last_comp since there might be data entry handling error
fe_df = fe_df.with_columns(
    (pl.col("squat").shift(1)).over("primary_key").alias("previous_squat"),
    (pl.col("bench").shift(1)).over("primary_key").alias("previous_bench"),
    (pl.col("deadlift").shift(1)).over("primary_key").alias("previous_deadlift"),
    (pl.col("total").shift(1)).over("primary_key").alias("previous_total"),
)

(149431, 24)


In [17]:
fe_df.filter(pl.col("name") == "Joshua Luu")

date,name,sex,place,age,age_class,bodyweight,event,meet_country,equipment,squat,bench,deadlift,total,wilks,tested,federation,meet_name,country,state,parent_federation,origin_country,primary_key,year_of_birth,time_since_last_comp,years_since_last_comp,is_origin_country,date_as_ordinal,cumulative_comps,meet_type,squat_progress,bench_progress,deadlift_progress,total_progress,wilks_progress,previous_squat,previous_bench,previous_deadlift,previous_total
date,str,str,i64,f64,str,f64,str,str,str,f64,f64,f64,f64,f64,str,str,str,str,str,str,str,str,i32,i64,f64,bool,i64,u32,str,f64,f64,f64,f64,f64,f64,f64,f64,f64


## Visualisation

In [16]:
import altair as alt
import seaborn as sns

numerical_cols = [
    "age",
    "bodyweight",
    "bodyweight_change",
    "time_since_last_comp",
    "cumulative_comps",
    "total",
    "wilks",
]

# fe_df_numerical = fe_df.select(pl.col(numerical_cols)).sample(5000).to_pandas()

# correlation_df = fe_df_numerical.corr()

In [31]:
### plot jl_df using altair

jl_df = fe_df.filter(pl.col("name") == "Joshua Luu").select(numerical_cols + ["date"]).to_pandas()

jl_df

Unnamed: 0,age,bodyweight,bodyweight_change,time_since_last_comp,cumulative_comps,total,wilks,date
0,19.5,58.7,,,0,477.5,415.57,2017-03-19
1,19.5,58.5,-0.2,63.0,1,490.0,427.81,2017-05-21
2,19.5,59.0,0.5,75.0,2,514.5,445.65,2017-08-04
3,20.0,58.9,-0.1,70.0,3,520.0,451.12,2017-10-13
4,20.0,58.8,-0.1,57.0,4,531.5,461.83,2017-12-09
5,20.5,65.65,6.85,132.0,5,522.5,412.07,2018-04-20
6,20.5,66.05,0.4,105.0,6,547.5,429.63,2018-08-03
7,21.0,68.25,2.2,71.0,7,542.5,414.62,2018-10-13
8,21.0,68.75,0.5,56.0,8,535.0,406.55,2018-12-08
9,21.5,68.65,-0.1,139.0,9,535.0,407.01,2019-04-26


In [41]:
alt.Chart(jl_df).mark_line().encode(x=alt.X("date:T", title="Date"), y=alt.Y("wilks:Q", title="Wilks Score")).properties(title="Wilks Score Over Time")

In [178]:
# Create a scatter plot for each feature against 'total'
plots = [
    alt.Chart(fe_df_numerical)
    .mark_circle()
    .encode(
        x=alt.X(f"{feature}:Q", title=feature),
        y=alt.Y("total:Q", title="Total"),
        tooltip=[feature, "total"],
    )
    .properties(width=200, height=200, title=f"Total vs {feature}")
    for feature in numerical_cols
]

In [179]:
# alt.hconcat(*plots)

## Modelling

In [17]:
modelling_cols = [
    "name",
    "date",
    "bodyweight",
    "age_class",
    "sex",
    "total",
    "time_since_last_comp",
    "bodyweight_change",
    "cumulative_comps",
    "meet_type",
    "previous_squat",
    "previous_bench",
    "previous_deadlift",
    "previous_total",
]

In [18]:
modelling_df = fe_df.select(modelling_cols)

modelling_df.filter(pl.col("name") == "Joshua Luu").sort("date", descending=True)

name,date,bodyweight,age_class,sex,total,time_since_last_comp,bodyweight_change,cumulative_comps,meet_type,previous_squat,previous_bench,previous_deadlift,previous_total
str,date,f64,str,str,f64,i64,f64,u32,str,f64,f64,f64,f64
"""Joshua Luu""",2023-09-23,74.1,"""24-34""","""M""",617.5,154.0,-0.05,14,"""local""",220.0,130.0,270.0,620.0
"""Joshua Luu""",2023-04-22,74.15,"""24-34""","""M""",620.0,280.0,-0.35,13,"""local""",205.0,127.5,262.5,595.0
"""Joshua Luu""",2022-07-16,74.5,"""24-34""","""M""",595.0,392.0,-1.9,12,"""local""",205.0,120.0,255.0,580.0
"""Joshua Luu""",2021-06-19,76.4,"""24-34""","""M""",580.0,625.0,7.65,11,"""local""",200.0,122.0,245.0,567.0
"""Joshua Luu""",2019-10-03,68.75,"""20-23""","""M""",567.0,160.0,0.1,10,"""international""",193.0,116.0,226.0,535.0
"""Joshua Luu""",2019-04-26,68.65,"""20-23""","""M""",535.0,139.0,-0.1,9,"""local""",185.0,115.0,235.0,535.0
"""Joshua Luu""",2018-12-08,68.75,"""20-23""","""M""",535.0,56.0,0.5,8,"""local""",192.5,117.5,232.5,542.5
"""Joshua Luu""",2018-10-13,68.25,"""20-23""","""M""",542.5,71.0,2.2,7,"""local""",190.0,115.0,242.5,547.5
"""Joshua Luu""",2018-08-03,66.05,"""20-23""","""M""",547.5,105.0,0.4,6,"""local""",185.0,112.5,225.0,522.5
"""Joshua Luu""",2018-04-20,65.65,"""20-23""","""M""",522.5,132.0,6.85,5,"""local""",183.0,107.5,241.0,531.5


In [None]:
### Use of MLFlow to track experiments
### XG Boost - Tree based modelling
## Low bias - high variance -> overfitting
## High bias - low variance -> underfitting
## High bias - high variance -> underfitting
## Low bias - low variance -> good fit

## k fold validation - split data into k folds, train on k-1 folds and test on the remaining fold
## k fold validation is used to tune hyperparameters
## can adjust XG boost parameters -> ask ChatGPT

In [19]:
import optuna
import mlflow

mlflow.set_tracking_uri("http://localhost:8080")


def get_or_create_experiment(experiment_name):
    """
    Retrieve the ID of an existing MLflow experiment or create a new one if it doesn't exist.

    This function checks if an experiment with the given name exists within MLflow.
    If it does, the function returns its ID. If not, it creates a new experiment
    with the provided name and returns its ID.

    Parameters:
    - experiment_name (str): Name of the MLflow experiment.

    Returns:
    - str: ID of the existing or newly created MLflow experiment.
    """

    if experiment := mlflow.get_experiment_by_name(experiment_name):
        return experiment.experiment_id
    else:
        return mlflow.create_experiment(experiment_name)


run_name = "first_attempt"
today_date = dt.today().strftime("%Y-%m-%d")
experiment_id = get_or_create_experiment(today_date)

KeyboardInterrupt: 

In [None]:
## Perform XG Boost on data
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# split data into X and y
columns_to_exclude = ["name", "total", "date"]
pre_X = modelling_df.select(pl.exclude(columns_to_exclude)).to_pandas()

# need to convert object columns to categorical
X = pre_X
for col in X.select_dtypes(include="object").columns:
    X[col] = X[col].astype("category")

y = modelling_df.select(["total"]).to_pandas()

In [130]:
# split data into train and test sets
RANDOM_SEED = 7
test_size = 0.33
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=RANDOM_SEED)

In [131]:
# standardise data for XG Boost
# Set the current active MLflow experiment
mlflow.set_experiment(experiment_id=experiment_id)
dtrain = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
dtest = xgb.DMatrix(X_test, y_test, enable_categorical=True)

In [119]:
# override Optuna's default logging to ERROR only
optuna.logging.set_verbosity(optuna.logging.ERROR)

# define a logging callback that will report on only new challenger parameter configurations if a
# trial has usurped the state of 'best conditions'


def champion_callback(study, frozen_trial):
    """
    Logging callback that will report when a new trial iteration improves upon existing
    best trial values.

    Note: This callback is not intended for use in distributed computing systems such as Spark
    or Ray due to the micro-batch iterative implementation for distributing trials to a cluster's
    workers or agents.
    The race conditions with file system state management for distributed trials will render
    inconsistent values with this callback.
    """

    winner = study.user_attrs.get("winner", None)

    if study.best_value and winner != study.best_value:
        study.set_user_attr("winner", study.best_value)
        if winner:
            improvement_percent = (abs(winner - study.best_value) / study.best_value) * 100
            print(f"Trial {frozen_trial.number} achieved value: {frozen_trial.value} with " f"{improvement_percent: .4f}% improvement")
        else:
            print(f"Initial trial {frozen_trial.number} achieved value: {frozen_trial.value}")

In [120]:
import math


def objective(trial):
    with mlflow.start_run(nested=True):
        # Define hyperparameters
        params = {
            "objective": "reg:squarederror",
            "eval_metric": "rmse",
            "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]),
            "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True),
            "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True),
        }

        if params["booster"] == "gbtree" or params["booster"] == "dart":
            params["max_depth"] = trial.suggest_int("max_depth", 1, 9)
            params["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True)
            params["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True)
            params["grow_policy"] = trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"])

        # Train XGBoost model
        bst = xgb.train(params, dtrain)
        preds = bst.predict(dtest)
        error = mean_squared_error(y_test, preds)

        # Log to MLflow
        mlflow.log_params(params)
        mlflow.log_metric("mse", error)
        mlflow.log_metric("rmse", math.sqrt(error))

    return error

In [132]:
import matplotlib.pyplot as plt


def plot_feature_importance(model, booster):
    """|
    Plots feature importance for an XGBoost model.

    Args:
    - model: A trained XGBoost model

    Returns:
    - fig: The matplotlib figure object
    """
    fig, ax = plt.subplots(figsize=(10, 8))
    importance_type = "weight" if booster == "gblinear" else "gain"
    xgb.plot_importance(
        model,
        importance_type=importance_type,
        ax=ax,
        title=f"Feature Importance based on {importance_type}",
    )
    plt.tight_layout()
    plt.close(fig)

    return fig


def plot_residuals(model, dvalid, valid_y, save_path=None):
    """
    Plots the residuals of the model predictions against the true values.

    Args:
    - model: The trained XGBoost model.
    - dvalid (xgb.DMatrix): The validation data in XGBoost DMatrix format.
    - valid_y (pd.Series): The true values for the validation set.
    - save_path (str, optional): Path to save the generated plot. If not specified, plot won't be saved.

    Returns:
    - None (Displays the residuals plot on a Jupyter window)
    """

    # Predict using the model
    preds = model.predict(dvalid)

    # Calculate residuals
    residuals = valid_y - preds

    # Set Seaborn style
    sns.set_style("whitegrid", {"axes.facecolor": "#c2c4c2", "grid.linewidth": 1.5})

    # Create scatter plot
    fig = plt.figure(figsize=(12, 8))
    plt.scatter(valid_y, residuals, color="blue", alpha=0.5)
    plt.axhline(y=0, color="r", linestyle="-")

    # Set labels, title and other plot properties
    plt.title("Residuals vs True Values", fontsize=18)
    plt.xlabel("True Values", fontsize=16)
    plt.ylabel("Residuals", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(axis="y")

    plt.tight_layout()

    # Save the plot if save_path is specified
    if save_path:
        plt.savefig(save_path, format="png", dpi=600)

    # Show the plot
    plt.close(fig)

    return fig

In [134]:
# Initiate the parent run and call the hyperparameter tuning child run logic
with mlflow.start_run(experiment_id=experiment_id, run_name=run_name, nested=True):
    # Initialize the Optuna study
    study = optuna.create_study(direction="minimize")

    # Execute the hyperparameter optimization trials.
    # Note the addition of the `champion_callback` inclusion to control our logging
    study.optimize(objective, n_trials=500, callbacks=[champion_callback])

    mlflow.log_params(study.best_params)
    mlflow.log_metric("best_mse", study.best_value)
    mlflow.log_metric("best_rmse", math.sqrt(study.best_value))

    # Log tags
    mlflow.set_tags(
        tags={
            "project": "powerlifting-ml-progress",
            "optimizer_engine": "optuna",
            "model_family": "xgboost",
            "feature_set_version": 1,
        }
    )

    # Log a fit model instance
    model = xgb.train(study.best_params, dtrain)

    # Log the correlation plot
    # mlflow.log_figure(figure=correlation_plot, artifact_file="correlation_plot.png")

    # Log the feature importances plot
    importances = plot_feature_importance(model, booster=study.best_params.get("booster"))
    mlflow.log_figure(figure=importances, artifact_file="feature_importances.png")

    # Log the residuals plot
    # residuals = plot_residuals(model, dtest, y_test)
    # mlflow.log_figure(figure=residuals, artifact_file="residuals.png")

    artifact_path = "model"

    mlflow.xgboost.log_model(
        xgb_model=model,
        artifact_path=artifact_path,
        input_example=X_train.iloc[[0]],
        model_format="ubj",
        metadata={"model_data_version": 1},
    )

    # Get the logged model uri so that we can load it from the artifact store
    model_uri = mlflow.get_artifact_uri(artifact_path)

Initial trial 0 achieved value: 23594.561003033745
Trial 1 achieved value: 4870.922856412906 with  384.3961% improvement
Trial 10 achieved value: 2875.442942837973 with  69.3973% improvement
Trial 45 achieved value: 2874.3035010567723 with  0.0396% improvement
Trial 48 achieved value: 2869.191642884516 with  0.1782% improvement
Trial 73 achieved value: 2862.4562070955644 with  0.2353% improvement
Trial 157 achieved value: 2859.713236563413 with  0.0959% improvement
Trial 224 achieved value: 2857.9547364292534 with  0.0615% improvement


  input_schema = _infer_schema(input_example)
