# Projet B - Recherche d'Images par Texte avec CLIP

**Module** : R√©seaux de Neurones Approfondissement  
**Dur√©e** : 2h  
**Objectif** : D√©couvrir le multimodal avec CLIP (OpenAI)

---

## Objectifs du projet

Dans ce projet, vous allez :
1. Comprendre comment CLIP relie images et texte
2. Faire de la classification zero-shot (sans entra√Ænement)
3. Construire un moteur de recherche d'images par texte
4. Explorer les capacit√©s et limites du mod√®le

## 0. Installation

In [None]:
!pip install torch torchvision ftfy regex matplotlib numpy Pillow requests tqdm -q
!pip install git+https://github.com/openai/CLIP.git -q

In [None]:
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import requests
from io import BytesIO
from tqdm.auto import tqdm
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Charger CLIP
model, preprocess = clip.load("ViT-B/32", device=device)
print(f"CLIP charg√© !")

---

## 1. Comment fonctionne CLIP ?

CLIP (Contrastive Language-Image Pre-training) apprend √† aligner images et textes dans un m√™me espace vectoriel.

```
Image ‚Üí Image Encoder ‚Üí Embedding image (512 dim)
                                    ‚Üì
                              Similarit√© cosinus
                                    ‚Üë
Texte ‚Üí Text Encoder ‚Üí Embedding texte (512 dim)
```

### Principe
- Images et textes similaires ‚Üí Embeddings proches
- Images et textes diff√©rents ‚Üí Embeddings √©loign√©s

### Entra√Ænement
CLIP a √©t√© entra√Æn√© sur 400 millions de paires (image, texte) du web.

---

## 2. Premier exemple : Classification Zero-Shot

In [None]:
# Charger une image depuis une URL
def load_image_from_url(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content))

# Image exemple : un chat
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
image = load_image_from_url(url)

plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.axis('off')
plt.title("Image √† classifier")
plt.show()

In [None]:
# Classification zero-shot
# On compare l'image √† plusieurs descriptions textuelles

# Pr√©parer l'image
image_input = preprocess(image).unsqueeze(0).to(device)

# Labels candidats
labels = ["a photo of a cat", "a photo of a dog", "a photo of a bird", "a photo of a car"]
text_inputs = clip.tokenize(labels).to(device)

# Encoder
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Normaliser
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# Similarit√©
similarity = (image_features @ text_features.T).softmax(dim=-1)

print("Classification Zero-Shot:")
for label, prob in zip(labels, similarity[0]):
    print(f"  {label}: {prob.item():.1%}")

**Observation** : Sans aucun entra√Ænement sp√©cifique, CLIP reconna√Æt correctement le chat !

---

## 3. Exercice : Classification personnalis√©e

In [None]:
# ============================================
# EXERCICE 1 : Cr√©er votre propre classifieur
# ============================================

def zero_shot_classify(image, labels, model, preprocess, device):
    """
    Classification zero-shot avec CLIP.
    
    Args:
        image: Image PIL
        labels: Liste de descriptions textuelles
        model: Mod√®le CLIP
        preprocess: Fonction de pr√©traitement
        device: Device (cuda/cpu)
    
    Returns:
        dict avec labels et probabilit√©s
    """
    # TODO: Impl√©menter
    
    # 1. Pr√©traiter l'image
    image_input = preprocess(image).unsqueeze(0).to(device)
    
    # 2. Tokenizer les labels
    text_inputs = clip.tokenize(labels).to(device)
    
    # 3. Encoder image et texte
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)
    
    # 4. Normaliser
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # 5. Calculer similarit√© + softmax
    similarity = (image_features @ text_features.T).softmax(dim=-1)
    
    # 6. Retourner r√©sultats
    results = {}
    for label, prob in zip(labels, similarity[0]):
        results[label] = prob.item()
    
    return results

In [None]:
# Test avec diff√©rentes images
urls = [
    "https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg",  # Chien
    "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Camponotus_flavomarginatus_ant.jpg/1200px-Camponotus_flavomarginatus_ant.jpg",  # Fourmi
]

labels_test = [
    "a photo of a dog",
    "a photo of a cat", 
    "a photo of an insect",
    "a photo of a person"
]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for idx, url in enumerate(urls):
    img = load_image_from_url(url)
    results = zero_shot_classify(img, labels_test, model, preprocess, device)
    
    axes[idx].imshow(img)
    axes[idx].axis('off')
    best_label = max(results, key=results.get)
    axes[idx].set_title(f"{best_label}\n({results[best_label]:.1%})")

plt.tight_layout()
plt.show()

---

## 4. Moteur de Recherche d'Images par Texte

In [None]:
# Cr√©er une base d'images
image_urls = {
    "chat_roux": "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
    "chien_labrador": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg",
    "tour_eiffel": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/Tour_Eiffel_Wikimedia_Commons_%28cropped%29.jpg/800px-Tour_Eiffel_Wikimedia_Commons_%28cropped%29.jpg",
    "plage_tropicale": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/75/Maldives_beach.JPG/1200px-Maldives_beach.JPG",
    "montagne_neige": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e7/Everest_North_Face_toward_Base_Camp_Tibet_Luca_Galuzzi_2006.jpg/1200px-Everest_North_Face_toward_Base_Camp_Tibet_Luca_Galuzzi_2006.jpg",
    "voiture_sport": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1c/Ferrari_F40_in_Monterey.jpg/1200px-Ferrari_F40_in_Monterey.jpg",
}

# Charger et afficher
images = {}
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for idx, (name, url) in enumerate(image_urls.items()):
    try:
        img = load_image_from_url(url)
        images[name] = img
        axes[idx].imshow(img)
        axes[idx].set_title(name.replace('_', ' '))
        axes[idx].axis('off')
    except:
        print(f"Erreur chargement: {name}")

plt.tight_layout()
plt.show()
print(f"\n{len(images)} images charg√©es")

In [None]:
# ============================================
# EXERCICE 2 : Construire l'index d'embeddings
# ============================================

class ImageSearchEngine:
    """
    Moteur de recherche d'images par texte avec CLIP.
    """
    
    def __init__(self, model, preprocess, device):
        self.model = model
        self.preprocess = preprocess
        self.device = device
        self.image_embeddings = None
        self.image_names = []
        self.images = {}
    
    def index_images(self, images_dict):
        """
        Indexe une collection d'images.
        
        Args:
            images_dict: dict {nom: image_PIL}
        """
        self.images = images_dict
        self.image_names = list(images_dict.keys())
        
        embeddings = []
        
        print("Indexation des images...")
        for name, img in tqdm(images_dict.items()):
            # TODO: Encoder chaque image
            img_input = self.preprocess(img).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                img_embedding = self.model.encode_image(img_input)
            
            # Normaliser
            img_embedding = img_embedding / img_embedding.norm(dim=-1, keepdim=True)
            embeddings.append(img_embedding)
        
        # Empiler tous les embeddings
        self.image_embeddings = torch.cat(embeddings, dim=0)
        print(f"Index cr√©√©: {self.image_embeddings.shape}")
    
    def search(self, query, top_k=3):
        """
        Recherche les images les plus similaires √† une requ√™te texte.
        
        Args:
            query: Texte de recherche
            top_k: Nombre de r√©sultats
        
        Returns:
            Liste de (nom, score, image)
        """
        # TODO: Impl√©menter la recherche
        
        # 1. Encoder la requ√™te
        text_input = clip.tokenize([query]).to(self.device)
        
        with torch.no_grad():
            text_embedding = self.model.encode_text(text_input)
        
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        
        # 2. Calculer similarit√© avec toutes les images
        similarities = (text_embedding @ self.image_embeddings.T).squeeze(0)
        
        # 3. Trier et retourner top_k
        top_indices = similarities.argsort(descending=True)[:top_k]
        
        results = []
        for idx in top_indices:
            name = self.image_names[idx]
            score = similarities[idx].item()
            results.append((name, score, self.images[name]))
        
        return results

In [None]:
# Cr√©er le moteur de recherche
search_engine = ImageSearchEngine(model, preprocess, device)
search_engine.index_images(images)

In [None]:
# Test de recherche
queries = [
    "a cute animal",
    "a famous monument in Paris",
    "a tropical vacation destination",
    "a fast red sports car",
]

for query in queries:
    print(f"\nüîç Recherche: '{query}'")
    results = search_engine.search(query, top_k=2)
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    for idx, (name, score, img) in enumerate(results):
        axes[idx].imshow(img)
        axes[idx].set_title(f"{name}\n(score: {score:.3f})")
        axes[idx].axis('off')
    plt.suptitle(f"Requ√™te: {query}")
    plt.tight_layout()
    plt.show()

---

## 5. Recherche en Fran√ßais

In [None]:
# CLIP comprend un peu le fran√ßais (entra√Æn√© sur web multilingue)
queries_fr = [
    "un chat mignon",
    "la tour eiffel √† paris",
    "une plage paradisiaque",
    "une voiture de course",
]

for query in queries_fr:
    results = search_engine.search(query, top_k=1)
    name, score, _ = results[0]
    print(f"'{query}' ‚Üí {name} (score: {score:.3f})")

In [None]:
# Pour de meilleurs r√©sultats en fran√ßais, on peut utiliser un traducteur
# comme dans le projet A

from transformers import pipeline

translator = pipeline("translation_fr_to_en", model="Helsinki-NLP/opus-mt-fr-en")

def search_french(engine, query_fr, top_k=3):
    """Recherche avec traduction FR -> EN."""
    query_en = translator(query_fr, max_length=100)[0]['translation_text']
    print(f"FR: {query_fr}")
    print(f"EN: {query_en}")
    return engine.search(query_en, top_k)

# Test
results = search_french(search_engine, "un animal domestique adorable")
print(f"\nR√©sultat: {results[0][0]} (score: {results[0][1]:.3f})")

---

## 6. Analyse des embeddings

In [None]:
# Visualisons la similarit√© entre toutes les paires d'images
similarity_matrix = (search_engine.image_embeddings @ search_engine.image_embeddings.T).cpu().numpy()

plt.figure(figsize=(8, 6))
sns.heatmap(similarity_matrix, 
            xticklabels=[n.replace('_', '\n') for n in search_engine.image_names],
            yticklabels=[n.replace('_', '\n') for n in search_engine.image_names],
            annot=True, fmt='.2f', cmap='Blues')
plt.title("Similarit√© entre images (embeddings CLIP)")
plt.tight_layout()
plt.show()

In [None]:
# Visualisation t-SNE des embeddings
from sklearn.manifold import TSNE

# Ajouter quelques embeddings de texte
texts = ["a cat", "a dog", "a building", "a beach", "a mountain", "a car"]
text_inputs = clip.tokenize(texts).to(device)

with torch.no_grad():
    text_embeddings = model.encode_text(text_inputs)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# Combiner image et texte embeddings
all_embeddings = torch.cat([search_engine.image_embeddings, text_embeddings], dim=0).cpu().numpy()
all_labels = search_engine.image_names + texts
all_types = ['image'] * len(search_engine.image_names) + ['text'] * len(texts)

# t-SNE
tsne = TSNE(n_components=2, perplexity=3, random_state=42)
embeddings_2d = tsne.fit_transform(all_embeddings)

# Plot
plt.figure(figsize=(10, 8))
colors = ['blue' if t == 'image' else 'red' for t in all_types]
markers = ['o' if t == 'image' else '^' for t in all_types]

for i, (x, y) in enumerate(embeddings_2d):
    plt.scatter(x, y, c=colors[i], marker=markers[i], s=100)
    plt.annotate(all_labels[i].replace('_', ' '), (x, y), fontsize=9)

plt.title("Espace des embeddings CLIP (t-SNE)")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.legend(['Images (bleu)', 'Textes (rouge)'])
plt.show()

**Observation** : Les images et textes similaires sont proches dans l'espace des embeddings !

---

## 7. Exercices

### Exercice 3 : √âtendre la base d'images
Ajoutez plus d'images et testez des requ√™tes plus complexes.

### Exercice 4 : Cr√©er une d√©mo interactive
Permettez √† l'utilisateur de saisir une requ√™te.

### Exercice 5 : Analyser les limites
Trouvez des cas o√π CLIP se trompe.

In [None]:
# Espace pour vos exp√©rimentations

# Exercice 3 : Ajouter des images
# nouvelles_images = {
#     "nom": load_image_from_url("url"),
# }

# Exercice 4 : D√©mo interactive
def demo_recherche():
    print("\n" + "="*50)
    print("RECHERCHE D'IMAGES - D√©mo Interactive")
    print("="*50)
    print("Entrez une description textuelle")
    print("Tapez 'quit' pour quitter")
    print("="*50)
    
    while True:
        query = input("\nRecherche: ")
        if query.lower() == 'quit':
            break
        
        results = search_engine.search(query, top_k=3)
        
        print(f"\nR√©sultats pour '{query}':")
        for name, score, _ in results:
            print(f"  {name}: {score:.3f}")

# D√©commenter pour tester
# demo_recherche()

In [None]:
# Exercice 5 : Cas limites
# Testez ces requ√™tes ambigu√´s ou difficiles

difficult_queries = [
    "something blue",  # Ambigu
    "a happy scene",   # Subjectif
    "danger",          # Abstrait
    "the number 5",    # Conceptuel
]

print("Tests de requ√™tes difficiles:")
for query in difficult_queries:
    results = search_engine.search(query, top_k=1)
    name, score, _ = results[0]
    print(f"  '{query}' ‚Üí {name} ({score:.3f})")

---

## 8. Conclusion

### Ce que vous avez appris

1. **CLIP** aligne images et textes dans un m√™me espace vectoriel
2. **Zero-shot** : Classification sans entra√Ænement sp√©cifique
3. **Recherche s√©mantique** : Trouver des images par description
4. **Multimodal** : Combiner vision et langage

### Lien avec l'actualit√©

CLIP est √† la base de :
- **DALL-E** : G√©n√©ration d'images √† partir de texte
- **Midjourney** : Art g√©n√©ratif
- **Stable Diffusion** : Images open-source

### Limitations

- Biais du dataset (web anglophone)
- Difficult√© avec concepts abstraits
- Ne comprend pas le contexte/l'humour

### Pour aller plus loin

- Explorer d'autres mod√®les CLIP (ViT-L/14, etc.)
- Utiliser CLIP pour filtrer des images
- Combiner avec g√©n√©ration d'images