In [0]:
from pyspark.sql.types import FloatType, StringType, StructField, StructType

In [0]:
%restart_python

In [0]:
"""
This script performs model inference for the Student Success Tool (SST) pipeline.

It loads a pre-trained ML model from MLflow Model Registry,
reads a processed dataset from Delta Lake, performs inference, calculates SHAP values,
and writes the predictions back to Delta Lake.

The script is designed to run within a Databricks environment as a job task, leveraging
Databricks utilities for job task values and Spark session management.

This is a POC script, it is advised to review and tests before using in production.
"""

# Import necessary libraries
import logging
import os
import argparse
from joblib import Parallel, delayed

import typing as t
import functools as ft
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import shap
from databricks.connect import DatabricksSession


# Import project-specific modules
import student_success_tool.dataio as dataio
from student_success_tool.modeling import inference
import student_success_tool.modeling as modeling
from student_success_tool.schemas import pdp as schemas
# from student_success_tool.pipeline_utils.plot import plot_shap_beeswarm
# from pipelines.tasks.utils import emails


# Disable mlflow autologging (prevents conflicts in Databricks environments)
mlflow.autolog(disable=True)

# Configure logging
logging.basicConfig(level=logging.INFO)
logging.getLogger("py4j").setLevel(logging.WARNING)  # Suppress py4j logging


class ModelInferenceTask:
    """Encapsulates the model inference logic for the SST pipeline."""

    def __init__(self):
        """Initializes the ModelInferenceTask."""
        # self.args = args
        self.spark_session = self.get_spark_session()
        self.cfg = self.read_config(toml_file_path)  
        # print(self.args)
        # print(f"{job_root_dir}/ext/")


    def get_spark_session(self) -> DatabricksSession | None:
        """
        Attempts to create a Spark session.
        Returns:
            DatabricksSession | None: A Spark session if successful, None otherwise.
        """
        try:
            spark_session = DatabricksSession.builder.getOrCreate()
            logging.info("Spark session created successfully.")
            return spark_session
        except Exception:
            logging.error("Unable to create Spark session.")
            raise


    def read_config(self, toml_file_path: str) -> schemas.PDPProjectConfig:
        """Reads the institution's model's configuration file."""
        try:
            cfg = dataio.read_config(toml_file_path, schema=schemas.PDPProjectConfig)
            return cfg
        except FileNotFoundError:
            logging.error("Configuration file not found at %s", toml_file_path)
            raise
        except Exception as e:
            logging.error("Error reading configuration file: %e", e)
            raise


    def load_mlflow_model(self) -> mlflow.pyfunc.PyFuncModel:
        """Loads the MLflow model."""
        model_schema = f"{databricks_institution_name}_gold"
        model_uri = f"runs:/{self.cfg.models['graduation'].run_id}/model"
        print(model_uri)

        try:
            load_model_func = {
                "sklearn": mlflow.sklearn.load_model,
                "xgboost": mlflow.xgboost.load_model,
                "lightgbm": mlflow.lightgbm.load_model,
                "pyfunc": mlflow.pyfunc.load_model,  # Default
            }.get(model_type, mlflow.pyfunc.load_model)
            model = load_model_func(model_uri)
            logging.info("MLflow '%s' model loaded from '%s'", model_type, model_uri)
            return model
        except Exception as e:
            logging.error("Error loading MLflow model: %s", e)
            raise  # Critical error; re-raise to halt execution


    def predict(self, model: mlflow.pyfunc.PyFuncModel, df: pd.DataFrame) -> pd.DataFrame:
        """Performs inference and adds predictions to the DataFrame."""
        try:
            model_feature_names = model.named_steps["column_selector"].get_params()["cols"]
        except AttributeError:
             model_feature_names = model.metadata.get_input_schema().input_names()

        df_serving = df[model_feature_names]

        df_predicted = df_serving.copy()
        df_predicted["predicted_label"] = model.predict(df_serving)
        try:
            df_predicted["predicted_prob"] = model.predict_proba(df_serving)[:, 1]
        except AttributeError:
            logging.error("Model does not have predict_proba method.  Skipping.")
            raise
        return df_predicted

    def write_data_to_delta(self, df: pd.DataFrame, table_name_suffix: str):
        """Writes a DataFrame to a Delta Lake table."""
        write_schema = f"{databricks_institution_name}_silver"
        table_path = f"{DB_workspace}.{write_schema}.{db_run_id}_{table_name_suffix}"

        try:
            dataio.to_delta_table(df, table_path, spark_session=self.spark_session)
            logging.info("%s data written to: %s", table_name_suffix.capitalize(), table_path)
        except Exception as e:
            logging.error("Error writing %s data to Delta Lake: %s", table_name_suffix, e)
            raise # Critical, prevent further execution.

    @staticmethod
    def predict_proba(
        X: pd.DataFrame,
        model: mlflow.pyfunc.PyFuncModel,
        feature_names: t.Optional[list[str]] = None,
        pos_label: t.Optional[bool | str] = None,
    ) -> np.ndarray:
        """Predicts probabilities using the provided model.  Handles data prep."""

        if feature_names is None:
            feature_names = model.named_steps["column_selector"].get_params()["cols"]
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(data=X, columns=feature_names)
        # else: # This check seems unnecessary and potentially incorrect.
        #     assert X.shape[1] == len(feature_names)  # Check *number* of columns.
        pred_probs = model.predict_proba(X)
        if pos_label is not None:
            return pred_probs[:, model.classes_.tolist().index(pos_label)]
        else:
            return pred_probs


    def calculate_shap_values(
        self,
        model: mlflow.pyfunc.PyFuncModel,
        df_processed: pd.DataFrame,
        model_feature_names: list[str]
    ) -> pd.DataFrame | None:
        """Calculates SHAP values."""
        try:
            # --- Load features table ---
            features_table = dataio.read_features_table("assets/pdp/features_table.toml")

            # --- SHAP Values Calculation ---
            # TODO: Consider saving the explainer during training.
            shap_ref_data_size = 200  # Consider getting from config.

            experiment_id = self.cfg.models["graduation"].experiment_id # Consider refactoring this
            df_train = modeling.evaluation.extract_training_data_from_model(experiment_id)
            train_mode = df_train.mode().iloc[0]  # Use .iloc[0] for single row
            df_ref = (
                df_train.sample(
                    n=min(shap_ref_data_size, df_train.shape[0]),
                    random_state=self.cfg.random_state,
                )
                .fillna(train_mode)
                .loc[:, model_feature_names]
            )

            explainer_object = shap.explainers.KernelExplainer(
                ft.partial(
                    self.predict_proba,  # Use the static method
                    model=model,
                    feature_names=model_feature_names,
                    pos_label=self.cfg.pos_label,
                ),
                df_ref,
                link="identity",
            )

            # shap_values = explainer_object(df_processed[model_feature_names])

            def create_explanation(model, data_chunk, explainer_object):
                explanation = explainer_object(data_chunk)
                return explanation
            
            

            def parallel_explanations_joblib(model, X, explainer_object, n_jobs=-1):
                chunks = np.array_split(X, len(X) // 4)
                # results = Parallel(n_jobs=n_jobs)(delayed(create_explanation)(model, chunk, explainer_object) for chunk in chunks)
                results = Parallel(n_jobs=n_jobs)(delayed(lambda model, chunk, explainer: explainer(chunk))(model, chunk, explainer_object) for chunk in chunks)

                combined_values = np.concatenate([r.values for r in results], axis=0)

                combined_data = np.concatenate([r.data for r in results], axis=0)


                combined_explanation = shap.Explanation(values=combined_values, data=combined_data, feature_names=model_feature_names)

                return combined_explanation

            shap_values = parallel_explanations_joblib(model, df_processed[model_feature_names], explainer_object)

            # shap_schema = StructType(
            #     [StructField(self.cfg.student_id_col, StringType(), nullable=False)]
            #     + [StructField(col, FloatType(), nullable=False) for col in model_feature_names]
            # )

            # df_shap_values = (
            #     self.spark_session.createDataFrame(
            #         df_processed.reindex(
            #             columns=model_feature_names + [self.cfg.student_id_col]
            #         )
            #     )
            #     .repartition(self.spark_session.sparkContext.defaultParallelism)
            #     .mapInPandas(
            #         ft.partial(
            #             inference.calculate_shap_values_spark_udf,
            #             student_id_col=self.cfg.student_id_col,
            #             model_features=model_feature_names,
            #             explainer=explainer,
            #             mode=train_mode,
            #         ),
            #         schema=shap_schema,
            #     )
            #     .toPandas()
            #     .set_index(self.cfg.student_id_col)
            #     .reindex(df_processed[self.cfg.student_id_col])
            #     .reset_index(drop=False)
            # )


            return shap_values
        except Exception as e:
            logging.error("Error during SHAP value calculation: %s", e)
            raise
            


    def get_top_features_for_display(self, df_serving, unique_ids, df_predicted, shap_values, model_feature_names):
        """
        Selects top features to display and store
        """
        if not self.spark_session:
            logging.error("Spark session not initialized. Cannot post process shap values.")
            return None

        # --- Load features table ---
        features_table = dataio.read_features_table("assets/pdp/features_table.toml")

        # --- Inference Parameters ---
        inference_params = {
            "num_top_features": 5,
            "min_prob_pos_label": 0.5,
        }

        pred_probs = df_predicted["predicted_prob"]
        # --- Feature Selection for Display ---
        try:
            result = inference.select_top_features_for_display(
                df_serving,
                unique_ids,
                pred_probs,
                shap_values.values,
                n_features=inference_params["num_top_features"],
                features_table=features_table,
                needs_support_threshold_prob=inference_params["min_prob_pos_label"],
            )
            return result

        except Exception as e:
            logging.error("Error top features to display: %s", e)
            return None


    def run(self):
        """Executes the model inference pipeline."""
        print(processed_dataset_path)
        df_processed = dataio.from_delta_table(processed_dataset_path, spark_session=self.spark_session)
        unique_ids = df_processed[self.cfg.student_id_col]
        
        model = self.load_mlflow_model()
        model_feature_names = model.named_steps["column_selector"].get_params()["cols"]
        
         # --- Email notify users ---
        # Uncomment below once we want to enable CC'ing to DK's email.
        # emails.send_inference_kickoff_email(SENDER_EMAIL, [notif_email], [DK_CC_EMAIL], MANDRILL_USERNAME, MANDRILL_PASSWORD)
        # emails.send_inference_kickoff_email(
        #     SENDER_EMAIL, [notif_email], [], MANDRILL_USERNAME, MANDRILL_PASSWORD
        # )


        df_predicted = self.predict(model, df_processed)
        # self.write_data_to_delta(df_predicted, "predicted_dataset")

        # --- SHAP Values Calculation ---
        shap_values = self.calculate_shap_values(model, df_processed, model_feature_names)
        # with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
        #     # Distribute the data to the processes and collect the results
        #     results = pool.map(self.calculate_shap_values(model, df_processed, model_feature_names), df_processed)


        if shap_values is not None:  # Proceed only if SHAP values were calculated
            # --- SHAP Summary Plot ---
            # shap_fig = plot_shap_beeswarm(shap_values)
            
            shap_top_features_results = self.get_top_features_for_display(df_processed, unique_ids, df_predicted, shap_values, model_feature_names)
            # --- Save Results to ext/ folder in Gold volume. ---
            if shap_top_features_results is not None:
                # Specify the folder for the output files to be stored.
                result_path = f"{job_root_dir}/ext/"
                os.makedirs(result_path, exist_ok=True)
                print('result_path:', result_path)

                # Write the DataFrame to Unity Catalog table
                self.write_data_to_delta(shap_top_features_results, "shap_results_dataset")

                # Write the DataFrame to CSV in the specified volume
                spark_df = self.spark_session.createDataFrame(shap_top_features_results)
                spark_df.coalesce(1).write.format("csv").option("header", "true").mode("overwrite").save(result_path + "inference_output")
                # Write the SHAP chart png to the volume
                # shap_fig.savefig(result_path + "shap_chart.png", bbox_inches="tight")
            else:
                logging.error("Empty Shap results, cannot create the SHAP chart and table")
                raise Exception("Empty Shap results, cannot create the SHAP chart and table")

        # --- Write Inference Dataset --- (This was missing, but good to have)
        self.write_data_to_delta(df_processed[model_feature_names], "inference_dataset")

# def parse_arguments() -> argparse.Namespace:
#     """Parses command line arguments."""
#     parser = argparse.ArgumentParser(
#         description="Perform model inference for the SST pipeline.",
#         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
#     )
#     parser.add_argument("--DB_workspace", type=str, required=True, help="Databricks workspace identifier")
#     parser.add_argument("--databricks_institution_name", type=str, required=True, help="Databricks institution name")
#     parser.add_argument("--db_run_id", type=str, required=True, help="Databricks run ID")
#     parser.add_argument("--model_name", type=str, required=True, help="Model name")
#     parser.add_argument("--model_type", type=str, required=True, help="Model type")
#     parser.add_argument("--job_root_dir", required=True, type=str, help="Folder path to store job output files")
#     parser.add_argument("--toml_file_path", type=str, required=True, help="Path to configuration file")
#     parser.add_argument("--processed_dataset_path", type=str, required=True, help="Path to processed dataset table")
#     return parser.parse_args()

# if __name__ == "__main__":


In [0]:
!pip install git+https://github.com/datakind/student-success-tool.git@pdp-inference-pipeline-refactor

In [0]:
spark_session = DatabricksSession.builder.getOrCreate()
course_dataset = dataio.from_delta_table("dev_sst_02.uni_of_crystal_testing_bronze.1080441261747080_course_dataset_validated", spark_session=spark_session)

In [0]:
course_dataset.course_name[course_dataset.course_name=="ENGL101"]

In [0]:
processed_dataset_path = "dev_sst_02.uni_of_crystal_testing_silver.1080441261747080_processed_dataset"
spark_session = DatabricksSession.builder.getOrCreate()
df_processed = dataio.from_delta_table(processed_dataset_path, spark_session=spark_session)

In [0]:
load_model_func = {
                "sklearn": mlflow.sklearn.load_model,
                "xgboost": mlflow.xgboost.load_model,
                "lightgbm": mlflow.lightgbm.load_model,
                "pyfunc": mlflow.pyfunc.load_model,  # Default
            }.get("sklearn", mlflow.pyfunc.load_model)
model = load_model_func("runs:/0b12e0d2eda648a88031636cc21749b6/model")
# model = mlflow.pyfunc.load_model("runs:/0b12e0d2eda648a88031636cc21749b6/model", "sklearn")
model_feature_names = model.named_steps["column_selector"].get_params()["cols"]
shap_ref_data_size = 200
experiment_id = "461684477982665" # Consider refactoring this
df_train = modeling.evaluation.extract_training_data_from_model(experiment_id)
train_mode = df_train.mode().iloc[0]  # Use .iloc[0] for single row
df_ref = (
    df_train.sample(
        n=min(shap_ref_data_size, df_train.shape[0]),
        random_state=1234,
    )
    .fillna(train_mode)
    .loc[:, model_feature_names]
)


def predict_proba(
    X: pd.DataFrame,
    model: mlflow.pyfunc.PyFuncModel,
    feature_names: t.Optional[list[str]] = None,
    pos_label: t.Optional[bool | str] = None,
) -> np.ndarray:
    """Predicts probabilities using the provided model.  Handles data prep."""

    if feature_names is None:
        feature_names = model.named_steps["column_selector"].get_params()["cols"]
    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(data=X, columns=feature_names)
    # else: # This check seems unnecessary and potentially incorrect.
    #     assert X.shape[1] == len(feature_names)  # Check *number* of columns.
    pred_probs = model.predict_proba(X)
    if pos_label is not None:
        return pred_probs[:, model.classes_.tolist().index(pos_label)]
    else:
        return pred_probs
    
explainer_object = shap.explainers.KernelExplainer(
    ft.partial(
        predict_proba,  # Use the static method
        model=model,
        feature_names=model_feature_names,
        pos_label=True,
    ),
    df_ref,
    link="identity",
)

In [0]:
def create_explanation(model, data_chunk, explainer_object):
    explanation = explainer_object(data_chunk)
    return explanation

import os

num_cpus = os.cpu_count()

def parallel_explanations_joblib(model, X, explainer_object, n_jobs=-1):
    chunks = np.array_split(X, len(X) // 4)
    results = Parallel(n_jobs=n_jobs)(delayed(lambda model, chunk, explainer: explainer(chunk))(model, chunk, explainer_object) for chunk in chunks)


    combined_values = np.concatenate([r.values for r in results], axis=0)
    combined_data = np.concatenate([r.data for r in results], axis=0)

    combined_explanation = shap.Explanation(values=combined_values, data=combined_data, feature_names=results[0].feature_names)

    return combined_explanation

def create_explanation(model, data_chunk, explainer_object):
    explanation = explainer_object(data_chunk)
    return explanation





In [0]:
shap_values = parallel_explanations_joblib(model, df_processed[model_feature_names], explainer_object, n_jobs=-1)

In [0]:
def parallel_shap_spark(spark, model, data, explainer_object, feature_names):
    """Parallelizes SHAP calculations using Spark with broadcast."""

    broadcast_explainer = spark.sparkContext.broadcast(explainer_object)

    def calculate_shap_partition(iterator):
        explainer = broadcast_explainer.value #get the broadcasted explainer.
        for row in iterator:
            chunk = np.array(row.features)
            explanation = explainer(chunk)
            yield (explanation.values, explanation.data)

    spark_df = spark.createDataFrame([(row,) for row in data.tolist()], ["features"])
    results_rdd = spark_df.rdd.mapPartitions(calculate_shap_partition)
    collected_results = results_rdd.collect()

    combined_values = np.concatenate([r[0] for r in collected_results], axis=0)
    combined_data = np.concatenate([r[1] for r in collected_results], axis=0)
    combined_explanation = shap.Explanation(values=combined_values, data=combined_data, feature_names=feature_names)

    return combined_explanation

# Example Usage:
# spark = SparkSession.builder.appName("ShapSpark").getOrCreate()

#Example Model and Data
# def example_model(X):
#     return np.sum(X, axis=1)
# data = np.random.rand(1000, 10)
# background_data = shap.sample(data, 100)
# explainer = shap.KernelExplainer(model, background_data)
# feature_names = ["feature_" + str(i) for i in range(data.shape[1])]

result = parallel_shap_spark(spark_session, model, df_processed[model_feature_names], explainer_object, model_feature_names)
print(result.shape)

spark.stop()

In [0]:
explainer_object

In [0]:
spark.sparkContext.broadcast(explainer_object)

In [0]:
shap_values = explainer_object(df_processed[model_feature_names])

In [0]:

task = ModelInferenceTask()
task.run()

In [0]:
DB_workspace="dev_sst_02"
databricks_institution_name="uni_of_crystal_testing"
db_run_id="206958269704514"
model_name="latest_enrollment_model"
model_type="sklearn"
job_root_dir=f"/Volumes/{DB_workspace}/{databricks_institution_name}_gold/gold_volume/inference_jobs/{db_run_id}"
processed_dataset_path=f"{DB_workspace}.{databricks_institution_name}_silver.{db_run_id}_processed_dataset"
toml_file_path=f"/Volumes/{DB_workspace}/{databricks_institution_name}_gold/gold_volume/configuration_files/{databricks_institution_name}_latest_enrollment_model_configuration_file.toml"

In [0]:
import multiprocessing

In [0]:
import multiprocessing

def my_function(data):
    # Process the data
    return result

df_predicted = self.predict(model, df_processed)

if __name__ == '__main__':
    # Prepare the data for processing
    data_list = [data1, data2, data3, data4]

    # Create a pool of processes, using all available cores
    with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
        # Distribute the data to the processes and collect the results
        results = pool.map(self.predict(), model, dataframe)

    # Process the results
    print(results)

In [0]:
dir(shap.Explanation)

In [0]:
from joblib import Parallel, delayed

In [0]:
import shap
import numpy as np
from joblib import Parallel, delayed

def model(X):
    return np.sum(X, axis=1)

X = np.random.rand(1000, 10)

def create_explanation(model, data_chunk, background_data):
    explainer = shap.KernelExplainer(model, background_data)
    explanation = explainer(data_chunk)
    return explanation

def parallel_explanations_joblib(model, X, background_data, n_jobs=-1):
    chunks = np.array_split(X, len(X) // 100)
    results = Parallel(n_jobs=n_jobs)(delayed(create_explanation)(model, chunk, background_data) for chunk in chunks)

    combined_values = np.concatenate([r.values for r in results], axis=0)
    combined_data = np.concatenate([r.data for r in results], axis=0)

    combined_explanation = shap.Explanation(values=combined_values, data=combined_data, feature_names=results[0].feature_names)

    return combined_explanation

background_data = shap.sample(X, 100)
combined_explanation = parallel_explanations_joblib(model, X, background_data)

print(combined_explanation.shape)

In [0]:
import shap
import numpy as np
import multiprocessing
from functools import partial

# Example model
def model(X):
    return np.sum(X, axis=1)

# Example data
X = np.random.rand(1000, 10)

def create_explanation(model, data_chunk, background_data):
    explainer = shap.KernelExplainer(model, background_data)
    explanation = explainer(data_chunk)
    return explanation

def parallel_explanations(model, X, background_data, n_processes=None):
    if n_processes is None:
        n_processes = multiprocessing.cpu_count()

    chunks = np.array_split(X, n_processes)
    pool = multiprocessing.Pool(processes=n_processes)
    partial_create = partial(create_explanation, model, background_data=background_data)
    results = pool.map(partial_create, chunks)
    pool.close()
    pool.join()

    # Combine the Explanation objects (if needed)
    combined_explanation = shap.Explanation.combine(results)
    return combined_explanation

# Example usage
background_data = shap.sample(X, 100)
combined_explanation = parallel_explanations(model, X, background_data)

print(combined_explanation.shape)