Hello, when building machine learning models, your goal is to provide insights to your business team or consumers about specific outcomes. The model offers an estimation or prediction to support the business outcome, but it doesn’t explain how to change that outcome. For instance, if you’ve worked with customer data, you might start by predicting which customers are likely to churn. However, the next logical question from the business would be: “What can we do with this information?”

This leads to the second part of the problem: using predictions to inform actions. For example, if a customer is predicted to be high-risk, how can we reduce their risk? Conversely, if a customer is low-risk, how can we ensure they remain so? These scenarios highlight two types of interventions: reducing high-risk cases and proactively managing low-risk cases.

Let’s delve into practical tools for this. I have been using SciKit-Learn as my go-to library for building regression and classification algorithms, and DiCE to generate counterfactuals. DiCE works well for datasets with a few thousand or tens of thousands of rows, but challenges arise with millions of rows. For instance, generating a counterfactual for a single row might take five seconds. Scaling this to 500 million rows would take approximately 500 million seconds, which is impractical.

To address this, I’ve used Ray Serve to distribute the computation of counterfactuals across a Spark cluster. Ray allows you to leverage Spark nodes for parallel processing, even controlling how many CPU cores are used within each node. This setup significantly reduces computation time.

Here are examples to illustrate the solution:

## Example 1: Generating a Counterfactual for a Single Row

In [None]:
#collapse
from dice_ml import Dice
import pandas as pd

# Assuming `model` is your trained model and `data` is your input dataset
dice = Dice(model=model, data=data, method="random")

# Generate a counterfactual
counterfactual = dice.generate_counterfactuals(instance, total_CFs=1, desired_class="opposite")
print(counterfactual.cf_examples_list[0].final_cfs_df)

## Example 2: Distributing Counterfactual Generation Using Ray

In [None]:
#collapse
from ray.util.spark import setup_ray_cluster
import ray

def setup_ray_cluster_with_config(num_nodes, cpus_per_node):
    # Initialize Ray on Spark cluster with configuration
    setup_ray_cluster(num_nodes=num_nodes, cpus_per_node=cpus_per_node)

# Example: Setting up Ray with specific parameters
setup_ray_cluster_with_config(num_nodes=10, cpus_per_node=4)

@ray.remote
def generate_cf(row):
    return dice.generate_counterfactuals(row, total_CFs=1, desired_class="opposite")

# Parallel processing
results = ray.get([generate_cf.remote(row) for row in data])

## Example 3: Using Broadcast Variables in Ray

In [None]:
#collapse
from pyspark import SparkContext

# Broadcasting model and data
sc = SparkContext.getOrCreate()
model_broadcast = sc.broadcast(model)
data_broadcast = sc.broadcast(data)

# Example function leveraging broadcast variables
@ray.remote
def process_row(row):
    dice_local = Dice(model=model_broadcast.value, data=data_broadcast.value, method="random")
    return dice_local.generate_counterfactuals(row, total_CFs=1, desired_class="opposite")

results = ray.get([process_row.remote(row) for row in data])

Additionally, Ray supports logging mechanisms to collect logs from Spark nodes and propagate them to centralized logging systems, such as Log4j or custom log collectors. This ensures robust monitoring in production.

In summary, we’ve covered:

Setting up counterfactual analysis for classification problems.

Distributing counterfactual generation using Ray on Spark clusters.

Leveraging Ray’s broadcast features for efficient computation.

Logging and monitoring distributed tasks in production.

With these techniques, you can efficiently scale counterfactual analysis and integrate it into production systems. Thank you!

