In [None]:
import numpy as np
import matplotlib.pyplot as plt
import hydra
import seaborn as sns
from federatedlearning.datasets.common import get_dataset
from hydra import compose, initialize
from omegaconf import OmegaConf

In [None]:
EXPERIMENT_ID = ""
RUN_ID = ""

In [None]:
# hydra global initialization
hydra.core.global_hydra.GlobalHydra.instance().clear()
# config_path in initialize() must be relative
initialize(
    version_base="1.1",
    config_path=f"../mlruns/{EXPERIMENT_ID}/{RUN_ID}/artifacts",
    job_name="jupyterlab",
)
cfg: OmegaConf = compose(config_name="config")
print(OmegaConf.to_yaml(cfg))

In [None]:
# Get Dataset
train_dataset, _, client_groups = get_dataset(cfg)

# Get number of clients and classes
num_clients = cfg.federatedlearning.num_clients
num_classes = len(np.unique(train_dataset.targets))

# Collect label data for each client
label_counts = np.zeros((num_clients, num_classes), dtype=int)
for client_id in range(num_clients):
    labels = [
        train_dataset.targets[int(idx)] for idx in client_groups[client_id]
    ]
    for label in labels:
        label_counts[client_id, label] += 1

# Calculate the percentage of data each client has
label_ratios = label_counts / label_counts.sum(axis=1, keepdims=True)

# Visualize
plt.figure(figsize=(15, 10))
sns.heatmap(
    label_ratios,
    annot=True,
    fmt=".2f",
    cmap="YlGnBu",
    xticklabels=[f"Class {i}" for i in range(num_classes)],
    yticklabels=[f"Client {i}" for i in range(num_clients)],
)

plt.title("Data Distribution per Client (Proportion)")
plt.xlabel("Class")
plt.ylabel("Client ID")
plt.tight_layout()

# Save this figure
save_path = "/workspace/outputs/data_distribution_per_client.png"
plt.savefig(save_path)

plt.show()