In [None]:
import json
import os
import pickle

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from ray import tune

from fllib.datasets import DatasetCatalog

sns.set(font='Times New Roman', font_scale=1)
sns.set_style("white")
sns.set_context("notebook", rc={"lines.linewidth": 4})
font_size = 16

In [None]:
def partition_dataset(dataset_config):
    dataset = DatasetCatalog.get_dataset(dataset_config)

    data = {}
    for i, client_dataset in enumerate(dataset.client_datasets):
        # Assuming client_dataset.train_set.dataset.targets is accessible
        # and contains all targets
        all_targets = torch.tensor(client_dataset.train_set.dataset.targets)

        # Now, use the indices from the Subset to filter out the relevant targets
        subset_indices = torch.tensor(client_dataset.train_set.indices)
        targets = all_targets[subset_indices]

        # Count the occurrences of each unique target
        unique, counts = torch.unique(targets, return_counts=True)

        # Convert to a dictionary or another suitable format if necessary
        class_counts = dict(zip(unique.tolist(), counts.tolist()))
        data[f"Client {i}"] = class_counts
        classes = range(10)  # Assuming 10 classes
    x, y, sizes = [], [], []
    
    for client, counts in data.items():
        client_id = int(client.split()[1])
        for class_id in classes:
            x.append(client_id)
            y.append(class_id)
            sizes.append(counts.get(class_id, 0))
    return {"config": dataset_config, "result": (x, y, sizes)}

In [None]:
alphas = [0.01, 0.1, 1.0, 10.0][::-1]
dataset_config = {
    "type": "CIFAR10",
    "splitter_config": {
        "type": "DirichletSplitter",
        "random_seed": 12345,
        "num_clients": 15,
        "alpha": tune.grid_search(alphas)
    },
}

tuner = tune.Tuner(
    partition_dataset,
    param_space=dataset_config,
)
results = tuner.fit()

dfs = []
for result in results:
    params_path = os.path.join(result.path, "params.pkl")
    result_path = os.path.join(result.path, "result.json")
    # 读取 .pkl 文件
    with open(params_path, 'rb') as f:
        params = pickle.load(f)
    
    with open(result_path, 'r') as file:
        exp_result = json.load(file)

    # 创建 DataFrame
    df = pd.DataFrame({
        'x': exp_result['result'][0],
        'y': exp_result['result'][1],
        'sizes': exp_result['result'][2],
        'alpha': [params['splitter_config']['alpha']]* len(exp_result['result'][0])
    })
    dfs.append(df)
df_combined = pd.concat(dfs, ignore_index=True)

In [None]:

# 创建 FacetGrid 对象
g = sns.FacetGrid(df_combined, col="alpha", col_wrap=4, col_order=alphas, height=4., sharey=True, sharex=False)

# 在每个子网格中添加散点图
g.map_dataframe(sns.scatterplot, 'x', 'y', size='sizes', hue='x', sizes=(0, 500), palette="muted",
    edgecolor="black", linewidth=1, 
    legend=False,)

for ax in g.axes.flat:
    for _, spine in ax.spines.items():
        spine.set_visible(True)
    new_title = r"$\alpha$ = " + ax.get_title().split("=")[-1]
    ax.set_title(new_title, fontdict={'weight': 'bold', 'size': font_size})
    ax.set_ylabel(f"Label", fontdict={'weight': 'bold', 'size': font_size})
    ax.set_xlabel(f"Client ID", fontdict={'weight': 'bold', 'size': font_size})
    ax.set_xticks(list(set(df_combined['x'])), list(set(df_combined['x'])), size=font_size, weight='bold')
    ax.set_yticks(list(set(df_combined['y'])), list(set(df_combined['y'])), size=font_size, weight='bold')

plt.show()
g.savefig('./dirichlet_partition.png', bbox_inches = "tight", pad_inches=0.01)

In [None]:
num_shards = [50, 100, 500, 1000][::-1]
dataset_config = {
    "type": "CIFAR10",
    "splitter_config": {
        "type": "ShardSplitter",
        "num_clients": 15,
        "num_shards": tune.grid_search(num_shards)
    },
}

tuner = tune.Tuner(
    partition_dataset,
    param_space=dataset_config,
)
results = tuner.fit()

dfs = []
for result in results:
    params_path = os.path.join(result.path, "params.pkl")
    result_path = os.path.join(result.path, "result.json")
    # 读取 .pkl 文件
    with open(params_path, 'rb') as f:
        params = pickle.load(f)
    
    with open(result_path, 'r') as file:
        exp_result = json.load(file)

    # 创建 DataFrame
    df = pd.DataFrame({
        'x': exp_result['result'][0],
        'y': exp_result['result'][1],
        'sizes': exp_result['result'][2],
        'num_shards': [params['splitter_config']['num_shards']]* len(exp_result['result'][0])
    })
    dfs.append(df)
df_combined = pd.concat(dfs, ignore_index=True)

In [None]:

# 创建 FacetGrid 对象
g = sns.FacetGrid(df_combined, col="num_shards", col_wrap=4, col_order=num_shards, height=4., sharey=True, sharex=False)

# 在每个子网格中添加散点图
g.map_dataframe(sns.scatterplot, 'x', 'y', size='sizes', hue='x', sizes=(0, 500), palette="muted",
    edgecolor="black", linewidth=1, 
    legend=False,)

for ax in g.axes.flat:
    for _, spine in ax.spines.items():
        spine.set_visible(True)
    new_title = ax.get_title().split("=")[-1] + " shards"
    ax.set_title(new_title, fontdict={'weight': 'bold', 'size': font_size})
    ax.set_ylabel(f"Label", fontdict={'weight': 'bold', 'size': font_size})
    ax.set_xlabel(f"Client ID", fontdict={'weight': 'bold', 'size': font_size})
    ax.set_xticks(list(set(df_combined['x'])), list(set(df_combined['x'])), size=font_size, weight='bold')
    ax.set_yticks(list(set(df_combined['y'])), list(set(df_combined['y'])), size=font_size, weight='bold')

plt.show()
g.savefig('./shard_partition', bbox_inches = "tight", pad_inches=0.01)