In [1]:
from PIL import Image
import os
from torch.utils.data import DataLoader, ConcatDataset, random_split, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
import torch
from collections import Counter

In [2]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, label, transform=None):
        self.data_dir = data_dir
        self.label = label
        self.transform = transform
        self.image_paths = [
            os.path.join(data_dir, fname) for fname in os.listdir(data_dir)
        ]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return {"img": image, "label": self.label}

In [3]:
num_partitions = 4

def load_data(batch_size: int):
    pytorch_transforms = Compose(
        [Resize((256, 256)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    normal_train_dataset = CustomDataset("C:/Users/14871/Downloads/data/train/Normal", label=0, transform=pytorch_transforms)
    tb_train_dataset = CustomDataset("C:/Users/14871/Downloads/data/train/Tuberculosis", label=1, transform=pytorch_transforms)
    pneumonia_train_dataset = CustomDataset("C:/Users/14871/Downloads/data/train/Pneumonia", label=2, transform=pytorch_transforms)
    full_train_dataset = ConcatDataset([normal_train_dataset, tb_train_dataset, pneumonia_train_dataset])

    normal_test_dataset = CustomDataset("C:/Users/14871/Downloads/data/test/Normal", label=0, transform=pytorch_transforms)
    tb_test_dataset = CustomDataset("C:/Users/14871/Downloads/data/test/Tuberculosis", label=1, transform=pytorch_transforms)
    pneumonia_test_dataset = CustomDataset("C:/Users/14871/Downloads/data/test/Pneumonia", label=2, transform=pytorch_transforms)
    test_dataset = ConcatDataset([normal_test_dataset, tb_test_dataset, pneumonia_test_dataset])

     # Split the training set into partitions for clients
    partition_size = len(full_train_dataset) // num_partitions
    partition_sizes = [partition_size] * num_partitions
    partition_sizes[-1] += len(full_train_dataset) % num_partitions  # Handle the remainder
    partitions = random_split(full_train_dataset, partition_sizes)
    
    # Retrieve the data partition for the given partition_id
    #partition_train_data = partitions[partition_id]
    
    # Split into training and validation
    #train_size = int(0.8 * len(partition_train_data))
    #val_size = len(partition_train_data) - train_size
    #partition_train, partition_val = random_split(partition_train_data, [train_size, val_size])
    
    # Create DataLoaders for train, validation, and test
    #trainloader = DataLoader(partition_train, batch_size=batch_size, shuffle=True)
    #valloader = DataLoader(partition_val, batch_size=batch_size)
    #testloader = DataLoader(test_dataset, batch_size=batch_size)
    
    #return trainloader, valloader, testloader
    return partitions, test_dataset

In [4]:
# load as pt
partitions, test_dataset = load_data(32)

for i, _ in enumerate(partitions):
    partition_train_data = partitions[i]
    
    # Split into training and validation
    train_size = int(0.8 * len(partition_train_data))
    val_size = len(partition_train_data) - train_size
    partition_train, partition_val = random_split(partition_train_data, [train_size, val_size])

    # Save the dataset as a pickle file
    train_dataset_name = "train" + str(i)+"dataset.pt"
    val_dataset_name = "val" + str(i)+"dataset.pt"
    torch.save(partition_train, train_dataset_name)
    torch.save(partition_val, val_dataset_name)
    torch.save(test_dataset, "testset.pt")

In [5]:
for i in range(4):
    # Load the saved dataset from a pickle file
    train_dataset_name = "train" + str(i)+"dataset.pt"
    val_dataset_name = "val" + str(i)+"dataset.pt"

    partition_train = torch.load(train_dataset_name)
    partition_val = torch.load(val_dataset_name)

    trainloader = DataLoader(partition_train, batch_size=32, shuffle=True)
    valloader = DataLoader(partition_val, batch_size=32)

    # count labels
    train_labels = [sample["label"] for sample in trainloader.dataset]
    print(i)
    #print(f"train_labels {train_labels}")
    val_labels = [sample["label"] for sample in valloader.dataset]
    #print(f"val_labels {val_labels}")

    train_counts = Counter(train_labels)
    print(f"train_counts {train_counts}")
    val_counts = Counter(val_labels)
    print(f"val_counts {val_counts}")
    

0
train_counts Counter({0: 1853, 2: 826, 1: 345})
val_counts Counter({0: 493, 2: 178, 1: 85})
1
train_counts Counter({0: 1856, 2: 808, 1: 360})
val_counts Counter({0: 447, 2: 211, 1: 98})
2
train_counts Counter({0: 1798, 2: 856, 1: 370})
val_counts Counter({0: 450, 2: 218, 1: 88})
3
train_counts Counter({0: 1825, 2: 835, 1: 364})
val_counts Counter({0: 466, 2: 213, 1: 78})
