Python note book to test CLIP with CIFAR-10 data set.

## Project Plan: CLIP with CIFAR-10 and FAISS

### 1. Setup and Initialization
- Install dependencies (`transformers`, `datasets`, `faiss-cpu` or `faiss-gpu`, `torch`, `torchvision`, `scikit-learn`, `matplotlib`).
- Import necessary libraries.
- Check and configure device (CPU/GPU).
- **Update:** Ensure reproducible results by setting random seeds.

### 2. Data Loading (CIFAR-10)
- Load the CIFAR-10 dataset from Hugging Face `datasets`.
- Split into test/train sets (focusing on test set for evaluation).
- Visualize a few random samples to verify data.
- **Optimization:** Use a `DataLoader` for efficient batch processing during embedding generation.

### 3. Model Loading (CLIP)
- Load pretrained CLIP model and processor from `transformers` (e.g., `openai/clip-vit-base-patch32`).
- **Prompt Engineering:** Instead of raw labels ("dog"), use prompt templates (e.g., "a photo of a dog", "a low resolution photo of a dog") to improve zero-shot accuracy on CIFAR-10's low-res images.

### 4. Embedding Generation (Batched)
- **Image Embeddings:** Pass CIFAR-10 images through the CLIP vision model in batches to get feature vectors.
- **Text Embeddings:** Encode the class labels using the CLIP text model.
- **Normalization:** Normalize all vectors to unit length (L2 norm) so that Inner Product (IP) search in FAISS equals Cosine Similarity.

### 5. Indexing with FAISS
- Initialize a FAISS index (e.g., `IndexFlatIP` for inner product/cosine similarity).
- **Optimization:** If dataset grows, consider `IndexIVFFlat` for faster approximate search.
- Add image embeddings to the FAISS index.

### 6. Retrieval Pipeline
- **Zero-Shot Classification:** Query the image-index with text vectors to find the best matching image class.
- **Visual Search (Image-to-Image):** Use a query image to find the nearest neighbors in the dataset (Simulating "product search").

### 7. Evaluation & Metrics
- **Quantitative:**
    - Zero-Shot Accuracy: Percentage of correctly predicted class labels.
    - Top-k Accuracy: (e.g., Top-1, Top-5).
- **Qualitative:**
    - Visualize Top-K retrieval results for specific queries (Show query image next to retrieved results).
    - Generate a **Confusion Matrix** to visualize misclassifications.


In [None]:
# 1. Setup and Initialization
# Install required libraries if not already installed
# %pip install transformers datasets faiss-cpu torch torchvision scikit-learn matplotlib

import torch
import numpy as np
import faiss
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Check device availability
device = "cuda" if torch.cuda.is_available() else "cpu"
# specific for mac m1/m2 chips
if torch.backends.mps.is_available():
    device = "mps"
print(f"Using device: {device}")

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [None]:
# 2. Data Loading (CIFAR-10)
# Load CIFAR-10 dataset
dataset = load_dataset("cifar10")

# We'll use the test split for evaluation to save time, 
# but you can use 'train' for building a larger index if needed.
test_dataset = dataset["test"]

# Define the class names for CIFAR-10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Visualize a few samples
fig, axs = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
    idx = np.random.randint(0, len(test_dataset))
    image = test_dataset[idx]['img']
    label = test_dataset[idx]['label']
    axs[i].imshow(image)
    axs[i].set_title(class_names[label])
    axs[i].axis('off')
plt.show()

In [None]:
# 3. Model Loading (CLIP)
model_id = "openai/clip-vit-base-patch32"

# Load model and processor
model = CLIPModel.from_pretrained(model_id).to(device)
processor = CLIPProcessor.from_pretrained(model_id)

print(f"Model {model_id} loaded successfully on {device}")