In [1]:
import tensorflow as tf
import numpy as np
import json
import os
import random
from scipy.stats import dirichlet

2025-05-21 06:00:24.637084: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-21 06:00:24.641380: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-21 06:00:24.697069: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-21 06:00:24.697101: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-21 06:00:24.697159: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [2]:
def preprocess_image(image):
    return (image / 255.0).tolist()  # 转换为列表

def generate_data(x, y, num_clients, alpha):
    num_classes = 100
    client_data = {str(i): {'x': [], 'y': []} for i in range(num_clients)}
    
    # 为每个类别生成分布
    for k in range(num_classes):
        # 获取当前类别的所有样本
        idx_k = [i for i, label in enumerate(y) if label == k]
        random.shuffle(idx_k)
        total_samples = len(idx_k)
        
        # 使用Dirichlet分布生成每个客户端的样本数量
        proportions = dirichlet.rvs(alpha * np.ones(num_clients))[0]
        proportions = proportions * total_samples
        
        # 计算每个客户端应该获得的样本数量
        proportions = [int(p) for p in proportions]
        
        # 分配样本给各个客户端
        start_idx = 0
        for i in range(num_clients):
            end_idx = start_idx + proportions[i]
            # 直接添加预处理后的图像（已经是列表格式）
            client_data[str(i)]['x'].extend([x[idx] for idx in idx_k[start_idx:end_idx]])
            client_data[str(i)]['y'].extend([int(y[idx]) for idx in idx_k[start_idx:end_idx]])
            start_idx = end_idx
    
    return client_data

def get_cluster_id(labels):
    counts = np.bincount(labels, minlength=100)
    return int(np.argmax(counts))

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
y_train = y_train.reshape(-1)
y_test = y_test.reshape(-1)

# 使用全部数据
train_size = int(len(x_train) * 1.0)
test_size = int(len(x_test) * 1.0)

# train_size = int(len(x_train) * 0.2)
# test_size = int(len(x_test) * 0.2)

print(train_size)
print(test_size)

50000
10000


In [4]:
# 随机选择数据
train_indices = np.random.choice(len(x_train), train_size, replace=False)
test_indices = np.random.choice(len(x_test), test_size, replace=False)

x_train = x_train[train_indices]
y_train = y_train[train_indices]
x_test = x_test[test_indices]
y_test = y_test[test_indices]

# 数据预处理并转换为列表
x_train = [preprocess_image(img) for img in x_train]
x_test = [preprocess_image(img) for img in x_test]
y_train = y_train.tolist()
y_test = y_test.tolist()

# 生成训练和测试数据
num_clients = 100
alpha = 100  # Dirichlet分布的参数，越小越non-IID

train_data = generate_data(x_train, y_train, num_clients, alpha)
test_data = generate_data(x_test, y_test, num_clients, alpha)

In [5]:
# 为每个客户端生成cluster_id
cluster_ids = {client_id: get_cluster_id(data['y']) 
                for client_id, data in train_data.items()}

# 准备输出数据
train_output = {
    'user_data': train_data,
    'cluster_ids': list(cluster_ids.values()),
    'users': list(train_data.keys())
}

test_output = {
    'user_data': test_data,
    'cluster_ids': list(cluster_ids.values()),
    'users': list(train_data.keys())
}

In [6]:
# 打印统计信息
print(f"Number of clients: {num_clients}")
print(f"Average training samples per client: {np.mean([len(data['x']) for data in train_data.values()])}")
print(f"Average test samples per client: {np.mean([len(data['x']) for data in test_data.values()])}")
print(f"Cluster distribution: {np.bincount(list(cluster_ids.values()), minlength=100)}") 

# 打印每个客户端的数据量
print("\n每个客户端的数据量:")
for client_id in sorted(train_data.keys()):
    train_samples = len(train_data[client_id]['x'])
    test_samples = len(test_data[client_id]['x'])
    print(f"客户端 {client_id}: 训练集 {train_samples} 样本, 测试集 {test_samples} 样本")

Number of clients: 100
Average training samples per client: 449.82
Average test samples per client: 48.67
Cluster distribution: [6 4 3 4 2 2 2 3 4 4 1 3 1 1 1 2 4 3 3 1 0 0 1 1 2 1 2 0 1 2 1 2 1 2 0 1 1
 3 0 1 0 0 2 0 1 1 1 0 0 1 0 2 0 0 0 1 1 0 1 0 0 0 1 1 0 1 0 0 0 0 0 0 0 1
 1 0 0 1 0 0 1 1 0 0 0 1 0 1 0 0 0 0 1 1 0 0 0 0 1 0]

每个客户端的数据量:
客户端 0: 训练集 448 样本, 测试集 42 样本
客户端 1: 训练集 458 样本, 测试集 43 样本
客户端 10: 训练集 435 样本, 测试集 42 样本
客户端 11: 训练集 449 样本, 测试集 52 样本
客户端 12: 训练集 447 样本, 测试集 41 样本
客户端 13: 训练集 452 样本, 测试集 51 样本
客户端 14: 训练集 445 样本, 测试集 53 样本
客户端 15: 训练集 455 样本, 测试集 44 样本
客户端 16: 训练集 451 样本, 测试集 56 样本
客户端 17: 训练集 450 样本, 测试集 47 样本
客户端 18: 训练集 452 样本, 测试集 47 样本
客户端 19: 训练集 452 样本, 测试集 49 样本
客户端 2: 训练集 450 样本, 测试集 59 样本
客户端 20: 训练集 450 样本, 测试集 40 样本
客户端 21: 训练集 455 样本, 测试集 47 样本
客户端 22: 训练集 443 样本, 测试集 45 样本
客户端 23: 训练集 445 样本, 测试集 38 样本
客户端 24: 训练集 453 样本, 测试集 48 样本
客户端 25: 训练集 448 样本, 测试集 42 样本
客户端 26: 训练集 439 样本, 测试集 51 样本
客户端 27: 训练集 446 样本, 测试集 47 样本
客户端 28: 训练集 458 样本, 测试集 53 样本

In [7]:
# 创建输出目录
os.makedirs('/root/learning-tangle/leaf/data/cifar100/data/train', exist_ok=True)
os.makedirs('/root/learning-tangle/leaf/data/cifar100/data/test', exist_ok=True)

# 保存数据
with open('/root/learning-tangle/leaf/data/cifar100/data/train/data.json', 'w') as file:
    json.dump(train_output, file)
with open('/root/learning-tangle/leaf/data/cifar100/data/test/data.json', 'w') as file:
    json.dump(test_output, file)