In [1]:
import os
import copy
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from collections import defaultdict
import kagglehub
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

In [2]:
path = kagglehub.dataset_download("immada/casia-fasd")
print("Data synchronized at:", path)

Data synchronized at: /kaggle/input/datasets/immada/casia-fasd


In [3]:
import os

for root, dirs, files in os.walk(path):
    if 'train' in dirs:
        print(f" training path is: {os.path.join(root, 'train')}")
        actual_train_path = os.path.join(root, 'train')
        break

 training path is: /kaggle/input/datasets/immada/casia-fasd/casia-fasd/train


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_PATH = "/kaggle/input/datasets/immada/casia-fasd/casia-fasd/train"

NUM_CLIENTS = 10
LOCAL_EPOCHS = 50
GLOBAL_ROUNDS = 3
BATCH_SIZE = 32
LR = 0.00005

In [5]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Smaller for lightweight CNN
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
class ClientDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label

In [7]:
import random

def prepare_clients(balance=True, seed=42):
    """
    Prepare client data with optional balancing
    
    Args:
        balance: If True, randomly sample spoof images to match V1 count per client
        seed: Random seed for reproducibility
    """
    random.seed(seed)
    clients = defaultdict(lambda: {"paths": [], "labels": []})

    live_dir = os.path.join(DATA_PATH, "live")
    spoof_dir = os.path.join(DATA_PATH, "spoof")

    # -------- LIVE (only v1) ----------
    for file in os.listdir(live_dir):
        if "v1" in file:  # only high quality
            match = re.search(r'bs(\d+)v', file)
            if match:
                client_id = int(match.group(1))
                full_path = os.path.join(live_dir, file)
                clients[client_id]["paths"].append(full_path)
                clients[client_id]["labels"].append(0)  # live = 0

    # -------- SPOOF (all) ----------
    # First, collect all spoof images per client
    spoof_per_client = defaultdict(list)
    
    for file in os.listdir(spoof_dir):
        match = re.search(r's(\d+)v', file)
        if match:
            client_id = int(match.group(1))
            full_path = os.path.join(spoof_dir, file)
            spoof_per_client[client_id].append(full_path)
    
    # Now add spoof images (balanced or all)
    for client_id in range(1, NUM_CLIENTS + 1):
        spoof_paths = spoof_per_client[client_id]
        
        if balance and len(spoof_paths) > 0:
            # Get number of V1 images for this client
            num_v1 = len([l for l in clients[client_id]["labels"] if l == 0])
            
            # Randomly sample spoof images to match V1 count
            if len(spoof_paths) >= num_v1:
                selected_spoof = random.sample(spoof_paths, num_v1)
            else:
                # If we have fewer spoof images than V1, use all spoof images
                selected_spoof = spoof_paths
        else:
            # Use all spoof images (unbalanced)
            selected_spoof = spoof_paths
        
        # Add selected spoof images
        for path in selected_spoof:
            clients[client_id]["paths"].append(path)
            clients[client_id]["labels"].append(1)  # spoof = 1

    return clients

In [8]:
clients_data = prepare_clients()

# Count total V1 and Spoof images
total_v1 = 0
total_spoof = 0

for client_id in range(1, NUM_CLIENTS + 1):
    labels = clients_data[client_id]["labels"]
    v1_count = labels.count(0)  # live = 0
    spoof_count = labels.count(1)  # spoof = 1
    
    total_v1 += v1_count
    total_spoof += spoof_count
    
    print(f"Client {client_id:2d}: V1={v1_count:3d}, Spoof={spoof_count:3d}")

print("\n" + "="*50)
print(f"TOTAL V1 (High-Quality Live) Images: {total_v1}")
print(f"TOTAL Spoof Images: {total_spoof}")
print(f"TOTAL Training Images: {total_v1 + total_spoof}")
print("="*50)

Client  1: V1=200, Spoof=200
Client  2: V1=145, Spoof=145
Client  3: V1=167, Spoof=167
Client  4: V1=237, Spoof=237
Client  5: V1=146, Spoof=146
Client  6: V1=147, Spoof=147
Client  7: V1=145, Spoof=145
Client  8: V1=211, Spoof=211
Client  9: V1=185, Spoof=185
Client 10: V1=102, Spoof=102

TOTAL V1 (High-Quality Live) Images: 1685
TOTAL Spoof Images: 1685
TOTAL Training Images: 3370


In [9]:
client_loaders = {}

for client_id in range(1, NUM_CLIENTS + 1):
    dataset = ClientDataset(
        clients_data[client_id]["paths"],
        clients_data[client_id]["labels"],
        transform
    )
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    client_loaders[client_id] = loader

print("All clients prepared.")


All clients prepared.


In [10]:
class SmallCNN(nn.Module):
    """Lightweight CNN for face anti-spoofing"""
    def __init__(self, num_classes=2):
        super(SmallCNN, self).__init__()
        
        # Feature extraction layers
        self.features = nn.Sequential(
            # Conv Block 1: 3 -> 32
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 128->64
            
            # Conv Block 2: 32 -> 64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64->32
            
            # Conv Block 3: 64 -> 128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32->16
            
            # Conv Block 4: 128 -> 256
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16->8
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Initialize model
global_model = SmallCNN(num_classes=2).to(device)
print("Model Architecture:")
print(global_model)

# Count parameters
total_params = sum(p.numel() for p in global_model.parameters())
trainable_params = sum(p.numel() for p in global_model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


Model Architecture:
SmallCNN(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(256, eps=1e-0

In [11]:
def train_client(model, loader, epochs, lr):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

    return model.state_dict()

In [12]:
def federated_averaging(global_model, client_updates):
    global_dict = global_model.state_dict()

    for key in global_dict.keys():
        global_dict[key] = torch.stack([client_updates[i][key].float() for i in range(len(client_updates))], 0).mean(0)

    global_model.load_state_dict(global_dict)
    return global_model

In [13]:
from tqdm import tqdm

for round_num in range(GLOBAL_ROUNDS):
    print(f"\n========== ROUND {round_num + 1}/{GLOBAL_ROUNDS} ==========")
    
    client_updates = []

    for client_id in tqdm(range(1, NUM_CLIENTS + 1), desc=f"Training clients (Round {round_num + 1})"):
        client_model = copy.deepcopy(global_model)
        
        updated_state = train_client(
            client_model,
            client_loaders[client_id],
            LOCAL_EPOCHS,
            LR
        )
        
        client_updates.append(updated_state)

    global_model = federated_averaging(global_model, client_updates)
    print(f"Round {round_num + 1} complete. Global model updated.")

print("\n✓ Federated learning complete!")




Training clients (Round 1): 100%|██████████| 10/10 [06:56<00:00, 41.65s/it]


Round 1 complete. Global model updated.



Training clients (Round 2): 100%|██████████| 10/10 [06:37<00:00, 39.73s/it]


Round 2 complete. Global model updated.



Training clients (Round 3): 100%|██████████| 10/10 [06:40<00:00, 40.08s/it]

Round 3 complete. Global model updated.

✓ Federated learning complete!





In [14]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import numpy as np


In [15]:
class TestDataset(Dataset):
    def __init__(self, image_paths, labels, is_v2_flags, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.is_v2_flags = is_v2_flags
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        is_v2 = torch.tensor(self.is_v2_flags[idx], dtype=torch.long)
        return img, label, is_v2

In [16]:
def prepare_test_data():
    TEST_PATH = "/kaggle/input/datasets/immada/casia-fasd/casia-fasd/test"
    
    live_dir = os.path.join(TEST_PATH, "live")
    spoof_dir = os.path.join(TEST_PATH, "spoof")

    image_paths = []
    labels = []
    is_v2_flags = []

    # -------- LIVE ----------
    for file in os.listdir(live_dir):
        full_path = os.path.join(live_dir, file)

        image_paths.append(full_path)
        labels.append(0)  # real

        if "v2" in file:
            is_v2_flags.append(1)
        else:
            is_v2_flags.append(0)

    # -------- SPOOF ----------
    for file in os.listdir(spoof_dir):
        full_path = os.path.join(spoof_dir, file)

        image_paths.append(full_path)
        labels.append(1)  # spoof
        is_v2_flags.append(0)  # not v2 real

    return image_paths, labels, is_v2_flags

In [17]:
image_paths, labels, is_v2_flags = prepare_test_data()

test_dataset = TestDataset(
    image_paths,
    labels,
    is_v2_flags,
    transform
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Test dataset prepared.")

Test dataset prepared.


In [18]:
from tqdm import tqdm

def evaluate_model(model, loader):
    model.eval()

    all_preds = []
    all_labels = []
    all_v2_flags = []

    total_batches = len(loader)

    with torch.no_grad():
        for images, labels, v2_flags in tqdm(loader, total=total_batches, desc="Evaluating"):
            images = images.to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
            all_v2_flags.extend(v2_flags.numpy())

    return np.array(all_preds), np.array(all_labels), np.array(all_v2_flags)

In [19]:
preds, true_labels, v2_flags = evaluate_model(global_model, test_loader)


Evaluating: 100%|██████████| 2056/2056 [08:45<00:00,  3.91it/s]


In [20]:
print("=== OVERALL TEST PERFORMANCE ===")

print("Accuracy:", accuracy_score(true_labels, preds))
print("Precision:", precision_score(true_labels, preds))
print("Recall:", recall_score(true_labels, preds))
print("F1 Score:", f1_score(true_labels, preds))

print("\nConfusion Matrix:")
print(confusion_matrix(true_labels, preds))

print("\nClassification Report:")
print(classification_report(true_labels, preds, target_names=["Real", "Spoof"]))


=== OVERALL TEST PERFORMANCE ===
Accuracy: 0.8588453470343234
Precision: 0.8673786284700216
Recall: 0.9835423479104531
F1 Score: 0.9218152732171424

Confusion Matrix:
[[ 1758  8370]
 [  916 54742]]

Classification Report:
              precision    recall  f1-score   support

        Real       0.66      0.17      0.27     10128
       Spoof       0.87      0.98      0.92     55658

    accuracy                           0.86     65786
   macro avg       0.76      0.58      0.60     65786
weighted avg       0.84      0.86      0.82     65786

