In [1]:
#----------------------
# Import Statements
#----------------------
import os
import json
import hashlib
import random
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
import imageio
from skimage import restoration, img_as_float
from skimage.color import rgba2rgb
from PIL import Image

# Torch stuff (for model part)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, roc_auc_score

In [2]:
#------------------------
# Variables Declarations
#------------------------
medmnist_dir = "../Dataset_Derma/data_raw"                     # Directory for raw DermaMNSIT Data
preprocess_dir = "../Dataset_Derma/data_preprocessed"          # Base Directory for preprocessed Data
clean_output_dir = os.path.join(preprocess_dir, "dermanist_clean") # Directory to Store Clean Clients data and labels
poisoned_output_dir = os.path.join(preprocess_dir, "dermanist_poison_flip")  # Directory to Store Poisoned Clients data and labels

# To check if the directories exist in the system
os.makedirs(medmnist_dir, exist_ok=True)
os.makedirs(preprocess_dir, exist_ok=True)

# Processing / client partition options
NUM_CLIENTS = 20          # No.of Federated Clients
RANDOM_SEED = 42           
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Poisoning config
POISONED_CLIENT_NUM = 4           # No.of Clients to Corrupt as per our Group no:9       
FLIPPING_RATES = [0.1, 0.25, 0.5, 0.75]  # mapped to poisoned clients

# -------------------------
# Training Config
# -------------------------
NUM_CLASSES = 7   # No.of CLasses as per MEDMNIST Site
BATCH_SIZE = 64   # Batch Size for Dataloading
IMAGE_SIZE = 28   # Image Dimensions (28x28 pixels)
EPOCHS = 10       # No.of Epochs
LR = 0.001             # Learning Rate (Fine-tune)
DEVICE = "cpu"


#File to save Pretrained initial model
INITIAL_MODEL_PRETRAINED = "initial_resnet18_pretrained.pth"

In [3]:
# Uncomment the below line if this is the first time and didn't install medmnist dataset
# %pip install medmnist

In [4]:
# -------------------------
# Section B: Load medmnist DermaMNIST Dataset
# -------------------------

# Attempt to import the MedMNIST (DermaMNIST in Particular)
try:
    import medmnist
    from medmnist import DermaMNIST
except Exception as e:
    raise RuntimeError("Please install medmnist: pip install medmnist by uncommenting the cell before") from e

# ensure raw data dir exists
os.makedirs(medmnist_dir, exist_ok=True)

# Load Train, Validation and test Splits from DermaMNIST
derma_train = DermaMNIST(root=medmnist_dir, split="train", download=True, size=IMAGE_SIZE)
derma_val   = DermaMNIST(root=medmnist_dir, split="val", download=True, size=IMAGE_SIZE)
derma_test  = DermaMNIST(root=medmnist_dir, split="test", download=True, size=IMAGE_SIZE)

# Printing those samples count/size
print("Train samples:", len(derma_train))
print("Val samples:  ", len(derma_val))
print("Test samples: ", len(derma_test))

100%|██████████| 19.7M/19.7M [00:06<00:00, 3.17MB/s]


Train samples: 7007
Val samples:   1003
Test samples:  2005


In [5]:
# -------------------------
# 1) Data Cleaning:
#    Preprocess: deduplicate + denoise + channel normalize
# -------------------------

# Compute SHA-1 hash of image to detect duplicates
def sha1_of_image_uint8(img_uint8: np.ndarray) -> str:
    return hashlib.sha1(img_uint8.tobytes()).hexdigest()

seen_hash = set()
processed: List[Dict] = []
count=0

# Loop through train dataset
for i in range(len(derma_train)):
    img_pil, label = derma_train[i]
    label = int(label.item())
    img = np.array(img_pil)  # PIL -> numpy

    # ensure 3 channels (uint8)
    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)
    elif img.shape[2] == 4:
        # rgba -> rgb (returns float in [0,1] for rgba2rgb if input normalized; but it handles uint8 as well)
        img = (rgba2rgb(img / 255.0) * 255).astype(np.uint8)

    img_float = img_as_float(img)  # Normalize to float [0,1]
    img_uint8 = (np.clip(img_float, 0, 1) * 255).astype(np.uint8)

    # check for Duplicates
    img_hash = sha1_of_image_uint8(img_uint8)
    if img_hash in seen_hash:
        count+=1
        print("Duplicates Removed: ",count)   #To check how many Duplicates are removed
        continue
    seen_hash.add(img_hash)

    # denoise (fallback to original if failed)
    try:
        denoise = restoration.denoise_bilateral(img_float, channel_axis=-1)
    except Exception:
        denoise = img_float

    denoise_uint8 = (np.clip(denoise, 0, 1) * 255).astype(np.uint8)

    processed.append({
        "image": denoise_uint8,
        "label": label,
        "orig_index": i
    })

print(f"Kept {len(processed)} unique & denoised images from train split")

Duplicates Removed:  1
Kept 7006 unique & denoised images from train split


In [None]:
# -------------------------
# 2 & 3) Data Normalization and Augmentation
#     Define transforms
#    - save_transform for saving augmented PNGs
#    - model transforms for training/eval
# -------------------------
from torchvision.transforms import functional as TF

# Augmentation transforms for saving images
save_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.2,0.2,0.2)], p=0.5),
    transforms.ToTensor()
])

# Training transform (resize, normalize)
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Validation or Test transform (resize, normalize)
eval_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# -------- For pretrained ResNet-18 (ImageNet normalization) --------
pretrain_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

pretrain_eval_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])



In [7]:
# -------------------------
# 4) Data Partitioning: Partition into clients (non-IID via Dirichlet)
# -------------------------
labels = np.array([r["label"] for r in processed])
num_classes_in_data = labels.max() + 1
assert num_classes_in_data == NUM_CLASSES or NUM_CLASSES == 7, "Verify NUM_CLASSES"

# create index list by label
idx_by_label = [np.where(labels == class_id)[0].tolist() for class_id in range(num_classes_in_data)]

client_index = [[] for _ in range(NUM_CLIENTS)]

# Distribute samples per class across clients using Dirichlet distribution
for class_id, idxs in enumerate(idx_by_label):
    if len(idxs) == 0:
        continue
    np.random.shuffle(idxs)
    proportions = np.random.dirichlet([0.5] * NUM_CLIENTS)
    splits = (np.cumsum(proportions) * len(idxs)).astype(int)
    prev = 0
    for cid, cut in enumerate(splits):
        client_index[cid].extend(idxs[prev:cut])
        prev = cut
    client_index[-1].extend(idxs[prev:])

# Build per-client dataset dictionary
client_data = {cid: [processed[i] for i in client_index[cid]] for cid in range(NUM_CLIENTS)}
print("Client sample counts:", [len(v) for v in client_data.values()])
print("Total samples:", sum(len(v) for v in client_data.values()))


Client sample counts: [38, 163, 865, 207, 140, 810, 305, 1351, 102, 357, 194, 86, 168, 101, 1166, 177, 85, 137, 193, 361]
Total samples: 7006


In [8]:
# -------------------------
# 5) Data Labelling: Save CLEAN per-client datasets (augmented images saved)
# -------------------------
os.makedirs(clean_output_dir, exist_ok=True)
total_images = 0

for cid, recs in client_data.items():
    client_dir = os.path.join(clean_output_dir, f"Client_{cid}")
    img_dir = os.path.join(client_dir, "images")
    os.makedirs(img_dir, exist_ok=True)
    records = []

    for idx, rec in enumerate(recs):
        img = rec["image"]  # uint8 HxWx3
        label = rec["label"]

        # Apply augmentation pipeline before saving (use save_transform which expects HxWxC uint8)
        # It returns tensor; convert back to uint8 HWC for PNG.
        img_tensor = save_transform(img)  # C x H x W, float [0,1]
        img_aug = (img_tensor.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)

        filename = f"{idx}.png"
        filepath = os.path.join(img_dir, filename)
        imageio.imwrite(filepath, img_aug)
        records.append({"filename": filename, "label": label})
        total_images += 1

    # Save CSV with Labels
    df = pd.DataFrame(records)
    df.to_csv(os.path.join(client_dir, "labels.csv"), index=False)

print(f"Saved clean client datasets as PNG images. Total images: {total_images}")

Saved clean client datasets as PNG images. Total images: 7006


In [9]:
# -------------------------
# 6) Data Corrpution: Create POISONED datasets (label flipping)
# -------------------------
os.makedirs(poisoned_output_dir, exist_ok=True)
rng = np.random.RandomState(RANDOM_SEED + 9)  

# Select CLients for poisioning
poisoned_client_ids = rng.choice(NUM_CLIENTS, POISONED_CLIENT_NUM, replace=False).tolist()
client_flip_map = dict(zip(poisoned_client_ids, FLIPPING_RATES[:POISONED_CLIENT_NUM]))

client_data_corrupted = {}
corruption_summary = {}

# Apply Label Flipping
for cid, recs in client_data.items():
    recs_copy = [dict(r) for r in recs]  # Copy Records
    flipped_count = 0
    total = len(recs_copy)
    if cid in client_flip_map:
        flip = client_flip_map[cid]
        n = len(recs_copy)
        num_to_flip = int(round(flip * n))
        if num_to_flip > 0:
            flip_idx = rng.choice(n, num_to_flip, replace=False)
            for idx in flip_idx:
                orig = recs_copy[idx]["label"]
                choices = [l for l in range(NUM_CLASSES) if l != orig]
                new_label = int(rng.choice(choices))   # Random Wrong Lbel
                recs_copy[idx]["label"] = new_label
                flipped_count += 1

    client_data_corrupted[cid] = recs_copy
    corruption_summary[cid] = {
        "total": total,
        "flipped": flipped_count,
        "percent": 100 * flipped_count / total if total > 0 else 0
    }

# Save poisoned clients images and labels
total_poisoned_images = 0
for cid, recs in client_data_corrupted.items():
    if not recs:
        continue
    client_dir = os.path.join(poisoned_output_dir, f"Client_{cid}")
    img_dir = os.path.join(client_dir, "images")
    os.makedirs(img_dir, exist_ok=True)
    records = []
    for idx, rec in enumerate(recs):
        img = rec["image"]
        label = rec["label"]
        filename = f"{idx}.png"
        filepath = os.path.join(img_dir, filename)
        imageio.imwrite(filepath, img)
        records.append({"filename": filename, "label": label})
        total_poisoned_images += 1
    pd.DataFrame(records).to_csv(os.path.join(client_dir, "labels.csv"), index=False)

meta = {
    "poisoned_client_ids": poisoned_client_ids,
    "client_flip_map": client_flip_map,
    "seed": RANDOM_SEED,
    "corruption_summary": corruption_summary
}
with open(os.path.join(poisoned_output_dir, "meta.json"), "w") as f:
    json.dump(meta, f, indent=2)

print(f"Saved poisoned datasets. Total poisoned images (all clients): {total_poisoned_images}")
print("Poisoned clients and flip map:", client_flip_map)

Saved poisoned datasets. Total poisoned images (all clients): 7006
Poisoned clients and flip map: {1: 0.1, 12: 0.25, 15: 0.5, 2: 0.75}


In [10]:
# =========================================================
# Section C: Using a Pre-trained Model (ResNet-18)
# =========================================================

# -------------------------
# Dataset class
# -------------------------
class SimpleImageDataset(Dataset):
    '''Reads Images and Labels from CSV'''
    def __init__(self, csv_file_or_df, root_dir=None, transform=None):
        if isinstance(csv_file_or_df, pd.DataFrame):
            self.df = csv_file_or_df.reset_index(drop=True)
        else:
            self.df = pd.read_csv(csv_file_or_df)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = str(row["filename"])
        if os.path.isabs(filename):
            img_path = filename
        else:
            img_path = os.path.join(self.root_dir, filename) if self.root_dir else filename
        img = Image.open(img_path).convert("RGB")
        label = int(row["label"])
        if self.transform:
            img = self.transform(img)
        return img, label

# -------------------------
# Step 1: Select Pre-trained Model
# -------------------------
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# -------------------------
# Step 2: Adapting the Model
# -------------------------
# Adjusted conv1 and remove maxpool to fit small 28x28 inputs
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
# Replaced final FC layer
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)

# -------------------------
# Step 3: Fine-tuning
# -------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)

# Build datasets using PRETRAIN transforms
# Collect training samples from all clients (On Centralised Dataset)
train_rows = []
for client_name in os.listdir(clean_output_dir):
    client_dir = os.path.join(clean_output_dir, client_name)
    labels_csv = os.path.join(client_dir, "labels.csv")
    img_dir = os.path.join(client_dir, "images")
    if os.path.exists(labels_csv):
        df = pd.read_csv(labels_csv)
        for _, r in df.iterrows():
            abs_path = os.path.abspath(os.path.join(img_dir, r["filename"]))
            train_rows.append({"filename": abs_path, "label": int(r["label"])})

train_df = pd.DataFrame(train_rows)
print("Total centralized training samples:", len(train_df))
if len(train_df) == 0:
    raise RuntimeError("No centralized training samples found in clean_output_dir")

# Use centralized dataset + MedMNIST val/test
train_dataset = SimpleImageDataset(train_df, transform=pretrain_train_transform)
val_dataset   = DermaMNIST(root=medmnist_dir, split="val",  download=True, transform=pretrain_eval_transform)
test_dataset  = DermaMNIST(root=medmnist_dir, split="test", download=True, transform=pretrain_eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# -------------------------
# Training Loop
# -------------------------
best_val_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    running_loss, preds_train, labels_train = 0.0, [], []

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds_train.extend(outputs.argmax(1).cpu().numpy())
        labels_train.extend(labels.cpu().numpy())

    train_loss = running_loss / len(train_loader.dataset)
    train_acc = accuracy_score(labels_train, preds_train)

    # Validation
    model.eval()
    y_true, y_probs = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            y_probs.extend(probs)
            y_true.extend(labels.cpu().numpy())

    y_true = np.array(y_true)
    y_probs = np.array(y_probs)
    preds = y_probs.argmax(axis=1)
    val_acc = accuracy_score(y_true, preds)
    try:
        val_auc = roc_auc_score(y_true, y_probs, multi_class="ovr")
    except:
        val_auc = float("nan")

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {train_loss:.4f} "
          f"| Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), INITIAL_MODEL_PRETRAINED)
        print(f"Saved best model at epoch {epoch+1} (Val Acc: {val_acc:.4f})")

# -------------------------
# Final Test Evaluation
# -------------------------
model.load_state_dict(torch.load(INITIAL_MODEL_PRETRAINED))
model.eval()
y_true, y_probs = [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()
        y_probs.extend(probs)
        y_true.extend(labels.cpu().numpy())

y_true = np.array(y_true)
y_probs = np.array(y_probs)
preds = y_probs.argmax(axis=1)
test_acc = accuracy_score(y_true, preds)
try:
    test_auc = roc_auc_score(y_true, y_probs, multi_class="ovr")
except:
    test_auc = float("nan")

print(f"\n Final Test Results (Pretrained) → ACC: {test_acc:.4f}, AUC: {test_auc:.4f}")


Total centralized training samples: 7006
Train: 7006 | Val: 1003 | Test: 2005
Epoch 1/10 | Loss: 0.9933 | Train Acc: 0.6643 | Val Acc: 0.7009 | Val AUC: 0.8443
Saved best model at epoch 1 (Val Acc: 0.7009)
Epoch 2/10 | Loss: 0.8467 | Train Acc: 0.6984 | Val Acc: 0.7089 | Val AUC: 0.8699
Saved best model at epoch 2 (Val Acc: 0.7089)
Epoch 3/10 | Loss: 0.7936 | Train Acc: 0.7140 | Val Acc: 0.7408 | Val AUC: 0.8955
Saved best model at epoch 3 (Val Acc: 0.7408)
Epoch 4/10 | Loss: 0.7526 | Train Acc: 0.7235 | Val Acc: 0.7109 | Val AUC: 0.8956
Epoch 5/10 | Loss: 0.7177 | Train Acc: 0.7368 | Val Acc: 0.6959 | Val AUC: 0.8995
Epoch 6/10 | Loss: 0.7005 | Train Acc: 0.7398 | Val Acc: 0.6810 | Val AUC: 0.9020
Epoch 7/10 | Loss: 0.6786 | Train Acc: 0.7439 | Val Acc: 0.7168 | Val AUC: 0.9005
Epoch 8/10 | Loss: 0.6631 | Train Acc: 0.7564 | Val Acc: 0.6650 | Val AUC: 0.9081
Epoch 9/10 | Loss: 0.6440 | Train Acc: 0.7581 | Val Acc: 0.6550 | Val AUC: 0.8963
Epoch 10/10 | Loss: 0.6358 | Train Acc: 0.7608