In [2]:
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report

# 为了简化，这里我们直接使用原始代码中的ConjunctionSet类
from flextrees.utils.ConjunctionSet import ConjunctionSet

In [3]:
# 配置参数
N_CLIENTS = 5  # 客户端数量
DATA_DISTRIBUTION = 'iid'  # 'iid' 或 'non-iid'
MODEL_TYPE = 'cart'  # 模型类型，简化版只实现'cart'
MAX_DEPTH = 5  # 决策树最大深度
    
# 筛选参数
FILTERING_METHOD = 'mean'  # 筛选方法
ACC_THRESHOLD = 0.6  # 准确率阈值
F1_THRESHOLD = 0.5  # F1分数阈值

# 3. 配置本地模型参数
local_model_params = {
    'max_depth': MAX_DEPTH,
    'criterion': 'gini',
    'splitter': 'best',
    'model_type': MODEL_TYPE,
}

In [4]:
def load_dataset(dataset_name='adult', categorical=False):
    """加载指定的数据集"""
    try:
        # 这里简化为手动加载adult数据集
        from flextrees.datasets import adult
        return adult(ret_feature_names=True, categorical=categorical)
    except ImportError:
        # 如果找不到特定的数据集，使用模拟数据
        print("未找到指定数据集，使用模拟数据")
        from sklearn.datasets import make_classification
        X, y = make_classification(n_samples=1000, n_features=10, n_classes=2, random_state=42)
        feature_names = [f'x{i}' for i in range(X.shape[1])]
        X_df = pd.DataFrame(X, columns=feature_names)
        y_df = pd.Series(y)
        
        # 创建数据集对象
        class Dataset:
            def __init__(self, X, y):
                self.X_data = X
                self.y_data = y
            def to_numpy(self):
                return self.X_data.to_numpy(), self.y_data.to_numpy()
                
        train_data = Dataset(X_df[:800], y_df[:800])
        test_data = Dataset(X_df[800:], y_df[800:])
        return train_data, test_data, feature_names

# 1. 加载数据
print(f"\n加载数据集...")
train_data, test_data, feature_names = load_dataset(categorical=False)
print(f"数据集加载完成，特征数量: {len(feature_names)}")


加载数据集...
数据集加载完成，特征数量: 109


In [5]:
def split_data_to_clients(data, n_clients=5, iid=True):
    """将数据分割给多个客户端"""
    X, y = data.X_data.to_numpy(), data.y_data.to_numpy()
    client_data = []
    
    if iid:
        # IID分割: 随机均匀分割
        indices = np.random.permutation(len(X))
        chunk_size = len(indices) // n_clients
        
        for i in range(n_clients):
            start_idx = i * chunk_size
            end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(indices)
            client_indices = indices[start_idx:end_idx]
            
            # 创建数据集对象
            class ClientDataset:
                def __init__(self, X, y):
                    self.X_data = pd.DataFrame(X)
                    self.y_data = pd.Series(y)
            
            client_data.append(ClientDataset(X[client_indices], y[client_indices]))
    else:
        # Non-IID分割: 按类别偏向分配
        classes = np.unique(y)
        client_indices = [[] for _ in range(n_clients)]
        
        # 按类别分配
        for c in classes:
            idx = np.where(y == c)[0]
            np.random.shuffle(idx)
            
            # 偏向分配
            if len(classes) >= n_clients:
                # 如果类别数多于客户端数，每个客户端主要分配一种类型
                client_id = int(c % n_clients)
                client_indices[client_id].extend(idx[:int(len(idx)*0.6)])
                
                # 其余的随机分配
                remaining_idx = idx[int(len(idx)*0.6):]
                np.random.shuffle(remaining_idx)
                chunk_size = len(remaining_idx) // n_clients
                for i in range(n_clients):
                    start_idx = i * chunk_size
                    end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(remaining_idx)
                    client_indices[i].extend(remaining_idx[start_idx:end_idx])
            else:
                # 如果类别数少于客户端数，将每个类别平均分配
                chunk_size = len(idx) // n_clients
                for i in range(n_clients):
                    start_idx = i * chunk_size
                    end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(idx)
                    client_indices[i].extend(idx[start_idx:end_idx])
        
        # 创建数据集
        for indices in client_indices:
            class ClientDataset:
                def __init__(self, X, y):
                    self.X_data = pd.DataFrame(X)
                    self.y_data = pd.Series(y)
            
            client_data.append(ClientDataset(X[indices], y[indices]))
    
    return client_data

    
# 2. 将数据分发到客户端
print(f"\n将数据分发到 {N_CLIENTS} 个客户端 ({DATA_DISTRIBUTION} 分布)...")
client_data = split_data_to_clients(train_data, N_CLIENTS, DATA_DISTRIBUTION == 'iid')


将数据分发到 5 个客户端 (iid 分布)...


In [6]:
# 2. 树训练和规则提取函数
def train_local_model(client_data, model_params):
    """在本地训练决策树并提取规则"""
    # 根据模型类型创建分类器
    model_type = model_params.get('model_type', 'cart')
    
    # 创建分类器，这里简化只使用CART决策树
    clf = DecisionTreeClassifier(
        random_state=42,
        min_samples_split=max(1.0, int(0.02 * len(client_data.X_data))),
        max_depth=model_params.get('max_depth', 5),
        criterion=model_params.get('criterion', 'gini'),
        splitter=model_params.get('splitter', 'best')
    )
    
    # 准备训练数据
    X_data, y_data = client_data.X_data.to_numpy(), client_data.y_data.to_numpy()
    X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
    
    # 训练模型
    clf.fit(X_train, y_train)
    
    # 模型评估
    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='macro')
    print(f"客户端本地模型 - 准确率: {acc:.4f}, F1分数: {f1:.4f}, 训练样本数: {len(X_train)}, 测试样本数: {len(X_test)}")
    
    # 从决策树提取规则
    feature_names = [f'x{i}' for i in range(X_data.shape[1])]
    feature_types = ['int'] * len(feature_names)
    
    # 创建规则集合
    local_cs = ConjunctionSet(
        feature_names=feature_names, 
        original_data=X_train, 
        pruning_x=X_train, 
        pruning_y=y_train,
        model=[clf],  # 模型列表
        feature_types=feature_types,  # 特征类型
        amount_of_branches_threshold=3000,  # 分支数量阈值
        minimal_forest_size=1,  # 最小森林大小
        estimators=clf,  # 估计器
        filter_approach='probability',  # 过滤方法
        personalized=False  # 是否个性化
    )
    
    # 创建返回结果
    result = {
        'local_tree': clf,
        'local_cs': local_cs,
        'local_branches': local_cs.get_branches_list(),
        'local_branches_df': local_cs.get_conjunction_set_df().round(decimals=5),
        'local_classes': clf.classes_,
        'X_test': X_test,
        'y_test': y_test,
        'local_acc': acc,
        'local_f1': f1
    }
    
    return result


In [7]:
# 4. 在每个客户端训练本地模型并提取规则
print("\n第1步: 训练本地模型并提取规则...")
client_models = []
for i, data in enumerate(client_data):
    print(f"\n客户端{i+1}训练中...")
    model = train_local_model(data, local_model_params)
    client_models.append(model)
print("第1步: 本地模型训练完成，规则已提取")



第1步: 训练本地模型并提取规则...

客户端1训练中...
客户端本地模型 - 准确率: 0.8564, F1分数: 0.7827, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)


AQUÍr

客户端2训练中...
客户端本地模型 - 准确率: 0.8410, F1分数: 0.7131, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端3训练中...
客户端本地模型 - 准确率: 0.8531, F1分数: 0.7621, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端4训练中...
客户端本地模型 - 准确率: 0.8509, F1分数: 0.7722, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端5训练中...
客户端本地模型 - 准确率: 0.8432, F1分数: 0.7868, 训练样本数: 3648, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr
第1步: 本地模型训练完成，规则已提取


In [8]:
# 3. 规则评估和筛选函数
def evaluate_trees_on_client(client_model, all_models):
    """评估所有模型在当前客户端数据上的表现"""
    X_test, y_test = client_model['X_test'], client_model['y_test']
    eval_results = []
    
    for model in all_models:
        tree = model['local_tree']
        y_pred = tree.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred, average='macro')
        eval_results.append((acc, f1))
    
    return eval_results

def filter_trees(evaluation_results, filter_params):
    """基于评估结果筛选树"""
    # 计算平均性能
    avg_results = np.mean(evaluation_results, axis=0)
    
    # 根据筛选方法确定阈值
    filter_method = filter_params.get('filter_method', 'mean')
    if filter_method == 'mean':
        acc_threshold = np.mean(avg_results[:,0])
        f1_threshold = np.mean(avg_results[:,1])
    elif filter_method == 'percentile':
        percentile_value = filter_params.get('filter_value', 75)
        acc_threshold = np.percentile(avg_results[:,0], percentile_value)
        f1_threshold = np.percentile(avg_results[:,1], percentile_value)
    else:
        # 默认使用固定阈值
        acc_threshold = filter_params.get('acc_threshold', 0.6)
        f1_threshold = filter_params.get('f1_threshold', 0.5)
    
    # 筛选满足条件的树索引
    selected_indices = []
    for i in range(len(avg_results)):
        # print(avg_results[i], acc_threshold, f1_threshold)
        if avg_results[i][0] >= acc_threshold and avg_results[i][1] >= f1_threshold:
            selected_indices.append(i)
    
    # 如果没有树被选中，选择表现最好的一棵
    if not selected_indices:
        best_idx = np.argmax(avg_results[0])  # 使用准确率选择
        selected_indices = [best_idx]
    
    print(f"筛选后选择了 {len(selected_indices)} 棵树，索引: {selected_indices}")
    return selected_indices


# 5. 评估所有客户端上的所有树
print("\n第2-5步: 筛选弱决策树...")
all_evaluations = []
for i, client_model in enumerate(client_models):
    print(f"客户端 {i+1} 正在评估所有树...")
    eval_results = evaluate_trees_on_client(client_model, client_models)
    all_evaluations.append(eval_results)
    
# 6. 筛选表现好的树
filter_params = {
    'filter_method': FILTERING_METHOD,
    'acc_threshold': ACC_THRESHOLD / 2, 
    'f1_threshold': F1_THRESHOLD / 2
}
    
selected_indices = filter_trees(all_evaluations, filter_params)


第2-5步: 筛选弱决策树...
客户端 1 正在评估所有树...
客户端 2 正在评估所有树...
客户端 3 正在评估所有树...
客户端 4 正在评估所有树...
客户端 5 正在评估所有树...
筛选后选择了 3 棵树，索引: [2, 3, 4]


In [9]:
def generate_cs_dt_branches_from_list(client_cs, classes_, tree_model, threshold=3000):
    """Function that generate a global ConjuctionSet, a GlobalTree and the branches
    associated to the tree in the server node.
    """

    cs = ConjunctionSet(
        filter_approach="entropy",
        amount_of_branches_threshold=threshold,
        feature_names=[],
        personalized=False,
    )
    cs.aggregate_branches(client_cs, classes_)
    cs.buildConjunctionSet()
    print(f"Conjunction set length: {len(cs.conjunctionSet)}")
    cs.conjunctionSet = delete_duplicated_rules(cs.conjunctionSet)
    print(f"Conjunction set length after removing duplicates: {len(cs.conjunctionSet)}")
    branches_df_aggregator = cs.get_conjunction_set_df().round(decimals=5)

    probabilities = branches_df_aggregator["branch_probability"].to_list()
    # new_probas = [x for x in probabilities]
    new_probas = list(probabilities)
    total_probas = sum(new_probas)
    # self._save_rules(client_cs, cs, round_number)
    branches_df_aggregator["branch_probability"] = branches_df_aggregator[
        "branch_probability"
    ].map(lambda x: x / total_probas)
    if pd.isna(branches_df_aggregator).any().any():
        import time
        time.sleep(5)
        print("Before fillna")
        print(branches_df_aggregator)
        branches_df_aggregator = branches_df_aggregator.fillna(np.inf)
        print("After fillna")
        print(branches_df_aggregator)
        time.sleep(6)
    else:
        print(f"branches df aggreagator is not null: {branches_df_aggregator}")
        branches_df_aggregator.to_csv("branches_df_aggregator.csv")
    new_df_dict = {
        col: branches_df_aggregator[col].values
        for col in branches_df_aggregator.columns
    }
    new_dt_model = tree_model([True] * len(branches_df_aggregator), classes_)
    new_dt_model.split(new_df_dict)
    return [cs, new_dt_model, branches_df_aggregator]


def delete_duplicated_rules(rules_dataset):
    rules = {str(rule): rule for rule in rules_dataset}
    return list(rules.values())


In [10]:

import numpy as np
from scipy.stats import entropy

EPSILON = 0.000001

#################################### NEW ####################################
from sklearn.metrics import auc, cohen_kappa_score, roc_curve

#############################################################################


class TreeBranch:
    def __init__(self, mask, classes=None, depth=0):
        self.mask = mask
        # print(mask)
        # print(self.mask)
        self.classes_ = classes
        self.left = None
        self.right = None
        self.split_feature = None
        self.split_value = None
        self.detph = depth

    def split(self, df):
        """Function that splits a node into two childs.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.
        """
        # print(df)
        # if np.sum(self.mask)==1 or self.has_same_class(df):
        if np.sum(self.mask) == 1:
            self.left = None
            self.right = None
            return
        self.features = [int(i.split("_")[0]) for i in df.keys() if "upper" in str(i)]
        # print(self.features)
        # print(f"Printing self.mask: {self.mask}")
        # print(f"Printing len self.mask: {len(self.mask)}")
        self.split_feature, self.split_value = self.select_split_feature(df)
        self.create_mask(df)
        is_splitable = self.is_splitable()
        if is_splitable is False:
            self.left = None
            self.right = None
            return
        # print(f"Left tree mask: {list(np.logical_and(self.mask,np.logical_or(self.left_mask,self.both_mask)))}")
        # print(f"Left len tree mask: {len(list(np.logical_and(self.mask,np.logical_or(self.left_mask,self.both_mask))))}")
        # print(f"Right tree mask: {list(np.logical_and(self.mask,np.logical_or(self.right_mask,self.both_mask)))}")
        # print(f"Right len tree mask: {len(list(np.logical_and(self.mask,np.logical_or(self.right_mask,self.both_mask))))}")
        # print(f"Both mask: {self.both_mask}")
        # print(f"Len de both mask: {len(self.both_mask)}")
        # print(f"Logical or entre right mask y both mask: {np.logical_or(self.right_mask,self.both_mask)}")
        # print(f"True right mask: {self.right_mask}")
        self.left = TreeBranch(
            list(
                np.logical_and(self.mask, np.logical_or(self.left_mask, self.both_mask))
            ),
            self.classes_,
            depth=self.detph + 1,
        )
        self.right = TreeBranch(
            list(
                np.logical_and(
                    self.mask, np.logical_or(self.right_mask, self.both_mask)
                )
            ),
            self.classes_,
            depth=self.detph + 1,
        )
        self.left.split(df)
        self.right.split(df)

    def is_splitable(self):
        """Function that checks if a node is splittable.

        Returns:
            bool: Returns True if the node is splittable, False otherwise.
        """
        if (
            np.sum(
                np.logical_and(self.mask, np.logical_or(self.left_mask, self.both_mask))
            )
            == 0
            or np.sum(
                np.logical_and(
                    self.mask, np.logical_or(self.right_mask, self.both_mask)
                )
            )
            == 0
        ):
            return False
        if np.sum(
            np.logical_and(self.mask, np.logical_or(self.left_mask, self.both_mask))
        ) == np.sum(self.mask) or np.sum(
            np.logical_and(self.mask, np.logical_or(self.right_mask, self.both_mask))
        ) == np.sum(
            self.mask
        ):
            return False
        return True

    def create_mask(self, df):
        """Function that creates the mask for the childs of a node.

        Args:
            df (Pandas.Dataframe): Dataframe with the instances of the node.
        """
        self.left_mask = df[str(self.split_feature) + "_upper"] <= self.split_value
        self.right_mask = df[str(self.split_feature) + "_lower"] >= self.split_value
        self.both_mask = (df[str(self.split_feature) + "_lower"] < self.split_value) & (
            df[str(self.split_feature) + "_upper"] > self.split_value
        )
        # self.both_mask = [True if self.split_value < upper and self.split_value > lower else False for lower, upper in
        #             zip(df[str(self.split_feature) + '_lower'], df[str(self.split_feature) + "_upper"])]

    def select_split_feature(self, df):
        """Function that select the feature to split the node. It calculates the
        metric for each feature and returns the feature with the lowest metric.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.

        Returns:
            tuple: Tuple containing the feature and the value of the feature that
                minimizes the metric.
        """
        feature_to_value = {}
        feature_to_metric = {}
        for feature in self.features:
            value, metric = self.check_feature_split_value(df, feature)
            feature_to_value[feature] = value
            feature_to_metric[feature] = metric
        # print('SELECT_SPLIT_FEATURE')
        # print(feature_to_value)
        # print(feature_to_metric)
        feature = min(feature_to_metric, key=feature_to_metric.get)
        return feature, feature_to_value[feature]

    def check_feature_split_value(self, df, feature):
        """Function that calculate the metric for a given feature.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.
            feature (str): Feature to calculate the metric.

        Returns:
            tuple: Two values, the first one is the value of the feature that
                minimizes the metric, and the second one is the metric for that
                value.
        """
        value_to_metric = {}
        values = list(
            set(
                list(df[str(feature) + "_upper"][self.mask])
                + list(df[str(feature) + "_lower"][self.mask])
            )
        )
        np.random.shuffle(values)
        # values = values[:3]
        # print(values)
        for value in values:
            left_mask = [
                True if upper <= value else False
                for upper in df[str(feature) + "_upper"]
            ]
            right_mask = [
                True if lower >= value else False
                for lower in df[str(feature) + "_lower"]
            ]
            both_mask = [
                True if value < upper and value > lower else False
                for lower, upper in zip(
                    df[str(feature) + "_lower"], df[str(feature) + "_upper"]
                )
            ]
            value_to_metric[value] = self.get_value_metric(
                df, left_mask, right_mask, both_mask
            )
        # print('CHECK FEATURE SPLIT VALUE')
        # print(value_to_metric)
        val = min(value_to_metric, key=value_to_metric.get)
        return val, value_to_metric[val]

    def get_value_metric(self, df, left_mask, right_mask, both_mask):
        """Function that calculates the metric for a given value of a feature.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.
            value_mask (Pandas.DataFrame): Masked dataframe with the instances of the
                node for a given value of a feature.
            feature_mask (Pandas.DataFrame): Masked dataframe with the instances of the
                node for a given feature.

        Returns:
            _type_: _description_
        """
        l_df_mask = np.logical_and(np.logical_or(left_mask, both_mask), self.mask)
        r_df_mask = np.logical_and(np.logical_or(right_mask, both_mask), self.mask)
        if np.sum(l_df_mask) == 0 or np.sum(r_df_mask) == 0:
            return np.inf
        l_entropy, r_entropy = self.calculate_entropy(
            df, l_df_mask
        ), self.calculate_entropy(df, r_df_mask)
        l_prop = np.sum(l_df_mask) / len(l_df_mask)
        r_prop = np.sum(r_df_mask) / len(l_df_mask)
        return l_entropy * l_prop + r_entropy * r_prop

    def predict_probas_and_depth(self, inst, training_df, explanation=None):
        """This function returns the prediction of the instance and the depth of the
        tree that has been used to predict the instance. Also returns the explanation,
        that is calculated as the feature and the value of the feature that has been
        used to predict the instance.

        Args:
            inst (np.array): Instance to be predicted
            training_df (Pandas.DataFrame): Pandas dataframe with the training data.

        Returns:
            prediction: The prediction for a node
        """
        if explanation is None:
            explanation = {}
        if self.left is None and self.right is None:
            return self.node_probas(training_df), 1, {}
        if inst[self.split_feature] <= self.split_value:
            prediction, depth, aux_explanation = self.left.predict_probas_and_depth(
                inst, training_df
            )
            explanation.update(aux_explanation)
            explanation[f"x{self.split_feature}"] = "<=" + str(self.split_value)
            return prediction, depth + 1, explanation
        else:
            prediction, depth, aux_explanation = self.right.predict_probas_and_depth(
                inst, training_df
            )
            explanation.update(aux_explanation)
            explanation[f"x{self.split_feature}"] = ">" + str(self.split_value)
            return prediction, depth + 1, explanation

    def node_probas(self, df):
        """Function that get the probabilities for a node. Those probabilities are
        the label of the node.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.

        Returns:
            array: Array with the probabilities for each class.
        """
        x = df["probas"][self.mask].mean()
        return x / x.sum()

    """
    def get_node_prediction(self,training_df):
        v=training_df['probas'][self.mask][0]
        v=[i/np.sum(v) for i in v]
        return np.array(v)
    
    
    def opposite_col(self,s):
        if 'upper' in s:
            return s.replace('upper','lower')
        else:
            return s.replace('lower', 'upper')
    """

    def calculate_entropy(self, test_df, test_df_mask):
        """Function that calculates the entropy for a given node.

        Args:
            df (Pandas.DataFrame): Dataframe with the instances of the node.
            df_mask (Pandas.DataFrame): Masked dataframe with the instances of the node.

        Returns:
            float: The entropy of the node.
        """
        x = test_df["probas"][test_df_mask].mean()
        return entropy(x / x.sum())

    def count_depth(self):
        """Function that counts the depth of a node.

        Returns:
            int: The depth of the node.
        """
        if self.right is None:
            return 1
        return max(self.left.count_depth(), self.right.count_depth()) + 1

    def number_of_children(self):
        """
        Function that returns the number of children of a node.
        Returns:
            int: The number of children of a node.
        """
        if self.right is None:
            return 1
        return 1 + self.right.number_of_children() + self.left.number_of_children()

    """
    def has_same_class(self,df):
        labels=set([np.argmax(l) for l in df['probas'][self.mask]])
        if len(labels)>1:
            return False
        return True
    """

    ################################## NEW ####################################
    def new_model_measures(self, X, Y, branches_df, classes_p=None):
        # DEPRECATED
        # NO LAS ESTOY UTILIZANDO
        result_dict = {}
        probas, depths = [], []
        self.classes_ = classes_p if classes_p is not None else self.classes_
        for inst in X:
            prob, depth = self.predict_probas_and_depth(inst, branches_df)
            probas.append(prob)
            depths.append(depth)
        print(self.classes_)
        # Modificar la predicción, para que se haga sobre la clase más probable sobre las que tiene el cliente
        # Modificar para que funcione correctamente sobre multiclase
        predictions = [
            self.classes_[i] for i in np.array([np.argmax(prob) for prob in probas])
        ]
        result_dict["new_model_average_depth"] = np.mean(depths)
        result_dict["new_model_min_depth"] = np.min(depths)
        result_dict["new_model_max_depth"] = np.max(depths)
        result_dict["new_model_accuracy"] = np.sum(predictions == Y) / len(Y)
        result_dict["new_model_auc"] = self.get_auc(Y, np.array(probas), self.classes_)
        result_dict["new_model_kappa"] = cohen_kappa_score(Y, predictions)
        result_dict["new_model_number_of_nodes"] = self.number_of_children()
        result_dict["new_model_probas"] = probas
        result_dict["predictions"] = predictions
        return result_dict

    def predict(self, X, classes_, branches_df):
        """Function that predicts the class for a given instance.

        Args:
            X (array): instance to predict
            classes_ (array): classes of the dataset
            branches_df (Pandas.DataFrame): Dataframe with the branches of the tree.

        Returns:
            tuple: Predictions and explanations for the instances predicted.
        """
        probas, depths = [], []
        explanations = []
        classes_ = self.classes_ if self.classes_ is not None else classes_
        for inst in X:
            prob, depth, explanation = self.predict_probas_and_depth(inst, branches_df)
            probas.append(prob)
            depths.append(depth)
            explanations.append(explanation)
        predictions = [
            classes_[i] for i in np.array([np.argmax(prob) for prob in probas])
        ]
        explanations = [
            self.generate_explanation(pred, expl)
            for pred, expl in zip(predictions, explanations)
        ]
        return predictions, explanations

    def predict_probas(self, X, classes_, branches_df):
        """Function that predicts the class for a given instance.

        Args:
            X (array): instance to predict.
            classes_ (array): classes of the dataset.
            branches_df (Pandas.DataFrame): Dataframe with the branches of the tree.

        Returns:
            array: Returns an array containing the probabilities for each class.
        """
        probas, depths = [], []
        explanations = []
        classes_ = self.classes_ if self.classes_ is not None else classes_
        for inst in X:
            # prob, depth = self.predict_probas_and_depth(inst, branches_df)
            prob, depth, explanation = self.predict_probas_and_depth(inst, branches_df)
            probas.append(prob)
            depths.append(depth)
            explanations.append(explanation)
        return np.array(probas), explanations

    def get_auc(self, Y, y_score, classes):
        """Function to calculate the auc for a given set of predictions.

        Args:
            Y (array): Labels of the instances.
            y_score (array): Array with the predictions for each instance.
            classes (array): Array with the classes of the dataset.

        Returns:
            _type_: _description_
        """
        y_test_binarize = np.array([[1 if i == c else 0 for c in classes] for i in Y])
        fpr, tpr, _ = roc_curve(y_test_binarize.ravel(), y_score.ravel())
        return auc(fpr, tpr)

    def generate_explanation(self, target, explanation):
        """Function that generates an explanation for a given target instance.

        Args:
            target (int): Label predicted for an instance
            explanation (Dict): Dict containing the explanation for the instance
            in format feature: '<=value' or feature: '>value'.
        """
        ret = ""
        ret += f"The instance was classified as {target}. "
        ret += "Because:"
        ret += "".join(
            [f" {feature}{value}," for feature, value in explanation.items()]
        )
        return f"{ret[:-1]}."


In [11]:
# 4. 规则聚合函数
def aggregate_rules(client_models, selected_indices):
    """聚合所选客户端的规则"""
    # 仅保留所选树的规则
    selected_models = [client_models[i] for i in selected_indices]
    
    # 提取规则和类别
    client_branches = [model['local_branches'] for model in selected_models]
    client_classes = [model['local_classes'] for model in selected_models]
    client_branches_df = [model['local_branches_df'] for model in selected_models]
    model_types = ['cart'] * len(selected_models)  # 简化为只使用CART
    
    
    # 准备输入格式
    list_of_weights = [(branches, classes, branches_df, model_type) 
                       for branches, classes, branches_df, model_type in 
                       zip(client_branches, client_classes, client_branches_df, model_types)]
    
    # 提取所有类别和特征
    classes_ = set()
    for client_class in client_classes:
        classes_ |= set(client_class)
    classes_ = list(classes_)
    
    # 提取分支列表
    client_cs = [cs for cs in client_branches]
    
    # 聚合为全局模型
    global_model = generate_cs_dt_branches_from_list(client_cs, classes_, TreeBranch)
    
    return global_model

# 7. 聚合规则并构建全局模型
print("\n第6-9步: 聚合规则并构建全局模型...")
global_model = aggregate_rules(client_models, selected_indices)
print("全局模型构建完成")


第6-9步: 聚合规则并构建全局模型...
Estimators: None
Iteration 1: 15 conjunctions

Las reglas actuales son: 
Iteration 2: 41 conjunctions

Las reglas actuales son: 
Conjunction set length: 117
Conjunction set length after removing duplicates: 117
branches df aggreagator is not null:      0_upper  0_lower  1_upper  1_lower  2_upper  2_lower  3_upper  3_lower  \
0       30.5     -inf   7139.5     -inf   2218.0     -inf     40.5     -inf   
1       30.5     -inf   7139.5     -inf   2218.0     -inf     40.5     -inf   
2        inf     30.5   7139.5     -inf   2218.0     -inf     40.5     -inf   
3        inf     30.5   7139.5     -inf   2218.0     -inf     40.5     -inf   
4       30.5     -inf   7139.5     -inf   2391.5   2218.0     40.5     -inf   
..       ...      ...      ...      ...      ...      ...      ...      ...   
112      inf     -inf   5095.5     -inf   1794.0     -inf      inf     39.0   
113      inf     -inf   7032.5   5095.5   1794.0     -inf      inf     -inf   
114      inf     -

In [12]:
# 5. 全局模型评估函数
def evaluate_global_model(global_model, test_data):
    """评估全局聚合模型的性能"""
    X_test, y_test = test_data.to_numpy()
    
    # 从全局模型中获取分支类别
    branches_df = global_model[2]
    classes_tree = get_classes_branches(branches_df)
    
    # 使用全局模型进行预测
    y_pred, _ = global_model[1].predict(X_test, classes_tree, branches_df)
    
    # 计算性能指标
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='macro')
    report = classification_report(y_test, y_pred)
    
    print("\n全局模型在测试集上的性能:")
    print(f"准确率: {acc:.4f}")
    print(f"宏平均F1: {f1:.4f}")
    print(f"分类报告: \n{report}")
    
    return acc, f1

def get_classes_branches(branches):
    """从分支DataFrame中获取类别"""
    assert branches is not None
    return list(range(len(branches['probas'].iloc[0])))


In [13]:

# 8. 评估全局模型
print("\n第10步: 评估全局模型...")
eval_results = evaluate_global_model(global_model, test_data)
    
print("\n--- ICDTA4FL 简化Demo完成 ---")


第10步: 评估全局模型...

全局模型在测试集上的性能:
准确率: 0.8543
宏平均F1: 0.7769
分类报告: 
              precision    recall  f1-score   support

           0       0.87      0.95      0.91      7392
           1       0.79      0.54      0.65      2377

    accuracy                           0.85      9769
   macro avg       0.83      0.75      0.78      9769
weighted avg       0.85      0.85      0.84      9769


--- ICDTA4FL 简化Demo完成 ---
