In [12]:
import os
import json
import cv2
import torch
import numpy as np
from PIL import Image
import chromadb
import clip
import shutil
import base64
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

In [13]:
# Configuration
PROJECT_ROOT = "C:/Users/SANGLI Kenneth/Documents/project_root"
VIDEO_DIR = os.path.join(PROJECT_ROOT, "videos")
ANNOTATION_FILE = os.path.join(PROJECT_ROOT, "annotations.json")
SELECTED_SUBSET_FILE = os.path.join(PROJECT_ROOT, "annotations/selected_subset.json")
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output")
FPS = 1  # Extraire 1 frame par vidéo (frame centrale)
COLLECTION_NAME = "kinetics_700_subset"

In [14]:
# Créer le dossier de sortie
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [15]:
# Créer le client ChromaDB
client = chromadb.PersistentClient(path=os.path.join(PROJECT_ROOT, "chroma_db"))

In [16]:
# Charger le modèle CLIP
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [17]:
# Définir la fonction d'embedding pour ChromaDB
class ClipEmbeddingFunction:
    def __call__(self, input):
        if not isinstance(input, torch.Tensor):
            raise ValueError("Input to embedding function must be a torch.Tensor")
        with torch.no_grad():
            return clip_model.encode_image(input).cpu().numpy()

In [18]:
# Créer ou récupérer la collection ChromaDB
embedding_function = ClipEmbeddingFunction()
collection = client.get_or_create_collection(
    name=COLLECTION_NAME,
    embedding_function=embedding_function
)

def extract_single_frame(video_path):
    """Extrait la frame centrale d'une vidéo."""
    print(f"Extraction de la frame centrale de : {video_path}")
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        print(f"Erreur : Impossible d'ouvrir la vidéo {video_path}")
        return None
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames <= 0:
        print(f"Erreur : Aucune frame dans la vidéo {video_path}")
        cap.release()
        return None
    
    # Sélectionner la frame centrale
    central_frame_idx = total_frames // 2
    cap.set(cv2.CAP_PROP_POS_FRAMES, central_frame_idx)
    success, frame = cap.read()
    if not success:
        print(f"Erreur : Impossible de lire la frame centrale de {video_path}")
        cap.release()
        return None
    
    # Convertir en image PIL et sauvegarder
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(frame_rgb)
    frame_filename = os.path.join(OUTPUT_DIR, f"frame_{os.path.basename(video_path).split('.')[0]}.png")
    pil_image.save(frame_filename)
    print(f"Frame extraite : {frame_filename}")
    
    cap.release()
    return frame_filename

In [19]:
def load_kinetics_data(annotation_file, video_dir):
    """Charge les annotations et extrait une frame par vidéo."""
    with open(annotation_file, 'r') as f:
        annotations = json.load(f)
    
    frame_paths = []
    labels = []
    video_ids = []
    
    for video_id, info in annotations.items():
        video_path = os.path.join(video_dir, info['path'])
        label = info['label']
        frame_path = extract_single_frame(video_path)
        if frame_path is None:
            print(f"Aucune frame extraite pour {video_path}. Passage à la suivante.")
            continue
        frame_paths.append(frame_path)
        labels.append(label)
        video_ids.append(video_id)
    
    return frame_paths, labels, video_ids

def encode_image(image_path):
    """Encode une image avec CLIP et normalise l'embedding."""
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
    # Normalisation explicite
    image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)
    return image_embedding.cpu().numpy()

def encode_text(text):
    """Encode un texte avec CLIP et normalise l'embedding."""
    text_inputs = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_inputs)
    # Normalisation explicite
    text_embedding = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_embedding.cpu().numpy()

def calculate_metrics(similarity_matrix, labels, text_labels):
    """Calcule les métriques de performance."""
    num_samples = similarity_matrix.shape[0]
    
    precision_at_1 = 0
    precision_at_3 = 0
    correct_distances = []
    incorrect_distances = []
    
    for i in range(num_samples):
        true_label = labels[i]
        true_idx = text_labels.index(true_label)
        
        similarities = similarity_matrix[i]
        sorted_indices = np.argsort(similarities)[::-1]
        
        if sorted_indices[0] == true_idx:
            precision_at_1 += 1
        if true_idx in sorted_indices[:3]:
            precision_at_3 += 1
        
        correct_distances.append(similarities[true_idx])
        incorrect_indices = [idx for idx in range(len(text_labels)) if idx != true_idx]
        incorrect_distances.extend(similarities[incorrect_indices])
    
    precision_at_1 = precision_at_1 / num_samples
    precision_at_3 = precision_at_3 / num_samples
    avg_correct_distance = np.mean(correct_distances)
    avg_incorrect_distance = np.mean(incorrect_distances)
    
    return {
        "precision@1": precision_at_1,
        "precision@3": precision_at_3,
        "avg_correct_distance": avg_correct_distance,
        "avg_incorrect_distance": avg_incorrect_distance
    }

In [20]:
def generate_results_html(frame_paths, text_queries, labels, video_ids, similarity_matrix, metrics):
    """Génère un fichier HTML avec les résultats."""
    def image_to_data_uri(img_path):
        try:
            with open(img_path, "rb") as img_file:
                encoded = base64.b64encode(img_file.read()).decode('utf-8')
                return encoded
        except Exception as e:
            print(f"Erreur lors de l'encodage de {img_path} : {e}")
            return None

    similarity_matrix_np = similarity_matrix
    
    html = """
    <html>
    <head>
        <title>CLIP Results with Kinetics-700</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <style>
            body {{ padding: 20px; }}
            table {{ width: 100%; }}
            th, td {{ vertical-align: middle; text-align: center; }}
        </style>
    </head>
    <body>
    <div class="container">
        <h1 class="my-4">CLIP Results with Kinetics-700</h1>
        <h2>Métriques</h2>
        <table class="table table-bordered">
            <tr><th>Précision@1</th><td>{0:.2f}</td></tr>
            <tr><th>Précision@3</th><td>{1:.2f}</td></tr>
            <tr><th>Distance moyenne (correcte)</th><td>{2:.2f}</td></tr>
            <tr><th>Distance moyenne (incorrecte)</th><td>{3:.2f}</td></tr>
        </table>
        <h2>Résultats détaillés</h2>
        <table class="table table-bordered table-striped">
        <thead class="table-light">
        <tr>
            <th>Video ID</th>
            <th>Frame</th>
            <th>Ground Truth Label</th>
            <th>Best Matched Text</th>
    """.format(
        metrics['precision@1'],
        metrics['precision@3'],
        metrics['avg_correct_distance'],
        metrics['avg_incorrect_distance']
    )

    for text in text_queries:
        html += f'<th>{text}</th>'
    html += "</tr></thead><tbody>"
    
    for i, (img_path, label, video_id) in enumerate(zip(frame_paths, labels, video_ids)):
        img_data_uri = image_to_data_uri(img_path)
        img_display = f"<img src='data:image/png;base64,{img_data_uri}' width='100'>" if img_data_uri else "Image indisponible"
        
        similarities = similarity_matrix_np[i, :]
        best_match_index = np.argmax(similarities)
        best_matched_text = text_queries[best_match_index]
        
        html += f"""
        <tr>
            <td>{video_id}</td>
            <td>{img_display}</td>
            <td>{label}</td>
            <td>{best_matched_text}</td>
        """
        for j, score in enumerate(similarities):
            html += f"<td class='bg-info'><b>{score:.2f}</b></td>" if j == best_match_index else f"<td>{score:.2f}</td>"
        html += "</tr>"
    
    html += """
        </tbody>
        </table>
    </div>
    </body>
    </html>
    """
    
    with open(os.path.join(OUTPUT_DIR, "results_kinetics.html"), "w") as f:
        f.write(html)
    print(f"Résultats sauvegardés dans {os.path.join(OUTPUT_DIR, 'results_kinetics.html')}")

def plot_similarity_distribution(correct_distances, incorrect_distances):
    """Affiche la distribution des similarités."""
    plt.figure(figsize=(10, 6))
    plt.hist(correct_distances, bins=30, alpha=0.5, label='Correct Pairs', color='green')
    plt.hist(incorrect_distances, bins=30, alpha=0.5, label='Incorrect Pairs', color='red')
    plt.title('Distribution des similarités cosinus')
    plt.xlabel('Similarité cosinus')
    plt.ylabel('Fréquence')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(OUTPUT_DIR, 'similarity_distribution.png'))
    plt.close()
    print(f"Graphique sauvegardé dans {os.path.join(OUTPUT_DIR, 'similarity_distribution.png')}")

def search_similar_images(query, top_k=5):
    """Recherche les images les plus similaires à une requête textuelle dans ChromaDB."""
    # Encoder et normaliser la requête
    query_embedding = encode_text(query)
    
    # Effectuer la recherche dans ChromaDB
    results = collection.query(
        query_embeddings=[query_embedding.tolist()[0]],
        n_results=top_k,
        include=["metadatas", "embeddings"]  # Récupérer les embeddings pour recalculer la similarité
    )
    
    # Extraire les résultats
    image_paths = []
    similarities = []
    labels = []
    video_ids = []
    
    # Recalculer la similarité manuellement
    for id_, embedding, metadata in zip(results['ids'][0], results['embeddings'][0], results['metadatas'][0]):
        # Convertir les embeddings en numpy arrays
        stored_embedding = np.array(embedding).reshape(1, -1)
        # Recalculer la similarité cosinus
        similarity = cosine_similarity(query_embedding, stored_embedding)[0][0]
        # S'assurer que la similarité est dans la plage [0, 1]
        similarity = max(0, min(1, similarity))
        
        image_paths.append(metadata['frame'])
        similarities.append(similarity)
        labels.append(metadata['label'])
        video_ids.append(id_)
    
    # Trier les résultats par similarité décroissante
    sorted_results = sorted(zip(image_paths, similarities, labels, video_ids), key=lambda x: x[1], reverse=True)
    image_paths, similarities, labels, video_ids = zip(*sorted_results) if sorted_results else ([], [], [], [])
    
    return list(image_paths), list(similarities), list(labels), list(video_ids)

def display_search_results(query, image_paths, similarities, labels, video_ids):
    """Génère un fichier HTML pour afficher les résultats de la recherche."""
    def image_to_data_uri(img_path):
        try:
            with open(img_path, "rb") as img_file:
                encoded = base64.b64encode(img_file.read()).decode('utf-8')
                return encoded
        except Exception as e:
            print(f"Erreur lors de l'encodage de {img_path} : {e}")
            return None

    html = """
    <html>
    <head>
        <title>Recherche d'images similaires - Kinetics-700</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <style>
            body {{ padding: 20px; }}
            table {{ width: 100%; }}
            th, td {{ vertical-align: middle; text-align: center; }}
        </style>
    </head>
    <body>
    <div class="container">
        <h1 class="my-4">Résultats de la recherche pour : "{}"</h1>
        <table class="table table-bordered table-striped">
        <thead class="table-light">
        <tr>
            <th>Video ID</th>
            <th>Frame</th>
            <th>Label</th>
            <th>Similarité</th>
        </tr>
        </thead>
        <tbody>
    """.format(query)
    
    for img_path, similarity, label, video_id in zip(image_paths, similarities, labels, video_ids):
        img_data_uri = image_to_data_uri(img_path)
        img_display = f"<img src='data:image/png;base64,{img_data_uri}' width='100'>" if img_data_uri else "Image indisponible"
        
        html += f"""
        <tr>
            <td>{video_id}</td>
            <td>{img_display}</td>
            <td>{label}</td>
            <td>{similarity:.4f}</td>
        </tr>
        """
    
    html += """
        </tbody>
        </table>
    </div>
    </body>
    </html>
    """
    
    output_file = os.path.join(OUTPUT_DIR, f"search_results_{query.replace(' ', '_')}.html")
    with open(output_file, "w") as f:
        f.write(html)
    print(f"Résultats de la recherche sauvegardés dans {output_file}")

In [None]:
def main():
    # Étape 1 : Charger les données Kinetics-700
    print("Étape 1 : Chargement des données Kinetics-700...")
    frame_paths, labels, video_ids = load_kinetics_data(ANNOTATION_FILE, VIDEO_DIR)
    if not frame_paths:
        print("Erreur : Aucune frame n'a été extraite.")
        return
    
    # Étape 2 : Définir les queries descriptives
    print("Étape 2 : Génération des queries descriptives...")
    concepts = sorted(list(set(labels)))
    query_dict = {
        "playing basketball": "A group of people playing basketball on an outdoor court with a hoop",
        "playing soccer": "A group of people kicking a soccer ball on a grassy field with goalposts",
        "cooking": "A person in a kitchen chopping vegetables on a cutting board",
        "painting": "A person painting a canvas with colorful paints in a studio",
        "swimming": "A person swimming in a blue pool under the sun",
        "reading a book": "A person sitting on a couch reading a book in a cozy room",
        "playing guitar": "A person playing an acoustic guitar in a park",
        "riding a bike": "A person riding a bike on a sunny trail with trees",
        "doing yoga": "A person doing yoga on a mat in a calm studio",
        "writing on a whiteboard": "A person writing on a whiteboard in a classroom",
        # Complétez avec les autres classes...
    }
    text_queries = [query_dict.get(c, f"A person performing the action of {c} in a typical setting") for c in concepts]
    print(f"Concepts : {concepts}")
    print(f"Queries : {text_queries}")

    # Étape 3 : Calculer les embeddings et stocker dans ChromaDB
    print("Étape 3 : Calcul des embeddings et stockage dans ChromaDB...")
    image_embeddings = []
    for i, frame_path in enumerate(frame_paths):
        embedding = encode_image(frame_path)
        image_embeddings.append(embedding[0])
        collection.add(
            ids=[video_ids[i]],
            embeddings=[embedding[0].tolist()],
            metadatas=[{"frame": frame_path, "label": labels[i]}]
        )
        print(f"Embedding ajouté pour {video_ids[i]}")
    
    # Étape 4 : Calculer la matrice de similarité
    print("Étape 4 : Calcul de la matrice de similarité...")
    text_embeddings = [encode_text(query)[0] for query in text_queries]
    similarity_matrix = cosine_similarity(image_embeddings, text_embeddings)
    
    # Étape 5 : Calculer les métriques
    print("Étape 5 : Calcul des métriques...")
    metrics = calculate_metrics(similarity_matrix, labels, concepts)
    print("Métriques :")
    print(f"Précision@1 : {metrics['precision@1']:.2f}")
    print(f"Précision@3 : {metrics['precision@3']:.2f}")
    print(f"Distance moyenne (correcte) : {metrics['avg_correct_distance']:.2f}")
    print(f"Distance moyenne (incorrecte) : {metrics['avg_incorrect_distance']:.2f}")
    
    # Étape 6 : Visualiser les résultats
    print("Étape 6 : Génération des visualisations...")
    generate_results_html(frame_paths, text_queries, labels, video_ids, similarity_matrix, metrics)
    
    # Étape 7 : Afficher la distribution des similarités
    correct_distances = [similarity_matrix[i, concepts.index(labels[i])] for i in range(len(labels))]
    incorrect_distances = [similarity_matrix[i, j] for i in range(len(labels)) for j in range(len(concepts)) if j != concepts.index(labels[i])]
    plot_similarity_distribution(correct_distances, incorrect_distances)
    
    # Étape 8 : Interface de recherche interactive
    print("Étape 8 : Recherche interactive...")
    while True:
        query = input("Entrez une requête textuelle (ou 'quitter' pour arrêter) : ")
        if query.lower() == 'quitter':
            break
        if not query.strip():
            print("Erreur : La requête ne peut pas être vide.")
            continue
        
        top_k = input("Combien de résultats voulez-vous (par défaut 5) ? ")
        try:
            top_k = int(top_k) if top_k.strip() else 5
            if top_k <= 0:
                raise ValueError
        except ValueError:
            print("Erreur : Entrez un nombre positif. Utilisation de la valeur par défaut (5).")
            top_k = 5
        
        image_paths, similarities, labels, video_ids = search_similar_images(query, top_k)
        if not image_paths:
            print("Aucun résultat trouvé. Vérifiez la collection ChromaDB.")
            continue
        
        display_search_results(query, image_paths, similarities, labels, video_ids)
        print(f"Ouvrez {os.path.join(OUTPUT_DIR, f'search_results_{query.replace(' ', '_')}.html')} pour voir les résultats.")

if __name__ == "__main__":
    main()

Étape 1 : Chargement des données Kinetics-700...
Extraction de la frame centrale de : C:/Users/SANGLI Kenneth/Documents/project_root\videos\video_1.mp4
Frame extraite : C:/Users/SANGLI Kenneth/Documents/project_root\output\frame_video_1.png
Extraction de la frame centrale de : C:/Users/SANGLI Kenneth/Documents/project_root\videos\video_2.mp4
Frame extraite : C:/Users/SANGLI Kenneth/Documents/project_root\output\frame_video_2.png
Extraction de la frame centrale de : C:/Users/SANGLI Kenneth/Documents/project_root\videos\video_3.mp4
Frame extraite : C:/Users/SANGLI Kenneth/Documents/project_root\output\frame_video_3.png
Extraction de la frame centrale de : C:/Users/SANGLI Kenneth/Documents/project_root\videos\video_4.mp4
Frame extraite : C:/Users/SANGLI Kenneth/Documents/project_root\output\frame_video_4.png
Extraction de la frame centrale de : C:/Users/SANGLI Kenneth/Documents/project_root\videos\video_5.mp4
Frame extraite : C:/Users/SANGLI Kenneth/Documents/project_root\output\frame_vid

KeyError: ' padding'