# Prerequisites

First of all, we import all of the modules, functions, and classes that we are
going to use.


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import sys
import torch
import torchvision
from tqdm import tqdm
import torch.nn as nn
import torchvision.models as models
from contextlib import nullcontext
from functools import partialmethod
from pprint import pp
from sklearn.metrics import ConfusionMatrixDisplay
from torch import optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader
import os
import random
import torchvision.transforms as T
from torch.utils.data import Dataset
from PIL import Image, ImageFilter
import numpy as np
import cv2
import io
from typing import Optional
from torch.optim import AdamW
plt.rcParams["font.family"] = "serif"


Here, we set the seed for PyTorch's RNG to get (not entirely) reproducible
results.


In [None]:
torch.manual_seed(512)

print(f"Set {512} as the seed of Torch's RNG!")

And as the final part in this section, we load the CUDA device in case it is
available.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using {device} device!")

# Data


## Loading

First, we need to load the data from the directory in which the dataset is
stored.

By default, PyTorch sorts the labels and sets the target values according to it.
In this case, 0 is assigned to the label FAKE and 1 is assigned to the label
REAL. But, we prefer to have 1 for FAKE and 0 for REAL since it's more logical
to have the error signal, 1, to indicate a fake image. Thus, we define the
following class to handle this transformation.


In [None]:
class TargetMapper:
    def __init__(self, labels):
        self.labels_dict = {v: k for k, v in enumerate(labels)}
        self.labels_list = sorted(labels)

    def __call__(self, target):
        return self.labels_dict[self.labels_list[target]]

# Path

In [None]:
# Paths
artifact_path = "/kaggle/input/artifacts-cleaned/ArtiFact"
genimage_path = "/kaggle/input/tiny-genimage"
cifake_path = "/kaggle/input/cifake-real-and-ai-generated-synthetic-images"


# Augmentations

In [None]:

# Augmentation Classes
class JPEGCompression:
    def __init__(self, quality_range=(30, 100), p=0.5, use_pil_jpeg=True):
        self.quality_range = quality_range
        self.p = p
        self.use_pil_jpeg = use_pil_jpeg

    def __call__(self, img):
        if random.random() > self.p:
            return img

        quality = random.randint(self.quality_range[0], self.quality_range[1])

        if self.use_pil_jpeg:
            # PIL JPEG compression
            buffer = io.BytesIO()
            img.save(buffer, format='JPEG', quality=quality)
            buffer.seek(0)
            img = Image.open(buffer)
        else:
            # OpenCV JPEG compression
            img_np = np.array(img)
            encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
            _, encoded_img = cv2.imencode('.jpg', cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR), encode_param)
            decoded_img = cv2.imdecode(encoded_img, cv2.IMREAD_COLOR)
            img = Image.fromarray(cv2.cvtColor(decoded_img, cv2.COLOR_BGR2RGB))

        return img


class RandomGaussianBlur:
    def __init__(self, sigma_range=(0, 3), p=0.5):
        self.sigma_range = sigma_range
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img

        sigma = random.uniform(self.sigma_range[0], self.sigma_range[1])
        return img.filter(ImageFilter.GaussianBlur(radius=sigma))


# Dataset Class

In [None]:


# Dataset Class
class RandomizedDataset(Dataset):
    def __init__(self, data_paths, transform=None, sample_fraction=1.0):
        self.data = []
        self.labels = []
        self.transform = transform

        # Load data
        for root_path, label_map in data_paths:
            for label_name, label_value in label_map.items():
                folder_path = os.path.join(root_path, label_name)
                if os.path.exists(folder_path):
                    images = [
                        os.path.join(folder_path, filename)
                        for filename in os.listdir(folder_path)
                        if filename.endswith((".png", ".jpg", ".jpeg"))
                    ]
                    labels = [label_value] * len(images)
                    self.data.extend(images)
                    self.labels.extend(labels)

        # Apply random sampling if sample_fraction < 1.0
        if sample_fraction < 1.0:
            sampled_indices = random.sample(range(len(self.data)), int(sample_fraction * len(self.data)))
            self.data = [self.data[i] for i in sampled_indices]
            self.labels = [self.labels[i] for i in sampled_indices]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data[idx]
        label = self.labels[idx]

        img = Image.open(image_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label



# Helper function to prepare label mappings
def prepare_genimage_paths(root_path):
    data_paths = []
    subdirectories = [os.path.join(root_path, subdir) for subdir in os.listdir(root_path) if os.path.isdir(os.path.join(root_path, subdir))]

    for subdir in subdirectories:
        train_path = os.path.join(subdir, "train")
        val_path = os.path.join(subdir, "val")

        if os.path.exists(train_path):
            data_paths.append((train_path, {"ai": 1, "nature": 0}))

        if os.path.exists(val_path):
            data_paths.append((val_path, {"ai": 1, "nature": 0}))

    return data_paths

# Transformation Class

In [None]:


# Prepare transformations
def get_transforms(augmentation_type: str = "blur+jpeg_0.5") -> tuple:
    base_transform = [
        T.Resize((32,32)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]

    test_transform = T.Compose([
        T.Resize((32,32)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    if augmentation_type == "no_aug":
        train_transform = base_transform

    elif augmentation_type == "gaussian_blur":
        train_transform = [RandomGaussianBlur(p=0.5)] + base_transform

    elif augmentation_type == "jpeg":
        train_transform = [JPEGCompression(p=0.5, use_pil_jpeg=random.choice([True, False]))] + base_transform

    elif augmentation_type == "blur+jpeg_0.5":
        train_transform = [
            RandomGaussianBlur(p=0.2),
            JPEGCompression(p=0.2, use_pil_jpeg=random.choice([True, False])),
        ] + base_transform

    elif augmentation_type == "blur+jpeg_0.1":
        train_transform = [
            RandomGaussianBlur(p=0.1),
            JPEGCompression(p=0.1, use_pil_jpeg=random.choice([True, False])),
        ] + base_transform

    else:
        raise ValueError(f"Unknown augmentation type: {augmentation_type}")

    return T.Compose(train_transform), test_transform


# Train and Val Loader

In [None]:
# Prepare datasets
train_data_paths = [
    (os.path.join(artifact_path, "train"), {"FAKE": 1, "REAL": 0}),
    (os.path.join(cifake_path, "train"), {"FAKE": 1, "REAL": 0}),
]
train_data_paths.extend(prepare_genimage_paths(genimage_path))

val_data_paths = [
    (os.path.join(artifact_path, "val"), {"FAKE": 1, "REAL": 0}),
    (os.path.join(cifake_path, "test"), {"FAKE": 1, "REAL": 0}),
]
val_data_paths.extend(prepare_genimage_paths(genimage_path))

# Get transforms
train_transform, test_transform = get_transforms(augmentation_type="no_aug")

# Create datasets
train_data = RandomizedDataset(train_data_paths, transform=train_transform, sample_fraction=0.8)
val_data = RandomizedDataset(val_data_paths, transform=test_transform)

# DataLoaders
def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images).to(device)
    labels = torch.tensor(labels).to(device)
    return images, labels

train_loader = DataLoader(train_data, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=128, shuffle=False, collate_fn=collate_fn)

## Utilities

Here, we define the following class to stop the training process if no
improvements are gained. It also handles saving and reloading the best model
parameters.


In [None]:
class EarlyStopping():
    def __init__(self, model, patience=5, metric_name="loss", mode="min", device="cuda"):
        """
        Early stopping utility for PyTorch models.

        Args:
        - model: The PyTorch model to monitor and save.
        - patience: Number of epochs with no improvement after which training will stop.
        - metric_name: The name of the metric to monitor.
        - mode: "min" for minimizing the metric, "max" for maximizing the metric.
        - device: The device on which the operations should run ("cuda" or "cpu").
        """
        self.model = model
        self.patience = patience
        self.metric_name = metric_name
        self.mode = mode
        self.device = device
        self.counter = 0
        self.best_metric_value = float("inf") if mode == "min" else -float("inf")
        self.checkpoint_path = "_checkpoint.pth"

    def __call__(self, metrics, last_epoch=False):
        """
        Checks whether training should stop based on the monitored metric.

        Args:
        - metrics: A dictionary containing the monitored metrics.
        - last_epoch: Boolean indicating if this is the final epoch.

        Returns:
        - should_stop: Boolean indicating whether training should stop.
        """
        metric_value = metrics[self.metric_name]
        delta = metric_value - self.best_metric_value
        improvement = delta > 0 if self.mode == "max" else delta < 0

        if improvement:
            self.counter = 0
            self.best_metric_value = metric_value
            self._save_checkpoint()
        else:
            self.counter += 1

        should_stop = self.counter >= self.patience

        if should_stop or last_epoch:
            self._load_checkpoint()

        return should_stop

    def _save_checkpoint(self):
        """Saves the model state to a checkpoint file."""
        torch.save(self.model.state_dict(), self.checkpoint_path)

    def _load_checkpoint(self):
        """Loads the model state from the checkpoint file."""
        self.model.load_state_dict(torch.load(self.checkpoint_path, map_location=self.device))

# Training Loop

In [None]:

EPOCHS = 10
VERBOSE = False

def file_name(model, suffix=""):
    return f"{OUTPUT_DIR}/{model.name}{suffix}.pt"
def update_history(history, metrics):
    """
    Updates the training history with metrics from the current epoch.

    Args:
    - history (dict): The dictionary storing the history of metrics.
    - metrics (dict): The current epoch's metrics containing "train" and "val" keys.
    """
    # Initialize keys in history if empty
    if not history:
        for split in metrics:
            history[split] = {key: [] for key in metrics[split].keys()}

    # Append metrics to history
    for split in metrics:
        for key, value in metrics[split].items():
            history[split][key].append(value)

def save_model(model, optimizer, suffix=""):
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        file_name(model, suffix),
    )

def load_model(model, optimizer, suffix=""):
    checkpoint = torch.load(file_name(model, suffix), map_location=device)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])

def get_metrics(model, optimizer, loader, split, desc_prefix=""):
    with torch.no_grad() if split == "val" else nullcontext():
        model.train() if split == "train" else model.eval()

        desc = f"{desc_prefix}{split.title()}ing"
        total_loss, total_items = 0.0, 0
        tp, tn, fp, fn = 0, 0, 0, 0

        for images, labels in (
            tqdm(loader, desc=desc, file=sys.stdout) if VERBOSE else loader
        ):
            # Move data to GPU
            images = images.to(device)
            labels = labels.to(device).unsqueeze(1)
            items = images.shape[0]

            if split == "train":
                optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            predictions = (outputs >= 0.5).float()
            loss = model.loss_function(outputs, labels.float())

            if split == "train":
                # Backward pass and optimizer step
                loss.backward()
                optimizer.step()

            # Accumulate metrics
            total_loss += loss.item() * items
            total_items += items

            tp += ((predictions == 1) & (labels == 1)).sum().item()
            tn += ((predictions == 0) & (labels == 0)).sum().item()
            fp += ((predictions == 1) & (labels == 0)).sum().item()
            fn += ((predictions == 0) & (labels == 1)).sum().item()

        # Calculate overall metrics
        loss = total_loss / total_items
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0

        # Create confusion matrix
        confusion_matrix = torch.tensor(
            [
                [tn, fp],
                [fn, tp],
            ]
        ).cpu().numpy()  # Move to CPU for compatibility with numpy

        return {
            "loss": loss,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1_score,
            "confusion_matrix": confusion_matrix,
        }


def train_model(model, optimizer, train_loader, val_loader):
    # Move model to GPU
    model.to(device)

    early_stopping = EarlyStopping(model)
    history = {}

    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}{':' if VERBOSE else '...'}")

        # Add tqdm progress bar for the training loop
        train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
        val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)

        # Get metrics for training
        train_metrics = get_metrics(model, optimizer, train_loader_tqdm, split="train", desc_prefix="Epoch ")

        # Get metrics for validation
        val_metrics = get_metrics(model, optimizer, val_loader_tqdm, split="val", desc_prefix="Epoch ")

        # Combine metrics
        metrics = {"train": train_metrics, "val": val_metrics}

        # Update training history
        update_history(history, metrics)

        # Print metrics for the epoch
        print(f"Epoch {epoch + 1} Results:")
        print(f"  Train Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}, "
              f"Precision: {train_metrics['precision']:.4f}, Recall: {train_metrics['recall']:.4f}, "
              f"F1-Score: {train_metrics['f1_score']:.4f}")
        print(f"  Val Loss:   {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, "
              f"Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}, "
              f"F1-Score: {val_metrics['f1_score']:.4f}")

        if VERBOSE:
            print_results(metrics)

        # Early stopping based on validation metrics
        if early_stopping(metrics["val"], last_epoch=epoch == (EPOCHS - 1)):
            break

    # Final evaluation on validation set
    return get_metrics(model, optimizer, val_loader, split="val"), history


This is the base for our CNN models.


## Original Model

The CIFAKE paper evaluates several models. This model is one of them and was
recommended by the authors.


In [None]:
class CifakeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = "cifakenet"

        # Load the DenseNet-121 model
        densenet = models.densenet121(pretrained=True)
        self.features = densenet.features

        # Define the MLP head for classification
        self.mlp_head = nn.Sequential(
            nn.Linear(densenet.classifier.in_features, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

        # Loss function
        self.loss_function = nn.BCELoss()

    def forward(self, x):
        # Forward pass through feature extractor
        x = self.features(x)
        x = torch.flatten(x, 1)  # Flatten the output

        # Forward pass through the MLP head
        x = self.mlp_head(x)
        return x

In the following cells, we initialize the model and run the routine on it. The
routine includes training, evaluating, and drawing the evaluation results plus
the actiavation maps.


In [None]:
model = CifakeNet()

In [None]:
model = model.to(device)

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

In [None]:
saved_model = torch.load("/kaggle/input/densenet_cifake/pytorch/default/1/cifakenet.pt")

In [None]:
test_metrics, history = train_model(model, optimizer, train_loader, val_loader)

In [None]:
model_save_path = "/kaggle/working/densenet121"

# Save the model
torch.save(model.state_dict(), model_save_path)

print(f"Model saved to {model_save_path}")
