# GANFingerprint Deepfake Detection

This notebook provides a complete walkthrough of training, evaluating, and using the GANFingerprint deepfake detection model. The model is designed to detect GAN-generated fake images by analyzing subtle fingerprint patterns in both spatial and frequency domains.

## What are GAN Fingerprints?
GAN Fingerprints are distinctive patterns or traces that are unintentionally embedded in images generated by Generative Adversarial Networks (GANs). These GAN fingerprints are akin to real human fingerprints, with the comparison that humans unintentionally leave fingerprints on the items they touch, that can be used to trace their identities. Just like human fingerprints, these GAN Fingerprints are unique to the GAN architecture the images are generated from, due to these factors:

1. Each GAN architecture has its own unique way of generating images based on its specific design, loss functions, and optimization methods.

2. Even GANs with identical architectures but different training datasets, random initializations, or hyperparameters will produce images with subtly different characteristics.

## Objective of the project

With GAN image generation images getting more advanced, there may be difficulties identifying deepfake images through existing methods, such as detecting distortions in facial features and image details. Through our project, we hope to create a deepfake detection model that can identify deepfake images reliably, no matter how realistic the generated images are to the human eye. By customizing and creating a model that can discriminate deepfake images from real ones through their GAN Fingerprint profiles, we hope to come up with a more sophisticated model which can capture details invisible to the human eye.

## 1. Setup and Dependencies

#### First, let's install all necessary dependencies so that the model runs properly.

In [None]:
# Create Virtual Environment for model
!python3 -m venv myenv

#### Next: Activate the Virtual Environment (manually in terminal)

**Navigate to the deepfake-detection-GANFingerprint folder, making it the root folder in your IDE.**

Open a terminal in the notebook's folder, then run:

- **On macOS/Linux/WSL**:
  ```bash
  source myenv/bin/activate
  ```

- **On Windows**:
  ```cmd
  myenv\Scripts\activate
  ```

Once activated, your terminal prompt will show `(myenv)`, meaning you're using the virtual environment.


#### Next, install all necessary dependencies using requirements.txt

In [None]:
# Install core dependencies
%pip install -r requirements.txt

# For CUDA 11.8 support (recommended for newer NVIDIA GPUs)
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# For CPU-only installation (if no GPU is available)
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Then, let's verify that we're running from the correct directory and all required files are present.

In [None]:
import os
import sys

# Add the current directory to the Python path
sys.path.append('.')

# Check if all required files exist
files_needed = [
    "config.py",
    "data_loader.py",
    "models/__init__.py",
    "models/fingerprint_net.py", 
    "models/layers.py",
    "train.py",
    "evaluate.py",
    "inference.py",
    "utils/metrics.py",
    "utils/experiment.py",
    "utils/visualization.py",
    "utils/reproducibility.py",
    "utils/augmentations.py",
    "utils/gradcam.py"
]

missing = [f for f in files_needed if not os.path.exists(f)]
if missing:
    print("❌ Missing required files:")
    for f in missing:
        print(f"  - {f}")
    print("\nPlease run this notebook from the project root directory")
else:
    print("✅ All required files found!")

### Model Directory and dataset required

If configured properly, the model should have the following directory layout:

```
deepfake_detector/
├── config.py                 # Configuration parameters
├── data_loader.py            # Dataset and dataloader implementation
├── models/
│   ├── __init__.py           # Module initialization
│   ├── fingerprint_net.py    # GANFingerprint model architecture
│   ├── layers.py             # Custom layers and blocks
├── train.py                  # Training script
├── evaluate.py               # Evaluation script
├── inference.py              # Inference on new images
├── utils/
│   ├── __init__.py           # Utilities module initialization
│   ├── reproducibility.py    # Random seed and reproducibility utilities
│   ├── visualization.py      # Plotting and visualization tools
│   ├── metrics.py            # Performance metrics calculation
│   ├── augmentations.py      # Advanced augmentation techniques
|   ├── experiment.py         # Logging of information when training model
|   ├── gradcam.py            # Grad-CAM visualization of inference results
├── checkpoints/              # Directory for saved model checkpoints
├── logs/                     # TensorBoard logs and training records
```

## 2. Check Pytorch and CUDA status

In [None]:
# Check PyTorch version and CUDA availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Import the necessary packages

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms

# Import custom modules
import config
from data_loader import get_dataset_stats
from models import FingerprintNet



## 4. Display Current Configuration

The model allows for configuration of hyperparameters. The hyperparameters that led to the best results are listed below.

In [None]:
# Display current configuration
print("Current Configuration:")
print(f"DATA_ROOT: {config.DATA_ROOT}")
print(f"INPUT_SIZE: {config.INPUT_SIZE}")
print(f"BACKBONE: {config.BACKBONE}")
print(f"BATCH_SIZE: {config.BATCH_SIZE}")
print(f"EARLY_STOPPING_PATIENCE: {config.EARLY_STOPPING_PATIENCE}")
print(f"LEARNING_RATE: {config.LEARNING_RATE}")
print(f"WEIGHT_DECAY: {config.WEIGHT_DECAY}")
print(f"NUM_EPOCHS: {config.NUM_EPOCHS}")
print(f"NUM_WORKERS: {config.NUM_WORKERS}")
print(f"DROPOUT_RATE: {config.DROPOUT_RATE}")
print(f"DEVICE: {config.DEVICE}")
print(f"USE_AMP: {config.USE_AMP}")
print(f"CHECKPOINT_DIR: {config.CHECKPOINT_DIR}")
print(f"LOG_DIR: {config.LOG_DIR}")

## 5. Check Dataset Structure

#### The dataset we will be using is the  'deepfake and real images' dataset by Manjil Kariki.

Link: https://www.kaggle.com/datasets/manjilkarki/deepfake-and-real-images

Download the dataset and place it in the root directory in a folder named 'data'. The cell below will help you check if your directory is configured correctly.

In [None]:
# Check dataset structure
def check_dataset_structure():
    paths = [
        config.TRAIN_REAL_DIR,
        config.TRAIN_FAKE_DIR,
        config.VAL_REAL_DIR,
        config.VAL_FAKE_DIR,
        config.TEST_REAL_DIR,
        config.TEST_FAKE_DIR
    ]
    
    for path in paths:
        if not os.path.exists(path):
            print(f"❌ {path} does not exist!")
        else:
            print(f"✅ {path} exists with {len(os.listdir(path))} images")

check_dataset_structure()
get_dataset_stats()

## 6. Display Sample Images

In [None]:
# Display some sample images from the dataset
def show_samples(real_dir, fake_dir, n=5):
    transform = transforms.Compose([
        transforms.Resize(config.INPUT_SIZE),
        transforms.CenterCrop(config.INPUT_SIZE),
        transforms.ToTensor()
    ])
    
    # Check if directories exist
    if not os.path.exists(real_dir) or not os.path.exists(fake_dir):
        print(f"Error: One or more directories do not exist:\n{real_dir}\n{fake_dir}")
        return
    
    # Get image lists
    real_files = os.listdir(real_dir)
    fake_files = os.listdir(fake_dir)
    
    if not real_files or not fake_files:
        print("Error: One or more directories are empty")
        return
    
    real_images = [os.path.join(real_dir, f) for f in real_files[:n]]
    fake_images = [os.path.join(fake_dir, f) for f in fake_files[:n]]
    
    plt.figure(figsize=(15, 6))
    for i, img_path in enumerate(real_images + fake_images):
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img)
        
        plt.subplot(2, n, i + 1)
        plt.imshow(img_tensor.permute(1, 2, 0))
        plt.title("Real" if i < n else "Fake")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Show samples from training set
try:
    show_samples(config.TRAIN_REAL_DIR, config.TRAIN_FAKE_DIR)
except Exception as e:
    print(f"Error displaying samples: {e}")

## 7. Initialize the model with overview of trainable paramaters

#### Running the cell below will show the architecture of the model and the number of trainable parameters

In [None]:
# Initialize the model
model = FingerprintNet(backbone=config.BACKBONE)
model = model.to(config.DEVICE)

# Count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print model architecture summary
print(model)
print(f"Total trainable parameters: {count_parameters(model):,}")

## 8. Training the model

#### Run this cell to train the model.

This cell allows you to resume training from a saved checkpoint in the checkpoints directory. To do so:

Simply comment ```best_checkpoint = train_model()```

and uncomment ```best_checkpoint = train_model(resume_checkpoint="checkpoints/ganfingerprint_.pth")```, inserting the correct relative directory of your last saved checkpoint file.

In [None]:
from train import train as train_function

# Wrapper for the training function
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

def train_model(data_root=config.DATA_ROOT, 
                batch_size=config.BATCH_SIZE,
                num_workers=config.NUM_WORKERS, 
                lr=config.LEARNING_RATE, 
                weight_decay=config.WEIGHT_DECAY,
                dropout_rate=config.DROPOUT_RATE,
                epochs=config.NUM_EPOCHS, 
                early_stopping_patience=config.EARLY_STOPPING_PATIENCE,
                backbone=config.BACKBONE,
                no_amp=not config.USE_AMP, 
                resume_checkpoint=None):
    
    # Override config values if needed
    config.DATA_ROOT = data_root
    config.BATCH_SIZE = batch_size
    config.NUM_WORKERS= num_workers
    config.LEARNING_RATE = lr
    config.WEIGHT_DECAY = weight_decay
    config.DROPOUT_RATE = dropout_rate
    config.NUM_EPOCHS = epochs
    config.EARLY_STOPPING_PATIENCE = early_stopping_patience
    config.BACKBONE = backbone
    config.USE_AMP = not no_amp
    
    # Create args object
    args = Args(
        data_root=data_root,
        batch_size=batch_size,
        lr=lr,
        epochs=epochs,
        backbone=backbone,
        no_amp=no_amp,
        resume_checkpoint=resume_checkpoint
    )
    
    # Call the training function
    train_function(args)
    
    # Return the path to the best checkpoint
    return os.path.join(config.CHECKPOINT_DIR, f"ganfingerprint_best.pth")

# Train the model 
best_checkpoint = train_model()

# To resume training from a checkpoint (uncomment to run):
# best_checkpoint = train_model(resume_checkpoint="checkpoints/ganfingerprint_.pth")

## 9. Evaluating the trained model

#### This cell lets you evaluate the model you just trained, showing you various statistics and visualizations.

To do so, replace the input of this function call ```evaluate_model("checkpoints\ganfingerprint_best.pth")``` with the relative path of your checkpoint file.

In [None]:
from evaluate import evaluate as evaluate_function

# Wrapper for the evaluation function
def evaluate_model(checkpoint_path, output_dir="eval_results"):
    """
    Evaluate the model on the test set.
    
    Args:
        checkpoint_path: Path to the model checkpoint
        output_dir: Directory to save evaluation results
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Call the evaluation function
    evaluate_function(checkpoint_path, output_dir)
        
    # Display the generated images
    image_paths = [
        os.path.join(output_dir, "confusion_matrix.png"),
        os.path.join(output_dir, "roc_curve.png"),
        os.path.join(output_dir, "precision_recall_curve.png")
    ]
    

# Evaluate the model (Insert relative directory to trained model in checkpoints folder to run the evaluation)
evaluate_model("checkpoints\ganfingerprint_best.pth")

## 10. Testing the model with actual images

After training and evaluating the model, we can try out the capabilities of the model by having it classify any image that is outside of the dataset. Lets see if it is able to classify images correctly!

By passing images that are named with their ground truth labels (e.g ```true_image```), the inference functions are able to provide you with various statistics indicating the model performance, such as precision and recall for batch inference. 

After running a batch inference, the results will be saved under ```inference_results```, a folder created to store the prediction outputs.

The inference function features Heatmaps created by ```Grad-CAM```, which will show us parts of the images the model looks out for to determine its predictions.

### Single image inference

#### To use this cell:

Replace the placeholders for ```model_checkpoint``` and ```test_image_path``` with the relevant relative directories and then run the cell.

In [None]:
import torch.serialization
import os
from inference import run_inference

# Add numpy scalar to safe globals for PyTorch 2.6+ compatibility
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])

# Function to run single image inference
def run_single_inference(checkpoint_path, image_path, output_dir=None):
    """
    Run inference on a single image using functions from inference.py
    """
    # Fix path separators if needed
    checkpoint_path = checkpoint_path.replace('\\', '/')
    image_path = image_path.replace('\\', '/')
    if output_dir:
        output_dir = output_dir.replace('\\', '/')
        os.makedirs(output_dir, exist_ok=True)
    
    # Run inference
    print(f"Running inference on: {image_path}")
    print(f"Using checkpoint: {checkpoint_path}")
    run_inference(checkpoint_path, image_path, output_dir, batch_mode=False)

model_checkpoint = "checkpoints\ganfingerprint_best.pth"
test_image_path =  "path_to_image.jpg"

run_inference(model_checkpoint, test_image_path, use_gradcam=True)

### Batch Inference

#### To use this cell:

Replace the placeholders for ```model_checkpoint``` and ```test_dir``` with the relevant relative directories and then run the cell.

Note that instead of a single image, ```test_dir``` should lead to a folder of images you want to test.

In [None]:
from inference import run_inference
import os

def run_batch_inference(checkpoint_path, image_dir, output_dir="inference_results", use_gradcam=False):
    """
    Run inference on a directory of images using functions from inference.py
    
    Args:
        checkpoint_path: Path to the model checkpoint
        image_dir: Directory containing images to process
        output_dir: Directory to save results
        use_gradcam: Whether to generate Grad-CAM visualizations
    """
    # Fix path separators if needed
    checkpoint_path = checkpoint_path.replace('\\', '/')
    image_dir = image_dir.replace('\\', '/')
    output_dir = output_dir.replace('\\', '/')
    os.makedirs(output_dir, exist_ok=True)
    
    # Run inference
    print(f"Running batch inference on directory: {image_dir}")
    print(f"Using checkpoint: {checkpoint_path}")
    print(f"Grad-CAM visualization: {'Enabled' if use_gradcam else 'Disabled'}")
    run_inference(checkpoint_path, image_dir, output_dir, batch_mode=True, use_gradcam=use_gradcam)

model_checkpoint = "checkpoints\ganfingerprint_best.pth"
test_dir = "relative/path/to/folder"

# Run inference on a directory of images WITH Grad-CAM visualization
run_batch_inference(model_checkpoint, test_dir, "inference_results_with_gradcam", use_gradcam=True)

# Interactive Demo

### Does this mean that the GANFingerprint model is the best?

No. In our case, this GANFingerprint detection model is trained on a dataset that has images generated from pre-GPT GAN models.

As a result, this model will fail to classify GAN Generated images from modern Generation Models, like **ChatGPT**.

Use this interactive widget with Grad-CAM to see where the model is looking at to make its predictions. Are they looking too much into the background, rather than facial features?

The cell below will generate a basic interface that allows you to upload any generated image of dimension 256x256 pixels. Feel free to pit modern GAN Generated images
(or any image with a face) against the model you trained!

In [None]:
# Enhanced interface for GANFingerprint detector with Grad-CAM visualization
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import io
import os
from PIL import Image
import matplotlib.pyplot as plt
import base64
import torch
import numpy as np

# Import from existing inference.py
from inference import predict_image_calibrated, run_inference
import config
from models import FingerprintNet

# For Grad-CAM
from utils.gradcam import get_gradcam_layer, GradCAM, generate_gradcam

# Set the checkpoint path - user only needs to modify this line
MODEL_CHECKPOINT = "checkpoints\ganfingerprint_best.pth"
MODEL_CHECKPOINT = MODEL_CHECKPOINT.replace('\\', '/')

# Load the model once using existing code
print("Loading model from", MODEL_CHECKPOINT)
model = FingerprintNet(backbone=config.BACKBONE)
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=config.DEVICE, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(config.DEVICE)
model.eval()
print("Model loaded successfully!")

# Define the image transformation (same as in inference.py)
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize(config.INPUT_SIZE),
    transforms.CenterCrop(config.INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to extract image data from uploaded file (handles different widget versions)
def extract_image_data(uploaded_file):
    """Extract image data from the uploaded file object"""
    if hasattr(uploaded_file, 'content'):
        return uploaded_file.content
    elif hasattr(uploaded_file, 'value') and isinstance(uploaded_file.value, bytes):
        return uploaded_file.value
    elif callable(getattr(uploaded_file, 'getvalue', None)):
        return uploaded_file.getvalue()
    elif hasattr(uploaded_file, 'data'):
        return uploaded_file.data
    elif isinstance(uploaded_file, bytes):
        return uploaded_file
    else:
        # Try to print debug info
        print(f"File object type: {type(uploaded_file)}")
        print(f"Available attributes: {dir(uploaded_file)}")
        raise ValueError("Could not extract image data from the uploaded file")

# Function to generate Grad-CAM visualization
def generate_gradcam_visualization(image_path):
    """Generate Grad-CAM visualization for the given image path"""
    # Load and preprocess the image
    orig_image = Image.open(image_path).convert('RGB')
    image_tensor = transform(orig_image).to(config.DEVICE)
    
    # Get target layer for Grad-CAM
    target_layer = get_gradcam_layer(model)
    
    # Generate Grad-CAM
    raw_logit, heatmap, superimposed = generate_gradcam(model, image_tensor, orig_image)
    
    # Calculate the prediction
    orig_prob = torch.sigmoid(torch.tensor(raw_logit)).item()
    
    # Apply calibration for fake images
    if orig_prob < 0.5:  # Predicted as fake
        calibrated_prob = 1.0 - (2.0 * orig_prob)
    else:  # Predicted as real
        calibrated_prob = orig_prob
    
    pred_class = "Real" if orig_prob >= 0.5 else "Fake"
    
    return orig_image, superimposed, pred_class, calibrated_prob

# Function to process uploaded image and show result
def process_image(uploaded_file, use_gradcam=False):
    try:
        # Extract image data
        img_data = extract_image_data(uploaded_file)
        
        # Save to temp file (needed for predict_image_calibrated)
        temp_path = "temp_uploaded_image.jpg"
        with open(temp_path, "wb") as f:
            f.write(img_data)
        
        if use_gradcam:
            # Generate Grad-CAM visualization
            orig_image, superimposed, pred_class, prob = generate_gradcam_visualization(temp_path)
            
            # Create visualization
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
            fig.subplots_adjust(top=0.85)  # Make more room for title
            
            # Original image
            ax1.imshow(orig_image)
            ax1.set_title("Original Image", fontsize=14, pad=10)
            ax1.axis('off')
            
            # Grad-CAM visualization
            ax2.imshow(superimposed)
            
            # Set title color based on prediction
            title_color = 'green' if pred_class == 'Real' else 'red'
            ax2.set_title(f"Grad-CAM: {pred_class} ({prob:.4f})", 
                        fontsize=14, color=title_color, pad=10)
            ax2.axis('off')
            
            # Add information about prediction above the figures
            plt.suptitle(f"Prediction: {pred_class} ({prob:.4f})", 
                       fontsize=16, y=0.98, color=title_color)
            
            # Add extra space between plots
            plt.tight_layout(pad=3.0)
        else:
            # Use existing predict_image_calibrated from inference.py
            prob, pred_class = predict_image_calibrated(model, temp_path, transform)
            
            # Create visualization
            image = Image.open(temp_path).convert('RGB')
            plt.figure(figsize=(8, 8))
            plt.imshow(image)
            plt.axis('off')
            
            color = 'green' if pred_class == 'Real' else 'red'
            plt.title(f"Prediction: {pred_class} (Confidence: {prob:.4f})", 
                    color=color, fontsize=16)
        
        # Save to buffer for display
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight')
        plt.close()
        buf.seek(0)
        
        # Remove temp file
        os.remove(temp_path)
        
        # Convert to base64 for display
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        
        # Create HTML result
        if use_gradcam:
            result_html = f"""
            <div style="text-align: center;">
                <h2>GANFingerprint Analysis Result (with Grad-CAM)</h2>
                <img src="data:image/png;base64,{img_str}" style="max-width: 800px">
                <h3 style="color: {title_color}">Prediction: {pred_class}</h3>
                <p>Confidence: {prob:.4f}</p>
                <p><small>Grad-CAM highlights the regions that influenced the model's decision (red/yellow areas = more influence)</small></p>
            </div>
            """
        else:
            result_html = f"""
            <div style="text-align: center;">
                <h2>GANFingerprint Analysis Result</h2>
                <img src="data:image/png;base64,{img_str}" style="max-width: 500px">
                <h3 style="color: {'green' if pred_class == 'Real' else 'red'}">Prediction: {pred_class}</h3>
                <p>Confidence: {prob:.4f}</p>
            </div>
            """
        return result_html
    
    except Exception as e:
        import traceback
        # Clean up temp file if it exists
        if 'temp_path' in locals() and os.path.exists(temp_path):
            os.remove(temp_path)
        
        # Return error message
        return f"""
        <div style="text-align: center; color: red; padding: 20px; background-color: #fff3f3; border-radius: 10px;">
            <h2>Error Processing Image</h2>
            <p>{str(e)}</p>
            <pre style="text-align: left; background-color: #f8f8f8; padding: 10px; max-height: 300px; overflow: auto;">
{traceback.format_exc()}
            </pre>
        </div>
        """

# Create the interface
header_html = """
<div style="text-align: center; margin-bottom: 20px; background-color: #f0f0f0; padding: 20px; border-radius: 10px;">
    <h1 style="color: #333333;">GANFingerprint Detector</h1>
    <p style="color: #333333;">Upload an image to determine if it's a real photo or AI-generated</p>
    <p style="font-size: 0.8em; color: #666;">For best results, use images that are 256x256 pixels</p>
</div>
"""
display(HTML(header_html))

# Create upload widget
upload = widgets.FileUpload(
    accept='image/*',
    multiple=False,
    description='Select Image',
    layout=widgets.Layout(width='300px')
)

# Create analyze button
button = widgets.Button(
    description='Analyze Image',
    button_style='primary',
    disabled=True,
    layout=widgets.Layout(width='300px')
)

# Create Grad-CAM toggle
gradcam_toggle = widgets.Checkbox(
    value=False,
    description='Show Grad-CAM Visualization',
    indent=False,
    layout=widgets.Layout(width='300px')
)

# Create tooltip for Grad-CAM toggle
gradcam_info = widgets.HTML(
    value="""
    <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
        Grad-CAM shows which parts of the image influenced the model's decision
    </div>
    """
)

# Create output area
output = widgets.Output()

# Enable button when file is uploaded
def on_upload_change(change):
    button.disabled = len(upload.value) == 0
    
upload.observe(on_upload_change, names='value')

# Handle button click
def on_button_click(b):
    with output:
        clear_output()
        if len(upload.value) == 0:
            print("Please upload an image first")
            return
        
        display(HTML("<p>Analyzing image...</p>"))
        
        # Get the uploaded file
        if isinstance(upload.value, dict):
            # Older ipywidgets format
            uploaded_file = list(upload.value.values())[0]
        else:
            # Newer ipywidgets format
            uploaded_file = upload.value[0]
        
        # Process the image and display result
        result_html = process_image(uploaded_file, use_gradcam=gradcam_toggle.value)
        clear_output()
        display(HTML(result_html))

button.on_click(on_button_click)

# Layout the interface
gradcam_controls = widgets.VBox([gradcam_toggle, gradcam_info], 
                              layout=widgets.Layout(align_items='center'))

controls = widgets.HBox([upload, button, gradcam_controls], 
                      layout=widgets.Layout(justify_content='center', 
                                          display='flex',
                                          margin='20px 0'))
display(controls)
display(output)

print(f"Interface loaded. Using model from {MODEL_CHECKPOINT}")
print("Ready to analyze images!")