In [111]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from collections import defaultdict
# from models.Nets import *
import argparse

class DynamicCNN(nn.Module):
    def __init__(self, num_layers, input_channels=3, base_channels=32, num_classes=10):
        super(DynamicCNN, self).__init__()
        self.features = self._make_conv_layers(num_layers, input_channels, base_channels)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_channels * (2 ** ((num_layers - 1) // 2)), num_classes)
        )

    def _make_conv_layers(self, num_layers, in_channels, base_channels):
        layers = []
        for i in range(num_layers):
            out_channels = base_channels * (2 ** (i // 2))  
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            if i % 2 == 1:  
                layers.append(nn.MaxPool2d(2))
            in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def state_dict(self, exclude_classifier=True):
        original_state = super().state_dict()
        if exclude_classifier:
            
            return {k: v for k, v in original_state.items() if not k.startswith('classifier')}
        return original_state
    
    
    def get_conv_layers(self):
        return [m for m in self.features if isinstance(m, nn.Conv2d)]

def generate_client_layer_config(num_level, k, beta=1.0, seed=None):
    if seed is not None:
        np.random.seed(seed)

    num_types = len(num_level)
    alpha = [beta] * num_types
    proportion = np.random.dirichlet(alpha)

    counts = np.floor(proportion * k).astype(int)
    remainder = k - np.sum(counts)

    for i in np.random.choice(num_types, remainder, replace=True):
        counts[i] += 1

    client_layers = []
    for count, layer in zip(counts, num_level):
        client_layers.extend([layer] * count)

    np.random.shuffle(client_layers)

    return client_layers, proportion.tolist()
    
def compare_models_aggregatable_layers(m1, m2):
    layers1 = m1.get_conv_layers()
    layers2 = m2.get_conv_layers()
    min_len = min(len(layers1), len(layers2))
    count = 0
    for i in range(min_len):
        if layers1[i].weight.shape == layers2[i].weight.shape:
            count += 1
        else:
            break
    return count

def get_shared_keys(models):
    from collections import Counter
    keys_by_model = [set(m.state_dict().keys()) for m in models]
    print( keys_by_model)
    shared = set.intersection(*keys_by_model)
    conv_keys = [k for k in shared if 'classifier' not in k]
    return conv_keys

class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(800, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        result = {}
        result['activation'] = x
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        result['hint'] = x
        x = F.relu(self.fc1(x))
        result['representation'] = x
        x = self.fc2(x)
        result['output'] = x
        return result

if __name__ == "__main__":
    num_level = [2, 4, 6]
    k = 10
    beta = 0.7

    num_layers_list, prop = generate_client_layer_config(num_level, k, beta, seed=42)
    
    # num_layers_list=[2,2,4,10,6,4,4]
#     trans_cifar10_train = transforms.Compose([transforms.ToTensor(),
#                                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#     # datasets.CIFAR10('./data/cifar10', train=True, download=True, transform=trans_cifar10_train)
#     train_set =datasets. CIFAR10('./data/cifar10', train=True, download=True, transform=trans_cifar10_train)
#     train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

#     
    # model1 = DynamicCNN(num_layers=8, input_channels=3)
    # x, y = next(iter(train_loader))
    # output = model1(x)
    # print(output.shape)
    models = [DynamicCNN(n) for n in num_layers_list]
    print(models[2].state_dict().keys())
    # models = [CNNMnist() for i in num_layers_list]
    # print(models[0].state_dict().keys())
    layer_groups = defaultdict(list)
    for i, model in enumerate(models):
        layer_groups[len(model.state_dict().keys())].append(i)
    layer_types = sorted(layer_groups.keys()) 
    
#     for i, model in enumerate(models):
#         num_conv_params = 2 * len(model.get_conv_layers())  
#         layer_groups[num_conv_params].append(i)
#     layer_types = sorted(layer_groups.keys()) 
#     print(layer_types)
#     print(layer_groups)
    max_layer = max(layer_types)
    # w_avg = copy.deepcopy(next(m for m in models if  2 * len(m.get_conv_layers()) == max_layer))
    w_avg = copy.deepcopy(next(m for m in models if len(m.state_dict())== max_layer))
    print( len(w_avg.state_dict()))
#     # conv_layers_list = get_shared_keys(models) 
#     print(len(conv_layers_list))
#     for i, model in enumerate(models):
#         layer_groups[len(model.keys())].append(i)
#     print(layer_groups)
    
    # state_dicts = [m.state_dict() for m in models]
    # print(models[0])
    # print(models[0].get_conv_layers())
    # print(models[1].state_dict().keys())
    # shared_keys = get_shared_keys(models)
    # print(shared_keys)
    # print(conv_layers_list)
   
    max_layers = max(num_layers_list)

    aggregatable_layers_per_model = []
#     for i, model in enumerate(models):
#         current_layers = num_layers_list[i]
#         deeper_models = [models[j] for j in range(len(models)) if num_layers_list[j] >= current_layers]
#         possible_counts = [compare_models_aggregatable_layers(model, m_deep) for m_deep in deeper_models]
#         aggregatable = min(possible_counts) if possible_counts else 0
#         # if current_layers == max_layers:
#         #     
#         #     aggregatable = current_layers
#         # else:
#         #     
#         #     deeper_models = [models[j] for j in range(len(models)) if num_layers_list[j] > current_layers]
#         #     possible_counts = [compare_models_aggregatable_layers(model, m_deep) for m_deep in deeper_models]
#         #     aggregatable = min(possible_counts) if possible_counts else 0
#         aggregatable_layers_per_model.append(aggregatable)

#    
#     print("✅ Aggregatable convolutional layers per model:")
#     for i, num in enumerate(num_layers_list):
#         print(f"Model {i+1} ({num} conv layers): {aggregatable_layers_per_model[i]} aggregatable layers")

客户端层数分配: [2, 4, 6, 4, 6, 4, 4, 4, 4, 4]
分布比例: [0.19514094537368495, 0.7499062374706622, 0.05495281715565283]
dict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias'])
12


In [94]:
def hierarchical_aggregation(w, lens=None, shared_keys=None):
    
    # 1. 分类客户端模型（按层数）并排序
    layer_groups = defaultdict(list)
    # for i, model in enumerate(w):
    #     layer_groups[len(model.state_dict().keys())].append(i)
    # layer_types = sorted(layer_groups.keys())  
    for i, model in enumerate(models):
        num_conv_params = 2 * len(model.get_conv_layers())  
        layer_groups[num_conv_params].append(i)
    layer_types = sorted(layer_groups.keys()) 
    
    
    # 2. 初始化全局模型（选择最大层数的第一个客户端作为模板）
    max_layer = max(layer_types)
    w_avg = copy.deepcopy(next(m for m in w if len(m.state_dict().keys()) == max_layer))
    
    # 3. 构建分层权重映射表
    weight_groups = {}
    for depth in range(1, max_layer + 1):
        # 筛选拥有至少depth层的客户端
        valid_clients = [i for i, model in enumerate(w) if len(model.keys()) >= depth]
        
        if lens is None:
            # 均分权重
            weight_groups[depth] = [1.0/len(valid_clients)] * len(valid_clients)
        else:
            # 按样本量加权
            total = sum(lens[i] for i in valid_clients)
            weight_groups[depth] = [lens[i]/total for i in valid_clients]
    
    # 4. 按层深度聚合
    layer_names = sorted(w_avg.keys())  # 确保层顺序一致
    for depth, key in enumerate(layer_names, 1):
        if shared_keys is not None and key not in shared_keys:
            continue
        
        # 获取当前层有效的客户端和权重
        valid_indices = [i for i, model in enumerate(w) if key in model]
        weights = [weight_groups[depth][valid_indices.index(i)] for i in valid_indices]
        
        # 加权聚合
        w_avg[key] = torch.zeros_like(w[valid_indices[0]][key])
        for idx, weight in zip(valid_indices, weights):
            w_avg[key] += w[idx][key] * weight
    
    return w_avg