In [21]:
import torch
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
import seaborn as sns

from simple_einet.einet import Einet, EinetConfig
from simple_einet.layers.distributions.piecewise_linear import PiecewiseLinear
from simple_einet.dist import DataType, Domain

### Import, Preprocess and Split the Dataset 

Traditional non-federated learning mode

In [39]:
data = fetch_openml('adult', version=2, as_frame=True)
df = data.frame

df = df.replace('?', np.nan)
df_clean = df.dropna()
X = df_clean.drop('class', axis=1)
y = df_clean['class']


In [40]:
numeric_features = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
categorical_features = X.select_dtypes(include=['category', 'object']).columns.tolist()

print(f"Numerical feat: ({len(numeric_features)}): {numeric_features}")
print(f"Categorical feat: ({len(categorical_features)}): {categorical_features}")

Numerical feat: (6): ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical feat: (8): ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']


In [19]:
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numeric_features),
        ('cat', LabelEncoder(), categorical_features)
    ]
)
X_numeric = StandardScaler().fit_transform(X[numeric_features])
X_numeric_df = pd.DataFrame(X_numeric, columns=numeric_features, index=X.index)

X_categorical_encoded = pd.DataFrame(index=X.index)
for col in categorical_features:
    le = LabelEncoder()
    X_categorical_encoded[col] = le.fit_transform(X[col].astype(str))
    
X_processed = pd.concat([X_numeric_df, X_categorical_encoded], axis=1)
y_encoded = LabelEncoder().fit_transform(y)

print(f"X shape after preprocessed: {X_processed.shape}")
print(f"Target unique: {np.unique(y_encoded)}")

X shape after preprocessed: (45222, 14)
Target unique: [0 1]


In [20]:
X_train, X_test, y_train, y_test = train_test_split(
    X_processed.values, y_encoded, test_size=0.33, random_state=42, stratify=y_encoded
)

X_train_tensor = torch.tensor(X_train).float()
X_test_tensor = torch.tensor(X_test).float()
y_train_tensor = torch.tensor(y_train).long()
y_test_tensor = torch.tensor(y_test).long()

print(f"X Train shape: {X_train_tensor.shape}")
print(f"X Test shape: {X_test_tensor.shape}")

X Train shape: torch.Size([30298, 14])
X Test shape: torch.Size([14924, 14])


### Construct the domain used for Einet with Piecewise Distribution

In [23]:
domains = []
all_features = numeric_features + categorical_features

for i, feature in enumerate(all_features):
    if feature in numeric_features:
        # 數值型特徵使用連續域
        domains.append(Domain(data_type=DataType.CONTINUOUS))
    else:
        # 類別型特徵使用離散域
        # 獲取唯一值作為離散值域
        unique_values = sorted(X_processed[feature].unique())
        domains.append(Domain(data_type=DataType.DISCRETE, values=unique_values))

print(f"Defined {len(domains)} feature domains.")

Defined 14 feature domains.


### Configure Einet

In [32]:
# 重塑訓練資料為 EiNet 所需的格式 [batch_size, channels, features]
# 這裡 channels=1，因為我們沒有多通道資料
X_train_reshaped = X_train_tensor.unsqueeze(1)  # [batch_size, 1, features]

# 配置 EiNet 使用 PiecewiseLinear 分布
config = EinetConfig(
    num_features=X_train_tensor.shape[1],  # 特徵數量
    depth=2,  # 網路深度
    num_sums=10,  # sum nodes 數量
    num_leaves=10,  # leaf nodes 數量  
    num_repetitions=5,  # 重複數量
    num_classes=2,  # 分類類別數（<=50K, >50K）
    leaf_type=PiecewiseLinear,  # 使用 PiecewiseLinear 分布
    leaf_kwargs={'alpha': 0.1},  # Laplace 平滑參數
    dropout=0.0
)

# 初始化模型
model = Einet(config)
print(f"模型參數數量: {sum(p.numel() for p in model.parameters())}")

模型參數數量: 1119


In [33]:
model.leaf.base_leaf.initialize(X_train_reshaped, domains)

Initializing PiecewiseLinear Leaf Layer: 100%|██████████| 5/5 [00:00<00:00,  8.41it/s]


### Train the model

In [194]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
cross_entropy = torch.nn.CrossEntropyLoss()

def accuracy(model, X, y):
    with torch.no_grad():
        outputs = model(X)
        predictions = outputs.argmax(-1) 
        correct = (predictions == y).sum()
        total = y.shape[0]
        return 100. * correct / total


def f1(model, X, y, num_classes=None):
    with torch.no_grad():
        outputs = model(X)
        predictions = outputs.argmax(-1)
        # 假設 y 和 predictions 為 1D tensor
        if num_classes is None:
            num_classes = int(torch.max(y).item()) + 1
        f1_scores = []
        for c in range(num_classes):
            tp = ((predictions == c) & (y == c)).sum().item()
            fp = ((predictions == c) & (y != c)).sum().item()
            fn = ((predictions != c) & (y == c)).sum().item()
            if tp + fp + fn == 0:
                f1 = 0.0  # 防止0除
            else:
                precision = tp / (tp + fp) if (tp + fp) != 0 else 0.0
                recall = tp / (tp + fn) if (tp + fn) != 0 else 0.0
                if precision + recall == 0:
                    f1 = 0.0
                else:
                    f1 = 2 * (precision * recall) / (precision + recall)
            f1_scores.append(f1)
        # 取 macro-average F1
        return 100. * sum(f1_scores) / len(f1_scores)


X_test_reshaped = X_test_tensor.unsqueeze(1)

num_epochs = 10
print("Start training...")
for epoch in range(num_epochs):
    optimizer.zero_grad()
    log_likelihoods = model(X_train_reshaped)
    loss = cross_entropy(log_likelihoods, y_train_tensor)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 5 == 0:
        acc_train = accuracy(model, X_train_reshaped, y_train_tensor)
        acc_test = accuracy(model, X_test_reshaped, y_test_tensor)
        print(f"Epoch: {epoch+1:2d}, Loss: {loss.item():.4f}, "
              f"Train Acc: {acc_train:.2f}%, Test Acc: {acc_test:.2f}%")

print("Finished training！")

Start training...
Epoch:  5, Loss: 0.4959, Train Acc: 64.07%, Test Acc: 62.89%
Epoch: 10, Loss: 0.4865, Train Acc: 64.86%, Test Acc: 64.05%
Finished training！


### Going Federated by constructing a Data Partitioner for 
- Horizontal 
- Vertical
- Hybrid 

In [36]:
import random
from typing import Dict, List

In [84]:
# 健壯版聯邦學習資料分割器（與之前相同）
class FederatedDataPartitionerRobust:
    """健壯版聯邦學習資料分割器"""
    
    def __init__(self, X, y, feature_names, numeric_features, categorical_features):
        self.X = X
        self.y = y
        self.feature_names = feature_names
        self.numeric_features = numeric_features
        self.categorical_features = categorical_features
        
    def horizontal_partition(self, num_clients: int = 3, random_state: int = 42) -> Dict:
        """水平分割：相同特徵，不同樣本"""
        print(f"🔄 執行水平分割，分成 {num_clients} 個客戶端...")
        
        np.random.seed(random_state)
        n_samples = len(self.X)
        indices = np.arange(n_samples)
        np.random.shuffle(indices)
        
        clients = {}
        samples_per_client = n_samples // num_clients
        
        for i in range(num_clients):
            start_idx = i * samples_per_client
            if i == num_clients - 1:
                end_idx = n_samples
            else:
                end_idx = (i + 1) * samples_per_client
                
            client_indices = indices[start_idx:end_idx]
            
            clients[f'client_{i}'] = {
                'X': self.X[client_indices],
                'y': self.y[client_indices],
                'features': self.feature_names,
                'numeric_features': self.numeric_features,
                'categorical_features': self.categorical_features,
                'n_samples': len(client_indices),
                'n_features': len(self.feature_names),
                'sample_indices': client_indices,
                'feature_indices': list(range(len(self.feature_names))),
                'feature_overlap': self.feature_names  # 完全重疊
            }
            
            print(f"  客戶端 {i}: {len(client_indices)} 樣本, {len(self.feature_names)} 特徵")
            
        return {
            'type': 'horizontal',
            'clients': clients,
            'total_samples': n_samples,
            'total_features': len(self.feature_names)
        }
    
    def vertical_partition(self, num_clients: int = 3, random_state: int = 42) -> Dict:
        """垂直分割：相同樣本，不同特徵"""
        print(f"🔄 執行垂直分割，分成 {num_clients} 個客戶端...")
        
        random.seed(random_state)
        all_features = self.feature_names.copy()
        random.shuffle(all_features)
        
        features_per_client = len(all_features) // num_clients
        clients = {}
        
        for i in range(num_clients):
            start_idx = i * features_per_client
            if i == num_clients - 1:
                end_idx = len(all_features)
            else:
                end_idx = (i + 1) * features_per_client
                
            client_features = all_features[start_idx:end_idx]
            client_numeric = [f for f in client_features if f in self.numeric_features]
            client_categorical = [f for f in client_features if f in self.categorical_features]
            
            feature_indices = [self.feature_names.index(f) for f in client_features]
            
            clients[f'client_{i}'] = {
                'X': self.X[:, feature_indices],
                'y': self.y,
                'features': client_features,
                'numeric_features': client_numeric,
                'categorical_features': client_categorical,
                'n_samples': len(self.X),
                'n_features': len(client_features),
                'feature_indices': feature_indices,
                'sample_indices': list(range(len(self.X))),
                'feature_overlap': []  # 垂直分割無特徵重疊
            }
            
            print(f"  客戶端 {i}: {len(self.X)} 樣本, {len(client_features)} 特徵")
            
        return {
            'type': 'vertical', 
            'clients': clients,
            'total_samples': len(self.X),
            'total_features': len(self.feature_names)
        }
    
    def hybrid_partition(self, num_clients: int = 4, 
                        sample_overlap_ratio: float = 0.3,
                        feature_overlap_ratio: float = 0.2,
                        random_state: int = 42) -> Dict:
        """健壯版混合分割：正確實現樣本和特徵重疊"""
        print(f"🔄 執行健壯版混合分割，分成 {num_clients} 個客戶端...")
        print(f"  樣本重疊比例: {sample_overlap_ratio:.1%}")
        print(f"  特徵重疊比例: {feature_overlap_ratio:.1%}")
        
        np.random.seed(random_state)
        random.seed(random_state)
        
        n_samples = len(self.X)
        n_features = len(self.feature_names)
        
        # 樣本分配策略
        base_samples_per_client = max(1, int(n_samples * 0.5 / num_clients))
        overlap_sample_count = int(n_samples * sample_overlap_ratio)
        all_sample_indices = np.arange(n_samples, dtype=int)
        np.random.shuffle(all_sample_indices)
        
        base_samples_end = min(base_samples_per_client * num_clients, n_samples)
        base_sample_indices = all_sample_indices[:base_samples_end]
        
        if overlap_sample_count > 0 and base_samples_end < n_samples:
            remaining_samples = all_sample_indices[base_samples_end:]
            overlap_sample_indices = remaining_samples[:min(overlap_sample_count, len(remaining_samples))]
        else:
            overlap_sample_indices = np.array([], dtype=int)
        
        # 特徵分配策略
        base_features_per_client = max(1, int(n_features * 0.6 / num_clients))
        overlap_feature_count = int(n_features * feature_overlap_ratio)
        all_feature_indices = np.arange(n_features, dtype=int)
        np.random.shuffle(all_feature_indices)
        
        base_features_end = min(base_features_per_client * num_clients, n_features)
        base_feature_indices = all_feature_indices[:base_features_end]
        
        if overlap_feature_count > 0 and base_features_end < n_features:
            remaining_features = all_feature_indices[base_features_end:]
            overlap_feature_indices = remaining_features[:min(overlap_feature_count, len(remaining_features))]
        else:
            overlap_feature_indices = np.array([], dtype=int)
        
        print(f"  基礎樣本: {len(base_sample_indices)} 個，重疊樣本池: {len(overlap_sample_indices)} 個")
        print(f"  基礎特徵: {len(base_feature_indices)} 個，重疊特徵池: {len(overlap_feature_indices)} 個")
        
        # 為每個客戶端分配資料
        clients = {}
        
        for i in range(num_clients):
            # 樣本分配
            client_base_start = i * base_samples_per_client
            client_base_end = min((i + 1) * base_samples_per_client, len(base_sample_indices))
            client_base_samples = base_sample_indices[client_base_start:client_base_end]
            
            if len(overlap_sample_indices) > 0:
                overlap_sample_size = min(len(overlap_sample_indices), len(client_base_samples) // 2)
                if overlap_sample_size > 0:
                    client_overlap_samples = np.random.choice(
                        overlap_sample_indices, size=overlap_sample_size, replace=False
                    )
                else:
                    client_overlap_samples = np.array([], dtype=int)
            else:
                client_overlap_samples = np.array([], dtype=int)
            
            if len(client_overlap_samples) > 0:
                client_sample_indices = np.concatenate([client_base_samples, client_overlap_samples])
            else:
                client_sample_indices = client_base_samples.copy()
            client_sample_indices = np.unique(client_sample_indices)
            
            # 特徵分配
            client_base_feat_start = i * base_features_per_client
            client_base_feat_end = min((i + 1) * base_features_per_client, len(base_feature_indices))
            client_base_features = base_feature_indices[client_base_feat_start:client_base_feat_end]
            
            if len(overlap_feature_indices) > 0:
                guaranteed_overlap_size = max(1, len(overlap_feature_indices) // 2)
                guaranteed_overlap_features = overlap_feature_indices[:guaranteed_overlap_size]
                client_overlap_features = guaranteed_overlap_features
            else:
                client_overlap_features = np.array([], dtype=int)
            
            if len(client_overlap_features) > 0:
                client_feature_indices = np.concatenate([client_base_features, client_overlap_features])
            else:
                client_feature_indices = client_base_features.copy()
            client_feature_indices = np.unique(client_feature_indices.astype(int))
            
            # 確保至少有一個特徵
            if len(client_feature_indices) == 0:
                client_feature_indices = np.array([0], dtype=int)
            
            # 獲取特徵名稱
            client_features = [self.feature_names[idx] for idx in client_feature_indices]
            client_numeric = [f for f in client_features if f in self.numeric_features]
            client_categorical = [f for f in client_features if f in self.categorical_features]
            
            # 計算重疊特徵
            overlap_features_names = []
            if len(client_overlap_features) > 0:
                overlap_features_names = [self.feature_names[idx] for idx in client_overlap_features]
            
            clients[f'client_{i}'] = {
                'X': self.X[np.ix_(client_sample_indices, client_feature_indices)],
                'y': self.y[client_sample_indices],
                'features': client_features,
                'numeric_features': client_numeric,
                'categorical_features': client_categorical,
                'n_samples': len(client_sample_indices),
                'n_features': len(client_features),
                'feature_indices': client_feature_indices,
                'sample_indices': client_sample_indices,
                'base_sample_count': len(client_base_samples),
                'overlap_sample_count': len(client_overlap_samples),
                'base_feature_count': len(client_base_features),
                'overlap_feature_count': len(client_overlap_features),
                'feature_overlap': overlap_features_names
            }
            
            print(f"  客戶端 {i}: {len(client_sample_indices)} 樣本 × {len(client_features)} 特徵")
        
        return {
            'type': 'hybrid_robust',
            'clients': clients,
            'total_samples': n_samples,
            'total_features': n_features,
            'sample_overlap_ratio': sample_overlap_ratio,
            'feature_overlap_ratio': feature_overlap_ratio
        }


In [208]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
import time


def _accuracy(model, X, y):
    with torch.no_grad():
        outputs = model(X)
        predictions = outputs.argmax(-1) 
        correct = (predictions == y).sum()
        total = y.shape[0]
        return 100. * correct / total


def _f1_score(model, X, y, num_classes=None):
    with torch.no_grad():
        outputs = model(X)
        predictions = outputs.argmax(-1)
        # 假設 y 和 predictions 為 1D tensor
        if num_classes is None:
            num_classes = int(torch.max(y).item()) + 1
        f1_scores = []
        for c in range(num_classes):
            tp = ((predictions == c) & (y == c)).sum().item()
            fp = ((predictions == c) & (y != c)).sum().item()
            fn = ((predictions != c) & (y == c)).sum().item()
            if tp + fp + fn == 0:
                f1 = 0.0  # 防止0除
            else:
                precision = tp / (tp + fp) if (tp + fp) != 0 else 0.0
                recall = tp / (tp + fn) if (tp + fn) != 0 else 0.0
                if precision + recall == 0:
                    f1 = 0.0
                else:
                    f1 = 2 * (precision * recall) / (precision + recall)
            f1_scores.append(f1)
        # 取 macro-average F1
        return 100. * sum(f1_scores) / len(f1_scores)
    
# 基於 simple-einet API 的聯邦學習訓練器
class FederatedEiNetTrainer:
    """
    聯邦 EiNet 訓練器 - 使用 simple-einet API 風格
    """
    
    def __init__(self, partition_info: Dict):
        self.partition_info = partition_info
        self.client_models = {}
        self.client_domains = {}
        self.training_history = {}
        
    def create_domains(self, features: List[str], numeric_features: List[str], 
                      categorical_features: List[str], X_processed: pd.DataFrame) -> List:
        """為給定特徵創建 domains"""
        domains = []
        
        for feature in features:
            if feature in numeric_features:
                if feature in X_processed.columns:
                    min_val = float(X_processed[feature].min())
                    max_val = float(X_processed[feature].max())
                    domains.append(Domain.continuous_range(min_val, max_val))
                else:
                    # 如果特徵不在 X_processed 中，使用預設範圍
                    domains.append(Domain.continuous_range(-3.0, 3.0))
            else:
                if feature in X_processed.columns:
                    values = sorted(X_processed[feature].unique().tolist())
                    domains.append(Domain.discrete_bins(values))
                else:
                    # 如果特徵不在 X_processed 中，使用預設值
                    domains.append(Domain.discrete_bins([0, 1]))
                
        return domains
    
    def train_client(self, client_id: str, client_data: Dict, X_processed: pd.DataFrame,
                    epochs: int = 100, verbose: bool = False) -> Dict:
        """訓練單個客戶端的 EiNet 模型"""
        
        X_client = client_data['X']
        y_client = client_data['y']
        
        X_client_reshaped = X_client.unsqueeze(1)
        
        # 創建該客戶端特徵的 domains
        domains = self.create_domains(
            client_data['features'],
            client_data['numeric_features'],
            client_data['categorical_features'],
            X_processed,
        )
        
        num_features = client_data['n_features']
        
        # 動態調整模型複雜度
        if num_features < 3:
            depth, num_sums, num_leaves = 1, 4, 4
        elif num_features < 6:
            depth, num_sums, num_leaves = 1, 8, 8  
        else:
            depth, num_sums, num_leaves = 2, 12, 12
            
        config = EinetConfig(
            num_features=num_features,
            depth=depth,
            num_sums=num_sums,
            num_leaves=num_leaves,
            num_repetitions=3,
            num_classes=2,
            leaf_type=PiecewiseLinear,
            leaf_kwargs={'alpha': 0.1},
            dropout=0.0
        )
        
        model = Einet(config)
        model.leaf.base_leaf.initialize(X_client_reshaped, domains)
        
        cross_entropy = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        
        if verbose:
            print(f"    📊 模型配置: depth={depth}, sums={num_sums}, leaves={num_leaves}")
            print(f"    🔧 特徵域: {len(domains)} 個 domain")
        
        start_time = time.time()
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            ll = model(X_client_reshaped)
            loss = cross_entropy(ll, y_client)
            loss.backward()
            optimizer.step()
            
            if (epoch + 1) % 5 == 0:
                acc_train = _accuracy(model, X_client_reshaped, y_client)
                f1_train = _f1_score(model, X_client_reshaped, y_client)
                print(f"Epoch: {epoch+1:2d}, Loss: {loss.item():.4f}, "
                      f"Train Acc: {acc_train:.2f}%, F1: {f1_train:.2f}%")
            
        training_time = time.time() - start_time
        
        train_accuracy = _accuracy(model, X_client_reshaped, y_client)
        train_f1 = _f1_score(model, X_client_reshaped, y_client)
        
        if verbose:
            print(f"    ✅ 訓練準確率: {train_accuracy:.3f}")
            print(f"    📈 訓練 F1 分數: {train_f1:.3f}")
            print(f"    ⏱️  訓練時間: {training_time:.3f} 秒")
        
        return {
            'client_id': client_id,
            'model': model,
            'domains': domains,
            'train_accuracy': train_accuracy,
            'train_f1': train_f1,
            'training_time': training_time,
            'config': config
        }
    
    def train_federated_learning(self, X_processed: pd.DataFrame, epochs: int = 100, 
                               verbose: bool = True) -> Dict:
        """執行聯邦學習訓練"""
        
        print(f"\n🚀 開始 {self.partition_info['type']} 聯邦學習訓練...")
        print(f"訓練參數: epochs={epochs}")
        
        start_time = time.time()
        results = {}
        
        # 訓練每個客戶端
        for client_id, client_data in self.partition_info['clients'].items():
            if verbose:
                print(f"\n📍 訓練 {client_id}...")
                print(f"   資料規模: {client_data['n_samples']} 樣本 × {client_data['n_features']} 特徵")
                if client_data.get('feature_overlap'):
                    print(f"   🔗 重疊特徵: {len(client_data['feature_overlap'])} 個 {client_data['feature_overlap']}")
                
            # 訓練客戶端模型
            client_result = self.train_client(
                client_id, client_data, X_processed, epochs, verbose=verbose
            )
            
            # 儲存結果
            self.client_models[client_id] = client_result['model']
            self.client_domains[client_id] = client_result['domains']
            
            results[client_id] = {
                'train_accuracy': client_result['train_accuracy'],
                'train_f1': client_result['train_f1'],
                'training_time': client_result['training_time'],
                'n_samples': client_data['n_samples'],
                'n_features': client_data['n_features'],
                'feature_overlap': client_data.get('feature_overlap', []),
                'config': client_result['config'],
                'domains_count': len(client_result['domains'])
            }
            
            if verbose:
                print(f"   🎯 {client_id} 訓練完成")
        
        # 計算整體統計
        total_samples = sum(r['n_samples'] for r in results.values())
        weighted_accuracy = sum(
            r['train_accuracy'] * r['n_samples'] for r in results.values()
        ) / total_samples
        
        weighted_f1 = sum(
            r['train_f1'] * r['n_samples'] for r in results.values()
        ) / total_samples
        
        total_time = time.time() - start_time
        
        if verbose:
            print(f"\n📊 {self.partition_info['type']} 聯邦學習完成！")
            print(f"   ⏱️  總訓練時間: {total_time:.2f} 秒")
            print(f"   🎯 加權平均訓練準確率: {weighted_accuracy:.3f}")
            print(f"   📈 加權平均 F1 分數: {weighted_f1:.3f}")
            print(f"   🏢 參與客戶端: {len(results)} 個")
            print(f"   📊 總樣本數: {total_samples}")
        
        return {
            'type': self.partition_info['type'],
            'client_results': results,
            'weighted_accuracy': weighted_accuracy,
            'weighted_f1': weighted_f1,
            'total_training_time': total_time,
            'total_samples': total_samples,
            'num_clients': len(results)
        }
    
    def evaluate_on_test(self, X_test, y_test, test_feature_names) -> Dict:
        """在測試集上評估聯邦模型"""
        
        print(f"\n📋 在測試集上評估聯邦 EiNet 模型...")
        
        client_evaluations = {}
        predictions_ensemble = []
        probabilities_ensemble = []
        
        for client_id, model in self.client_models.items():
            client_data = self.partition_info['clients'][client_id]
            
            # 找到客戶端特徵在測試集中的對應索引
            client_feature_indices = []
            for feature in client_data['features']:
                if feature in test_feature_names:
                    client_feature_indices.append(test_feature_names.index(feature))
            
            if len(client_feature_indices) == 0:
                print(f"   ⚠️  {client_id}: 沒有對應的測試特徵")
                continue
                
            # 提取客戶端對應的測試特徵
            X_test_client = X_test[:, client_feature_indices]
            
            # 按照 simple-einet 風格 reshape
            X_test_client_reshaped = torch.tensor(X_test_client).unsqueeze(1)
            
            # 預測
            try:
                acc = accuracy(model, X_test_client_reshaped, y_test)
                fscore = f1(model, X_test_client_reshaped, torch.from_numpy(y_test))
                
                probs = torch.exp(model(X_test_client_reshaped))
                predictions = probs.argmax(dim=-1)  
                
                client_evaluations[client_id] = {
                    'accuracy': acc,
                    'f1_score': fscore,
                    'n_test_features': len(client_feature_indices),
                    'predictions': predictions,
                }
                
                predictions_ensemble.append(predictions.detach().numpy())
                probabilities_ensemble.append(probs.detach().numpy())
                
                print(f"   {client_id}: 準確率 {acc:.3f}, F1 {fscore:.3f} ({len(client_feature_indices)} 特徵)")
                
            except Exception as e:
                print(f"   ❌ {client_id}: 評估失敗 - {str(e)}")
        
        # 集成預測（簡單投票和平均機率）
        if predictions_ensemble and probabilities_ensemble:
            # 投票集成
            predictions_array = np.array(predictions_ensemble)
            ensemble_predictions_vote = np.apply_along_axis(
                lambda x: np.bincount(x).argmax(), axis=0, arr=predictions_array
            )
            
            # 機率平均集成
            ensemble_probabilities = np.mean(probabilities_ensemble, axis=0)
            ensemble_predictions_prob = np.argmax(ensemble_probabilities, axis=1)
            
            vote_accuracy = accuracy_score(y_test, ensemble_predictions_vote)
            vote_f1 = f1_score(y_test, ensemble_predictions_vote, average='weighted')
            # 
            prob_accuracy = accuracy_score(y_test, ensemble_predictions_prob)
            prob_f1 = f1_score(y_test, ensemble_predictions_prob, average='weighted')
            
            # 選擇更好的集成方法
            if prob_accuracy >= vote_accuracy:
                ensemble_accuracy = prob_accuracy
                ensemble_f1 = prob_f1
                ensemble_predictions = ensemble_predictions_prob
                ensemble_method = "機率平均"
            else:
                ensemble_accuracy = vote_accuracy
                ensemble_f1 = vote_f1
                ensemble_predictions = ensemble_predictions_vote
                ensemble_method = "投票"
                
        else:
            ensemble_accuracy = 0.0
            ensemble_f1 = 0.0
            ensemble_predictions = None
            ensemble_method = "無"
        
        print(f"\n🎯 集成結果 ({ensemble_method}):")
        print(f"   準確率: {ensemble_accuracy:.3f}")
        print(f"   F1 分數: {ensemble_f1:.3f}")
        
        return {
            'client_evaluations': client_evaluations,
            'ensemble_accuracy': ensemble_accuracy,
            'ensemble_f1': ensemble_f1,
            'ensemble_predictions': ensemble_predictions,
            'ensemble_method': ensemble_method
        }


In [209]:
partitioner = FederatedDataPartitionerRobust(
    X=X_train_tensor, 
    y=y_train_tensor,
    feature_names=X_processed.columns.tolist(),
    numeric_features=numeric_features,
    categorical_features=categorical_features
)

In [210]:
# 實驗 1: 水平聯邦學習 - Simple-EiNet 風格
print("\n" + "="*60)
print("🔵 實驗 1: 水平聯邦學習 (使用 Simple-EiNet API)")
print("相同特徵，不同樣本")
print("="*60)

horizontal_partition = partitioner.horizontal_partition(num_clients=3, random_state=42)
horizontal_trainer = FederatedEiNetTrainer(horizontal_partition)
horizontal_results = horizontal_trainer.train_federated_learning(
    X_processed, epochs=5, verbose=True
)

# 在測試集上評估
horizontal_eval = horizontal_trainer.evaluate_on_test(
    X_test, y_test, X_processed.columns.tolist()
)


🔵 實驗 1: 水平聯邦學習 (使用 Simple-EiNet API)
相同特徵，不同樣本
🔄 執行水平分割，分成 3 個客戶端...
  客戶端 0: 10099 樣本, 14 特徵
  客戶端 1: 10099 樣本, 14 特徵
  客戶端 2: 10100 樣本, 14 特徵

🚀 開始 horizontal 聯邦學習訓練...
訓練參數: epochs=5

📍 訓練 client_0...
   資料規模: 10099 樣本 × 14 特徵
   🔗 重疊特徵: 14 個 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']


Initializing PiecewiseLinear Leaf Layer: 100%|██████████| 3/3 [00:00<00:00, 11.51it/s]


    📊 模型配置: depth=2, sums=12, leaves=12
    🔧 特徵域: 14 個 domain
Epoch:  5, Loss: 0.7386, Train Acc: 39.94%, F1: 39.77%
    ✅ 訓練準確率: 39.945
    📈 訓練 F1 分數: 39.771
    ⏱️  訓練時間: 1.039 秒
   🎯 client_0 訓練完成

📍 訓練 client_1...
   資料規模: 10099 樣本 × 14 特徵
   🔗 重疊特徵: 14 個 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']


Initializing PiecewiseLinear Leaf Layer: 100%|██████████| 3/3 [00:00<00:00, 12.29it/s]


    📊 模型配置: depth=2, sums=12, leaves=12
    🔧 特徵域: 14 個 domain
Epoch:  5, Loss: 0.7369, Train Acc: 33.70%, F1: 33.70%
    ✅ 訓練準確率: 33.696
    📈 訓練 F1 分數: 33.696
    ⏱️  訓練時間: 0.764 秒
   🎯 client_1 訓練完成

📍 訓練 client_2...
   資料規模: 10100 樣本 × 14 特徵
   🔗 重疊特徵: 14 個 ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']


Initializing PiecewiseLinear Leaf Layer: 100%|██████████| 3/3 [00:00<00:00, 13.99it/s]


    📊 模型配置: depth=2, sums=12, leaves=12
    🔧 特徵域: 14 個 domain
Epoch:  5, Loss: 0.6639, Train Acc: 72.44%, F1: 56.13%
    ✅ 訓練準確率: 72.436
    📈 訓練 F1 分數: 56.126
    ⏱️  訓練時間: 0.767 秒
   🎯 client_2 訓練完成

📊 horizontal 聯邦學習完成！
   ⏱️  總訓練時間: 3.91 秒
   🎯 加權平均訓練準確率: 48.693
   📈 加權平均 F1 分數: 43.198
   🏢 參與客戶端: 3 個
   📊 總樣本數: 30298

📋 在測試集上評估聯邦 EiNet 模型...
   client_0: 準確率 46.187, F1 45.285 (14 特徵)
   client_1: 準確率 32.511, F1 32.490 (14 特徵)
   client_2: 準確率 73.935, F1 57.164 (14 特徵)

🎯 集成結果 (投票):
   準確率: 0.518
   F1 分數: 0.551
