In [14]:
from os import makedirs, scandir
from os.path import isdir, join
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

from src.utils.utils import plot_disease_distribution

In [3]:
train_data_df_path = "../datasets/csv_splits/COVID-19_Radiography_Dataset_train.csv"
val_data_df_path = "../datasets/csv_splits/COVID-19_Radiography_Dataset_val.csv"

train_data_df = pd.read_csv(train_data_df_path)
val_data_df = pd.read_csv(val_data_df_path)

In [4]:
train_data_df.head()

Unnamed: 0,Image Index,Finding Labels,Path,COVID,Lung_Opacity,Normal,Viral Pneumonia
0,COVID-1.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
1,COVID-1000.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
2,COVID-1001.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
3,COVID-1002.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0
4,COVID-1004.png,COVID,COVID-19_Radiography_Dataset/COVID,1,0,0,0


In [161]:
def get_dirichlet_split(k: int, n: int) -> np.ndarray:
    """Get a random split of size `n` into `k` parts using a Dirichlet distribution.

    Args:
    -----
        k (int): Number of parts to split the elements into (clients).
        n (int): Number of elements to split (total number of images).

    Returns:
    --------
        np.ndarray: Array containing the split sizes.
    """
    split_sizes = np.random.dirichlet(np.ones(k), size=1)[0]

    # Scale the split sizes to the number of elements adjust the last split size to make sure the sum is equal to n
    split_sizes = np.round(split_sizes * n).astype(int)
    split_sizes[-1] = n - split_sizes[:-1].sum()
    return split_sizes


def get_split(n_splits: int, unbalanced: bool, class_distribution: dict) -> np.ndarray:
    """Get the split sizes for the clients.

    Args:
    -----
        n_splits (int): Number of clients to split the data into.
        unbalanced (bool): Whether to split the data into unbalanced clients.
        class_distribution (dict): Dictionary containing the class names and their counts.

    Returns:
    --------
        np.ndarray: Array containing the split sizes.
    """
    if unbalanced:
        split_sizes = {cls: 0 for cls in class_distribution.keys()}
        for cls in class_distribution.keys():
            split_sizes[cls] = get_dirichlet_split(n_splits, class_distribution[cls])
    else:
        split_sizes = {cls: np.full(n_splits, class_distribution[cls] // n_splits) for cls in class_distribution.keys()}
        for cls in class_distribution.keys():
            split_sizes[cls][:class_distribution[cls] % n_splits] += 1

    return split_sizes


def split_targets(reamaining_clients: list, removed_images: dict, df: pd.DataFrame) -> dict:
    """Split the removed images into the remaining clients.

    Args:
    -----
        reamaining_clients (list): List of clients to assign the removed images to.
        removed_images (dict): Dictionary containing the removed images for each client.

    Returns:
    --------
        dict: Dictionary containing the split images for each client.
    """
    splits = {idx: {} for idx in removed_images.keys()}
    for idx, imgs in removed_images.items():
        classes = df[df["Image Index"].isin(imgs)]["Finding Labels"].str.split("|").explode().unique()

        # Split classes into n remaining clients
        for cls in classes:
            filtered = df[df["Image Index"].isin(imgs)]
            filtered = filtered[filtered["Finding Labels"].str.contains(cls)]
            split_sizes = get_split(len(reamaining_clients), len(filtered), False)
            split_clients = np.split(filtered["Image Index"].values, np.cumsum(split_sizes)[:-1])

            for i, client in enumerate(reamaining_clients):
                splits[idx][client] = split_clients[i]
                
    return splits


def get_class_distribution(df: pd.DataFrame) -> Dict[str, int]:
    """Get the class distribution of the dataset.

    Args:
    -----
        df (pd.DataFrame): DataFrame containing the data with "Finding Labels".

    Returns:
    --------
        Dict[str, int]: Dictionary containing the class distribution.
    """
    class_distribution = df["Finding Labels"].str.split("|").explode().value_counts().to_dict()
    return class_distribution


def random_fl_split(
    n_splits: int,
    df: pd.DataFrame,
    unbalanced: bool = False,
    target_classes: Dict[int, Dict[int, List[str]]] = None,
    seed: int = 42,
) -> Tuple[pd.DataFrame]:
    """
    Splits the dataset into `n_splits` clients using random assignment with optional unbalancing.

    Args:
        n_splits (int): Number of clients to split the data into.
        df (pd.DataFrame): DataFrame containing the data with "Image Index" and "Finding Labels".
        unbalanced (bool): If True, creates unbalanced splits.
        extreme (bool): If True, applies extreme unbalancing based on `target_clients` and `target_classes`.
        target_classes (Dict[int, Dict[int, List[str]]]): Dictionary containing the target classes to be removed from each client.
        seed (int): Seed for reproducibility.

    Returns:
        Tuple[pd.DataFrame]: Tuple containing the DataFrames for each client.
    """
    np.random.seed(seed)
    
    if isinstance(target_classes, str):
        target_classes = [target_classes]
    
    cls_dist = get_class_distribution(df)
    split_sizes = get_split(n_splits, unbalanced, class_distribution=cls_dist)
    
    clients = {idx + 1: [] for idx in range(n_splits)}
    for cls, sizes in split_sizes.items():
        images = df[df["Finding Labels"].str.contains(cls)]["Image Index"].values
        for i, size in enumerate(sizes):
            clients[i + 1].extend(np.random.choice(images, size, replace=False))
            images = np.setdiff1d(images, clients[i + 1])

    if target_classes:
        for target_client, origin_clients in target_classes.items():
            for origin_client, classes in origin_clients.items():
                for cls in classes:
                    filtered = filter(lambda x: cls in x, clients[origin_client])
                    imgs = [img for img in filtered]
                    clients[target_client].extend(imgs)
                    clients[origin_client] = np.setdiff1d(clients[origin_client], imgs).tolist()
                    
    client_dfs = [df[df["Image Index"].isin(client)].reset_index(drop=True) for client in clients.values()]

    return tuple(client_dfs)

In [None]:
n_clients = 2

target_classes = {
    1: {
        2: ["COVID", "Viral Pneumonia"],
        3: ["COVID"],
        4: ["COVID"],
    },
    2: {
        3: ["Lung_Opacity"],
        4: ["Lung_Opacity"],
    },
    3: {
        2: ["Normal"],
        4: ["Normal"],
    },
    4: {
        # 2: ["Viral Pneumonia"],
        3: ["Viral Pneumonia"],
    },
}

# Seeds: 42, 1651

client_dfs = random_fl_split(n_clients, train_data_df, unbalanced=True, target_classes=target_classes, seed=864)

Unbalanced: True


In [190]:
client_dfs[0]["Finding Labels"].value_counts()

Finding Labels
COVID              2892
Normal             1378
Lung_Opacity        241
Viral Pneumonia     116
Name: count, dtype: int64

In [191]:
client_dfs[1]["Finding Labels"].value_counts()

Finding Labels
Lung_Opacity    4568
Name: count, dtype: int64

In [192]:
for idx, cl_df in enumerate(client_dfs):
    save_dir = f"../datasets/csv_splits/{n_clients}_clients/unbalanced_missing_classes/one_per_client_except_1"
    if not isdir(save_dir):
        print(f"Creating directory: {save_dir}")
        makedirs(save_dir)

    cl_df.to_csv(f"{save_dir}/CXR_covid_train_client_{idx + 1}.csv", index=False)
    # cl_df.to_csv(f"../datasets/csv_splits/{n_clients}_clients/CXR_covid_val_client_{idx + 1}.csv", index=False)

In [176]:
# Intersecting images
# len(set(client_dfs[2]["Image Index"].unique()).intersection(set(client_dfs[1]["Image Index"].unique())))

# Check intersection between clients
for i in range(n_clients):
    for j in range(i + 1, n_clients):
        print(f"Intersection between client {i} and client {j}: {len(set(client_dfs[i]['Image Index'].unique()).intersection(set(client_dfs[j]['Image Index'].unique())))}")

Intersection between client 0 and client 1: 0
Intersection between client 0 and client 2: 0
Intersection between client 0 and client 3: 0
Intersection between client 1 and client 2: 0
Intersection between client 1 and client 3: 0
Intersection between client 2 and client 3: 0
