In [44]:
from os import scandir
from os.path import join
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

from src.utils.utils import plot_disease_distribution

In [45]:
train_data_df_path = "../datasets/csv_splits/COVID-19_Radiography_Dataset_train.csv"
val_data_df_path = "../datasets/csv_splits/COVID-19_Radiography_Dataset_val.csv"

train_data_df = pd.read_csv(train_data_df_path)
val_data_df = pd.read_csv(val_data_df_path)

In [46]:
train_data_df.head()

Unnamed: 0,Image Index,Finding Labels,Path,COVID,Lung_Opacity,Normal,Viral Pneumonia
0,COVID-1.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
1,COVID-1000.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
2,COVID-1001.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
3,COVID-1002.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
4,COVID-1004.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0


In [47]:
def get_dirichlet_split(k: int, n: int) -> np.ndarray:
    """Get a random split of size `n` into `k` parts using a Dirichlet distribution.

    Args:
    -----
        k (int): Number of parts to split the elements into (clients).
        n (int): Number of elements to split (total number of images).

    Returns:
    --------
        np.ndarray: Array containing the split sizes.
    """
    # rng = np.random.default_rng(seed)
    split_sizes = np.random.dirichlet(np.ones(k), size=1)[0]

    # Scale the split sizes to the number of elements adjust the last split size to make sure the sum is equal to n
    split_sizes = np.round(split_sizes * n).astype(int)
    split_sizes[-1] = n - split_sizes[:-1].sum()
    return split_sizes


def get_split(n_splits: int, unbalanced: bool, class_distribution: dict) -> np.ndarray:
    """Get the split sizes for the clients.

    Args:
    -----
        n_splits (int): Number of clients to split the data into.
        unbalanced (bool): Whether to split the data into unbalanced clients.
        class_distribution (dict): Dictionary containing the class names and their counts.

    Returns:
    --------
        np.ndarray: Array containing the split sizes.
    """
    if unbalanced:
        split_sizes = {cls: 0 for cls in class_distribution.keys()}
        for cls in class_distribution.keys():
            split_sizes[cls] = get_dirichlet_split(n_splits, class_distribution[cls])
    else:
        split_sizes = {cls: np.full(n_splits, class_distribution[cls] // n_splits) for cls in class_distribution.keys()}
        for cls in class_distribution.keys():
            split_sizes[cls][:class_distribution[cls] % n_splits] += 1

    return split_sizes


def split_targets(reamaining_clients: list, removed_images: dict, df: pd.DataFrame) -> dict:
    """Split the removed images into the remaining clients.

    Args:
    -----
        reamaining_clients (list): List of clients to assign the removed images to.
        removed_images (dict): Dictionary containing the removed images for each client.

    Returns:
    --------
        dict: Dictionary containing the split images for each client.
    """
    splits = {idx: {} for idx in removed_images.keys()}
    for idx, imgs in removed_images.items():
        # for client in reamaining_clients:
        #     if imgs.size > 0:
        classes = df[df["Image Index"].isin(imgs)]["Finding Labels"].str.split("|").explode().unique()
        # Split classes into n remaining clients
        for cls in classes:
            filtered = df[df["Image Index"].isin(imgs)]
            filtered = filtered[filtered["Finding Labels"].str.contains(cls)]
            split_sizes = get_split(len(reamaining_clients), len(filtered), False)
            split_clients = np.split(filtered["Image Index"].values, np.cumsum(split_sizes)[:-1])
            for i, client in enumerate(reamaining_clients):
                # splits[client].extend(split_clients[i])
                splits[idx][client] = split_clients[i]
                
    return splits


def get_class_distribution(df: pd.DataFrame) -> Dict[str, int]:
    """Get the class distribution of the dataset.

    Args:
    -----
        df (pd.DataFrame): DataFrame containing the data with "Finding Labels".

    Returns:
    --------
        Dict[str, int]: Dictionary containing the class distribution.
    """
    class_distribution = df["Finding Labels"].str.split("|").explode().value_counts().to_dict()
    return class_distribution


def random_fl_split(
    n_splits: int,
    df: pd.DataFrame,
    unbalanced: bool = False,
    extreme: bool = False,
    # target_clients: Union[List[int], int] = None,
    # target_classes: Union[List[str], str] = None,
    target_classes: Dict[int, List[str]] = None,
    seed: int = 42,
) -> Tuple[pd.DataFrame]:
    """
    Splits the dataset into `n_splits` clients using random assignment with optional unbalancing.

    Args:
        n_splits (int): Number of clients to split the data into.
        df (pd.DataFrame): DataFrame containing the data with "Image Index" and "Finding Labels".
        unbalanced (bool): If True, creates unbalanced splits.
        extreme (bool): If True, applies extreme unbalancing based on `target_clients` and `target_classes`.
        target_classes (Dict[int, List[str]]): Dictionary containing the target classes to be removed from each client.
        seed (int): Seed for reproducibility.

    Returns:
        Tuple[pd.DataFrame]: Tuple containing the DataFrames for each client.
    """
    np.random.seed(seed)

    # images = df["Image Index"].unique()
    # np.random.shuffle(images)

    # assert 0 < n_splits <= len(images), "n_splits must be between 1 and the number of unique images."
    assert len(target_classes) < len(df["Finding Labels"].unique()), "Number of target classes must be less than the available number of classes."
    
    # if isinstance(target_clients, int):
    #     target_clients = [target_clients]
    if isinstance(target_classes, str):
        target_classes = [target_classes]

    # Generate split sizes
    # if unbalanced:
    #     random_points = np.sort(np.random.choice(len(images) - 1, n_splits - 1, replace=False))
    #     split_sizes = np.diff([0] + random_points.tolist() + [len(images)])
    # else:
    #     split_sizes = np.full(n_splits, len(images) // n_splits)
    #     split_sizes[:len(images) % n_splits] += 1
    
    print(f"Unbalanced: {unbalanced}")
    cls_dist = get_class_distribution(df)
    split_sizes = get_split(n_splits, unbalanced, class_distribution=cls_dist)
    # print(f"Split sizes: {split_sizes}")

    # clients = np.split(images, np.cumsum(split_sizes)[:-1])
    clients = {idx: [] for idx in range(n_splits)}
    for cls, sizes in split_sizes.items():
        images = df[df["Finding Labels"].str.contains(cls)]["Image Index"].values
        for i, size in enumerate(sizes):
            clients[i].extend(np.random.choice(images, size, replace=False))
            # choices = np.setdiff1d(images, clients[i])
            images = np.setdiff1d(images, clients[i])

    if extreme and target_classes:
        target_clients = list(target_classes.keys())
        to_swap = {idx: [] for idx in target_clients}
    #     for idx in target_clients:
    #         filtered = df[df["Image Index"].isin(clients[idx])]
    #         filtered = filtered[filtered["Finding Labels"].str.contains("|".join(target_classes[idx]))]
    #         to_swap[idx] = filtered["Image Index"].values

    #     remaining_clients = [client for client in range(n_splits) if client not in target_clients]

    #     removed_images = split_targets(remaining_clients, to_swap, df)

    #     for idx, splits in removed_images.items():
    #         for client, split in splits.items():
    #             clients[client] = np.concatenate([clients[client], split])
    #             clients[idx] = np.setdiff1d(clients[idx], split)

    # client_dfs = [df[df["Image Index"].isin(client)].reset_index(drop=True) for client in clients]
    # client_dfs = {idx: df[df["Image Index"].isin(client)].reset_index(drop=True) for idx, client in clients.items()}
    client_dfs = [df[df["Image Index"].isin(client)].reset_index(drop=True) for client in clients.values()]

    return tuple(client_dfs)

In [32]:
# _split = get_dirichlet_split(5, len(train_data_df), 4)
_split = get_dirichlet_split(8, len(train_data_df))
print(_split)

[2336 5093  856 7397  120  332  141  655]


In [33]:
_cls_dist = get_class_distribution(train_data_df)
_split_sizes = get_split(8, False, _cls_dist)

print(_split_sizes)

{'Normal': array([1020, 1019, 1019, 1019, 1019, 1019, 1019, 1019]), 'Lung_Opacity': array([602, 601, 601, 601, 601, 601, 601, 601]), 'COVID': array([362, 362, 362, 362, 361, 361, 361, 361]), 'Viral Pneumonia': array([135, 135, 135, 135, 134, 134, 134, 134])}


In [None]:
n_clients = 8

target_classes = {
    # 1: ["COVID", "Lung_Opacity"],
    # 2: ["Normal"]
    1: ["COVID"],
}

# Seeds:
# 2: 1651
# 4: ....
# 8: 

client_dfs = random_fl_split(n_clients, train_data_df, unbalanced=True, extreme=False, target_classes=target_classes, seed=1651)
# client_dfs = random_fl_split(n_clients, val_data_df, unbalanced=False, extreme=False, target_classes=target_classes, seed=1651)
# remaining_clients, to_swap = random_fl_split(4, train_data_df, unbalanced=True, extreme=True, target_classes=target_classes, seed=1651)

Unbalanced: True


In [38]:
for idx, cl_df in enumerate(client_dfs):
    cl_df.to_csv(f"../datasets/csv_splits/{n_clients}_clients/unbalanced_all_classes/CXR_covid_train_client_{idx + 1}.csv", index=False)
    # cl_df.to_csv(f"../datasets/csv_splits/{n_clients}_clients/CXR_covid_val_client_{idx + 1}.csv", index=False)

In [43]:
# Intersecting images
# len(set(client_dfs[2]["Image Index"].unique()).intersection(set(client_dfs[1]["Image Index"].unique())))

# Check intersection between clients
for i in range(n_clients):
    for j in range(i + 1, n_clients):
        print(f"Intersection between client {i} and client {j}: {len(set(client_dfs[i]['Image Index'].unique()).intersection(set(client_dfs[j]['Image Index'].unique())))}")

Intersection between client 0 and client 1: 0
Intersection between client 0 and client 2: 0
Intersection between client 0 and client 3: 0
Intersection between client 0 and client 4: 0
Intersection between client 0 and client 5: 0
Intersection between client 0 and client 6: 0
Intersection between client 0 and client 7: 0
Intersection between client 1 and client 2: 0
Intersection between client 1 and client 3: 0
Intersection between client 1 and client 4: 0
Intersection between client 1 and client 5: 0
Intersection between client 1 and client 6: 0
Intersection between client 1 and client 7: 0
Intersection between client 2 and client 3: 0
Intersection between client 2 and client 4: 0
Intersection between client 2 and client 5: 0
Intersection between client 2 and client 6: 0
Intersection between client 2 and client 7: 0
Intersection between client 3 and client 4: 0
Intersection between client 3 and client 5: 0
Intersection between client 3 and client 6: 0
Intersection between client 3 and 