In [0]:
"""
This script performs model inference for the Student Success Tool (SST) pipeline.

It loads a pre-trained ML model from MLflow Model Registry, 
reads a processed dataset from Delta Lake, performs inference, calculates SHAP values, 
and writes the predictions back to Delta Lake.  

The notebook is designed to run within a Databricks environment as a job task, leveraging Databricks 
utilities for widget input, job task values, and Spark session management.

This is a POC notebook, it is advised to refactor to .py and add tests before using in production.
"""

# Import necessary libraries
import logging
import os

import typing as t
import functools as ft
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import shap
from databricks.connect import DatabricksSession
from databricks.sdk.runtime import dbutils  # noqa: F401
from pyspark.sql.types import FloatType, StringType, StructField, StructType
from databricks.sdk import WorkspaceClient
from email.headerregistry import Address

# Import project-specific modules
import student_success_tool.dataio as dataio
from student_success_tool.modeling import inference
import student_success_tool.modeling as modeling
from student_success_tool.schemas import pdp as schemas
from student_success_tool import emails

# Disable mlflow autologging (prevents conflicts in Databricks environments)
mlflow.autolog(disable=True)

# Configure logging
logging.basicConfig(level=logging.INFO)
logging.getLogger("py4j").setLevel(logging.WARNING)  # Suppress py4j logging

# --- Spark Session Initialization ---
try:
    spark_session = DatabricksSession.builder.getOrCreate()
except Exception:
    logging.warning("Unable to create Spark session; are you in a Databricks runtime?")
    spark_session = None

# --- Configuration ---
# Input Parameters ( from Databricks widgets)
DB_workspace = dbutils.widgets.get("DB_workspace")
catalog = DB_workspace
institution_name = dbutils.widgets.get("databricks_institution_name")
sst_job_id = dbutils.widgets.get("db_run_id")
model_name = dbutils.widgets.get("model_name")
model_version = dbutils.widgets.get("version_id")
model_type = dbutils.widgets.get("model_type")
notif_email = dbutils.widgets.get("notification_email")

# Secrets from Databricks
w = WorkspaceClient()
MANDRILL_USERNAME = w.dbutils.secrets.get(scope="sst", key="MANDRILL_USERNAME")
MANDRILL_PASSWORD = w.dbutils.secrets.get(scope="sst", key="MANDRILL_PASSWORD")
SENDER_EMAIL = Address("Datakind Info", "help", "datakind.org")
DK_CC_EMAIL = 'education@datakind.org'

# --- Unity caatalog schemas ---
read_schema = f"{institution_name}_silver"
write_schema = f"{institution_name}_silver"
model_schema = f"{institution_name}_gold"


# --- Insititution Configuration ---
cfg = dataio.read_config(
    f"/Volumes/{DB_workspace}/{institution_name}_gold/gold_volume/configuration_files/{institution_name}_{model_name}_configuration_file.toml",
    schema=schemas.PDPProjectConfig,
)

# --- Model Configuration ---
# model_run_id = "890b54cf68b147d7a55f515f61d5bfb2"
# experiment_id = "1510364684601785"
experiment_id = cfg.models["graduation"].experiment_id
model_run_id = cfg.models["graduation"].run_id
model_uri = f"models:/{catalog}.{model_schema}.{model_name}/1"


# --- Load features table ---
features_table = dataio.read_features_table("assets/pdp/features_table.toml")


# --- Helper Functions ---
def mlflow_load_model(model_uri: str, model_type: str):
    """Loads an MLflow model based on its type."""

    # Dictionary mapping model types to loading functions
    load_model_func = {
        "sklearn": mlflow.sklearn.load_model,
        "xgboost": mlflow.xgboost.load_model,
        "lightgbm": mlflow.lightgbm.load_model,
        "pyfunc": mlflow.pyfunc.load_model,  # Default
    }.get(model_type, mlflow.pyfunc.load_model)

    model = load_model_func(model_uri)
    logging.info("MLflow '%s' model loaded from '%s'", model_type, model_uri)
    return model


def predict_proba(
    X,
    model,
    *,
    feature_names: t.Optional[list[str]] = None,
    pos_label: t.Optional[bool | str] = None,
) -> np.ndarray:
    """Predicts probabilities using the provided model."""

    if feature_names is None:
        feature_names = model.named_steps["column_selector"].get_params()["cols"]
    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(data=X, columns=feature_names)
    else:
        assert X.shape == len(feature_names)
    pred_probs = model.predict_proba(X)
    if pos_label is not None:
        return pred_probs[:, model.classes_.tolist().index(pos_label)]
    else:
        return pred_probs


# --- Main Inference Logic ---
if spark_session:
    # --- Data Loading ---
    df_processed_dataset = dataio.from_delta_table(
        f"{catalog}.{read_schema}.{sst_job_id}_processed_dataset",
        spark_session=spark_session,
    )
    unique_ids = df_processed_dataset[cfg.student_id_col]
    df_train = modeling.evaluation.extract_training_data_from_model(experiment_id)

    # --- Model Loading ---
    loaded_model = mlflow_load_model(model_uri, model_type)

    # --- Inference Parameters ---
    inference_params = {
        "num_top_features": 5,
        "min_prob_pos_label": 0.5,
    }

    # --- Feature Selection ---
    try:
        model_feature_names = loaded_model.named_steps["column_selector"].get_params()[
            "cols"
        ]
    except AttributeError:
        model_feature_names = loaded_model.metadata.get_input_schema().input_names()
    df_serving_dataset = df_processed_dataset[model_feature_names]

    # --- Write Inference Dataset ---
    inference_dataset_path = f"{catalog}.{write_schema}.{sst_job_id}_inference_dataset"
    dataio.to_delta_table(
        df_serving_dataset, inference_dataset_path, spark_session=spark_session
    )

    # --- Prediction ---
    df_predicted_dataset = df_serving_dataset.copy()
    df_predicted_dataset["predicted_label"] = loaded_model.predict(df_serving_dataset)
    try:
        df_predicted_dataset["predicted_prob"] = loaded_model.predict_proba(
            df_serving_dataset
        )[:, 1]
    except AttributeError:
        logging.error(
            "Model does not have predict_proba method. Skipping probability prediction."
        )

    # --- Write Predicted Dataset ---
    predicted_dataset_path = f"{catalog}.{write_schema}.{sst_job_id}_predicted_dataset"
    dataio.to_delta_table(
        df_predicted_dataset,
        predicted_dataset_path,
        spark_session=spark_session,
    )
    logging.info("Predictions saved to: %s",predicted_dataset_path)


    # --- Email notify users ---
    # Uncomment below once we want to enable CC'ing to DK's email.
    # emails.send_inference_kickoff_email(SENDER_EMAIL, [notif_email], [DK_CC_EMAIL], MANDRILL_USERNAME, MANDRILL_PASSWORD)
    emails.send_inference_kickoff_email(SENDER_EMAIL, [notif_email], [], MANDRILL_USERNAME, MANDRILL_PASSWORD)

    # --- SHAP Values Calculation ---
    pred_probs = df_predicted_dataset["predicted_prob"]

    # TODO: Consider saving the explainer during training
    # TODO: Pedro's note: Consider getting shap_ref_data_size from a workflow parameter or from the toml config file
    shap_ref_data_size = 200
    train_mode = df_train.mode().iloc
    df_ref = (
        df_train.sample(
            n=min(shap_ref_data_size, df_train.shape[0]),
            random_state=cfg.random_state,
        )
        .fillna(train_mode)
        .loc[:, model_feature_names]
    )

    explainer = shap.explainers.KernelExplainer(
        ft.partial(
            predict_proba,
            model=loaded_model,
            feature_names=model_feature_names,
            pos_label=cfg.pos_label,
        ),
        df_ref,
        link="identity",
    )

    shap_schema = StructType(
        [StructField(cfg.student_id_col, StringType(), nullable=False)]
        + [StructField(col, FloatType(), nullable=False) for col in model_feature_names]
    )

    df_shap_values = (
        spark_session.createDataFrame(
            df_processed_dataset.reindex(
                columns=model_feature_names + [cfg.student_id_col]
            )
        )
        .repartition(spark_session.sparkContext.defaultParallelism)
        .mapInPandas(
            ft.partial(
                inference.calculate_shap_values_spark_udf,
                student_id_col=cfg.student_id_col,
                model_features=model_feature_names,
                explainer=explainer,
                mode=train_mode,
            ),
            schema=shap_schema,
        )
        .toPandas()
        .set_index(cfg.student_id_col)
        .reindex(df_processed_dataset[cfg.student_id_col])
        .reset_index(drop=False)
    )

    # --- SHAP Summary Plot ---
    # NOTE: to change colors https://stackoverflow.com/questions/60153036/changing-the-gradient-color-of-shap-summary-plot-to-specific-2-or-3-rgb-grad
    shap.summary_plot(
        df_shap_values.loc[:, model_feature_names].to_numpy(),
        df_serving_dataset.loc[:, model_feature_names],
        class_names=loaded_model.classes_,
        max_display=20,
        show=False,
    )
    shap_fig = plt.gcf()

    # # --- Log SHAP Plot to MLflow ---
    # with mlflow.start_run(run_id=model_run_id) as run:
    #     mlflow.log_figure(
    #         shap_fig,
    #         f"shap_summary_dataset_name_dataset_"
    #         f"{df_ref.shape}_ref_rows.png",
    #     )

    # --- Feature Selection for Display ---
    result = inference.select_top_features_for_display(
        df_serving_dataset,
        unique_ids,
        pred_probs,
        df_shap_values[model_feature_names].to_numpy(),
        n_features=inference_params["num_top_features"],
        features_table=features_table,
        needs_support_threshold_prob=inference_params["min_prob_pos_label"],
    )
    # Write the inference-ready dataset to Delta Lake.
    shap_results_path = f"{catalog}.{write_schema}.{sst_job_id}_shap_results_dataset"
    dataio.to_delta_table(result, shap_results_path, spark_session=spark_session)

    # --- Save Results to ext/ folder in Gold volume. ---
    # Specify where in the gold volume these output files should be stored.
    result_path = f"/Volumes/{DB_workspace}/{institution_name}_gold/gold_volume/inference_jobs/{sst_job_id}/ext/"
    os.makedirs(result_path, exist_ok=True)

    # Write the DataFrame to CSV in the specified volume
    spark_df = spark.createDataFrame(result)
    # Note that this writes multiple files under the save() parameter as a directory.
    spark_df.coalesce(1).write.format("csv").option("header", "true").mode("overwrite").save(result_path+'inference_output')

    # Write the SHAP chart png to the volume
    shap_fig.savefig(result_path+'shap_chart.png')

else:
    logging.error("Spark session not initialized.")

: 

In [0]:
# TODO there are model dependencies that need to be installed at runtime
# This was the error receieved (although it still worked)
"""
- mlflow (current: 2.20.0, required: mlflow==2.19.0)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.
"""

In [0]:
# TODO save the charts
# # # --- Log SHAP Plot to MLflow ---
# with mlflow.start_run(run_id=model_run_id) as run:
#     mlflow.log_figure(
#         shap_fig,
#         f"shap_summary_dataset_name_dataset_"
#         f"{df_ref.shape}_{sst_job_id}_shap_results.png",
#     )