1. Imports and Setup

In [2]:
# -------------------------------------------------------------
# 1. Imports and Setup
# -------------------------------------------------------------
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os, random, copy

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------------------------------------
# 2. ResNet50 Model Definition
# -------------------------------------------------------------
class ResNetModel(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(ResNetModel, self).__init__()
        if pretrained:
            weights = models.ResNet50_Weights.DEFAULT
        else:
            weights = None

        self.resnet = models.resnet50(weights=weights)

        # Optional: freeze low-level feature extractors
        for param in self.resnet.layer1.parameters():
            param.requires_grad = False
        for param in self.resnet.layer2.parameters():
            param.requires_grad = False

        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# -------------------------------------------------------------
# 3. Load and Split Dataset
# -------------------------------------------------------------
source_au = r"C:\Users\Admin\OneDrive\Desktop\ResearchTrack\CASIA_FL_Project\CASIA 1.0 dataset\CASIA 1.0 dataset\Au\Au"
source_tp = r"C:\Users\Admin\OneDrive\Desktop\ResearchTrack\CASIA_FL_Project\CASIA 1.0 dataset\CASIA 1.0 dataset\Modified Tp\Tp"

DATA_ROOT = r"C:\Users\Admin\OneDrive\Desktop\ResearchTrack\CASIA_FL_Project\data"
os.makedirs(os.path.join(DATA_ROOT, "Authentic"), exist_ok=True)
os.makedirs(os.path.join(DATA_ROOT, "Tampered"), exist_ok=True)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet input size
    transforms.ToTensor(),
])

full_dataset = datasets.ImageFolder(DATA_ROOT, transform=transform)

train_size = int(0.75 * len(full_dataset))
test_size = int(0.20 * len(full_dataset))
eval_size = len(full_dataset) - train_size - test_size
train_set, test_set, eval_set = random_split(full_dataset, [train_size, test_size, eval_size])

def split_data(dataset, num_clients=10):
    data_per_client = len(dataset) // num_clients
    client_data = []
    for i in range(num_clients):
        indices = list(range(i * data_per_client, (i + 1) * data_per_client))
        client_data.append(torch.utils.data.Subset(dataset, indices))
    return client_data

client_datasets = split_data(train_set)
testloader = DataLoader(test_set, batch_size=32, shuffle=False)

print(f"‚úÖ Total images: {len(full_dataset)}")
print(f"Train: {len(train_set)} | Test: {len(test_set)} | Eval: {len(eval_set)}")
print(f"Clients: {len(client_datasets)}")

# -------------------------------------------------------------
# 4. Client Update Function
# -------------------------------------------------------------
def client_update(model, dataset, epochs, batch_size, lr, testloader=None):
    model = copy.deepcopy(model)
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for _ in range(epochs):
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    if testloader:
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

        print(f"üìç Client Evaluation ‚Üí Acc: {acc:.4f}, Prec: {precision:.4f}, Rec: {recall:.4f}, F1: {f1:.4f}")

    return model.state_dict()

# -------------------------------------------------------------
# ‚úÖ FedYogi Server Optimizer
# -------------------------------------------------------------
def fed_yogi_update(global_model, client_models, delta_prev, v_prev,
                    beta1=0.9, beta2=0.99, eta=0.01, tau=1e-3):
    global_state = global_model.state_dict()
    delta = {}

    for key in global_state.keys():
        # Ensure consistent dtype/device
        base = global_state[key].to(torch.float32).to(DEVICE)
        delta[key] = torch.zeros_like(base)

        for client_model in client_models:
            delta[key] += client_model[key].to(torch.float32).to(DEVICE) - base

        delta[key] /= len(client_models)

    for key in delta:
        delta_prev[key] = beta1 * delta_prev[key].to(torch.float32).to(DEVICE) + (1 - beta1) * delta[key]

    for key in delta:
        v_prev[key] = v_prev[key].to(torch.float32).to(DEVICE)
        v_prev[key] -= (1 - beta2) * torch.sign(v_prev[key] - delta[key]**2) * delta[key]**2

    new_state = {}
    for key in global_state.keys():
        new_state[key] = global_state[key].to(torch.float32).to(DEVICE) + \
                         eta * delta_prev[key] / (torch.sqrt(v_prev[key]) + tau)

    global_model.load_state_dict(new_state)
    return global_model, delta_prev, v_prev

# -------------------------------------------------------------
# 6. Evaluation Function
# -------------------------------------------------------------
def evaluate_model(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    print(f"üîç Evaluation Metrics:")
    print(f"Accuracy  : {acc:.4f}")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1 Score  : {f1:.4f}")

# -------------------------------------------------------------
# 7. Federated Training Loop (FedYogi)
# -------------------------------------------------------------
num_clients = 10
clients_per_round = 2
rounds = 15
epochs = 3
batch_size = 32
learning_rate = 0.001

print(f"üìå Hyperparameters:")
print(f"Clients           : {num_clients}")
print(f"Clients per round : {clients_per_round}")
print(f"Rounds            : {rounds}")
print(f"Epochs per client : {epochs}")
print(f"Batch size        : {batch_size}")
print(f"Learning rate     : {learning_rate}")

global_model = ResNetModel(num_classes=2, pretrained=True).to(DEVICE)
delta_prev = {k: torch.zeros_like(v) for k, v in global_model.state_dict().items()}
v_prev = {k: torch.ones_like(v) * 1e-3 for k, v in global_model.state_dict().items()}

for t in range(rounds):
    selected_clients = random.sample(client_datasets, clients_per_round)
    client_models = []

    for client_data in selected_clients:
        updated_weights = client_update(global_model, client_data, epochs, batch_size, learning_rate)
        client_models.append(updated_weights)

    global_model, delta_prev, v_prev = fed_yogi_update(global_model, client_models, delta_prev, v_prev)
    print(f"\n‚úÖ Round {t+1} complete")
    evaluate_model(global_model, testloader)


‚úÖ Total images: 1721
Train: 1290 | Test: 344 | Eval: 87
Clients: 10
üìå Hyperparameters:
Clients           : 10
Clients per round : 2
Rounds            : 15
Epochs per client : 3
Batch size        : 32
Learning rate     : 0.001

‚úÖ Round 1 complete
üîç Evaluation Metrics:
Accuracy  : 0.5116
Precision : 0.5061
Recall    : 0.5060
F1 Score  : 0.5056

‚úÖ Round 2 complete
üîç Evaluation Metrics:
Accuracy  : 0.5087
Precision : 0.5029
Recall    : 0.5028
F1 Score  : 0.5023

‚úÖ Round 3 complete
üîç Evaluation Metrics:
Accuracy  : 0.5058
Precision : 0.4987
Recall    : 0.4987
F1 Score  : 0.4976

‚úÖ Round 4 complete
üîç Evaluation Metrics:
Accuracy  : 0.5087
Precision : 0.5005
Recall    : 0.5004
F1 Score  : 0.4985

‚úÖ Round 5 complete
üîç Evaluation Metrics:
Accuracy  : 0.5087
Precision : 0.4979
Recall    : 0.4981
F1 Score  : 0.4938

‚úÖ Round 6 complete
üîç Evaluation Metrics:
Accuracy  : 0.5116
Precision : 0.5008
Recall    : 0.5007
F1 Score  : 0.4963

‚úÖ Round 7 complete
üîç Eval