In [1]:
# Cell 1: Imports and Setup
import os
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import copy
import re
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    confusion_matrix, classification_report
)

print("✓ Imports successful")

✓ Imports successful


In [2]:
# Cell 2: Configuration
DATA_PATH = "/kaggle/input/datasets/immada/casia-fasd/casia-fasd/train"
TEST_PATH = "/kaggle/input/datasets/immada/casia-fasd/casia-fasd/test"

NUM_CLIENTS = 10
GLOBAL_ROUNDS = 3

print(f"Configuration:")
print(f"  Clients: {NUM_CLIENTS}")
print(f"  Global Rounds: {GLOBAL_ROUNDS}")
print(f"  Train Path: {DATA_PATH}")
print(f"  Test Path: {TEST_PATH}")

Configuration:
  Clients: 10
  Global Rounds: 3
  Train Path: /kaggle/input/datasets/immada/casia-fasd/casia-fasd/train
  Test Path: /kaggle/input/datasets/immada/casia-fasd/casia-fasd/test


In [None]:
# Cell 3: Enhanced FFT Feature Extraction with KAN Insights
def compute_radial_profile(magnitude_spectrum):
    """Compute radial average of frequency spectrum"""
    h, w = magnitude_spectrum.shape
    center = (h // 2, w // 2)
    
    y, x = np.ogrid[:h, :w]
    r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
    
    max_r = min(center)
    radial_prof = np.zeros(max_r)
    
    for radius in range(max_r):
        mask = (r == radius)
        radial_prof[radius] = magnitude_spectrum[mask].mean() if mask.any() else 0
    
    return radial_prof


def extract_fft_features_with_kan(image_path, img_size=256):
    """
    Extract FFT features + KAN-discovered patterns
    Based on KAN findings from interpretability analysis
    """
    try:
        img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
        if img is None:
            return None
        
        img = cv2.resize(img, (img_size, img_size))
        
        # Compute FFT
        f_transform = np.fft.fft2(img)
        f_shift = np.fft.fftshift(f_transform)
        magnitude = np.abs(f_shift)
        
        # Radial profile
        radial_prof = compute_radial_profile(magnitude)
        
        # Frequency band analysis
        h, w = magnitude.shape
        center = (h // 2, w // 2)
        
        y, x = np.ogrid[:h, :w]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
        
        outer_radius = int(min(center) * 0.8)
        inner_radius = int(min(center) * 0.5)
        low_radius = int(min(center) * 0.3)
        
        high_freq_mask = (r >= inner_radius) & (r <= outer_radius)
        low_freq_mask = r < low_radius
        
        high_freq_energy = magnitude[high_freq_mask].sum()
        low_freq_energy = magnitude[low_freq_mask].sum()
        total_energy = magnitude.sum()
        
        # ===== ORIGINAL 8 FEATURES =====
        high_low_ratio = high_freq_energy / (low_freq_energy + 1e-10)
        high_freq_percentage = (high_freq_energy / total_energy) * 100
        radial_mean = radial_prof.mean()
        radial_std = radial_prof.std()
        radial_slope = np.polyfit(range(len(radial_prof)), radial_prof, 1)[0]
        
        
        # 1. Optimal Ratio Zone Score (KAN found optimal zone at ~0.18)
        optimal_ratio_score = np.exp(-((high_low_ratio - 0.18) ** 2) / (2 * 0.02 ** 2))
        
        # 2. Critical Slope Threshold (KAN found -2σ threshold)
        slope_threshold_violation = 1.0 if radial_slope < -3500 else 0.0
        
        # 3. Variance-based Spoof Score (KAN found high variance = spoof)
        variance_spoof_score = radial_std / (radial_mean + 1e-10)
        
        # 4. Quadratic Decay (steeper = spoof)
        x_prof = np.arange(len(radial_prof))
        poly_coeffs = np.polyfit(x_prof, np.log1p(radial_prof), 2)
        quadratic_decay = poly_coeffs[0]
        
        # 5. High-Frequency Deficit
        high_freq_deficit = max(0, 8e6 - high_freq_energy)
        
        # 6. Energy Distribution Uniformity
        mid_radius = int(min(center) * 0.5)
        mid_freq_mask = (r >= low_radius) & (r < mid_radius)
        mid_freq_energy = magnitude[mid_freq_mask].sum()
        energy_uniformity = np.std([low_freq_energy, mid_freq_energy, high_freq_energy]) / (total_energy + 1e-10)
        
        features = {
            # Original 8
            'high_freq_energy': high_freq_energy,
            'low_freq_energy': low_freq_energy,
            'high_low_ratio': high_low_ratio,
            'high_freq_percentage': high_freq_percentage,
            'radial_profile_mean': radial_mean,
            'radial_profile_std': radial_std,
            'radial_profile_slope': radial_slope,
            'total_energy': total_energy,
            
            # KAN-discovered 6
            'optimal_ratio_score': optimal_ratio_score,
            'slope_threshold_violation': slope_threshold_violation,
            'variance_spoof_score': variance_spoof_score,
            'quadratic_decay': quadratic_decay,
            'high_freq_deficit': high_freq_deficit,
            'energy_uniformity': energy_uniformity,
        }
        
        return features
    
    except Exception as e:
        return None

print("✓ FFT+KAN feature extraction defined (14 features total)")

✓ FFT+KAN feature extraction defined (14 features total)


In [4]:
# Cell 4: Prepare Federated Training Data (V1 + V2 - Both Quality Levels)
def prepare_federated_clients():
    """
    Prepare client data with BOTH quality levels for training
    - Live: V1 (high quality) + V2 (low quality)  ← CHANGED
    - Spoof: All versions
    """
    clients = defaultdict(lambda: {"features": [], "labels": []})
    
    live_dir = os.path.join(DATA_PATH, "live")
    spoof_dir = os.path.join(DATA_PATH, "spoof")
    
    # -------- LIVE (V1 + V2 - both quality levels) ----------
    print("Processing LIVE images (V1 + V2 - all quality)...")
    live_files = [f for f in os.listdir(live_dir) if "v1" in f or "v2" in f]
    
    for file in tqdm(live_files, desc="Live V1+V2"):
        match = re.search(r'bs(\d+)v', file)
        if match:
            client_id = int(match.group(1))
            full_path = os.path.join(live_dir, file)
            
            features = extract_fft_features_with_kan(full_path)
            if features is not None:
                clients[client_id]["features"].append(features)
                clients[client_id]["labels"].append(0)  # live = 0
    
    # -------- SPOOF (all) ----------
    print("\nProcessing SPOOF (all quality) images...")
    spoof_per_client = defaultdict(list)
    
    spoof_files = os.listdir(spoof_dir)
    for file in tqdm(spoof_files, desc="Spoof All"):
        match = re.search(r's(\d+)v', file)
        if match:
            client_id = int(match.group(1))
            full_path = os.path.join(spoof_dir, file)
            
            features = extract_fft_features_with_kan(full_path)
            if features is not None:
                spoof_per_client[client_id].append((features, 1))  # spoof = 1
    
    # Balance: match spoof count to live count for each client
    for client_id in range(1, NUM_CLIENTS + 1):
        spoof_samples = spoof_per_client[client_id]
        num_live = len([l for l in clients[client_id]["labels"] if l == 0])
        
        if len(spoof_samples) >= num_live:
            # Randomly sample to match live count
            import random
            random.seed(42)
            selected = random.sample(spoof_samples, num_live)
        else:
            selected = spoof_samples
        
        for features, label in selected:
            clients[client_id]["features"].append(features)
            clients[client_id]["labels"].append(label)
    
    return clients

# Prepare clients
clients_data = prepare_federated_clients()

# Display statistics
print("\n" + "="*60)
print("CLIENT DATA DISTRIBUTION")
print("="*60)

total_live = 0
total_spoof = 0

for client_id in range(1, NUM_CLIENTS + 1):
    labels = clients_data[client_id]["labels"]
    live_count = labels.count(0)
    spoof_count = labels.count(1)
    
    total_live += live_count
    total_spoof += spoof_count
    
    print(f"Client {client_id:2d}: Live={live_count:3d}, Spoof={spoof_count:3d}, Total={len(labels):3d}")

print("\n" + "="*60)
print(f"TOTAL Live (V1 + V2 - All Quality): {total_live}")
print(f"TOTAL Spoof:                        {total_spoof}")
print(f"TOTAL Training Samples:             {total_live + total_spoof}")
print("="*60)

Processing LIVE V1 (high quality) images...


Live V1: 100%|██████████| 9591/9591 [01:06<00:00, 144.19it/s]



Processing SPOOF (all quality) images...


Spoof All: 100%|██████████| 38736/38736 [14:00<00:00, 46.07it/s]



CLIENT DATA DISTRIBUTION
Client  1: V1=200, Spoof=200, Total=400
Client  2: V1=145, Spoof=145, Total=290
Client  3: V1=167, Spoof=167, Total=334
Client  4: V1=237, Spoof=237, Total=474
Client  5: V1=146, Spoof=146, Total=292
Client  6: V1=147, Spoof=147, Total=294
Client  7: V1=145, Spoof=145, Total=290
Client  8: V1=211, Spoof=211, Total=422
Client  9: V1=185, Spoof=185, Total=370
Client 10: V1=102, Spoof=102, Total=204

TOTAL V1 (High-Quality Live):  1685
TOTAL Spoof:                   1685
TOTAL Training Samples:        3370


In [5]:
# Cell 5: Prepare Test Data (Mixed Quality - V1 + V2)
def prepare_test_data():
    """
    Prepare test data with quality labels
    - V1 live = high quality real
    - V2 live = low quality real  
    - All spoof
    """
    live_dir = os.path.join(TEST_PATH, "live")
    spoof_dir = os.path.join(TEST_PATH, "spoof")
    
    features_list = []
    labels = []
    is_v2_flags = []
    
    # -------- LIVE (V1 and V2) ----------
    print("Processing TEST LIVE images...")
    live_files = os.listdir(live_dir)
    
    for file in tqdm(live_files, desc="Test Live"):
        full_path = os.path.join(live_dir, file)
        features = extract_fft_features_with_kan(full_path)
        
        if features is not None:
            features_list.append(features)
            labels.append(0)  # real = 0
            
            # Check if V2 (low quality)
            if "v2" in file:
                is_v2_flags.append(1)
            else:
                is_v2_flags.append(0)
    
    # -------- SPOOF ----------
    print("\nProcessing TEST SPOOF images...")
    spoof_files = os.listdir(spoof_dir)
    
    for file in tqdm(spoof_files, desc="Test Spoof"):
        full_path = os.path.join(spoof_dir, file)
        features = extract_fft_features_with_kan(full_path)
        
        if features is not None:
            features_list.append(features)
            labels.append(1)  # spoof = 1
            is_v2_flags.append(0)  # not v2 real
    
    return features_list, labels, is_v2_flags

# Prepare test data
test_features, test_labels, test_v2_flags = prepare_test_data()

print("\n" + "="*60)
print("TEST DATA SUMMARY")
print("="*60)
print(f"Total test samples: {len(test_labels)}")
print(f"Real (all):  {test_labels.count(0)}")
print(f"  - V1 (high quality): {sum([1 for l, v2 in zip(test_labels, test_v2_flags) if l == 0 and v2 == 0])}")
print(f"  - V2 (low quality):  {sum([1 for l, v2 in zip(test_labels, test_v2_flags) if l == 0 and v2 == 1])}")
print(f"Spoof (all): {test_labels.count(1)}")
print("="*60)

Processing TEST LIVE images...


Test Live: 100%|██████████| 10128/10128 [03:07<00:00, 54.01it/s]



Processing TEST SPOOF images...


Test Spoof: 100%|██████████| 55658/55658 [16:59<00:00, 54.57it/s]


TEST DATA SUMMARY
Total test samples: 65786
Real (all):  10128
  - V1 (high quality): 4830
  - V2 (low quality):  5298
Spoof (all): 55658





In [6]:
# Cell 6: Federated SVM Client
class FederatedSVMClient:
    """
    FL Client with SVM and FFT+KAN features
    """
    def __init__(self, client_id, features, labels):
        self.client_id = client_id
        
        # Convert features to DataFrame then to numpy
        self.X = pd.DataFrame(features).values
        self.y = np.array(labels)
        
        # Local scaler and model
        self.scaler = StandardScaler()
        self.model = SVC(
            kernel='rbf',
            C=10.0,
            gamma='scale',
            class_weight='balanced',
            random_state=42,
            probability=True  # For soft predictions
        )
        
    def train(self):
        """Train local SVM model"""
        # Standardize
        X_scaled = self.scaler.fit_transform(self.X)
        
        # Train SVM
        self.model.fit(X_scaled, self.y)
        
        return {
            'scaler_mean': self.scaler.mean_,
            'scaler_scale': self.scaler.scale_,
            'n_samples': len(self.X),
            'model': copy.deepcopy(self.model),
        }

print("✓ Federated SVM Client defined")

✓ Federated SVM Client defined


In [7]:
# Cell 7: Federated Server with Model Aggregation
class FederatedSVMServer:
    """
    FL Server - Aggregates SVM models
    """
    def __init__(self):
        self.global_scaler = StandardScaler()
        self.global_model = SVC(
            kernel='rbf',
            C=10.0,
            gamma='scale',
            class_weight='balanced',
            random_state=42,
            probability=True
        )
        
    def aggregate(self, client_updates):
        """Aggregate client scalers (weighted average)"""
        total_samples = sum(update['n_samples'] for update in client_updates)
        
        # Weighted average of scaler parameters
        weighted_mean = np.zeros_like(client_updates[0]['scaler_mean'])
        weighted_scale = np.zeros_like(client_updates[0]['scaler_scale'])
        
        for update in client_updates:
            weight = update['n_samples'] / total_samples
            weighted_mean += weight * update['scaler_mean']
            weighted_scale += weight * update['scaler_scale']
        
        self.global_scaler.mean_ = weighted_mean
        self.global_scaler.scale_ = weighted_scale
        
        print(f"  Aggregated {len(client_updates)} client models")
        print(f"  Total training samples: {total_samples}")
        
    def train_global(self, all_features, all_labels):
        """Train global model on aggregated data"""
        X = pd.DataFrame(all_features).values
        y = np.array(all_labels)
        
        # Use aggregated scaler
        X_scaled = self.global_scaler.transform(X)
        
        # Train global SVM
        self.global_model.fit(X_scaled, y)
        
        print(f"✓ Global model trained on {len(X)} samples")
        
    def evaluate(self, test_features, test_labels):
        """Evaluate global model"""
        X_test = pd.DataFrame(test_features).values
        X_test_scaled = self.global_scaler.transform(X_test)
        
        y_pred = self.global_model.predict(X_test_scaled)
        
        return y_pred

print("✓ Federated SVM Server defined")

✓ Federated SVM Server defined


In [8]:
# Cell 8: Federated Learning Training Loop
print("\n" + "="*70)
print("FEDERATED LEARNING TRAINING - FFT+KAN Features")
print("="*70)

# Initialize server
server = FederatedSVMServer()

# Global rounds
for round_num in range(GLOBAL_ROUNDS):
    print(f"\n========== ROUND {round_num + 1}/{GLOBAL_ROUNDS} ==========")
    
    client_updates = []
    all_round_features = []
    all_round_labels = []
    
    # Train each client
    for client_id in tqdm(range(1, NUM_CLIENTS + 1), desc=f"Training clients (Round {round_num + 1})"):
        client_features = clients_data[client_id]["features"]
        client_labels = clients_data[client_id]["labels"]
        
        # Create client
        client = FederatedSVMClient(client_id, client_features, client_labels)
        
        # Train locally
        update = client.train()
        client_updates.append(update)
        
        # Collect for global training
        all_round_features.extend(client_features)
        all_round_labels.extend(client_labels)
    
    # Aggregate scalers
    server.aggregate(client_updates)
    
    # Train global model
    server.train_global(all_round_features, all_round_labels)
    
    print(f"Round {round_num + 1} complete. Global model updated.")

print("\n✓ Federated learning complete!")


FEDERATED LEARNING TRAINING - FFT+KAN Features



Training clients (Round 1): 100%|██████████| 10/10 [00:00<00:00, 150.04it/s]


  Aggregated 10 client models
  Total training samples: 3370
✓ Global model trained on 3370 samples
Round 1 complete. Global model updated.



Training clients (Round 2): 100%|██████████| 10/10 [00:00<00:00, 164.40it/s]


  Aggregated 10 client models
  Total training samples: 3370
✓ Global model trained on 3370 samples
Round 2 complete. Global model updated.



Training clients (Round 3): 100%|██████████| 10/10 [00:00<00:00, 167.13it/s]


  Aggregated 10 client models
  Total training samples: 3370
✓ Global model trained on 3370 samples
Round 3 complete. Global model updated.

✓ Federated learning complete!


In [9]:
# Cell 9: Evaluation - Overall Performance
print("\n" + "="*70)
print("EVALUATION - FFT+KAN FEDERATED MODEL")
print("="*70)

# Get predictions
preds = server.evaluate(test_features, test_labels)
true_labels = np.array(test_labels)

# Overall metrics
print("\n=== OVERALL TEST PERFORMANCE ===")
print(f"Accuracy:  {accuracy_score(true_labels, preds):.4f}")
print(f"Precision: {precision_score(true_labels, preds):.4f}")
print(f"Recall:    {recall_score(true_labels, preds):.4f}")
print(f"F1 Score:  {f1_score(true_labels, preds):.4f}")

print("\nConfusion Matrix:")
cm = confusion_matrix(true_labels, preds)
print(cm)

print("\nClassification Report:")
print(classification_report(true_labels, preds, target_names=["Real", "Spoof"]))


EVALUATION - FFT+KAN FEDERATED MODEL

=== OVERALL TEST PERFORMANCE ===
Accuracy:  0.8668
Precision: 0.8664
Recall:    0.9962
F1 Score:  0.9268

Confusion Matrix:
[[ 1574  8554]
 [  209 55449]]

Classification Report:
              precision    recall  f1-score   support

        Real       0.88      0.16      0.26     10128
       Spoof       0.87      1.00      0.93     55658

    accuracy                           0.87     65786
   macro avg       0.87      0.58      0.60     65786
weighted avg       0.87      0.87      0.82     65786

