# **Histogram Based Method for Distributed IMM**

## 1. Histogram Formation

In [2]:
import math
import random
from collections import defaultdict, namedtuple
from pyspark.rdd import RDD
from pyspark.sql import SparkSession

# Define a small epsilon for floating-point comparisons
EPSILON = 1e-9

# Define the Instance and Split namedtuples
Instance = namedtuple("Instance", ["features", "label", "weight"])
Split = namedtuple("Split", ["feature_index", "threshold", "categories", "is_continuous"])

In [3]:
def samples_fraction_for_find_splits(max_bins: int, num_examples: int) -> float:
    """
    Calculate the subsample fraction for finding splits based on max_bins and num_examples.

    :param max_bins: Maximum number of bins used for splitting.
    :param num_examples: Number of examples (rows) in the dataset.
    :return: A float representing the fraction of data to use.
    """
    required_samples = max(max_bins * max_bins, 10000)
    if required_samples < num_examples:
        return float(required_samples) / num_examples
    else:
        return 1.0

In [4]:
def find_splits_for_continuous_feature_values(
    feature_values,
    is_continuous: bool,
    num_splits: int,
    max_bins: int,
    total_num_examples: int
) -> list:
    """
    Aggregates continuous feature values and counts, then computes split thresholds.

    :param feature_values: An iterable of numeric values (floats).
    :param is_continuous: Boolean indicating if the feature is continuous.
    :param num_splits: The maximum number of splits allowed for this feature.
    :param max_bins: Used for binning / quantile calculation logic.
    :param total_num_examples: Total number of examples (rows) for fraction calculation.
    :return: A list of Split objects representing split thresholds.
    """
    if not is_continuous:
        raise ValueError("find_splits_for_continuous_feature_values can only be used with a continuous feature.")

    # Aggregate values into a dictionary: {feature_value -> total_count}
    value_counts = defaultdict(int)
    count = 0

    for v in feature_values:
        value_counts[v] += 1  # Each data point has weight = 1
        count += 1

    # Convert to a normal dict and call the second function
    return find_splits_for_continuous_feature_weights(
        part_value_weights=dict(value_counts),
        count=count,
        is_continuous=is_continuous,
        num_splits=num_splits,
        max_bins=max_bins,
        total_num_examples=total_num_examples
    )

In [5]:
def find_splits_for_continuous_feature_weights(
    part_value_weights: dict,
    count: int,
    is_continuous: bool,
    num_splits: int,
    max_bins: int,
    total_num_examples: float
) -> list:
    """
    Computes split thresholds for a single continuous feature.

    :param part_value_weights: Dict of { feature_value -> count }.
    :param count: Total number of data points aggregated for this feature.
    :param is_continuous: Boolean indicating if the feature is continuous.
    :param num_splits: The maximum number of splits allowed for this feature.
    :param max_bins: Used for binning / quantile calculation logic.
    :param total_num_examples: Total number of examples for fraction calculation.
    :return: A list of Split objects representing split thresholds.
    """
    if not is_continuous:
        raise ValueError("find_splits_for_continuous_feature_weights can only be used with a continuous feature.")

    # If no values exist, return empty.
    if not part_value_weights:
        return []

    # Sum of counts for this feature.
    part_num_samples = sum(part_value_weights.values())

    # Compute the fraction of the data to use for splits
    fraction = samples_fraction_for_find_splits(
        max_bins=max_bins,
        num_examples=count
    )

    # Weighted number of samples (since weights are counts)
    weighted_num_samples = fraction * float(count)

    # Tolerance for floating-point adjustments
    tolerance = EPSILON * count * 100

    # Add zero-value count if needed
    # If the expected number of samples minus the actual is greater than tolerance, add a zero count
    if weighted_num_samples - part_num_samples > tolerance:
        part_value_weights = dict(part_value_weights)  # Make a copy to avoid mutating the original
        additional_count = weighted_num_samples - part_num_samples
        part_value_weights[0.0] = part_value_weights.get(0.0, 0.0) + additional_count

    # Sort the values
    sorted_pairs = sorted(part_value_weights.items(), key=lambda x: x[0])  # [(value, count), ...]

    # Number of possible splits is number of intervals between sorted values
    possible_splits = len(sorted_pairs) - 1

    if possible_splits == 0:
        # All feature values are the same => no splits
        return []

    if possible_splits <= num_splits:
        # If we have fewer or equal intervals compared to allowed splits, return all midpoints
        splits = []
        for i in range(1, len(sorted_pairs)):
            left_val = sorted_pairs[i - 1][0]
            right_val = sorted_pairs[i][0]
            midpoint = (left_val + right_val) / 2.0
            splits.append(Split(feature_index=-1, threshold=midpoint, categories=None, is_continuous=True))  # feature_index to be set later
        return splits

    # Otherwise, use stride-based approach
    stride = weighted_num_samples / (num_splits + 1)

    splits_builder = []
    index = 1
    current_count = sorted_pairs[0][1]
    target_count = stride

    while index < len(sorted_pairs):
        previous_count = current_count
        current_count += sorted_pairs[index][1]
        previous_gap = abs(previous_count - target_count)
        current_gap = abs(current_count - target_count)

        if previous_gap < current_gap:
            # Place a split threshold between previous value and current value
            left_val = sorted_pairs[index - 1][0]
            right_val = sorted_pairs[index][0]
            midpoint = (left_val + right_val) / 2.0
            splits_builder.append(Split(feature_index=-1, threshold=midpoint, categories=None, is_continuous=True))
            target_count += stride

        index += 1

    return splits_builder

In [6]:
def find_splits_for_categorical_feature(
    categories: list,
    counts: list,
    is_unordered: bool,
    num_splits: int
) -> list:
    """
    Computes split thresholds for a single categorical feature.

    :param categories: List of category names or indices.
    :param counts: List of counts corresponding to each category.
    :param is_unordered: Boolean indicating if the categorical feature is unordered.
    :param num_splits: The maximum number of splits allowed for this feature.
    :return: A list of Split objects representing split thresholds.
    """
    if is_unordered:
        # Handle unordered categorical features (multiclass with low arity)
        # Find all possible subsets of categories (excluding empty and full set)
        # For practicality, limit the number of subsets
        # Here, we'll use one-vs-all splits for simplicity
        splits = []
        for cat in categories:
            splits.append(Split(
                feature_index=-1,
                threshold=None,
                categories={cat},
                is_continuous=False
            ))
            if len(splits) >= num_splits:
                break
        return splits
    else:
        # Handle ordered categorical features
        # Treat as sorted categories and find splits based on sorted order
        sorted_categories_with_counts = sorted(zip(categories, counts), key=lambda x: x[0])
        sorted_categories = [x[0] for x in sorted_categories_with_counts]
        sorted_counts = [x[1] for x in sorted_categories_with_counts]

        # Number of possible splits is number of categories -1
        possible_splits = len(sorted_categories) - 1

        if possible_splits <= num_splits:
            # Return all possible splits by ordering
            splits = []
            for i in range(1, len(sorted_categories)):
                left_cats = set(sorted_categories[:i])
                splits.append(Split(
                    feature_index=-1,
                    threshold=None,
                    categories=left_cats,
                    is_continuous=False
                ))
            return splits
        else:
            # Use stride-based approach to distribute splits based on counts
            splits = []
            stride = sum(sorted_counts) / (num_splits + 1)
            current_sum = 0
            target = stride
            left_cats = set()

            for cat, cnt in zip(sorted_categories, sorted_counts):
                current_sum += cnt
                left_cats.add(cat)
                if current_sum >= target:
                    splits.append(Split(
                        feature_index=-1,
                        threshold=None,
                        categories=set(left_cats),
                        is_continuous=False
                    ))
                    target += stride
                    if len(splits) >= num_splits:
                        break

            return splits

In [7]:
def find_splits_by_sorting(
    sampled_input_rdd: RDD,
    num_features: int,
    is_continuous: list,
    is_unordered: list,
    max_splits_per_feature: list,
    max_bins: int,
    total_weighted_examples: float
) -> list:
    """
    Finds split thresholds for both continuous and categorical features by sorting and aggregating.

    :param sampled_input_rdd: RDD of Instance (features, label, weight).
    :param num_features: Total number of features.
    :param is_continuous: List of booleans indicating if each feature is continuous.
    :param is_unordered: List of booleans indicating if each categorical feature is unordered.
    :param max_splits_per_feature: List of max splits allowed for each feature.
    :param max_bins: Binning parameter used for fraction calculation logic.
    :param total_weighted_examples: Total weighted number of examples in the dataset.
    :return: A 2D list of Split objects. Outer list is indexed by feature, inner list contains splits for that feature.
    """
    # 1. Identify continuous and categorical features
    continuous_features = [i for i, cont in enumerate(is_continuous) if cont]
    categorical_features = [i for i, cont in enumerate(is_continuous) if not cont]

    # 2. Handle continuous features
    continuous_splits = find_splits_for_continuous_features(
        sampled_input_rdd=sampled_input_rdd,
        continuous_features=continuous_features,
        max_splits_per_feature=max_splits_per_feature,
        max_bins=max_bins,
        total_weighted_examples=total_weighted_examples
    )

    # 3. Handle categorical features
    categorical_splits = find_splits_for_categorical_features(
        sampled_input_rdd=sampled_input_rdd,
        categorical_features=categorical_features,
        is_unordered=is_unordered,
        max_splits_per_feature=max_splits_per_feature
    )

    # 4. Combine splits for all features
    all_splits = []
    for fidx in range(num_features):
        if is_continuous[fidx]:
            all_splits.append(continuous_splits.get(fidx, []))
        else:
            all_splits.append(categorical_splits.get(fidx, []))

    return all_splits

In [17]:
def find_splits_for_continuous_features(
    sampled_input_rdd: RDD,
    continuous_features: list,
    max_splits_per_feature: list,
    max_bins: int,
    total_weighted_examples: float
) -> dict:
    """
    Finds splits for continuous features.

    :param sampled_input_rdd: RDD of Instance (features, label, weight).
    :param continuous_features: List of feature indices that are continuous.
    :param max_splits_per_feature: List of max splits allowed for each feature.
    :param max_bins: Binning parameter used for fraction calculation logic.
    :param total_weighted_examples: Total weighted number of examples in the dataset.
    :return: Dictionary mapping feature index to list of Split objects.
    """
    if not continuous_features:
        return {}

    # For each Instance, emit (featureIndex, featureValue)
    feature_value_pairs = (
        sampled_input_rdd
        .flatMap(lambda inst: [
            (i, inst.features[i])
            for i in continuous_features
        ])
        .filter(lambda x: x[1] != 0.0)  # Optionally filter out zero values
    )

    # Aggregate counts for each feature and value
    feature_aggregates = (
        feature_value_pairs
        .map(lambda x: (x[0], x[1]))  # (featureIndex, featureValue)
        .map(lambda x: ((x[0], x[1]), 1))  # ((featureIndex, featureValue), 1)
        .reduceByKey(lambda a, b: a + b)  # ((featureIndex, featureValue), count)
        .map(lambda x: (x[0][0], (x[0][1], x[1])))  # (featureIndex, (featureValue, count))
    )

    # Collect as map: { featureIndex -> list of (featureValue, count) }
    feature_value_counts = feature_aggregates.groupByKey().mapValues(list).collectAsMap()

    # Now compute splits for each continuous feature
    continuous_splits = {}
    for fidx in continuous_features:
        value_weight_map = {v: c for v, c in feature_value_counts.get(fidx, [])}
        count = sum(value_weight_map.values())
        splits = find_splits_for_continuous_feature_weights(
            part_value_weights=value_weight_map,
            count=count,
            is_continuous=True,
            num_splits=max_splits_per_feature[fidx],
            max_bins=max_bins,
            total_num_examples=total_weighted_examples
        )
        # Assign the correct feature index to each split
        splits_with_index = [Split(feature_index=fidx, threshold=s.threshold, categories=None, is_continuous=True) for s in splits]
        continuous_splits[fidx] = splits_with_index

    return continuous_splits

In [18]:
def find_splits_for_categorical_features(
    sampled_input_rdd: RDD,
    categorical_features: list,
    is_unordered: list,
    max_splits_per_feature: list
) -> dict:
    """
    Finds splits for categorical features.

    :param sampled_input_rdd: RDD of Instance (features, label, weight).
    :param categorical_features: List of feature indices that are categorical.
    :param is_unordered: List of booleans indicating if each categorical feature is unordered.
    :param max_splits_per_feature: List of max splits allowed for each feature.
    :return: Dictionary mapping feature index to list of Split objects.
    """
    if not categorical_features:
        return {}

    # For each Instance, emit (featureIndex, category)
    feature_category_pairs = (
        sampled_input_rdd
        .flatMap(lambda inst: [
            (i, inst.features[i])
            for i in categorical_features
        ])
    )

    # Aggregate counts for each feature and category
    feature_aggregates = (
        feature_category_pairs
        .map(lambda x: (x[0], x[1]))  # (featureIndex, category)
        .map(lambda x: ((x[0], x[1]), 1))  # ((featureIndex, category), 1)
        .reduceByKey(lambda a, b: a + b)  # ((featureIndex, category), count)
        .map(lambda x: (x[0][0], (x[0][1], x[1])))  # (featureIndex, (category, count))
    )

    # Collect as map: { featureIndex -> list of (category, count) }
    feature_category_counts = feature_aggregates.groupByKey().mapValues(list).collectAsMap()

    # Now compute splits for each categorical feature
    categorical_splits = {}
    for fidx in categorical_features:
        category_count_pairs = feature_category_counts.get(fidx, [])
        categories = [x[0] for x in category_count_pairs]
        counts = [x[1] for x in category_count_pairs]
        unordered = is_unordered[fidx]

        splits = find_splits_for_categorical_feature(
            categories=categories,
            counts=counts,
            is_unordered=unordered,
            num_splits=max_splits_per_feature[fidx]
        )
        # Assign the correct feature index to each split
        splits_with_index = [Split(feature_index=fidx, threshold=None, categories=s.categories, is_continuous=False) for s in splits]
        categorical_splits[fidx] = splits_with_index

    return categorical_splits


In [19]:
def find_splits(
    input_rdd: RDD,
    num_features: int,
    is_continuous: list,
    is_unordered: list,
    max_splits_per_feature: list,
    max_bins: int,
    total_weighted_examples: float,
    seed: int
) -> list:
    """
    Finds splits for decision tree calculation. Handles both continuous and categorical features.

    :param input_rdd: RDD of Instances.
    :param num_features: Number of total features.
    :param is_continuous: List of booleans indicating if each feature is continuous.
    :param is_unordered: List of booleans indicating if each categorical feature is unordered.
                         Must align with categorical_features indices.
    :param max_splits_per_feature: List of max splits allowed for each feature.
    :param max_bins: Binning parameter used in the fraction calculation.
    :param total_weighted_examples: Weighted number of examples in the entire dataset.
    :param seed: Random seed for sampling.
    :return: A 2D list (list of lists) of Split objects [featureIndex -> list of Splits].
    """
    # Identify continuous and categorical features
    continuous_features = [i for i, cont in enumerate(is_continuous) if cont]
    categorical_features = [i for i, cont in enumerate(is_continuous) if not cont]

    if not continuous_features and not categorical_features:
        # No features to split on
        return [[] for _ in range(num_features)]

    # Compute fraction using 'samples_fraction_for_find_splits'
    num_examples = input_rdd.count()  # Total number of examples
    fraction = samples_fraction_for_find_splits(
        max_bins=max_bins,
        num_examples=num_examples
    )

    # Log the fraction (optional)
    print(f"Fraction of data used for calculating splits = {fraction}")

    # Sample the input if fraction < 1
    if fraction < 1.0:
        # PySpark sampling: sample(withReplacement, fraction, seed)
        # To ensure reproducibility, combine seed with some randomness
        rnd_seed = random.Random(seed).randint(0, 2**32 - 1)
        sampled_input = input_rdd.sample(withReplacement=False, fraction=fraction, seed=rnd_seed)
    else:
        sampled_input = input_rdd

    # Find splits by sorting
    splits_2d = find_splits_by_sorting(
        sampled_input_rdd=sampled_input,
        num_features=num_features,
        is_continuous=is_continuous,
        is_unordered=is_unordered,
        max_splits_per_feature=max_splits_per_feature,
        max_bins=max_bins,
        total_weighted_examples=total_weighted_examples
    )

    return splits_2d

In [20]:
def find_splits_for_categorical_feature(
    categories: list,
    counts: list,
    is_unordered: bool,
    num_splits: int
) -> list:
    """
    Computes split thresholds for a single categorical feature.

    :param categories: List of category names or indices.
    :param counts: List of counts corresponding to each category.
    :param is_unordered: Boolean indicating if the categorical feature is unordered.
    :param num_splits: The maximum number of splits allowed for this feature.
    :return: A list of Split objects representing split thresholds.
    """
    if is_unordered:
        # Handle unordered categorical features (multiclass with low arity)
        # Find all possible subsets of categories (excluding empty and full set)
        # For practicality, limit the number of subsets
        # Here, we'll use one-vs-all splits for simplicity
        splits = []
        for cat in categories:
            splits.append(Split(
                feature_index=-1,
                threshold=None,
                categories={cat},
                is_continuous=False
            ))
            if len(splits) >= num_splits:
                break
        return splits
    else:
        # Handle ordered categorical features
        # Treat as sorted categories and find splits based on sorted order
        sorted_categories_with_counts = sorted(zip(categories, counts), key=lambda x: x[0])
        sorted_categories = [x[0] for x in sorted_categories_with_counts]
        sorted_counts = [x[1] for x in sorted_categories_with_counts]

        # Number of possible splits is number of categories -1
        possible_splits = len(sorted_categories) - 1

        if possible_splits <= num_splits:
            # Return all possible splits by ordering
            splits = []
            for i in range(1, len(sorted_categories)):
                left_cats = set(sorted_categories[:i])
                splits.append(Split(
                    feature_index=-1,
                    threshold=None,
                    categories=left_cats,
                    is_continuous=False
                ))
            return splits

        # Otherwise, use stride-based approach to distribute splits based on counts
        splits = []
        stride = sum(sorted_counts) / (num_splits + 1)
        current_sum = 0
        target = stride
        left_cats = set()

        for cat, cnt in zip(sorted_categories, sorted_counts):
            current_sum += cnt
            left_cats.add(cat)
            if current_sum >= target:
                splits.append(Split(
                    feature_index=-1,
                    threshold=None,
                    categories=set(left_cats),
                    is_continuous=False
                ))
                target += stride
                if len(splits) >= num_splits:
                    break

        return splits

In [21]:
if __name__ == "__main__":
    # Initialize Spark
    spark = SparkSession.builder \
        .appName("FindSplitsExample") \
        .master("local[*]") \
        .getOrCreate()
    sc = spark.sparkContext

    # Example data: each Instance has (features, label, weight)
    # Let's assume feature 0 is continuous and feature 1 is categorical (unordered)
    data = [
        Instance(features=[1.0, 'A'], label=0.0, weight=1.0),
        Instance(features=[1.0, 'A'], label=1.0, weight=1.0),
        Instance(features=[2.0, 'B'], label=1.0, weight=1.0),
        Instance(features=[2.0, 'B'], label=0.0, weight=1.0),
        Instance(features=[2.0, 'C'], label=1.0, weight=1.0),
        Instance(features=[3.5, 'C'], label=1.0, weight=1.0),
        Instance(features=[3.5, 'D'], label=0.0, weight=1.0),
        Instance(features=[3.5, 'D'], label=1.0, weight=1.0),
    ]
    input_rdd = sc.parallelize(data)

    # Define the parameters
    num_features = 2
    is_continuous = [True, False]  # Feature 0 is continuous, Feature 1 is categorical
    is_unordered = [False, True]   # Feature 0 is not categorical, Feature 1 is unordered
    max_splits_per_feature = [2, 2]  # Allow up to 2 splits per feature
    max_bins = 32
    total_weighted_examples = float(len(data))  # Assuming all weights are 1
    seed = 42

    # Get splits
    splits = find_splits(
        input_rdd=input_rdd,
        num_features=num_features,
        is_continuous=is_continuous,
        is_unordered=is_unordered,
        max_splits_per_feature=max_splits_per_feature,
        max_bins=max_bins,
        total_weighted_examples=total_weighted_examples,
        seed=seed
    )

    print(splits)
    # Print the splits
    for fidx, feature_splits in enumerate(splits):
        if is_continuous[fidx]:
            print(f"Feature {fidx} (Continuous) splits:")
            for s in feature_splits:
                print(f"  Threshold = {s.threshold}")
        else:
            print(f"Feature {fidx} (Categorical) splits:")
            for s in feature_splits:
                print(f"  Categories = {s.categories}")
            if not feature_splits:
                print("  No splits found.")

    # # Stop Spark
    # spark.stop()


Fraction of data used for calculating splits = 1.0
[(0, 1.0), (0, 1.0), (0, 2.0), (0, 2.0), (0, 2.0)]
[[Split(feature_index=0, threshold=1.5, categories=None, is_continuous=True), Split(feature_index=0, threshold=2.75, categories=None, is_continuous=True)], [Split(feature_index=1, threshold=None, categories={'A'}, is_continuous=False), Split(feature_index=1, threshold=None, categories={'B'}, is_continuous=False)]]
Feature 0 (Continuous) splits:
  Threshold = 1.5
  Threshold = 2.75
Feature 1 (Categorical) splits:
  Categories = {'A'}
  Categories = {'B'}


In [22]:
# Show the first 5 rows of input_rdd
for row in input_rdd.take(5):
    print(row)

Instance(features=[1.0, 'A'], label=0.0, weight=1.0)
Instance(features=[1.0, 'A'], label=1.0, weight=1.0)
Instance(features=[2.0, 'B'], label=1.0, weight=1.0)
Instance(features=[2.0, 'B'], label=0.0, weight=1.0)
Instance(features=[2.0, 'C'], label=1.0, weight=1.0)


In [24]:
# prompt: import the iris data set and create a dataset by duplicating it ten times. load it into a pyspark rdd. create another column features containing the feature vector of each row

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors



# Load the iris dataset (replace with your actual path if needed)
try:
    iris_df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
except Exception as e:
    print(f"Error loading Iris dataset: {e}")
    # Handle the error appropriately (e.g., exit, use a local file)
    exit(1)

# Rename columns
iris_df.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

# Duplicate the DataFrame ten times
iris_duplicated = pd.concat([iris_df] * 10, ignore_index=True)

# Convert pandas DataFrame to PySpark DataFrame
spark_df = spark.createDataFrame(iris_duplicated)

# Create the 'features' column
from pyspark.sql.functions import array, col, udf
from pyspark.sql.types import ArrayType, DoubleType, FloatType
from pyspark.ml.linalg import Vectors, VectorUDT

feature_cols = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
@udf(returnType=VectorUDT())
def create_vector(sepal_length, sepal_width, petal_length, petal_width):
    return Vectors.dense([sepal_length, sepal_width, petal_length, petal_width])

spark_df = spark_df.withColumn('features', create_vector(*[col(c) for c in feature_cols]))

# Show the first 5 rows
spark_df.show(5)

# Convert the PySpark DataFrame to an RDD
iris_rdd = spark_df.rdd

# Show the first 5 elements of the RDD
iris_rdd.take(5)

+------------+-----------+------------+-----------+-----------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|    species|         features|
+------------+-----------+------------+-----------+-----------+-----------------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|[5.1,3.5,1.4,0.2]|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|[4.9,3.0,1.4,0.2]|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|[4.7,3.2,1.3,0.2]|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|[4.6,3.1,1.5,0.2]|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|[5.0,3.6,1.4,0.2]|
+------------+-----------+------------+-----------+-----------+-----------------+
only showing top 5 rows



[Row(sepal_length=5.1, sepal_width=3.5, petal_length=1.4, petal_width=0.2, species='Iris-setosa', features=DenseVector([5.1, 3.5, 1.4, 0.2])),
 Row(sepal_length=4.9, sepal_width=3.0, petal_length=1.4, petal_width=0.2, species='Iris-setosa', features=DenseVector([4.9, 3.0, 1.4, 0.2])),
 Row(sepal_length=4.7, sepal_width=3.2, petal_length=1.3, petal_width=0.2, species='Iris-setosa', features=DenseVector([4.7, 3.2, 1.3, 0.2])),
 Row(sepal_length=4.6, sepal_width=3.1, petal_length=1.5, petal_width=0.2, species='Iris-setosa', features=DenseVector([4.6, 3.1, 1.5, 0.2])),
 Row(sepal_length=5.0, sepal_width=3.6, petal_length=1.4, petal_width=0.2, species='Iris-setosa', features=DenseVector([5.0, 3.6, 1.4, 0.2]))]

In [26]:
# Define the parameters
num_features = 4
is_continuous = [True, True, True, True]  # Feature 0 is continuous, Feature 1 is categorical
is_unordered = [False, False, False, False]   # Feature 0 is not categorical, Feature 1 is unordered
max_splits_per_feature = [10, 10, 10, 10]  # Allow up to 2 splits per feature
max_bins = 32
total_weighted_examples = float(len(data))  # Assuming all weights are 1
seed = 42

# Get splits
splits = find_splits(
    input_rdd=iris_rdd,
    num_features=num_features,
    is_continuous=is_continuous,
    is_unordered=is_unordered,
    max_splits_per_feature=max_splits_per_feature,
    max_bins=max_bins,
    total_weighted_examples=total_weighted_examples,
    seed=seed
)

print(splits)
# Print the splits
for fidx, feature_splits in enumerate(splits):
    if is_continuous[fidx]:
        print(f"Feature {fidx} (Continuous) splits:")
        for s in feature_splits:
            print(f"  Threshold = {s.threshold}")
    else:
        print(f"Feature {fidx} (Categorical) splits:")
        for s in feature_splits:
            print(f"  Categories = {s.categories}")
        if not feature_splits:
            print("  No splits found.")


Fraction of data used for calculating splits = 1.0
[(0, 5.1), (1, 3.5), (2, 1.4), (3, 0.2), (0, 4.9)]
[[Split(feature_index=0, threshold=4.85, categories=None, is_continuous=True), Split(feature_index=0, threshold=5.05, categories=None, is_continuous=True), Split(feature_index=0, threshold=5.15, categories=None, is_continuous=True), Split(feature_index=0, threshold=5.45, categories=None, is_continuous=True), Split(feature_index=0, threshold=5.65, categories=None, is_continuous=True), Split(feature_index=0, threshold=5.95, categories=None, is_continuous=True), Split(feature_index=0, threshold=6.15, categories=None, is_continuous=True), Split(feature_index=0, threshold=6.35, categories=None, is_continuous=True), Split(feature_index=0, threshold=6.65, categories=None, is_continuous=True), Split(feature_index=0, threshold=6.95, categories=None, is_continuous=True)], [Split(feature_index=1, threshold=2.45, categories=None, is_continuous=True), Split(feature_index=1, threshold=2.650000000000