# CapCheck AI Detection - ViT Fine-Tuning

Fine-tune the `dima806/ai_vs_real_image_detection` model on newer AI-generated images.

**Goals:**
- Train on Flux, Midjourney v6, DALL-E 3, SD3 generated images
- Reduce false positives (real flagged as AI)
- Reduce false negatives (AI slipping through)

**Environment:** Google Colab (T4 GPU) or Local (MPS/CUDA)

## 1. Setup & Installation

In [None]:
# Install dependencies (run in Colab)
# !pip install -q transformers datasets evaluate accelerate huggingface_hub pillow scikit-learn matplotlib seaborn

In [None]:
import os
import torch
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset, Dataset, DatasetDict
import evaluate
from sklearn.metrics import confusion_matrix, classification_report

# Check available device
if torch.cuda.is_available():
    device = "cuda"
    print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = "mps"
    print("Using Apple MPS")
else:
    device = "cpu"
    print("Using CPU (this will be slow)")

print(f"PyTorch version: {torch.__version__}")

## 2. Configuration

In [None]:
# === CONFIGURATION ===

# Base model to fine-tune
BASE_MODEL = "dima806/ai_vs_real_image_detection"

# Your HuggingFace model name (where we'll push the fine-tuned model)
HF_MODEL_NAME = "capcheck/ai-image-detection"

# Dataset paths - adjust based on where your data is
# For Colab: mount Google Drive and point to your dataset
# For local: use relative or absolute path
DATA_DIR = "./data"  # Change this!

# Training hyperparameters
BATCH_SIZE = 16  # Reduce to 8 if you get OOM errors
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
WARMUP_RATIO = 0.1

# Output directory
OUTPUT_DIR = "./checkpoints"

# Labels
LABELS = ["real", "ai"]  # 0 = real, 1 = ai
ID2LABEL = {0: "Real", 1: "Fake"}
LABEL2ID = {"Real": 0, "Fake": 1}

## 3. Mount Google Drive (Colab Only)

In [None]:
# Uncomment if using Google Colab and your data is in Drive
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_DIR = "/content/drive/MyDrive/capcheck-training-data"

## 4. Load Base Model

In [None]:
print(f"Loading base model: {BASE_MODEL}")

# Load image processor (handles resizing, normalization)
image_processor = AutoImageProcessor.from_pretrained(BASE_MODEL)

# Load model
model = AutoModelForImageClassification.from_pretrained(
    BASE_MODEL,
    id2label=ID2LABEL,
    label2id=LABEL2ID,
    ignore_mismatched_sizes=True,  # In case labels differ
)

print(f"Model loaded: {model.config.num_labels} classes")
print(f"Parameters: {model.num_parameters():,}")

## 5. Load Dataset

Expected structure:
```
data/
├── train/
│   ├── ai/      # AI-generated images
│   └── real/    # Real photographs
├── val/
│   ├── ai/
│   └── real/
└── test/
    ├── ai/
    └── real/
```

In [None]:
def load_image_folder_dataset(data_dir):
    """Load dataset from ImageFolder structure."""
    data_dir = Path(data_dir)
    
    datasets = {}
    for split in ["train", "val", "test"]:
        split_dir = data_dir / split
        if not split_dir.exists():
            print(f"Warning: {split_dir} not found, skipping")
            continue
            
        images = []
        labels = []
        
        # Load real images (label = 0)
        real_dir = split_dir / "real"
        if real_dir.exists():
            for img_path in real_dir.glob("*"):
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]:
                    images.append(str(img_path))
                    labels.append(0)  # Real = 0
        
        # Load AI images (label = 1)
        ai_dir = split_dir / "ai"
        if ai_dir.exists():
            for img_path in ai_dir.glob("*"):
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]:
                    images.append(str(img_path))
                    labels.append(1)  # AI = 1
            # Also check subdirectories (flux/, midjourney/, etc.)
            for subdir in ai_dir.iterdir():
                if subdir.is_dir():
                    for img_path in subdir.glob("*"):
                        if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]:
                            images.append(str(img_path))
                            labels.append(1)
        
        if images:
            datasets[split] = Dataset.from_dict({
                "image_path": images,
                "label": labels,
            })
            print(f"{split}: {len(images)} images (Real: {labels.count(0)}, AI: {labels.count(1)})")
    
    return DatasetDict(datasets)

In [None]:
# Load dataset
print(f"Loading dataset from: {DATA_DIR}")
dataset = load_image_folder_dataset(DATA_DIR)
print(f"\nLoaded splits: {list(dataset.keys())}")

## 6. Preprocessing

In [None]:
def preprocess_function(examples):
    """Load and preprocess images for the model."""
    images = []
    valid_indices = []
    
    for idx, path in enumerate(examples["image_path"]):
        try:
            img = Image.open(path).convert("RGB")
            images.append(img)
            valid_indices.append(idx)
        except Exception as e:
            print(f"Error loading {path}: {e}")
            continue
    
    # Process images
    inputs = image_processor(images, return_tensors="pt")
    
    # Get labels for valid images only
    labels = [examples["label"][i] for i in valid_indices]
    
    return {
        "pixel_values": inputs["pixel_values"],
        "label": labels,
    }


def transform_for_training(example):
    """Transform a single example - used with set_transform for lazy loading."""
    try:
        img = Image.open(example["image_path"]).convert("RGB")
        inputs = image_processor(img, return_tensors="pt")
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "label": example["label"],
        }
    except Exception as e:
        print(f"Error: {e}")
        return None

In [None]:
# Apply transforms (lazy loading - more memory efficient)
for split in dataset:
    dataset[split].set_transform(transform_for_training)

print("Transforms applied. Images will be loaded on-demand.")

## 7. Define Metrics

In [None]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")


def compute_metrics(eval_pred):
    """Compute metrics for evaluation."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    return {
        "accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
        "precision": precision_metric.compute(predictions=predictions, references=labels, average="binary")["precision"],
        "recall": recall_metric.compute(predictions=predictions, references=labels, average="binary")["recall"],
        "f1": f1_metric.compute(predictions=predictions, references=labels, average="binary")["f1"],
    }

## 8. Training Setup

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # Training params
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_ratio=WARMUP_RATIO,
    
    # Optimization
    fp16=device == "cuda",  # Mixed precision on CUDA
    gradient_accumulation_steps=2,
    
    # Evaluation
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    
    # Logging
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=50,
    report_to="none",  # Set to "wandb" for W&B tracking
    
    # Misc
    remove_unused_columns=False,
    push_to_hub=False,  # We'll push manually after training
)

print("Training arguments configured.")

In [None]:
# Custom data collator to handle our format
def collate_fn(batch):
    """Collate batch of examples."""
    # Filter out None values (failed image loads)
    batch = [b for b in batch if b is not None]
    
    pixel_values = torch.stack([b["pixel_values"] for b in batch])
    labels = torch.tensor([b["label"] for b in batch])
    
    return {
        "pixel_values": pixel_values,
        "labels": labels,
    }

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset.get("train"),
    eval_dataset=dataset.get("val"),
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

print("Trainer initialized.")

## 9. Train!

In [None]:
# Start training
print("Starting training...")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {device}")
print("-" * 50)

train_result = trainer.train()

print("-" * 50)
print("Training complete!")
print(f"Best model saved to: {OUTPUT_DIR}")

## 10. Evaluation

In [None]:
# Evaluate on test set
if "test" in dataset:
    print("Evaluating on test set...")
    test_results = trainer.evaluate(dataset["test"])
    
    print("\n" + "=" * 50)
    print("TEST RESULTS")
    print("=" * 50)
    for key, value in test_results.items():
        if key.startswith("eval_"):
            print(f"{key.replace('eval_', ''):15}: {value:.4f}")
else:
    print("No test set found. Showing validation results.")
    val_results = trainer.evaluate()
    for key, value in val_results.items():
        if key.startswith("eval_"):
            print(f"{key.replace('eval_', ''):15}: {value:.4f}")

In [None]:
# Generate confusion matrix
def plot_confusion_matrix(trainer, dataset_split):
    """Generate and plot confusion matrix."""
    predictions = trainer.predict(dataset_split)
    preds = np.argmax(predictions.predictions, axis=-1)
    labels = predictions.label_ids
    
    cm = confusion_matrix(labels, preds)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt="d", 
        cmap="Blues",
        xticklabels=["Real", "AI"],
        yticklabels=["Real", "AI"],
    )
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.show()
    
    print("\nClassification Report:")
    print(classification_report(labels, preds, target_names=["Real", "AI"]))

# Plot for test or validation set
eval_split = "test" if "test" in dataset else "val"
if eval_split in dataset:
    plot_confusion_matrix(trainer, dataset[eval_split])

## 11. Save Model Locally

In [None]:
# Save the best model
SAVE_PATH = f"{OUTPUT_DIR}/capcheck-ai-detection-final"

trainer.save_model(SAVE_PATH)
image_processor.save_pretrained(SAVE_PATH)

print(f"Model saved to: {SAVE_PATH}")
print(f"Contents: {os.listdir(SAVE_PATH)}")

## 12. Push to HuggingFace Hub

In [None]:
# Login to HuggingFace
from huggingface_hub import login, HfApi

# Option 1: Use environment variable
# os.environ["HF_TOKEN"] = "your-token-here"

# Option 2: Interactive login
login()  # This will prompt for your token

In [None]:
# Push model to HuggingFace Hub
print(f"Pushing model to: {HF_MODEL_NAME}")

# Push using the trainer
trainer.push_to_hub(
    repo_id=HF_MODEL_NAME,
    commit_message="Fine-tuned ViT for AI image detection",
)

# Also push the image processor
image_processor.push_to_hub(HF_MODEL_NAME)

print(f"\nModel published to: https://huggingface.co/{HF_MODEL_NAME}")

In [None]:
# Create and push model card
MODEL_CARD = f"""
---
license: apache-2.0
base_model: {BASE_MODEL}
tags:
- image-classification
- vision
- ai-detection
- deepfake-detection
datasets:
- custom
metrics:
- accuracy
- f1
- precision
- recall
---

# CapCheck AI Image Detection

Fine-tuned Vision Transformer (ViT) for detecting AI-generated images.

## Model Description

This model is fine-tuned from `{BASE_MODEL}` on a custom dataset of modern AI-generated images including:
- Flux
- Midjourney v6
- DALL-E 3
- Stable Diffusion 3

## Training Details

- **Base Model**: {BASE_MODEL}
- **Epochs**: {NUM_EPOCHS}
- **Batch Size**: {BATCH_SIZE}
- **Learning Rate**: {LEARNING_RATE}

## Usage

```python
from transformers import pipeline

detector = pipeline("image-classification", model="{HF_MODEL_NAME}")
result = detector("path/to/image.jpg")
print(result)
```

## Labels

- `Real`: Authentic photograph
- `Fake`: AI-generated image

## Limitations

- Performance may vary on image types not seen during training
- Heavily compressed images may reduce accuracy
- New AI generators not in training data may evade detection

## License

Apache 2.0
"""

# Save model card
with open(f"{SAVE_PATH}/README.md", "w") as f:
    f.write(MODEL_CARD)

# Push model card to hub
api = HfApi()
api.upload_file(
    path_or_fileobj=f"{SAVE_PATH}/README.md",
    path_in_repo="README.md",
    repo_id=HF_MODEL_NAME,
    commit_message="Add model card",
)

print("Model card uploaded!")

## 13. Test the Published Model

In [None]:
# Test the model from HuggingFace Hub
from transformers import pipeline

print(f"Testing model from: {HF_MODEL_NAME}")

detector = pipeline("image-classification", model=HF_MODEL_NAME)

# Test with a sample image
# test_image = "path/to/test/image.jpg"
# result = detector(test_image)
# print(f"Result: {result}")

## Next Steps

After publishing to HuggingFace:

1. Update `ml/services/ai-image-detection/predict.py`:
   ```python
   MODEL_REGISTRY = {
       "v1.0.0": "dima806/ai_vs_real_image_detection",
       "v1.1.0": "capcheck/ai-image-detection",  # New!
   }
   ```

2. Update `ml/services/ai-image-detection/cog.yaml` to pre-download new model

3. Push to Replicate:
   ```bash
   cd ml/services/ai-image-detection
   cog push r8.im/your-username/ai-image-detection
   ```

4. Update backend with new Replicate version hash