In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from sklearn.datasets import load_iris
import pandas as pd

spark = SparkSession.builder.appName("IrisXGBoost").getOrCreate()

# Load iris data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = pd.Series(iris.target)

# Assemble features into a single vector column
assembler = VectorAssembler(
    inputCols=[col for col in iris.feature_names],
    outputCol="features"
)

# Convert pandas DataFrame to Spark DataFrame
spark_df = spark.createDataFrame(df)

# Split the data into training and testing sets
train_df, test_df = spark_df.randomSplit([0.8, 0.2], seed=42)
train_df.printSchema()
# train_df.select("label").distinct().show()

your 131072x1 screen size is bogus. expect trouble
25/03/15 00:11:37 WARN Utils: Your hostname, jerryasus resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/03/15 00:11:37 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/15 00:11:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/03/15 00:11:38 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


root
 |-- sepal length (cm): double (nullable = true)
 |-- sepal width (cm): double (nullable = true)
 |-- petal length (cm): double (nullable = true)
 |-- petal width (cm): double (nullable = true)
 |-- label: long (nullable = true)



In [2]:
# Import the Spark XGBoost classifier
from xgboost.spark import SparkXGBClassifier

# Define a subclass to add the missing __sklearn_tags__ attribute
class FixedSparkXGBClassifier(SparkXGBClassifier):
    __sklearn_tags__ = {}

# Instantiate the XGBoost Spark classifier for multi-class classification
xgb_classifier = SparkXGBClassifier(
    features_col="features",
    label_col="label",
    num_class=3,
    pred_contrib_col="predict_contrib",
    kwargs={"objective": "multi:softprob"},
)

train_dataset = assembler.transform(train_df)
test_dataset = assembler.transform(test_df)

# Train the model
model = xgb_classifier.fit(train_dataset)

# Make predictions on the test set
predictions = model.transform(test_dataset)
predictions.show(5, False)
# predictions.select("features", "label", "prediction", "probabilities").show()


2025-03-15 00:12:00,271 INFO XGBoost-PySpark: _fit Running xgboost-2.0.3 on 1 workers with
	booster params: {'objective': 'multi:softprob', 'device': 'cpu', 'num_class': 3, 'kwargs': {'objective': 'multi:softprob'}, 'nthread': 1}
	train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
	dmatrix_kwargs: {'nthread': 1, 'missing': nan}
[00:12:05] task 0 got new rank 0                                    (0 + 1) / 1]
Parameters: { "kwargs" } are not used.

2025-03-15 00:12:06,334 INFO XGBoost-PySpark: _fit Finished xgboost training!   
INFO:XGBoost-PySpark:Do the inference on the CPUs
INFO:XGBoost-PySpark:Do the inference on the CPUs
2025-03-15 00:12:09,038 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:12:09,038 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:12:09,038 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
INFO:XGBoost-PySpark:Do the inference on the CPUs                               
2025

+-----------------+----------------+-----------------+----------------+-----+-----------------+-----------------------------------------------------------+----------+----------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|sepal length (cm)|sepal width (cm)|petal length (cm)|petal width (cm)|label|features         |rawPrediction                                              |prediction|probability                                                     |predict_contrib                                                                                                                                                                                                                                                           

In [None]:
row = predictions.first()
predict_contrib = row["predict_contrib"]
print("predict_contrib array:", predict_contrib)

# # Get the number of features from one row in test_dataset
# features_vector = test_dataset.select("features").first()["features"]
# num_features = len(features_vector)
# print("Number of features:", num_features)

# # Each block consists of (number of features + bias term)
# block_length = num_features + 1
# print("Per-class block length (features + bias):", block_length)

# # We set the number of classes to 3
# num_classes = 3

# # Validate the total length of predict_contrib array
# expected_length = num_classes * block_length
# print("Expected total length:", expected_length)
# print("Actual total length:", len(predict_contrib))

[Stage 13:>                                                         (0 + 1) / 1]

predict_contrib array: [[0.0, 0.0, 3.1064515113830566, 0.0, -0.11870837211608887], [-0.9541416764259338, -0.44930654764175415, -1.9644542932510376, 0.21524153649806976, 0.6315736174583435], [-0.2419472187757492, -0.7378804683685303, -1.897713303565979, -1.629811406135559, 0.6291914582252502]]


2025-03-15 00:17:55,622 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
                                                                                

Copilot advised that:

```
The idea behind using the margin here comes from how SHAP values work. In XGBoost, when you use the "pred_contrib" output, the contributions are additive. This means that for each class, the sum of the feature contributions plus the bias equals the margin, which is the raw score (logit) before applying any normalization (like softmax). In multi-class classification:

Each class gets its own margin.
The final prediction is the class with the highest margin.
By summing the contributions in each block, you get the margin for that class.
If the first array (block) corresponds to class 0, then for samples predicted as class 0, the sum (margin) of the first block should be the highest among the three.
Thus, checking the summation (margin) is a way to validate that the block order in predict_contrib aligns with the class predictions.
```

TODO: Learn more about margin and SHAP value math theroy.

In [9]:
def get_class_contributions(predict_contrib, num_classes):
    # Divide the flat predict_contrib list into num_classes blocks
    block_size = len(predict_contrib) // num_classes
    return [predict_contrib[i * block_size : (i + 1) * block_size] for i in range(num_classes)]

def sum_block(block):
    total = 0
    for item in block:
        if isinstance(item, list):
            total += sum(item)
        else:
            total += item
    return total

num_classes = 3  # as set in your classifier

# Extract and print contribution block sums for a few predictions
few_rows = predictions.filter(predictions.prediction==1).limit(5).collect()
for idx, row in enumerate(few_rows):
    pc = row["predict_contrib"]
    blocks = get_class_contributions(pc, num_classes)
    # Calculate margin (sum of contributions including bias) for each class using the helper function
    margins = [sum_block(block) for block in blocks]
    predicted = row["prediction"]
    print(f"Row {idx}:")
    print(f"  Class 0 margin: {margins[0]}")
    print(f"  Class 1 margin: {margins[1]}")
    print(f"  Class 2 margin: {margins[2]}")
    print(f"  Predicted class: {predicted}")
    if predicted == 0:
        if margins[0] >= margins[1] and margins[0] >= margins[2]:
            print("  -> First array (class 0) has the highest margin. Validation OK.")
        else:
            print("  -> Unexpected: First array margin is not highest!")
    else:
        print("  -> For non-zero predictions, inspect margins accordingly.")
    print("---------------------------------------------------")

2025-03-15 00:24:34,385 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:34,401 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:34,483 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:34,507 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:34,514 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:34,515 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,232 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,233 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,234 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,237 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,400 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:37,401 INFO XGBoost-PySpar

Row 0:
  Class 0 margin: -2.5659544467926025
  Class 1 margin: 3.484293833374977
  Class 2 margin: -1.4969568401575089
  Predicted class: 1.0
  -> For non-zero predictions, inspect margins accordingly.
---------------------------------------------------
Row 1:
  Class 0 margin: -2.5659544467926025
  Class 1 margin: 3.4083408564329147
  Class 2 margin: -1.4969568401575089
  Predicted class: 1.0
  -> For non-zero predictions, inspect margins accordingly.
---------------------------------------------------
Row 2:
  Class 0 margin: -2.5659544467926025
  Class 1 margin: 3.4083408564329147
  Class 2 margin: -1.4969568401575089
  Predicted class: 1.0
  -> For non-zero predictions, inspect margins accordingly.
---------------------------------------------------
Row 3:
  Class 0 margin: -2.5659544467926025
  Class 1 margin: 1.7605621814727783
  Class 2 margin: -3.482727751135826
  Predicted class: 1.0
  -> For non-zero predictions, inspect margins accordingly.
----------------------------------

2025-03-15 00:24:41,002 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2025-03-15 00:24:41,027 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
                                                                                