## **Leave One Generator Out Approach**

#### **Imports**

In [32]:
import os 
import random

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import (
    models,
    transforms,
)
from torch.utils.data import (
    Dataset,
    DataLoader,
)
from sklearn.metrics import (
    classification_report,
    accuracy_score, 
    confusion_matrix,
)

In [33]:
import torch
print("Torch Version:", torch.__version__)
print("CUDA Version:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Torch Version: 2.5.1+cu121
CUDA Version: 12.1
CUDA available: True
Device name: NVIDIA GeForce RTX 3060 Laptop GPU


***

#### **Dataset Classes and Functionality**

In [34]:
def crop_bottom(image, px=40):
    width, height = image.size
    return image.crop((0, 0, width, height - px))

In [35]:
class GeneratorDataset(Dataset):
    def __init__(self, root_dirs, excluded_generator=None, transform=None):
        self.image_paths, self.labels = [], []
        self.transform = transform

        for root in root_dirs:
            if os.path.basename(root) == excluded_generator:
                continue
            for label_type in ['real', 'fake']:
                folder = os.path.join(root, label_type)
                for fname in os.listdir(folder):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                        self.image_paths.append(os.path.join(folder, fname))
                        self.labels.append(0 if label_type == 'real' else 1)

        combined = list(zip(self.image_paths, self.labels))
        random.shuffle(combined)
        self.image_paths, self.labels = zip(*combined)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = crop_bottom(img)
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

In [36]:
def load_test_generator(path, transform):
    image_paths, labels = [], []
    for label_type in ['real', 'fake']:
        folder = os.path.join(path, label_type)
        for fname in os.listdir(folder):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                image_paths.append(os.path.join(folder, fname))
                labels.append(0 if label_type == 'real' else 1)

    class TestSet(Dataset):
        def __init__(self, image_paths, labels, transform):
            self.image_paths = image_paths
            self.labels = labels
            self.transform = transform

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            img = Image.open(self.image_paths[idx]).convert("RGB")
            img = crop_bottom(img)
            if self.transform:
                img = self.transform(img)
            return img, self.labels[idx]

    return TestSet(image_paths, labels, transform)

***

#### **Training Function**

In [45]:
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    loop = tqdm(loader, desc="Training", leave=False)
    
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader)


***

#### **Eval Function**

In [40]:
def evaluate_and_report(model, loader, generator_name, device, thresholds=[0.90, 0.80, 0.70, 0.60]):
    model.eval()
    predictions, targets, confs = [], [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            max_probs, predicted = torch.max(probs, 1)

            predictions.extend(predicted.cpu().numpy())
            targets.extend(labels.cpu().numpy())
            confs.extend(max_probs.cpu().numpy())

    predictions, targets, confs = np.array(predictions), np.array(targets), np.array(confs)
    correct = predictions == targets
    report = classification_report(targets, predictions, output_dict=True)

    print(f"\n📊 Results for {generator_name}")
    print(classification_report(targets, predictions, target_names=['Real', 'Fake']))
    print("🔁 Confusion Matrix:")
    print(confusion_matrix(targets, predictions))
    print(f"✅ Accuracy: {accuracy_score(targets, predictions):.4f}")

    print("\n📈 Confidence Stats:")
    print(f"Mean (All):       {confs.mean():.4f}")
    print(f"Mean (Correct):   {confs[correct].mean():.4f}")
    print(f"Mean (Incorrect): {confs[~correct].mean():.4f}")
    print(f"Max:              {confs.max():.4f}")
    print(f"Min:              {confs.min():.4f}")

    for t in thresholds:
        mask = confs >= t
        if mask.sum() == 0:
            continue
        acc = np.mean(correct[mask])
        print(f"\n🔎 Threshold ≥ {t:.2f}:")
        print(f"  Samples: {mask.sum()} ({mask.mean()*100:.2f}%)")
        print(f"  Accuracy: {acc:.4f}")

    return {
        "generator": generator_name,
        "accuracy": accuracy_score(targets, predictions),
        "precision": report["macro avg"]["precision"],
        "recall": report["macro avg"]["recall"],
        "f1": report["macro avg"]["f1-score"],
        "mean_confidence": confs.mean(),
        "high_conf_coverage": (confs >= 0.90).mean(),
        "high_conf_accuracy": np.mean(correct[confs >= 0.90]) if np.any(confs >= 0.90) else 0.0
    }

***

#### **Training & Prediction Function**

In [None]:
def train_and_predict( generators, dataset_base_path, transform, unfreeze_layers=False, num_epochs=5, batch_size=32, learning_rate=1e-4, weight_decay=0.0, thresholds=[0.90, 0.80, 0.70, 0.60], verbose=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    metrics_all = []

    for holdout in generators:
        if verbose:
            print(f"\n🚫 Holding out: {holdout}")

        # Train / Val split
        train_dirs = [os.path.join(dataset_base_path, g) for g in generators if g != holdout]
        train_dataset = GeneratorDataset(train_dirs, transform=transform)
        val_dataset = load_test_generator(os.path.join(dataset_base_path, holdout), transform=transform)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # Model setup
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        model.fc = nn.Linear(model.fc.in_features, 2)

        # Freeze or unfreeze logic
        for name, param in model.named_parameters():
            param.requires_grad = ("layer4" in name or "fc" in name) if unfreeze_layers else False
        for param in model.fc.parameters():
            param.requires_grad = True

        model = model.to(device)
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        criterion = nn.CrossEntropyLoss()

        # Training
        for epoch in range(num_epochs):
            loss = train(model, train_loader, optimizer, criterion)
            if verbose:
                print(f"🧪 Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

        # Evaluation
        metrics = evaluate_and_report(model, val_loader, holdout, device, thresholds)
        metrics_all.append(metrics)

    return pd.DataFrame(metrics_all)

***

In [46]:
transform_basic = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

metrics_basic = train_and_predict(
    generators=['DALLE_dataset', 'IMAGEN_dataset', 'SD_dataset'],
    dataset_base_path='../../data',
    transform=transform_basic
)


🚫 Holding out: DALLE_dataset


                                                                       

🧪 Epoch 1/5, Loss: 0.6222


                                                                       

🧪 Epoch 2/5, Loss: 0.5368


                                                                       

🧪 Epoch 3/5, Loss: 0.4933


                                                                       

🧪 Epoch 4/5, Loss: 0.4643


                                                                       

🧪 Epoch 5/5, Loss: 0.4432

📊 Results for DALLE_dataset
              precision    recall  f1-score   support

        Real       0.78      0.77      0.77      2150
        Fake       0.77      0.78      0.78      2150

    accuracy                           0.77      4300
   macro avg       0.77      0.77      0.77      4300
weighted avg       0.77      0.77      0.77      4300

🔁 Confusion Matrix:
[[1651  499]
 [ 470 1680]]
✅ Accuracy: 0.7747

📈 Confidence Stats:
Mean (All):       0.7001
Mean (Correct):   0.7229
Mean (Incorrect): 0.6218
Max:              0.9964
Min:              0.5000

🔎 Threshold ≥ 0.90:
  Samples: 343 (7.98%)
  Accuracy: 0.9650

🔎 Threshold ≥ 0.80:
  Samples: 1100 (25.58%)
  Accuracy: 0.9445

🔎 Threshold ≥ 0.70:
  Samples: 2026 (47.12%)
  Accuracy: 0.9003

🔎 Threshold ≥ 0.60:
  Samples: 3073 (71.47%)
  Accuracy: 0.8461

🚫 Holding out: IMAGEN_dataset


                                                                       

🧪 Epoch 1/5, Loss: 0.6237


                                                                       

🧪 Epoch 2/5, Loss: 0.5335


                                                                       

🧪 Epoch 3/5, Loss: 0.4880


                                                                       

🧪 Epoch 4/5, Loss: 0.4568


                                                                       

🧪 Epoch 5/5, Loss: 0.4357

📊 Results for IMAGEN_dataset
              precision    recall  f1-score   support

        Real       0.81      0.80      0.80      1175
        Fake       0.80      0.82      0.81      1175

    accuracy                           0.81      2350
   macro avg       0.81      0.81      0.81      2350
weighted avg       0.81      0.81      0.81      2350

🔁 Confusion Matrix:
[[935 240]
 [216 959]]
✅ Accuracy: 0.8060

📈 Confidence Stats:
Mean (All):       0.7076
Mean (Correct):   0.7298
Mean (Incorrect): 0.6157
Max:              0.9941
Min:              0.5002

🔎 Threshold ≥ 0.90:
  Samples: 232 (9.87%)
  Accuracy: 0.9914

🔎 Threshold ≥ 0.80:
  Samples: 645 (27.45%)
  Accuracy: 0.9690

🔎 Threshold ≥ 0.70:
  Samples: 1146 (48.77%)
  Accuracy: 0.9267

🔎 Threshold ≥ 0.60:
  Samples: 1735 (73.83%)
  Accuracy: 0.8720

🚫 Holding out: SD_dataset


                                                                       

🧪 Epoch 1/5, Loss: 0.6265


                                                                       

🧪 Epoch 2/5, Loss: 0.5407


                                                                       

🧪 Epoch 3/5, Loss: 0.4942


                                                                       

🧪 Epoch 4/5, Loss: 0.4592


                                                                       

🧪 Epoch 5/5, Loss: 0.4357

📊 Results for SD_dataset
              precision    recall  f1-score   support

        Real       0.72      0.79      0.75      2675
        Fake       0.77      0.70      0.73      2675

    accuracy                           0.74      5350
   macro avg       0.75      0.74      0.74      5350
weighted avg       0.75      0.74      0.74      5350

🔁 Confusion Matrix:
[[2105  570]
 [ 804 1871]]
✅ Accuracy: 0.7432

📈 Confidence Stats:
Mean (All):       0.6912
Mean (Correct):   0.7120
Mean (Incorrect): 0.6312
Max:              0.9945
Min:              0.5001

🔎 Threshold ≥ 0.90:
  Samples: 373 (6.97%)
  Accuracy: 0.9464

🔎 Threshold ≥ 0.80:
  Samples: 1204 (22.50%)
  Accuracy: 0.8987

🔎 Threshold ≥ 0.70:
  Samples: 2355 (44.02%)
  Accuracy: 0.8611

🔎 Threshold ≥ 0.60:
  Samples: 3740 (69.91%)
  Accuracy: 0.8094


In [53]:
metrics_df = pd.DataFrame(metrics_basic)
metrics_df_basic_rounded = metrics_df.round(4)

display(metrics_df_basic_rounded)

Unnamed: 0,generator,accuracy,precision,recall,f1,mean_confidence,high_conf_coverage,high_conf_accuracy
0,DALLE_dataset,0.7747,0.7747,0.7747,0.7746,0.7001,0.0798,0.965
1,IMAGEN_dataset,0.806,0.8061,0.806,0.8059,0.7076,0.0987,0.9914
2,SD_dataset,0.7432,0.7451,0.7432,0.7427,0.6912,0.0697,0.9464


#### **Advanced Iteration**

In [52]:
transform_advanced = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

metrics_advanced = train_and_predict(
    generators=['DALLE_dataset', 'IMAGEN_dataset', 'SD_dataset'],
    dataset_base_path='../../data',
    transform=transform_advanced,
    unfreeze_layers=True,
    num_epochs=15,
    weight_decay=1e-5
)


🚫 Holding out: DALLE_dataset


                                                                       

🧪 Epoch 1/15, Loss: 0.3799


                                                                        

🧪 Epoch 2/15, Loss: 0.1524


                                                                        

🧪 Epoch 3/15, Loss: 0.0824


                                                                         

🧪 Epoch 4/15, Loss: 0.0523


                                                                         

🧪 Epoch 5/15, Loss: 0.0430


                                                                         

🧪 Epoch 6/15, Loss: 0.0275


                                                                          

🧪 Epoch 7/15, Loss: 0.0290


                                                                          

🧪 Epoch 8/15, Loss: 0.0261


                                                                          

🧪 Epoch 9/15, Loss: 0.0215


                                                                          

🧪 Epoch 10/15, Loss: 0.0148


                                                                          

🧪 Epoch 11/15, Loss: 0.0139


                                                                          

🧪 Epoch 12/15, Loss: 0.0166


                                                                          

🧪 Epoch 13/15, Loss: 0.0158


                                                                          

🧪 Epoch 14/15, Loss: 0.0108


                                                                          

🧪 Epoch 15/15, Loss: 0.0118

📊 Results for DALLE_dataset
              precision    recall  f1-score   support

        Real       0.93      0.91      0.92      2150
        Fake       0.91      0.93      0.92      2150

    accuracy                           0.92      4300
   macro avg       0.92      0.92      0.92      4300
weighted avg       0.92      0.92      0.92      4300

🔁 Confusion Matrix:
[[1947  203]
 [ 140 2010]]
✅ Accuracy: 0.9202

📈 Confidence Stats:
Mean (All):       0.9636
Mean (Correct):   0.9738
Mean (Incorrect): 0.8470
Max:              1.0000
Min:              0.5023

🔎 Threshold ≥ 0.90:
  Samples: 3799 (88.35%)
  Accuracy: 0.9526

🔎 Threshold ≥ 0.80:
  Samples: 3992 (92.84%)
  Accuracy: 0.9441

🔎 Threshold ≥ 0.70:
  Samples: 4106 (95.49%)
  Accuracy: 0.9357

🔎 Threshold ≥ 0.60:
  Samples: 4212 (97.95%)
  Accuracy: 0.9259

🚫 Holding out: IMAGEN_dataset


                                                                        

🧪 Epoch 1/15, Loss: 0.3420


                                                                        

🧪 Epoch 2/15, Loss: 0.1372


                                                                         

🧪 Epoch 3/15, Loss: 0.0720


                                                                         

🧪 Epoch 4/15, Loss: 0.0483


                                                                         

🧪 Epoch 5/15, Loss: 0.0384


                                                                          

🧪 Epoch 6/15, Loss: 0.0282


                                                                                

🧪 Epoch 7/15, Loss: 0.0248


                                                                          

🧪 Epoch 8/15, Loss: 0.0203


                                                                          

🧪 Epoch 9/15, Loss: 0.0203


                                                                          

🧪 Epoch 10/15, Loss: 0.0143


                                                                          

🧪 Epoch 11/15, Loss: 0.0122


                                                                          

🧪 Epoch 12/15, Loss: 0.0137


                                                                          

🧪 Epoch 13/15, Loss: 0.0165


                                                                          

🧪 Epoch 14/15, Loss: 0.0162


                                                                          

🧪 Epoch 15/15, Loss: 0.0108

📊 Results for IMAGEN_dataset
              precision    recall  f1-score   support

        Real       0.71      0.97      0.82      1175
        Fake       0.95      0.60      0.73      1175

    accuracy                           0.78      2350
   macro avg       0.83      0.78      0.77      2350
weighted avg       0.83      0.78      0.77      2350

🔁 Confusion Matrix:
[[1139   36]
 [ 475  700]]
✅ Accuracy: 0.7826

📈 Confidence Stats:
Mean (All):       0.9466
Mean (Correct):   0.9625
Mean (Incorrect): 0.8892
Max:              1.0000
Min:              0.5011

🔎 Threshold ≥ 0.90:
  Samples: 1947 (82.85%)
  Accuracy: 0.8285

🔎 Threshold ≥ 0.80:
  Samples: 2083 (88.64%)
  Accuracy: 0.8176

🔎 Threshold ≥ 0.70:
  Samples: 2193 (93.32%)
  Accuracy: 0.8012

🔎 Threshold ≥ 0.60:
  Samples: 2280 (97.02%)
  Accuracy: 0.7908

🚫 Holding out: SD_dataset


                                                                       

🧪 Epoch 1/15, Loss: 0.3739


                                                                        

🧪 Epoch 2/15, Loss: 0.1273


                                                                        

🧪 Epoch 3/15, Loss: 0.0700


                                                                         

🧪 Epoch 4/15, Loss: 0.0452


                                                                         

🧪 Epoch 5/15, Loss: 0.0327


                                                                         

🧪 Epoch 6/15, Loss: 0.0254


                                                                          

🧪 Epoch 7/15, Loss: 0.0226


                                                                          

🧪 Epoch 8/15, Loss: 0.0193


                                                                          

🧪 Epoch 9/15, Loss: 0.0198


                                                                          

🧪 Epoch 10/15, Loss: 0.0163


                                                                          

🧪 Epoch 11/15, Loss: 0.0079


                                                                          

🧪 Epoch 12/15, Loss: 0.0137


                                                                          

🧪 Epoch 13/15, Loss: 0.0091


                                                                          

🧪 Epoch 14/15, Loss: 0.0077


                                                                          

🧪 Epoch 15/15, Loss: 0.0062

📊 Results for SD_dataset
              precision    recall  f1-score   support

        Real       0.73      0.95      0.83      2675
        Fake       0.93      0.65      0.77      2675

    accuracy                           0.80      5350
   macro avg       0.83      0.80      0.80      5350
weighted avg       0.83      0.80      0.80      5350

🔁 Confusion Matrix:
[[2554  121]
 [ 936 1739]]
✅ Accuracy: 0.8024

📈 Confidence Stats:
Mean (All):       0.9507
Mean (Correct):   0.9651
Mean (Incorrect): 0.8924
Max:              1.0000
Min:              0.5005

🔎 Threshold ≥ 0.90:
  Samples: 4504 (84.19%)
  Accuracy: 0.8468

🔎 Threshold ≥ 0.80:
  Samples: 4808 (89.87%)
  Accuracy: 0.8311

🔎 Threshold ≥ 0.70:
  Samples: 5017 (93.78%)
  Accuracy: 0.8198

🔎 Threshold ≥ 0.60:
  Samples: 5197 (97.14%)
  Accuracy: 0.8110


In [61]:
metrics_df = pd.DataFrame(metrics_advanced)
metrics_df_advanced_rounded = metrics_df.round(4)

display(metrics_df_advanced_rounded)

Unnamed: 0,generator,accuracy,precision,recall,f1,mean_confidence,high_conf_coverage,high_conf_accuracy
0,DALLE_dataset,0.9202,0.9206,0.9202,0.9202,0.9636,0.8835,0.9526
1,IMAGEN_dataset,0.7826,0.8284,0.7826,0.7747,0.9466,0.8285,0.8285
2,SD_dataset,0.8024,0.8334,0.8024,0.7977,0.9507,0.8419,0.8468


***

#### **Visual Comparison of Runs**

In [63]:
metrics_df_basic_rounded["setup"] = "Basic"
metrics_df_advanced_rounded["setup"] = "Advanced"

# Combine into one DataFrame
combined_df = pd.concat([metrics_df_basic_rounded, metrics_df_advanced_rounded], ignore_index=True)

#### **Accuracy by Fold**

In [68]:
fig = px.bar(
    combined_df,
    x="generator",
    y="accuracy",
    color="setup",
    barmode="group",
    title="Accuracy per Generator: Basic vs Advanced"
)
fig.update_layout(yaxis_range=[0, 1], yaxis_title="Accuracy")
fig.show()

#### **Mean Confidence**

In [69]:
fig = px.bar(
    combined_df,
    x="generator",
    y="mean_confidence",
    color="setup",
    barmode="group",
    title="Mean Confidence per Generator"
)
fig.update_layout(yaxis_range=[0, 1], yaxis_title="Mean Confidence")
fig.show()

#### **High-Confidence Accuracy**

In [70]:
fig = px.bar(
    combined_df,
    x="generator",
    y="high_conf_accuracy",
    color="setup",
    barmode="group",
    title="High-Confidence Accuracy (≥ 0.90)"
)
fig.update_layout(yaxis_range=[0, 1], yaxis_title="Accuracy")
fig.show()

#### **High-Confidence Coverage**

In [71]:
fig = px.bar(
    combined_df,
    x="generator",
    y="high_conf_coverage",
    color="setup",
    barmode="group",
    title="High-Confidence Coverage (≥ 0.90)"
)
fig.update_layout(yaxis_range=[0, 1], yaxis_title="Coverage %")
fig.show()