In [63]:
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report

# 为了简化，这里我们直接使用原始代码中的ConjunctionSet类
from flextrees.utils.ConjunctionSet import ConjunctionSet
# 使用utils_function_aggregator来聚合规则
from flextrees.utils.utils_function_aggregator import generate_cs_dt_branches_from_list
from flextrees.utils.branch_tree import TreeBranch

In [20]:
# 配置参数
N_CLIENTS = 5  # 客户端数量
DATA_DISTRIBUTION = 'iid'  # 'iid' 或 'non-iid'
MODEL_TYPE = 'cart'  # 模型类型，简化版只实现'cart'
MAX_DEPTH = 5  # 决策树最大深度
    
# 筛选参数
FILTERING_METHOD = 'mean'  # 筛选方法
ACC_THRESHOLD = 0.6  # 准确率阈值
F1_THRESHOLD = 0.5  # F1分数阈值

# 3. 配置本地模型参数
local_model_params = {
    'max_depth': MAX_DEPTH,
    'criterion': 'gini',
    'splitter': 'best',
    'model_type': MODEL_TYPE,
}

In [21]:
def load_dataset(dataset_name='adult', categorical=False):
    """加载指定的数据集"""
    try:
        # 这里简化为手动加载adult数据集
        from flextrees.datasets import adult
        return adult(ret_feature_names=True, categorical=categorical)
    except ImportError:
        # 如果找不到特定的数据集，使用模拟数据
        print("未找到指定数据集，使用模拟数据")
        from sklearn.datasets import make_classification
        X, y = make_classification(n_samples=1000, n_features=10, n_classes=2, random_state=42)
        feature_names = [f'x{i}' for i in range(X.shape[1])]
        X_df = pd.DataFrame(X, columns=feature_names)
        y_df = pd.Series(y)
        
        # 创建数据集对象
        class Dataset:
            def __init__(self, X, y):
                self.X_data = X
                self.y_data = y
            def to_numpy(self):
                return self.X_data.to_numpy(), self.y_data.to_numpy()
                
        train_data = Dataset(X_df[:800], y_df[:800])
        test_data = Dataset(X_df[800:], y_df[800:])
        return train_data, test_data, feature_names

# 1. 加载数据
print(f"\n加载数据集...")
train_data, test_data, feature_names = load_dataset(categorical=False)
print(f"数据集加载完成，特征数量: {len(feature_names)}")


加载数据集...
数据集加载完成，特征数量: 109


In [22]:
train_data

Dataset(X_data=len:22792
is_generator:False
iterable:[[4.800e+01 0.000e+00 0.000e+00 ... 1.000e+00 0.000e+00 1.000e+00]
 [2.500e+01 0.000e+00 0.000e+00 ... 0.000e+00 0.000e+00 1.000e+00]
 [4.100e+01 7.430e+03 0.000e+00 ... 1.000e+00 1.000e+00 0.000e+00]
 ...
 [4.200e+01 0.000e+00 0.000e+00 ... 1.000e+00 0.000e+00 1.000e+00]
 [3.300e+01 0.000e+00 0.000e+00 ... 1.000e+00 1.000e+00 0.000e+00]
 [3.100e+01 5.178e+03 0.000e+00 ... 1.000e+00 0.000e+00 1.000e+00]]
iterable_indexes:[    0     1     2 ... 22789 22790 22791]
storage:{}, y_data=len:22792
is_generator:False
iterable:[0 0 1 ... 1 0 1]
iterable_indexes:[    0     1     2 ... 22789 22790 22791]
storage:{})

In [23]:
def split_data_to_clients(data, n_clients=5, iid=True):
    """将数据分割给多个客户端"""
    X, y = data.X_data.to_numpy(), data.y_data.to_numpy()
    client_data = []
    
    if iid:
        # IID分割: 随机均匀分割
        indices = np.random.permutation(len(X))
        chunk_size = len(indices) // n_clients
        
        for i in range(n_clients):
            start_idx = i * chunk_size
            end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(indices)
            client_indices = indices[start_idx:end_idx]
            
            # 创建数据集对象
            class ClientDataset:
                def __init__(self, X, y):
                    self.X_data = pd.DataFrame(X)
                    self.y_data = pd.Series(y)
            
            client_data.append(ClientDataset(X[client_indices], y[client_indices]))
    else:
        # Non-IID分割: 按类别偏向分配
        classes = np.unique(y)
        client_indices = [[] for _ in range(n_clients)]
        
        # 按类别分配
        for c in classes:
            idx = np.where(y == c)[0]
            np.random.shuffle(idx)
            
            # 偏向分配
            if len(classes) >= n_clients:
                # 如果类别数多于客户端数，每个客户端主要分配一种类型
                client_id = int(c % n_clients)
                client_indices[client_id].extend(idx[:int(len(idx)*0.6)])
                
                # 其余的随机分配
                remaining_idx = idx[int(len(idx)*0.6):]
                np.random.shuffle(remaining_idx)
                chunk_size = len(remaining_idx) // n_clients
                for i in range(n_clients):
                    start_idx = i * chunk_size
                    end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(remaining_idx)
                    client_indices[i].extend(remaining_idx[start_idx:end_idx])
            else:
                # 如果类别数少于客户端数，将每个类别平均分配
                chunk_size = len(idx) // n_clients
                for i in range(n_clients):
                    start_idx = i * chunk_size
                    end_idx = (i + 1) * chunk_size if i < n_clients - 1 else len(idx)
                    client_indices[i].extend(idx[start_idx:end_idx])
        
        # 创建数据集
        for indices in client_indices:
            class ClientDataset:
                def __init__(self, X, y):
                    self.X_data = pd.DataFrame(X)
                    self.y_data = pd.Series(y)
            
            client_data.append(ClientDataset(X[indices], y[indices]))
    
    return client_data

    
# 2. 将数据分发到客户端
print(f"\n将数据分发到 {N_CLIENTS} 个客户端 ({DATA_DISTRIBUTION} 分布)...")
client_data = split_data_to_clients(train_data, N_CLIENTS, DATA_DISTRIBUTION == 'iid')


将数据分发到 5 个客户端 (iid 分布)...


In [24]:
# 2. 树训练和规则提取函数
def train_local_model(client_data, model_params):
    """在本地训练决策树并提取规则"""
    # 根据模型类型创建分类器
    model_type = model_params.get('model_type', 'cart')
    
    # 创建分类器，这里简化只使用CART决策树
    clf = DecisionTreeClassifier(
        random_state=42,
        min_samples_split=max(1.0, int(0.02 * len(client_data.X_data))),
        max_depth=model_params.get('max_depth', 5),
        criterion=model_params.get('criterion', 'gini'),
        splitter=model_params.get('splitter', 'best')
    )
    
    # 准备训练数据
    X_data, y_data = client_data.X_data.to_numpy(), client_data.y_data.to_numpy()
    X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
    
    # 训练模型
    clf.fit(X_train, y_train)
    
    # 模型评估
    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='macro')
    print(f"客户端本地模型 - 准确率: {acc:.4f}, F1分数: {f1:.4f}, 训练样本数: {len(X_train)}, 测试样本数: {len(X_test)}")
    
    # 从决策树提取规则
    feature_names = [f'x{i}' for i in range(X_data.shape[1])]
    feature_types = ['int'] * len(feature_names)
    
    # 创建规则集合
    local_cs = ConjunctionSet(
        feature_names=feature_names, 
        original_data=X_train, 
        pruning_x=X_train, 
        pruning_y=y_train,
        model=[clf],  # 模型列表
        feature_types=feature_types,  # 特征类型
        amount_of_branches_threshold=3000,  # 分支数量阈值
        minimal_forest_size=1,  # 最小森林大小
        estimators=clf,  # 估计器
        filter_approach='probability',  # 过滤方法
        personalized=False  # 是否个性化
    )
    
    # 创建返回结果
    result = {
        'local_tree': clf,
        'local_cs': local_cs,
        'local_branches': local_cs.get_branches_list(),
        'local_branches_df': local_cs.get_conjunction_set_df().round(decimals=5),
        'local_classes': clf.classes_,
        'X_test': X_test,
        'y_test': y_test,
        'local_acc': acc,
        'local_f1': f1
    }
    
    return result


In [25]:
# 4. 在每个客户端训练本地模型并提取规则
print("\n第1步: 训练本地模型并提取规则...")
client_models = []
for i, data in enumerate(client_data):
    print(f"\n客户端{i+1}训练中...")
    model = train_local_model(data, local_model_params)
    client_models.append(model)
print("第1步: 本地模型训练完成，规则已提取")



第1步: 训练本地模型并提取规则...

客户端1训练中...
客户端本地模型 - 准确率: 0.8542, F1分数: 0.7612, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端2训练中...
客户端本地模型 - 准确率: 0.8564, F1分数: 0.7801, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端3训练中...
客户端本地模型 - 准确率: 0.8410, F1分数: 0.7699, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端4训练中...
客户端本地模型 - 准确率: 0.8432, F1分数: 0.7520, 训练样本数: 3646, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr

客户端5训练中...
客户端本地模型 - 准确率: 0.8487, F1分数: 0.7483, 训练样本数: 3648, 测试样本数: 912
Estimators: DecisionTreeClassifier(max_depth=5, min_samples_split=91, random_state=42)
AQUÍr
第1步: 本地模型训练完成，规则已提取


In [60]:
# 3. 规则评估和筛选函数
def evaluate_trees_on_client(client_model, all_models):
    """评估所有模型在当前客户端数据上的表现"""
    X_test, y_test = client_model['X_test'], client_model['y_test']
    eval_results = []
    
    for model in all_models:
        tree = model['local_tree']
        y_pred = tree.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred, average='macro')
        eval_results.append((acc, f1))
    
    return eval_results

def filter_trees(evaluation_results, filter_params):
    """基于评估结果筛选树"""
    # 计算平均性能
    avg_results = np.mean(evaluation_results, axis=0)
    
    # 根据筛选方法确定阈值
    filter_method = filter_params.get('filter_method', 'mean')
    if filter_method == 'mean':
        acc_threshold = np.mean(avg_results[:,0])
        f1_threshold = np.mean(avg_results[:,1])
    elif filter_method == 'percentile':
        percentile_value = filter_params.get('filter_value', 75)
        acc_threshold = np.percentile(avg_results[:,0], percentile_value)
        f1_threshold = np.percentile(avg_results[:,1], percentile_value)
    else:
        # 默认使用固定阈值
        acc_threshold = filter_params.get('acc_threshold', 0.6)
        f1_threshold = filter_params.get('f1_threshold', 0.5)
    
    # 筛选满足条件的树索引
    selected_indices = []
    for i in range(len(avg_results)):
        # print(avg_results[i], acc_threshold, f1_threshold)
        if avg_results[i][0] >= acc_threshold and avg_results[i][1] >= f1_threshold:
            selected_indices.append(i)
    
    # 如果没有树被选中，选择表现最好的一棵
    if not selected_indices:
        best_idx = np.argmax(avg_results[0])  # 使用准确率选择
        selected_indices = [best_idx]
    
    print(f"筛选后选择了 {len(selected_indices)} 棵树，索引: {selected_indices}")
    return selected_indices


# 5. 评估所有客户端上的所有树
print("\n第2-5步: 筛选弱决策树...")
all_evaluations = []
for i, client_model in enumerate(client_models):
    print(f"客户端 {i+1} 正在评估所有树...")
    eval_results = evaluate_trees_on_client(client_model, client_models)
    all_evaluations.append(eval_results)
    
# 6. 筛选表现好的树
filter_params = {
    'filter_method': FILTERING_METHOD,
    'acc_threshold': ACC_THRESHOLD / 2, 
    'f1_threshold': F1_THRESHOLD / 2
}
    
selected_indices = filter_trees(all_evaluations, filter_params)


第2-5步: 筛选弱决策树...
客户端 1 正在评估所有树...
客户端 2 正在评估所有树...
客户端 3 正在评估所有树...
客户端 4 正在评估所有树...
客户端 5 正在评估所有树...
筛选后选择了 2 棵树，索引: [1, 2]


In [61]:
# 4. 规则聚合函数
def aggregate_rules(client_models, selected_indices):
    """聚合所选客户端的规则"""
    # 仅保留所选树的规则
    selected_models = [client_models[i] for i in selected_indices]
    
    # 提取规则和类别
    client_branches = [model['local_branches'] for model in selected_models]
    client_classes = [model['local_classes'] for model in selected_models]
    client_branches_df = [model['local_branches_df'] for model in selected_models]
    model_types = ['cart'] * len(selected_models)  # 简化为只使用CART
    
    
    # 准备输入格式
    list_of_weights = [(branches, classes, branches_df, model_type) 
                       for branches, classes, branches_df, model_type in 
                       zip(client_branches, client_classes, client_branches_df, model_types)]
    
    # 提取所有类别和特征
    classes_ = set()
    for client_class in client_classes:
        classes_ |= set(client_class)
    classes_ = list(classes_)
    
    # 提取分支列表
    client_cs = [cs for cs in client_branches]
    
    # 聚合为全局模型
    global_model = generate_cs_dt_branches_from_list(client_cs, classes_, TreeBranch)
    
    return global_model

# 7. 聚合规则并构建全局模型
print("\n第6-9步: 聚合规则并构建全局模型...")
global_model = aggregate_rules(client_models, selected_indices)
print("全局模型构建完成")


第6-9步: 聚合规则并构建全局模型...
Estimators: None
Iteration 1: 16 conjunctions

Las reglas actuales son: 
Conjunction set length: 56
Conjunction set length after removing duplicates: 56
branches df aggreagator is not null:     0_upper  0_lower  1_upper  1_lower  2_upper  2_lower  3_upper  3_lower  \
0       inf     -inf   7139.5     -inf   2218.5     -inf     42.5     -inf   
1       inf     -inf   7139.5     -inf   2248.0   2218.5     42.5     -inf   
2       inf     -inf   7139.5     -inf   2248.0     -inf     42.5     -inf   
3       inf     -inf   7139.5     -inf   2218.5     -inf     42.5     -inf   
4       inf     -inf   7139.5     -inf   2248.0   2218.5     42.5     -inf   
5       inf     -inf   7139.5     -inf   2248.0     -inf     42.5     -inf   
6       inf     -inf   7139.5     -inf   2218.5     -inf     43.5     42.5   
7       inf     -inf   7139.5     -inf   2218.5     -inf      inf     43.5   
8       inf     -inf   7139.5     -inf   2248.0   2218.5      inf     42.5   
9      

In [62]:
# 5. 全局模型评估函数
def evaluate_global_model(global_model, test_data):
    """评估全局聚合模型的性能"""
    X_test, y_test = test_data.to_numpy()
    
    # 从全局模型中获取分支类别
    branches_df = global_model[2]
    classes_tree = get_classes_branches(branches_df)
    
    # 使用全局模型进行预测
    y_pred, _ = global_model[1].predict(X_test, classes_tree, branches_df)
    
    # 计算性能指标
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='macro')
    report = classification_report(y_test, y_pred)
    
    print("\n全局模型在测试集上的性能:")
    print(f"准确率: {acc:.4f}")
    print(f"宏平均F1: {f1:.4f}")
    print(f"分类报告: \n{report}")
    
    return acc, f1

def get_classes_branches(branches):
    """从分支DataFrame中获取类别"""
    assert branches is not None
    return list(range(len(branches['probas'].iloc[0])))


In [64]:

# 8. 评估全局模型
print("\n第10步: 评估全局模型...")
eval_results = evaluate_global_model(global_model, test_data)
    
print("\n--- ICDTA4FL 简化Demo完成 ---")


第10步: 评估全局模型...

全局模型在测试集上的性能:
准确率: 0.8509
宏平均F1: 0.7757
分类报告: 
              precision    recall  f1-score   support

           0       0.87      0.94      0.91      7412
           1       0.76      0.56      0.65      2357

    accuracy                           0.85      9769
   macro avg       0.81      0.75      0.78      9769
weighted avg       0.84      0.85      0.84      9769


--- ICDTA4FL 简化Demo完成 ---
