In [0]:
import mlflow
import mlflow.spark
from mlflow.models.signature import infer_signature
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.sql.functions import col, when, trim, lit
import matplotlib.pyplot as plt
import pandas as pd
import os
import shutil

In [0]:
df_raw = spark.read.format("csv").option("inferSchema", "true").load("/databricks-datasets/adult/adult.data")

In [0]:
cols = ["age", "workclass", "fnlwgt", "education", "education_num", "marital_status", "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss", "hours_per_week", "native_country", "income"]
df = df_raw.toDF(*cols)

In [0]:
df = df.withColumn("label", when(col("income").contains(">50K"), 1).otherwise(0))
for c in df.columns:
    if dict(df.dtypes)[c] == "string":
        df = df.withColumn(c, when(trim(col(c)).isin(["?", ""]), None).otherwise(col(c)))

df_clean = df.na.drop()

In [0]:
# Since we are using RF, we can keep native_country as is. 
# We remove education, as this information is in the eduction_num
df_clean = df_clean.drop("education")
df_clean = df_clean.drop("native_country")

In [0]:
train, test = df_clean.randomSplit([0.8, 0.2], seed=42)
print(f"Train Rows: {train.count()}")

In [0]:
def train_census_model(depth, trees):

    with mlflow.start_run(run_name="census_rf_model") as run:
        mlflow.log_param("model_type", "RandomForest")
        mlflow.log_param("max_depth", depth)
        mlflow.log_param("num_trees" , trees)

        stages = []
        cat_cols = ["workclass", "marital_status", "occupation", "relationship", "race", "sex"]
        num_cols = ["age", "fnlwgt", "education_num", "capital_gain", "capital_loss", "hours_per_week"]

        for c in cat_cols:
            stages.append(StringIndexer(inputCol=c, outputCol=c+"_idx", handleInvalid="skip"))
            
        final_inputs = [c+"_idx" for c in cat_cols] + num_cols
        stages.append(VectorAssembler(inputCols=final_inputs, outputCol="features", handleInvalid="skip"))
        
        rf = RandomForestClassifier(featuresCol="features", labelCol="label", maxDepth=depth, numTrees=trees, seed=42, maxBins=42)
        stages.append(rf)

        pipeline = Pipeline(stages=stages)

        print("Training Pipeline...")
        model = pipeline.fit(train)

        try:
            local_path = "/tmp/census_rf_model"
            if os.path.exists(local_path):
                shutil.rmtree(local_path) # Clean up old run
            
            # Save to local linux filesystem
            model.write().overwrite().save(f"file:{local_path}")
            
            # Log as generic artifact (Bypasses UC Volume check)
            mlflow.log_artifacts(local_path, artifact_path="spark_model")
            print("Model saved via Local Workaround.")
            
        except Exception as e:
            print(f"Model Saving Failed (Expected on some clusters): {e}")

        # Predict
        predictions = model.transform(test)

        eval_f1 = MulticlassClassificationEvaluator(metricName="f1")
        eval_auc = BinaryClassificationEvaluator(metricName="areaUnderROC")
        
        f1 = eval_f1.evaluate(predictions)
        auc = eval_auc.evaluate(predictions)

        print(f"AUC: {auc:.4f}")
        print(f"F1:  {f1:.4f}")

        # Log the metrics in mlflow
        mlflow.log_metric("f1_score", f1)
        mlflow.log_metric("auc", auc)

        try:
            rf_model = model.stages[-1]
            importances = rf_model.featureImportances.toArray()
            
            # Reconstruct Names manually (Lite Pipeline = Direct Mapping)
            feat_names = [c+"_idx" for c in cat_cols] + num_cols
            
            # Create DataFrame
            fi_df = pd.DataFrame({'Feature': feat_names, 'Importance': importances}).sort_values(by='Importance', ascending=False).head(15)
            
            # Plot
            plt.figure(figsize=(10,6))
            plt.barh(fi_df['Feature'], fi_df['Importance'], color='teal')
            plt.title(f"Feature Importance (Depth {depth})")
            plt.xlabel("Importance")
            plt.gca().invert_yaxis()
            plt.tight_layout()
            
            # Save & Log
            plot_path = "/tmp/feature_importance.png"
            plt.savefig(plot_path)
            plt.close()
            mlflow.log_artifact(plot_path)
        except Exception as e:
            print(f"Plotting failed: {e}")
        finally:
            del model

        return run.info.run_id

In [0]:
run_id = train_census_model(10, 10)