In [3]:
from rpart.DecisionTreeClassifier import DecisionTreeClassifier
import pandas as pd 
import numpy as np
from collections import Counter
from functools import reduce
import multiprocessing as mp


class Node:
    def __init__(
        self,
        feature=None,
        threshold=None,
        left=None,
        right=None,
        depth=None,
        *,
        value=None,
    ):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
        self.depth = depth

    def is_leaf_node(self):
        return self.value is not None

class MapReduceDecisionTreeClassifier(DecisionTreeClassifier):

    def __init__(self, max_depth=None, metric='gini', n_workers=8):
        super().__init__(max_depth=max_depth, metric=metric)
        self.n_workers = n_workers

    def fit(self, X, y):
        if isinstance(X, np.ndarray):
            X = pd.DataFrame(X)
            self.feature_names = [f"Feature {i}" for i in range(X.shape[1])]
        else:
            self.feature_names = X.columns
        self.n_classes = len(set(y))
        self.n_features = X.shape[1]
        self.feature_names = X.columns

        # Partition the dataset into smaller subsets
        data_partitions = self._partition_data(X, y, self.n_workers)
        
        # Build the tree using MapReduce
        self.root = self._grow_tree_mapreduce(data_partitions)

    def _partition_data(self, X, y, n_partitions):
        data = pd.concat([X, y], axis=1)
        return np.array_split(data, n_partitions)

    def _grow_tree_mapreduce(self, data_partitions, depth=0):
        # Stopping conditions
        if (self.max_depth is not None and depth >= self.max_depth) or self._is_leaf_node(data_partitions):
            leaf_value = self._most_common_label(data_partitions)
            return Node(value=leaf_value)

        # Find the best split across all partitions
        feature, threshold = self._best_split_mapreduce(data_partitions)
        if feature is None:
            leaf_value = self._most_common_label(data_partitions)
            return Node(value=leaf_value)

        # Split the data_partitions based on the best split
        left_partitions, right_partitions = self._split_partitions(data_partitions, feature, threshold)

        # Recursively grow the tree on the left and right child nodes
        left = self._grow_tree_mapreduce(left_partitions, depth + 1)
        right = self._grow_tree_mapreduce(right_partitions, depth + 1)
        return Node(feature, threshold, left, right)


    def _most_common_label(self, data_partitions):
        all_labels = np.concatenate([partition.iloc[:, -1].values for partition in data_partitions])
        return Counter(all_labels).most_common(1)[0][0]

def _best_split_mapreduce(self, data_partitions):
        # Define a function to be parallelized
        def find_best_splits(partition):
            return self._best_splits(partition.iloc[:, :-1], partition.iloc[:, -1])

        # Create a multiprocessing pool and map the find_best_splits function to the data partitions
        with mp.Pool(mp.cpu_count()) as pool:
            local_best_splits = pool.map(find_best_splits, data_partitions)

        # Flatten the list of local_best_splits
        all_splits = [split for partition_splits in local_best_splits for split in partition_splits]

        # Find the global best split by comparing scores
        best_split = min(all_splits, key=lambda x: x[2])
        return best_split[:2]

    def _reduce_best_splits(self, split1, split2):
        feature1, threshold1, score1 = split1
        feature2, threshold2, score2 = split2

        return (feature1, threshold1) if score1 < score2 else (feature2, threshold2)

    def _split_partitions(self, data_partitions, feature, threshold):
        left_partitions = []