In [1]:
%pip install tensorflow
%pip install scipy

import tensorflow as tf
import numpy as np
import json
import os
import random
from scipy.stats import dirichlet

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


2025-08-08 10:48:54.320607: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-08 10:48:54.323309: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-08 10:48:54.331405: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-08 10:48:54.348290: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754650134.373047  118057 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754650134.38

In [None]:
def preprocess_image(image):
    # 保持numpy数组格式，只进行归一化
    return image / 255.0

def generate_data(x, y, num_clients, alpha):
    num_classes = 10
    client_data = {str(i): {'x': [], 'y': []} for i in range(num_clients)}
    
    for k in range(num_classes):
        idx_k = [i for i, label in enumerate(y) if label == k]
        random.shuffle(idx_k)
        n_k = len(idx_k)
        if n_k == 0:
            continue
            
        proportions = dirichlet.rvs(alpha * np.ones(num_clients))[0]
        proportions = proportions / proportions.sum()  # 归一化
        
        # 按比例分配样本数
        proportions = (proportions * n_k).astype(int)
        
        # 改进的修正算法：随机分配剩余样本
        remaining = n_k - proportions.sum()
        if remaining > 0:
            # 随机选择客户端分配剩余样本
            indices = np.random.choice(num_clients, remaining, replace=False)
            for idx in indices:
                proportions[idx] += 1
        elif remaining < 0:
            # 随机减少样本
            indices = np.random.choice(num_clients, abs(remaining), replace=False)
            for idx in indices:
                if proportions[idx] > 0:
                    proportions[idx] -= 1
        
        # 分配样本
        start = 0
        for i, count in enumerate(proportions):
            if count > 0:
                client_data[str(i)]['x'].append(x[idx_k[start:start+count]])
                client_data[str(i)]['y'].append(y[idx_k[start:start+count]])
            start += count
    
    # 合并每个客户端的数据
    for client_id in client_data:
        if client_data[client_id]['x']:
            client_data[client_id]['x'] = np.concatenate(client_data[client_id]['x'], axis=0)
            client_data[client_id]['y'] = np.concatenate(client_data[client_id]['y'], axis=0)
        else:
            client_data[client_id]['x'] = np.array([])
            client_data[client_id]['y'] = np.array([])
    
    # 保证每个客户端至少有1个样本
    empty_clients = [cid for cid, data in client_data.items() if len(data['x']) == 0]
    for cid in empty_clients:
        # 从样本最多的客户端借一个
        max_cid = max(client_data, key=lambda k: len(client_data[k]['x']))
        client_data[cid]['x'] = client_data[max_cid]['x'][:1]
        client_data[cid]['y'] = client_data[max_cid]['y'][:1]
        client_data[max_cid]['x'] = client_data[max_cid]['x'][1:]
        client_data[max_cid]['y'] = client_data[max_cid]['y'][1:]
    
    return client_data

def get_cluster_id(labels):
    counts = np.bincount(labels, minlength=10)
    return int(np.argmax(counts))

# 修改输出格式，确保numpy数组被正确序列化
def numpy_to_list(data):
    if isinstance(data, np.ndarray):
        return data.tolist()
    elif isinstance(data, dict):
        return {k: numpy_to_list(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [numpy_to_list(item) for item in data]
    return data

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train = y_train.reshape(-1)
y_test = y_test.reshape(-1)

# 使用全部数据
train_size = int(len(x_train) * 1.0)
test_size = int(len(x_test) * 1.0)

# train_size = int(len(x_train) * 0.2)
# test_size = int(len(x_test) * 0.2)

print(train_size)
print(test_size)

60000
10000


In [4]:
# 随机选择数据
train_indices = np.random.choice(len(x_train), train_size, replace=False)
test_indices = np.random.choice(len(x_test), test_size, replace=False)

x_train = x_train[train_indices]
y_train = y_train[train_indices]
x_test = x_test[test_indices]
y_test = y_test[test_indices]

# 数据预处理并保持numpy数组格式
x_train = np.array([preprocess_image(img) for img in x_train])
x_test = np.array([preprocess_image(img) for img in x_test])
y_train = np.array(y_train)
y_test = np.array(y_test)

# 生成训练和测试数据
num_clients = 100
alpha = 0.1  # Dirichlet分布的参数，越小越non-IID

train_data = generate_data(x_train, y_train, num_clients, alpha)
test_data = generate_data(x_test, y_test, num_clients, alpha)

In [5]:
# 为每个客户端生成cluster_id
cluster_ids = {client_id: get_cluster_id(data['y']) 
                for client_id, data in train_data.items()}

# 保存数据前转换numpy数组为列表
train_output = {
    'user_data': numpy_to_list(train_data),
    'cluster_ids': list(cluster_ids.values()),
    'users': list(train_data.keys())
}

test_output = {
    'user_data': numpy_to_list(test_data),
    'cluster_ids': list(cluster_ids.values()),
    'users': list(train_data.keys())
}

In [6]:
# 打印统计信息
print(f"Number of clients: {num_clients}")
print(f"Average training samples per client: {np.mean([len(data['x']) for data in train_data.values()])}")
print(f"Average test samples per client: {np.mean([len(data['x']) for data in test_data.values()])}")
print(f"Cluster distribution: {np.bincount(list(cluster_ids.values()), minlength=100)}") 

# 打印每个客户端的数据量
print("\n每个客户端的数据量:")
for client_id in sorted(train_data.keys()):
    train_samples = len(train_data[client_id]['x'])
    test_samples = len(test_data[client_id]['x'])
    print(f"客户端 {client_id}: 训练集 {train_samples} 样本, 测试集 {test_samples} 样本")

Number of clients: 100
Average training samples per client: 600.0
Average test samples per client: 100.0
Cluster distribution: [10 13 13 12  9 10 11 11  3  8  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0]

每个客户端的数据量:
客户端 0: 训练集 150 样本, 测试集 98 样本
客户端 1: 训练集 370 样本, 测试集 58 样本
客户端 10: 训练集 857 样本, 测试集 114 样本
客户端 11: 训练集 786 样本, 测试集 82 样本
客户端 12: 训练集 943 样本, 测试集 46 样本
客户端 13: 训练集 153 样本, 测试集 353 样本
客户端 14: 训练集 44 样本, 测试集 31 样本
客户端 15: 训练集 716 样本, 测试集 14 样本
客户端 16: 训练集 203 样本, 测试集 58 样本
客户端 17: 训练集 131 样本, 测试集 37 样本
客户端 18: 训练集 1820 样本, 测试集 41 样本
客户端 19: 训练集 1739 样本, 测试集 134 样本
客户端 2: 训练集 171 样本, 测试集 36 样本
客户端 20: 训练集 159 样本, 测试集 126 样本
客户端 21: 训练集 2146 样本, 测试集 6 样本
客户端 22: 训练集 22 样本, 测试集 30 样本
客户端 23: 训练集 242 样本, 测试集 5 样本
客户端 24: 训练集 460 样本, 测试集 2 样本
客户端 25: 训练集 88 样

In [7]:
# 创建输出目录
os.makedirs('/root/learning-tangle/leaf/data/mnist/data/train', exist_ok=True)
os.makedirs('/root/learning-tangle/leaf/data/mnist/data/test', exist_ok=True)

# 保存数据
with open('/root/learning-tangle/leaf/data/mnist/data/train/data.json', 'w') as file:
    json.dump(train_output, file)
with open('/root/learning-tangle/leaf/data/mnist/data/test/data.json', 'w') as file:
    json.dump(test_output, file)
