# CLIP-Cues: Concept Bottleneck Model - Inference Example

This notebook demonstrates how to use pre-trained Concept Bottleneck Models to detect synthetic (AI-generated) images with interpretable predictions.

## Overview

We'll cover:
1. Loading a pre-trained concept bottleneck model
2. Running inference on sample images
3. Interpreting the results using human-readable concepts
4. Understanding which concepts contribute to predictions

Unlike the CLIP Orthogonal models, Concept Bottleneck Models provide interpretable predictions by using a vocabulary of human-readable concepts.

## Setup

First, install the required package if you haven't already:

```bash
pip install clip-cues
```

In [1]:
from pathlib import Path

from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np

from clip_cues import (
    ConceptClassifierInference,
    CLIPLargePatch14,
    ConceptSelectionHead,
)
from clip_cues.transforms import Transforms
from clip_cues.concepts import ConceptVocabulary

ImportError: cannot import name 'ConceptClassifierInference' from 'clip_cues' (/workspaces/clip-cues/src/clip_cues/__init__.py)

## Load Feature Extractor and Transforms

We use CLIP ViT-L/14 as the feature extractor, frozen to preserve its pre-trained vision-language knowledge.

In [None]:
# Load CLIP feature extractor
extractor = CLIPLargePatch14(cache_dir="../hf_cache")
extractor.freeze()

# Setup image transforms
transforms = Transforms(extractor.transforms)
inference_transforms = transforms.get_inference_transforms()

print(f"Feature extractor output dimension: {extractor.output_dim}")

## Load Concept Vocabulary

The concept vocabulary defines the human-readable concepts used for classification. We use antonym pairs (e.g., "natural" vs "artificial", "organic" vs "synthetic") to create a rich concept space.

In [None]:
# Load concept vocabulary
vocab = ConceptVocabulary.load_antonyms()

# Create concept embeddings using CLIP
concept_embeddings = vocab.create_embeddings(extractor.model)

print(f"Number of concepts: {len(vocab.concepts)}")
print(f"Concept embedding dimension: {concept_embeddings.shape}")
print(f"\nExample concepts:")
for i, concept in enumerate(vocab.concepts[:10]):
    print(f"  {i+1}. {concept}")

## Load Classification Head

The Concept Selection Head learns which concepts are relevant for detecting synthetic images and uses them to make predictions.

In [None]:
# Initialize the concept selection head
head = ConceptSelectionHead(
    num_concepts=len(vocab.concepts),
    concept_embeddings=concept_embeddings,
)

print(f"Concept selection head initialized with {len(vocab.concepts)} concepts")

## Create Inference Model and Load Weights

Available pre-trained concept models:
- `cm_antonyms_cnnspot.ckpt` - Trained on CNNSpot dataset
- `cm_antonyms_synthbuster.ckpt` - Trained on SynthBuster+ dataset
- `cm_antonyms_synthclic.ckpt` - Trained on SynthCLIC dataset
- `cm_antonyms_combined.ckpt` - Trained on combined datasets

In [None]:
# Create the full inference model
model = ConceptClassifierInference(extractor.model, head)

# Load pre-trained weights
checkpoint_path = "../data/checkpoints/cm_antonyms_combined.ckpt"

# Check if checkpoint exists
if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path, weights_only=True)

    # Remove "model." prefix from state dict keys
    weights = {k.replace("model.", ""): v for k, v in checkpoint["state_dict"].items()}

    # Load weights
    model.load_state_dict(weights, strict=False)
    print("✓ Pre-trained weights loaded successfully")
else:
    print(f"⚠ Checkpoint not found at {checkpoint_path}")
    print("Please download pre-trained checkpoints from the repository")

# Set model to evaluation mode
model.eval()

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model running on: {device}")

## Inference on a Single Image with Concept Attribution

Let's test the model on a sample image and see which concepts contribute to the prediction.

In [None]:
# Load an image (replace with your own image path)
image_path = "../examples/images/synthetic2.jpg"

if Path(image_path).exists():
    # Load and display the image
    image = Image.open(image_path)

    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.axis('off')
    plt.title('Input Image')
    plt.show()

    # Prepare the image for inference
    batch = inference_transforms({"image": [image]})
    pixel_values = torch.stack(batch["pixel_values"]).to(device)

    # Run inference and get concept activations
    with torch.no_grad():
        prob, concept_scores = model(pixel_values, return_concepts=True)

    # Display results
    synthetic_prob = prob.item()
    real_prob = 1 - synthetic_prob

    print(f"\n{'='*50}")
    print(f"Prediction Results")
    print(f"{'='*50}")
    print(f"Synthetic probability: {synthetic_prob:.3f}")
    print(f"Real probability:      {real_prob:.3f}")
    print(f"\nPrediction: {'SYNTHETIC' if synthetic_prob > 0.5 else 'REAL'}")
    print(f"Confidence: {max(synthetic_prob, real_prob):.1%}")
    print(f"{'='*50}")
else:
    print(f"⚠ Image not found at {image_path}")
    print("Please provide a valid image path")

## Visualize Top Contributing Concepts

Let's see which concepts most strongly influenced the prediction.

In [None]:
if Path(image_path).exists():
    # Get concept scores and gates
    concept_scores_np = concept_scores.cpu().numpy()[0]
    gates = model.head.gates.cpu().numpy()

    # Calculate weighted concept scores
    weighted_scores = concept_scores_np * gates

    # Get top contributing concepts
    top_k = 15
    top_indices = np.argsort(np.abs(weighted_scores))[-top_k:][::-1]

    # Plot top concepts
    fig, ax = plt.subplots(figsize=(12, 8))

    concepts_top = [vocab.concepts[i] for i in top_indices]
    scores_top = [weighted_scores[i] for i in top_indices]

    colors = ['red' if s > 0 else 'green' for s in scores_top]

    y_pos = np.arange(len(concepts_top))
    ax.barh(y_pos, scores_top, color=colors, alpha=0.7)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(concepts_top)
    ax.set_xlabel('Weighted Concept Score', fontsize=12)
    ax.set_title(f'Top {top_k} Contributing Concepts\n(Red = Synthetic, Green = Real)', fontsize=14, fontweight='bold')
    ax.axvline(x=0, color='black', linestyle='--', linewidth=1)
    ax.grid(axis='x', alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print detailed concept analysis
    print(f"\nTop {top_k} Contributing Concepts:")
    print(f"{'='*70}")
    print(f"{'Rank':<6} {'Concept':<30} {'Score':<12} {'Gate':<12} {'Weighted':<12}")
    print(f"{'-'*70}")
    for rank, idx in enumerate(top_indices, 1):
        concept = vocab.concepts[idx]
        score = concept_scores_np[idx]
        gate = gates[idx]
        weighted = weighted_scores[idx]
        direction = "→ Synth" if weighted > 0 else "→ Real"
        print(f"{rank:<6} {concept:<30} {score:>8.4f}     {gate:>8.4f}     {weighted:>8.4f}  {direction}")
    print(f"{'='*70}")

## Batch Inference on Multiple Images

Process multiple images at once for efficiency.

In [None]:
# Example: Process multiple images from a directory
image_dir = Path("../examples/images/")

if image_dir.exists():
    # Load all images
    image_paths = list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png"))

    if image_paths:
        images = [Image.open(p) for p in image_paths]

        # Transform all images
        batch = inference_transforms({"image": images})
        pixel_values = torch.stack(batch["pixel_values"]).to(device)

        # Run batch inference
        with torch.no_grad():
            probs, _ = model(pixel_values, return_concepts=True)

        # Display results
        fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5))
        if len(images) == 1:
            axes = [axes]

        for idx, (img_path, img, prob) in enumerate(zip(image_paths, images, probs)):
            axes[idx].imshow(img)
            axes[idx].axis('off')

            synthetic_prob = prob.item()
            prediction = "SYNTHETIC" if synthetic_prob > 0.5 else "REAL"
            confidence = max(synthetic_prob, 1 - synthetic_prob)

            color = 'red' if prediction == "SYNTHETIC" else 'green'
            axes[idx].set_title(
                f"{img_path.name}\n{prediction} ({confidence:.1%})",
                color=color,
                fontweight='bold'
            )

        plt.tight_layout()
        plt.show()

        # Print detailed results
        print("\nDetailed Results:")
        print(f"{'='*70}")
        for img_path, prob in zip(image_paths, probs):
            synthetic_prob = prob.item()
            print(f"{img_path.name:30} | Synthetic: {synthetic_prob:.3f} | {('SYNTHETIC' if synthetic_prob > 0.5 else 'REAL'):9}")
        print(f"{'='*70}")
    else:
        print("No images found in the directory")
else:
    print(f"Directory {image_dir} not found")
    print("Create a 'test_images' folder and add some images to test batch inference")

## Helper Function for Easy Inference with Concepts

In [None]:
def predict_image_with_concepts(image_path_or_pil, threshold=0.5, top_k=10):
    """
    Predict whether an image is synthetic or real with concept attribution.

    Args:
        image_path_or_pil: Path to image file or PIL Image object
        threshold: Classification threshold (default: 0.5)
        top_k: Number of top concepts to return (default: 10)

    Returns:
        dict: Prediction results with probabilities, label, and top concepts
    """
    # Load image if path is provided
    if isinstance(image_path_or_pil, (str, Path)):
        image = Image.open(image_path_or_pil)
    else:
        image = image_path_or_pil

    # Transform and prepare for inference
    batch = inference_transforms({"image": [image]})
    pixel_values = torch.stack(batch["pixel_values"]).to(device)

    # Run inference
    with torch.no_grad():
        prob, concept_scores = model(pixel_values, return_concepts=True)

    synthetic_prob = prob.item()
    real_prob = 1 - synthetic_prob

    # Get top concepts
    concept_scores_np = concept_scores.cpu().numpy()[0]
    gates = model.head.gates.cpu().numpy()
    weighted_scores = concept_scores_np * gates

    top_indices = np.argsort(np.abs(weighted_scores))[-top_k:][::-1]
    top_concepts = [
        {
            "concept": vocab.concepts[i],
            "score": float(concept_scores_np[i]),
            "gate": float(gates[i]),
            "weighted_score": float(weighted_scores[i]),
            "direction": "synthetic" if weighted_scores[i] > 0 else "real"
        }
        for i in top_indices
    ]

    return {
        "synthetic_probability": synthetic_prob,
        "real_probability": real_prob,
        "prediction": "synthetic" if synthetic_prob > threshold else "real",
        "confidence": max(synthetic_prob, real_prob),
        "top_concepts": top_concepts
    }

# Example usage
# result = predict_image_with_concepts("path/to/image.jpg")
# print(f"Prediction: {result['prediction']} (confidence: {result['confidence']:.1%})")
# print(f"\nTop concepts:")
# for c in result['top_concepts'][:5]:
#     print(f"  - {c['concept']}: {c['weighted_score']:.4f} ({c['direction']})")

## Understanding the Results

### Concept Scores

- **Concept Score**: Raw similarity between the image and each concept (from CLIP)
- **Gate**: Learned weight indicating how relevant each concept is for classification
- **Weighted Score**: Concept Score × Gate - the final contribution to the prediction
  - Positive values → push toward "synthetic"
  - Negative values → push toward "real"

### Interpreting Concepts

The model learns which concepts are most informative for distinguishing synthetic from real images:
- **High positive weights**: Concepts strongly associated with synthetic images
- **High negative weights**: Concepts strongly associated with real images
- **Low weights**: Concepts not useful for this classification task

### Model Selection Tips

- **SynthCLIC**: Best for web images and diverse generative models
- **SynthBuster+**: Good for social media images
- **CNNSpot**: Specialized for specific GAN architectures
- **Combined**: Best overall generalization across different sources

### Advantages of Concept Models

1. **Interpretability**: See which visual concepts drive predictions
2. **Debugging**: Understand when and why the model might fail
3. **Trust**: Validate that predictions are based on meaningful visual cues
4. **Insights**: Learn what visual characteristics distinguish synthetic images

### Limitations

- Concept vocabulary is predefined (though customizable)
- Performance may vary on heavily compressed or edited images
- Model trained on specific generative models may not generalize to all new methods
- Always validate results in critical applications

## Comparing Predictions Across Models

You can load multiple checkpoints to compare how different training datasets affect concept selection and predictions.

In [None]:
# Example: Compare predictions from different checkpoints
# checkpoint_names = [
#     "cm_antonyms_cnnspot.ckpt",
#     "cm_antonyms_synthbuster.ckpt",
#     "cm_antonyms_synthclic.ckpt",
#     "cm_antonyms_combined.ckpt"
# ]
#
# for ckpt_name in checkpoint_names:
#     checkpoint_path = f"../data/checkpoints/{ckpt_name}"
#     if Path(checkpoint_path).exists():
#         # Load checkpoint
#         checkpoint = torch.load(checkpoint_path, weights_only=True)
#         weights = {k.replace("model.", ""): v for k, v in checkpoint["state_dict"].items()}
#         model.load_state_dict(weights, strict=False)
#         model.eval()
#
#         # Run inference
#         result = predict_image_with_concepts(image_path, top_k=5)
#
#         print(f"\n{ckpt_name}:")
#         print(f"  Prediction: {result['prediction']} ({result['confidence']:.1%})")
#         print(f"  Top concepts:")
#         for c in result['top_concepts']:
#             print(f"    - {c['concept']}: {c['weighted_score']:.4f}")

## Next Steps

- Compare concept model predictions with CLIP Orthogonal models
- Try different pre-trained checkpoints to see how concept selection varies
- Create custom concept vocabularies for your specific use case
- Fine-tune models on your own datasets
- Analyze concept patterns across large image collections
- Check the [documentation](https://github.com/marco-willi/clip-cues) for more advanced usage