# Goals

The advances in AI image generation have made big strides in recent times. It is at a point where these generated images can pass as real ones if not enough attention is paid to trying to dicsern them. This can be a problem for many reasons starting from spam and ending with disinformation campaigns. 

The goal of this project is to train a model to detect these images. Since it's quite a novel task, we don't have a good idea of what the expected performance should be. This will require establishing a baseline and determining what kind of errors are acceptable. In addition, this project will also serve as a practice for working with parquet files and using an experiment tracking tool.

# Imports

This notebook was used mostly in a Kaggle environment, so either the data needs to be uploaded there (train files) or the references to Kaggle need to be changed (imports, paths and W&B credentials).

In [1]:
LOG_CHECKPOINTS = False

In [2]:
INTERACTIVE_MODE = False
DEBUG_MODE = False
REDUCED_DATASET_MODE = False

SWEEP_MODE = False
CREATE_NEW_SWEEP = False

TRAIN_CANDIDATES_MODE = False

TRAIN_FINAL_MODE = True
TEST_FINAL_MODE = True


# sweep_id = "qxp2wgy1"  # full (512, ~18000)
# sweep_id = "fxz1z4nx"  # reduced (256, 5000)
# sweep_id = "kr39572n" # reduced with transforms
# sweep_id = "9y8abom7" # reduced with transforms bayes
# sweep_id = "ukwbdg8y" # reduced final
sweep_id = None  # Add sweep_id to use, otherwise will create a new sweep with default config
if SWEEP_MODE and not CREATE_NEW_SWEEP:
    if sweep_id is None:
        raise ValueError("Specify the sweep id to use or allow to create a new sweep")

In [3]:
import gc
import io
import json
from pathlib import Path


import h5py
from kaggle_secrets import UserSecretsClient
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pyarrow
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import seaborn as sns
from sklearn.metrics import accuracy_score, recall_score, precision_score
from sklearn.model_selection import StratifiedKFold
import timm
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from torchvision.transforms import (
    RandomResizedCrop, RandomRotation, RandomHorizontalFlip, RandomVerticalFlip,
    ColorJitter, RandomAdjustSharpness,
    ToTensor, Compose, Normalize
)
from torchvision.transforms.functional import to_tensor
from tqdm.notebook import tqdm
import wandb

In [None]:
user_secrets = UserSecretsClient()
wandb_apikey = user_secrets.get_secret("wandb_apikey")
wandb.login(key=wandb_apikey)

use_gpu = torch.cuda.is_available()

# Data

The data for this project comes from the dataset on Hugging Face called [aiornot](https://huggingface.co/datasets/competitions/aiornot). It was originally used in a competition with the same name. The data itself consists of train images, labels for those images and test images without labels.

I am going to be using only the train images and I am going to split them into my own train, validation and test sets.
 

## Data Loading and EDA

### Function and Class Definitions

In [None]:
DATA_DIR = Path("/kaggle/input/ai-or-not")

RAW_IMAGE_SHAPE = (512, 512, 3)
if REDUCED_DATASET_MODE:
    RAW_IMAGE_SHAPE = (256, 256, 3)
PREPROCESSING_BATCH_SIZE = 1024

In [5]:
def read_data():
    parquets_filenames = list(DATA_DIR.glob("train*"))
    parquets = [pq.read_table(p) for p in parquets_filenames]
    data_table = ds.dataset(parquets_filenames, format="parquet").to_table()
    return data_table

def show_target_distribution():
    label_counts = labels["label"].value_counts()
    plt.title("Distribution of target")
    sns.barplot(y=label_counts.values, x=label_counts.index)
    plt.show()

def decode_image(bytes_):
    return np.array(
        Image.open(io.BytesIO(bytes_)).resize(RAW_IMAGE_SHAPE[:2]),
        dtype=np.uint8
    )

def decode_batch(image_batch, shape=RAW_IMAGE_SHAPE):
    images = image_batch.to_pydict()["image"]
    batch_size = len(images)
    images_array = np.empty((batch_size, *shape), dtype=np.uint8)
    for idx, image in enumerate(images):
        images_array[idx] = decode_image(image["bytes"])
    return images_array

### Basic Dataset Info

Since data is stored in Hugging Face's datasets, there are convinience functions in their library to download datasets, but I wanted to practice working with parquet files myself.

The images are split into train set containing labels and test set without labels. Since we need to evaluate the model ourselves, we are going to use only the train set and manually create the splits.

The train set consists of two parquet files.

In [6]:
data_table = read_data()

In [7]:
if DEBUG_MODE:
    data_table = data_table.slice(length=1000)
elif REDUCED_DATASET_MODE:
    data_table = data_table.slice(length=5000)

In [None]:
print(f"Columns: {data_table.column_names}")
print(f"Number of examples: {len(data_table)}")

As can be seen, the dataset contains 18618 examples and 3 columns `id`, `image` and `label`.

First, we will take a look at the distribution of labels to check how unbalanced the dataset is.

In [9]:
labels = data_table.select(["id", "label"]).to_pandas()

In [10]:
if INTERACTIVE_MODE:
    show_target_distribution()

The bar plot shows that the dataset is fairly balanced and we can use this distribution as is when splitting.

### Assessing the Task

Now, we need to take a look at the images themselves to assess the difficulty of the task and what kind of transformations are needed or we can use.

In [12]:
images_table = data_table.select(["image"])
if INTERACTIVE_MODE:
    decoded_batch = decode_batch(next(iter(images_table.to_batches(PREPROCESSING_BATCH_SIZE))))
    decoded_batch.shape

In [13]:
if INTERACTIVE_MODE:
    np.random.seed(0)
    samples = np.random.randint(len(decoded_batch), size=(16,))

    fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
    fig.suptitle('Sample Images')

    for i in range(4):
        for j in range(4):
            idx = samples[i * 4 + j]
            axis = axes[i, j]
            axis.imshow(decoded_batch[idx])
            axis.axis('off')
            axis.set_title(labels.loc[idx].label)
    plt.show()

The images themselves seem to be fine and there is no need for any special preprocessing. However, the labels are just `0` and `1`, so, we need to make sure which label is which. Label `1` seems to correspond with generated images, but to be certain we are going to take a closer look at them.

In [14]:
if INTERACTIVE_MODE:
    potential_ai_labels = labels[labels.label == 1]

    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12))
    fig.suptitle('Sample Images')

    for i in range(3):
        for j in range(3):
            label = potential_ai_labels.iloc[i * 3 + j]
            idx = label.name
            axis = axes[i, j]
            axis.imshow(decoded_batch[idx])
            axis.axis('off')
            axis.set_title(label.label)
    plt.show()

The first image has weirdly placed eyes and a neck at a wrong angle. That confirms that this class is generated images. Additionally, other images have this weird `smoothness` and there is also a man with 6 (and a half?) fingers.

### Manual Benchmark

The task is pretty novel, so, I have no idea what a good level performance can be. One way to check is to research what other people were able to achieve. The other way is to do a manual benchmark. We can manually sample some number of images and try to classify them ourselves. The results will indicate a human-level performance, which is generally a pretty good benchmark.

In [15]:
if INTERACTIVE_MODE and not DEBUG_MODE:
    np.random.seed(1)
    NUMBER_OF_IMAGES_TO_GUESS = 50
    samples = np.random.randint(len(decoded_batch), size=(NUMBER_OF_IMAGES_TO_GUESS,))
    # samples = [ # The samples I've guessed on
    #     37, 235, 908, 72, 767, 905, 715, 645, 847, 960, 144, 129, 972, 583, 749, 508, 390,
    #     281, 178, 276, 254, 357, 914, 468, 907, 252, 490, 668, 925, 398, 562, 580, 215, 983,
    #     753, 503, 478, 864, 86, 141, 393, 7, 319, 829, 534, 313, 513, 896, 316, 209
    # ]


    guess = []
    true = []

    for i in range(NUMBER_OF_IMAGES_TO_GUESS):
        idx = samples[i]
        label = labels.loc[idx].label
        true.append(label)
        plt.imshow(decoded_batch[idx])
        plt.axis('off')
        plt.show()
        guess.append(int(input("AI=1, Human=0:")))
        answer = "AI" if label else "Human"
        print("It was", answer)

    guess = np.array(guess)
    true = np.array(true)

In [None]:
# My Score
# Accuracy = 0.8400
# AI Accuracy = 0.9545
# Human Accuracy = 0.7500
# Recall = 0.9545
# Precision = 0.7500
#
#    \/      \/        \/      \/      \/  \/      \/          \/
# [0 0 1 1 0 0 1 1 0 1 0 1 0 0 0 0 1 0 0 1 1 0 0 1 0 1 1 0 0 1 0 1 1 1 0 1 1 1 0 0 0 0 1 0 0 1 1 0 0 0]: true
# [0 1 1 1 0 1 1 1 0 1 1 1 0 0 1 0 1 0 1 1 0 0 0 1 1 1 1 0 0 1 1 1 1 1 0 1 1 1 0 0 0 0 1 0 0 1 1 0 0 0]: guess

if INTERACTIVE_MODE and not DEBUG_MODE:
    print(true)
    print(guess)
    ai_mask = (true == 1)
    human_mask = (true == 0)
    print((
        f"Accuracy = {accuracy_score(true, guess):.4f}\n"
        f"Accuracy on AI images = {accuracy_score(true[ai_mask], guess[ai_mask]):.4f}\n"
        f"Accuracy on `human` images = {accuracy_score(true[human_mask], guess[human_mask]):.4f}\n"
        f"Recall = {recall_score(true, guess):.4f}\n"
        f"Precision = {precision_score(true, guess):.4f}"
    ))
    # NOTE: AI Accuracy and Precision are always equal.
else:
    print("Accuracy = 0.8400")
    print("Accuracy on AI images = 0.9545")
    print("Accuracy on `human` images = 0.7500")
    print("Recall = 0.9545")
    print("Precision = 0.7500")

My results surprised me in 2 ways:

- How intuitive it is to tell that a generated image is generated
- How quickly I was able to get better at it

At first, it was pretty difficult to tell the difference, but towards the end I got a sense of the overall generated style of `strangeness` and `smoothness`. I wouldn't be surprised that with a bit more practice, I would be able to get much higher accuracy (and, technically, the last images are after my 'training', and accuracy there is 100%).

Overall, I think the task should be pretty easy for a deep learning model and we should expect to get a really high recall with the precision being a bit behind. This error tradeoff seems to be reasonable for some possible types of applications, but for the concrete cases, it would depend on the cost of false positives and false negatives and thus require more data processing such as oversampling, collecting more data focusing on the type of errors we want to reduce, and data augmentation for the more difficult cases.

# Modeling

## Function and Class Definitions

In [None]:
SCRATCH = Path("/kaggle/scratch")
SCRATCH.mkdir(parents=True, exist_ok=True)
PREPROCESSED = SCRATCH / "preprocessed.hdf5"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
PREPROCESS_TRANSFORM = Compose([
    ToTensor(),
    Normalize(IMAGENET_MEAN, IMAGENET_STD, inplace=False)
])

In [17]:
def split_data(labels):
    sgkf = StratifiedKFold(shuffle=True, n_splits=10, random_state=0)
    folds = sgkf.split(X=labels[["id"]], y=labels["label"])
    for i, (train_index, test_index) in enumerate(folds):
        labels.loc[test_index, ["fold"]] = i

    labels["split"] = "train"
    labels.loc[labels["fold"] == 0, ["split"]] = "valid" 
    labels.loc[(labels["fold"] == 1) | (labels["fold"] == 2), ["split"]] = "test" 

    if INTERACTIVE_MODE:
        to_plot = {"split": [], "target": [], "count": []}
        for split in ["train", "test", "valid"]:
            split_counts = labels.loc[labels["split"] == split, ["label"]].value_counts()
            for idx in range(2):
                to_plot["split"].append(split)
                to_plot["count"].append(split_counts.iloc[idx])
                to_plot["target"].append(split_counts.index.get_level_values("label")[idx])
        sns.barplot(data=to_plot, x="target", y="count", hue="split")
        plt.title("Label distribution in splits")
        plt.show()
    
    return labels


def preprocess_splits():
    # (                       size, disk-to-gpu speed)
    # Storing as float:       55GB, ~0.63it/s
    # Storing as uint8:       14GB, ~2.00it/s
    # Storing as uint8 + lzf: 12GB, ~0.02it/s
    with h5py.File(PREPROCESSED, "w") as file:
        for split in ["train", "valid", "test"]:
            split_labels = labels[labels["split"] == split]
            split_images_table = images_table.take(split_labels.index.values)
            total = (len(split_images_table) -  1) // PREPROCESSING_BATCH_SIZE + 1
            split_table = data_table.select(["image"]).take(split_labels.index.values)
            split_group = file.create_group(split)
            image_dataset = split_group.create_dataset(
                "image",
                shape=(len(split_images_table), *RAW_IMAGE_SHAPE),
                dtype=np.uint8,
            )
            batches = tqdm(
                enumerate(split_images_table.to_batches(PREPROCESSING_BATCH_SIZE)),
                desc=f"Preprocessing {split} images",
                mininterval=1, total=total
            )
            for idx, batch in batches:
                start = idx * PREPROCESSING_BATCH_SIZE
                end = start + PREPROCESSING_BATCH_SIZE
    #             image_dataset[start:end] = preprocess_batch(decode_batch(batch))
                image_dataset[start:end] = decode_batch(batch)

            split_group.create_dataset(
                "target",
                dtype=np.int64,
                data=split_labels.label.values.astype(np.int64)
            )


def preprocess_image(image):
    return PREPROCESS_TRANSFORM(image)


def get_train_transforms(max_rotation, max_brightness, max_contrast, sharpness_p, flip_p, use_crop):
    transforms = nn.Sequential(
        RandomResizedCrop(RAW_IMAGE_SHAPE[:2], antialias=True),
        RandomRotation(max_rotation),
        RandomHorizontalFlip(p=flip_p),
        RandomVerticalFlip(p=flip_p),
        ColorJitter(brightness=max_brightness, contrast=max_contrast),
        RandomAdjustSharpness(sharpness_factor=1.5, p=sharpness_p),
        Normalize(IMAGENET_MEAN, IMAGENET_STD, inplace=False)
    )
    if not use_crop:
        transforms = transforms[1:]
    return torch.jit.script(transforms)

valid_transforms = torch.jit.script(nn.Sequential(
    Normalize(IMAGENET_MEAN, IMAGENET_STD, inplace=False)
))


def train(config=None):
    with wandb.init(config=config, project="idl-aiornot"):
        config = wandb.config
        print(config)

        train_transforms = get_train_transforms(
            config.get("max_rotation", 0), config.get("max_brightness", 0.0),
            config.get("max_contrast", 0.0), config.get("sharpness_p", 0.0),
            config.get("flip_p", 0.0), config.get("use_crop", False)
        )
        model_name = config["model_name"]
        pooling = config.get("pooling", "avg")
        # ConvNeXt doesn't work with concat pooling
        if model_name.find("convnext_") != -1:
            pooling = "avg"
        module = AIOrNotModule(
            model_name=model_name, global_pool=pooling, lr=config["lr"],
            label_smoothing=config.get("label_smoothing", 0.03),
            weight_decay=config.get("weight_decay", 0.01),
            
            train_transforms=train_transforms, valid_transforms=valid_transforms
        )
        datamodule = AIOrNotDataModule(PREPROCESSED, batch_size=32)

        
        wandb_logger = WandbLogger(log_model=LOG_CHECKPOINTS)
        callbacks = []
        if LOG_CHECKPOINTS:
            callbacks.append(pl.callbacks.ModelCheckpoint(monitor="val_BinaryAUROC", mode="max"))

        trainer = pl.Trainer(
            gradient_clip_val=config["gradient_clip_value"],
            precision=16, benchmark=True,
            accelerator="gpu" if use_gpu else "tpu",
            auto_scale_batch_size="power",
            auto_lr_find=False,
            detect_anomaly=False,
            devices=-1,
            max_epochs=config["epochs"],
            log_every_n_steps=10,
            logger=wandb_logger,
            callbacks=callbacks
        )

        trainer.tune(model=module, datamodule=datamodule)
        trainer.fit(model=module, datamodule=datamodule)

        # In case we are using runtime hyperparameter tuning, we might need to update them for logging
        wandb.config.update(module.hparams)
        wandb.config.update(datamodule.hparams)
        print(wandb.config)
        
    # Force memory clearing
    del trainer
    del module
    del datamodule
    del wandb_logger
    del train_transforms
    gc.collect()

#### Dataset and Lightning Data Module

In [None]:
class AIOrNotDataset(torch.utils.data.Dataset):
    def __init__(self, images, targets, normalize=True):
        self.images = images
        self.targets = targets
        self.normalize = normalize
    
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        if self.normalize:
            image = preprocess_image(image)
        else:
            image = to_tensor(image)
        return (image, torch.tensor((self.targets[idx],), dtype=torch.float32))


class AIOrNotDataModule(pl.LightningDataModule):
    def __init__(self, path, batch_size):
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.save_hyperparameters(ignore=["path"])
        
    def setup(self, stage):
        self.fd = h5py.File(self.path)
        if stage == "fit":
            train_group = self.fd["train"]
            valid_group = self.fd["valid"]
            train_images = train_group["image"]
            valid_images = valid_group["image"]
            train_targets = train_group["target"]
            valid_targets = valid_group["target"]
            self.train_dataset = AIOrNotDataset(train_images, train_targets)
            self.valid_dataset = AIOrNotDataset(valid_images, valid_targets)
        elif stage == "test":
            test_group = self.fd["test"]
            test_images = test_group["image"]
            test_targets = test_group["target"]
            self.test_dataset = AIOrNotDataset(test_images, test_targets)
            
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
            drop_last=True,
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.valid_dataset,
            batch_size=2*self.batch_size,
            shuffle=False,
            num_workers=1,
            pin_memory=True,
            drop_last=False,
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=2*self.batch_size,
            shuffle=False,
            num_workers=1,
            pin_memory=True,
            drop_last=False,
        )
            

    def teardown(self, stage):
        self.fd.close()

#### Lightning Module

In [None]:
class AIOrNotModule(pl.LightningModule):
    def __init__(
        self, model_name, global_pool, lr, label_smoothing, weight_decay, num_classes=1,
        train_transforms=None, valid_transforms=None
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=num_classes,
            global_pool=global_pool
        )
#         self.loss_func = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        self.loss_func = nn.BCEWithLogitsLoss()
        self.label_smoothing = label_smoothing
        self.lr = lr
        self.weight_decay = weight_decay
        self.save_hyperparameters(ignore=[
            "model", "num_classes", "train_transforms", "valid_transforms"
        ])
        task = "binary"
        metrics = torchmetrics.MetricCollection([
            torchmetrics.Accuracy(num_classes=num_classes, task=task),
            torchmetrics.Precision(num_classes=num_classes, task=task),
            torchmetrics.Recall(num_classes=num_classes, task=task),
            torchmetrics.AUROC(num_classes=num_classes, task=task),
        ])
        self.train_metrics = metrics.clone(prefix="train_")
        self.val_metrics = metrics.clone(prefix="val_")
        self.test_metrics = metrics.clone(prefix="test_")
        self.train_transforms = train_transforms
        self.valid_transforms = valid_transforms
        
    def forward(self, x, transform=None):
        if transform is not None:
            x = transform(x)
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x, self.train_transforms)
        loss = self.loss_func(y_hat, y)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log_dict(self.train_metrics(y_hat, y))
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x, self.valid_transforms)
        loss = self.loss_func(y_hat, y)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log_dict(self.val_metrics(y_hat, y))
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x, self.valid_transforms)
        loss = self.loss_func(y_hat, y)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log_dict(self.test_metrics(y_hat, y))
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.lr,
            total_steps=self.trainer.estimated_stepping_batches,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1,
            }
        }

## Data Processing

With EDA done, we can now create the splits. Since the dataset is fairly balanced, we can just randomly split it into train (70%), validation (10%) and test (20%) sets, while keeping the same distribution of labels. One thing to note is that the dataset is too large to fit into memory, so, we need to use a streaming approach. We are going to process the data in batches and save the splits on disk using hdf5 format. During training, we are going to load the data from these files on demand.

In [18]:
labels = split_data(labels)

The only preprocessing we are going to do before saving the data into hdf5 files is to decode them into numpy array and, optionally for the reduced mode, resize them to 256x256.

In [None]:
preprocess_splits()

In [None]:
!ls -lh /kaggle/scratch

In [22]:
if not INTERACTIVE_MODE and not DEBUG_MODE:
    del labels
    del images_table
    del data_table
    gc.collect()

## Baseline Model

Now, we are going to try a baseline model. For training, I am using PyTorch Lightning with torchvision for basic preprocessing. Logging is done using Weights & Biases.

The baseline model is a simple pretrained ResNet-18 trained with a bit of label smoothing and weight decay. The batch size and learning rate are automatically tuned (batch size by maximizing the GPU memory and learning rate by using [the learning rate finder](https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#learning-rate-finder)).

In [26]:
if INTERACTIVE_MODE:
    print(timm.list_models(pretrained=True))

In [27]:
if INTERACTIVE_MODE and not SWEEP_MODE:
    hyperparams = {
        "pooling": "avg",
        "lr": 0.001,
        "label_smoothing": 0.05,
        "weight_decay": 0.03,
        "epochs": 5,
        "batch_size": 32,
        "gradient_clip_value": None,
        "model_name": "resnet18",
    }


    module = AIOrNotModule(
        model_name=hyperparams["model_name"], global_pool=hyperparams["pooling"], lr=hyperparams["lr"],
        label_smoothing=hyperparams["label_smoothing"],
        weight_decay=hyperparams["weight_decay"]
    )
    datamodule = AIOrNotDataModule(PREPROCESSED, batch_size=hyperparams["batch_size"])

    wandb_logger = WandbLogger(project="idl-aiornot")
    # Add debug tag if running in DEBUG_MODE to filter those runs out
    if DEBUG_MODE:
        wandb_logger.experiment.tags += ("debug",)
    trainer = pl.Trainer(
        gradient_clip_val=hyperparams["gradient_clip_value"],
        precision=16, benchmark=True,
        accelerator="gpu" if use_gpu else "cpu",
        auto_scale_batch_size="power",
        auto_lr_find=True,
        detect_anomaly=True,
        devices=-1 if use_gpu else None,
        max_epochs=hyperparams["epochs"],
        log_every_n_steps=1 if DEBUG_MODE else 10,
        logger=wandb_logger
    )

    trainer.tune(model=module, datamodule=datamodule)
    trainer.fit(model=module, datamodule=datamodule)

    # In case we are using runtime hyperparameter tuning, we might need to update them for logging
    wandb.config.update(module.hparams)
    wandb.config.update(datamodule.hparams)

    wandb.finish()

|Split|AUROC|Accuracy|Precision|Recall|Loss|
|---|---|---|---|---|---|
|Train|0.9999|1.0000|1.0000|1.0000|0.0157|
|Val.|0.9982|0.9753|0.9683|0.9872|0.0558|

The baseline model achieves really good results and as expected recall is higher than precision. However, the model is definitely overfitting too much since it has perfect accuracy on the train set. Precision suffers more from this overfitting, which might mean that compared to generated images, real images are more diverse and thus harder to classify.

One thing to note is that training on a full dataset takes a lot of time. In order to speed up the iteration process, I am going to be using a reduced dataset. The reduced dataset is created by resizing the images to 256x256 and taking only 50% of the data. After doing a hyperparameter search on the reduced dataset, I am going to select a few candidates and train them on the full dataset to do a final evaluation.

## Hyperparameter Optimization

The baseline works fairly well and there doesn't seem to be any problems with training. However, there is still a lot of room for improvement. We are going to take an easy route and do hyperparameter optimization. We are going to use Weights and Biases sweeps for that. It has an easy way to set up search spaces, which we are going to use for starting several levels of sweeps with more refined search spaces each time.

An example of a sweep config used is shown below. For the HP search, I did a few sweeps with different search spaces:

1. A random search without data augmentation
2. A random search with more tight bounds for the previous search space and data augmentation
3. A bayesian search with the more refined search space as the previous one
4. A random with only a few parameters to tune and data augmentation

In [34]:
sweep_config = {
    "method": "random",
    "name": "hp_tuning_reduced_with_transforms",
    "metric": {
        "goal": "maximize",
        "name": "val_BinaryAUROC"
    },
    "parameters": {
        "pooling": {
            "values": ["avg", "catavgmax"]
        },
        "lr": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2
        },
        "label_smoothing": {
            "distribution": "uniform",
            "min": 0.0,
            "max": 0.1
        },
        "weight_decay": {
            "distribution": "uniform",
            "min": 0.01,
            "max": 0.2
        },
        "epochs": {
            "distribution": "int_uniform",
            "min": 2,
            "max": 20
        },
        "gradient_clip_value": {
            "values": [None, 0.5, 2.0, 5.0, 10.0]
        },
        "model_name": {
            "values": [
                "resnet18", "resnet50",
                "convnext_tiny_in22k", "convnext_small_in22k",
                "efficientnet_b1",
            ]
        },
        "max_rotation": {
            "distribution": "uniform",
            "min": 0,
            "max": 45
        },
        "max_brightness": {
            "distribution": "uniform",
            "min": 0.0,
            "max": 0.4
        },
        "max_contrast": {
            "distribution": "uniform",
            "min": 0.0,
            "max": 0.4
        },
        "sharpness_p": {
            "distribution": "uniform",
            "min": 0.0,
            "max": 0.6
        },
        "flip_p": {
            "distribution": "uniform",
            "min": 0.0,
            "max": 0.5
        },
        "use_crop": {
            "values": ["False", "True"]
        }
    }
}

if SWEEP_MODE:
    if sweep_id is None:
        sweep_id = wandb.sweep(sweep_config, project="idl-aiornot")

In [35]:
if SWEEP_MODE:
    wandb.agent(sweep_id, train, count=5, project="idl-aiornot")

After doing these 4 sweeps (totaling 128 models trained over ~13 hours), I narrowed down on 6 possible candidates. They are not that different, but that is to be expected considering the performance of the baseline model and the fact that the reduced dataset was used.

## Final Model Training

With the candidate models selected, we can now train them on the full dataset. We are going to use the same hyperparameters by filtering the runs with `reduced_candidate` tag, getting the training configs and using them to train the models on the full dataset.

After that we can decide on the final model to use and retrain it again to check if the results are reproducible and to save the final model artifact.

In [37]:
# manually chosen runs with best val metrics and lowest overfitting
if TRAIN_CANDIDATES_MODE:
    wandb_api = wandb.Api()
    candidate_runs = wandb_api.runs(filters={"tags": "reduced_candidate"}, path="daniilgaltsev/idl-aiornot")
    candidate_configs = [run.config for run in candidate_runs]

In [38]:
if TRAIN_CANDIDATES_MODE:
    for config in candidate_configs:
        train(config=config)

Training those models took about ~4.5 hours, which underlines how much of a speedup the reduced dataset provides. If hyperparameter search was done on the full dataset, it would take 128/6*13h = ~277 hours instead of ~13.

In any case, all thse candidate models perform very well with validation AUROC around 0.9985. I have decided to choose the candidate with the lowest amount of overfitting, which is a convnext_tiny model training only for 2 epochs.

In [None]:
if TRAIN_FINAL_MODE:
    candidate_run_id = "u4de23v3"
    wandb_api = wandb.Api()
    run = wandb_api.run(path=f"daniilgaltsev/idl-aiornot/{candidate_run_id}")
    LOG_CHECKPOINTS = True
    train(config=run.config)

This training run didn't go as well as the previous ones since there was some instability in the training process towards the end, but I decided to use it anyway because this is the inherent stochasticity of the training process and it is important to take it into account and try to fix it properly in the future.

Since we have saved the best model checkpoint, we can load it and use it for inference and to get the test set metrics.

In [None]:
if TEST_FINAL_MODE:
    final_id = "zu9wrq3b"
    wandb_api = wandb.Api()
    run = wandb_api.run(path=f"daniilgaltsev/idl-aiornot/{final_id}")
    artifact = wandb_api.artifact(f"daniilgaltsev/idl-aiornot/model-{final_id}:latest")
    model_path = artifact.checkout()

    model = AIOrNotModule.load_from_checkpoint(model_path + "/" + "model.ckpt")
    datamodule = AIOrNotDataModule(PREPROCESSED, batch_size=run.config["batch_size"])

    with wandb.init(resume=True, id=final_id, project="idl-aiornot"):
        wandb_logger = WandbLogger()
        trainer = pl.Trainer(
            precision=16, benchmark=True,
            accelerator="gpu" if use_gpu else "tpu",
            devices=-1,
            logger=wandb_logger,
            log_every_n_steps=1
        )
        trainer.test(model=model, datamodule=datamodule)

|Metric|Baseline Model| Final Model|
|---|---|---|
|Train AUROC|**0.9999**|0.9960|
|Val. AUROC|**0.9982**|0.9977|
|Train Accuracy|**1.0000**|0.9688|
|Val. Accuracy|0.9753|**0.9780**|
|Train Precision|**1.0000**|0.9500|
|Val. Precision|0.9683|**0.9761**|
|Train Recall|**1.0000**|**1.0000**|
|Val. Recall|**0.9872**|0.9829|
|Train Loss|**0.0157**|0.0813|
|Val. Loss|**0.0558**|0.0584|

| Test Metric | Manual Benchmark | Final Model |
|--------|-----------|-------------|
|Accuracy|0.8400|**0.9296**|
| Recall |0.9545|**0.9990**|
|Precision|0.7500|**0.8882**|
|  AUROC |N/A|0.9962|
|  Loss  |N/A|0.1678|

The tables above show the performance of the final model compared to the baseline model and the manual benchmark.

Compared to the baseline model, the final model only has better accuracy and precision on the validation set. However, it does seem to have lower overfitting, which is a good sign. Since we know that the baseline model is overfitting, we can expect it to perform worse on the test set than expected, but I didn't check that. Another thing to note is that we know that the final model had some instability during training, so, it is possible to further improve it by adding a bit of regularization, doing better weight initilization for the final layer or doing gradual unfreezing.

Compared to the manual benchmark, the performance is much better. They both follow the same pattern of being much better at recognizing generated images as generated, but precision still suffers in both cases. This once again shows that real images are harder to generalize on and the next step would be to try to improve the model by using more real images or coming up with a better way to augment them.

# Discussions

In this project we tried to detect if images are generated or real. From the experiments it can be concluded that this task is not difficult for both humans and deep learning models, especially when it's more important to detect that generated images are generated. The final model performance is shown below.

|Test Metric|Value |
|-----------|------|
| Accuracy  |0.9296|
|  Recall   |0.9990|
| Precision |0.8882|
|   AUROC   |0.9962|

The table shows that the main type of error the model is making is labeling real images as generated, which also mirrors the results of the manual benchmark. This means that further work should be focused on improving the model's precision. It's also important to note that the final model experienced some instability during training, which might have negatively affected the results.

The main problem with the experiment procedure was that the  the reduced dataset might have been too small compared to batch sizes used. This reduced the possible number of training steps, which might have skewed the results of hyperparameter tuning. This could have been fixed by manually selected a batch size to use instead of using the maximum possible one.

In order to further improve performance several steps are possible. First, more real images can be collected. This shouldn't be too difficult since they are a lot of publicly available images. The only concern is to make sure that they are not generated, which can be done by collecting only images before a certain data (for example, images from 2019 and earlier). Secondly, training stability can be improved by using better weight initialization for classification layer and adding gradual unfreezeing while finetuning the model. Also, the root cause of the instability should be investigated by inspecting gradients and activation histograms throughout the training process. Finally, the errors that the model is making should be analyzed to see if there are any patterns to those mistake and if they can be incorporated into the data augmentation process.

# Contributions and Takeaways

The project's goal was to train a model capable of differentiating between real and generated images. For that a dataset of real and AI images was used. A manual benchmark was performed to establish what kind of performance can be expected from an untrained person. The benchmark showed that the task is not difficult and that the modern deep learning techniques should be capable of solving the problem at a decent level. The modelling itself was done using PyTorch Lightning and Weights & Biases. The data was transformed from parquet files into hdf5 files with preprocessed images and targets in NumPy arrays. Over a course of 4 hyperparameter sweeps several model candidates were established. The final model then was trained and its results evaluated on the test set. The final model was able to achieve almost perfect recall, but still has a lot of ways to improve its precision.