# Extracción de Embeddings y Clasificación con k-NN usando FAISS y CLIP

En este notebook, vamos a realizar la extracción de embeddings de imágenes utilizando el modelo CLIP y almacenarlos en un índice FAISS para hacer búsquedas rápidas de vecinos más cercanos. Finalmente, implementamos un clasificador k-NN para predecir la clase de una imagen utilizando los vecinos más cercanos.
## Contenido:
- Preprocesamiento y carga de imágenes
- Extracción de embeddings con CLIP
- Almacenamiento de embeddings en FAISS
- Clasificación k-NN basada en los embeddings
- Ejemplo de uso del clasificador k-NN

In [1]:
# importamos las librerías necesarias
import os
import torch
import clip
from PIL import Image
import faiss
import csv
import numpy as np
from tqdm import tqdm
from sklearn.manifold import TSNE
import plotly.express as px

### 1. Preprocesamiento de imágenes y modelo CLIP

En esta sección, cargamos y preprocesamos las imágenes utilizando PIL y realizamos la extracción de los embeddings de las imágenes usando el modelo CLIP preentrenado. El modelo CLIP está diseñado para aprender representaciones de imágenes y textos de manera conjunta.

Nos cambiamos a la carpeta anterior, porque luego para estar manipulando las imágenes es mucho más fácil meternos a nivel de la carpeta "data"

In [2]:
os.chdir("../")

In [3]:
# Load the CLIP model and preprocess function
device = 'cpu'  # or 'cuda' if available
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Initialize FAISS index (assuming 512-dimensional embeddings for ViT-B/32)
embedding_dim = 512
faiss_index = faiss.IndexFlatL2(embedding_dim)

# Folder containing the dataset
dataset_folder = "./data/"

# CSV to store mapping between embedding ID and class
csv_file = "./faiss/embeddings_mapping.csv"
csv_data = []

#### Construcción del índice FAISS

En esta parte, tomamos los embeddings extraídos y los almacenamos en un índice FAISS. FAISS es una librería optimizada para la búsqueda de similitudes entre grandes volúmenes de datos de alta dimensión.

In [4]:
# Iterate through the dataset folder
embedding_id = 0
class_folders = [f for f in os.listdir(dataset_folder) if os.path.isdir(os.path.join(dataset_folder, f))]

# Use tqdm to show progress over class folders
for class_folder in tqdm(class_folders, desc="Processing classes"):
    class_path = os.path.join(dataset_folder, class_folder)

    img_files = os.listdir(class_path)

    # Use tqdm to show progress over images in each class folder
    for img_file in tqdm(img_files, desc=f"Processing images in {class_folder}", leave=False):
        img_path = os.path.join(class_path, img_file)

        try:
            # Load and preprocess image
            img = Image.open(img_path)
            img_preprocessed = preprocess(img).unsqueeze(0).to(device)

            # Calculate the embedding
            with torch.no_grad():
                img_embedding = clip_model.encode_image(img_preprocessed).cpu().numpy().astype(float)

            # Add embedding to FAISS index
            faiss_index.add(img_embedding)

            # Add mapping of embedding ID to class in the CSV
            csv_data.append([embedding_id, class_folder])
            embedding_id += 1
        except Exception as e:
            print(f"Error processing {img_path}: {e}")


# Save the FAISS index
faiss.write_index(faiss_index, "./faiss/image_embeddings.index")

# Write CSV file with mappings
with open(csv_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['EmbeddingID', 'Class'])  # Header
    writer.writerows(csv_data)

print(f"Finished processing. FAISS index saved as 'image_embeddings.index' and CSV saved as '{csv_file}'.")


Processing classes: 100%|██████████| 5/5 [00:31<00:00,  6.38s/it]

Finished processing. FAISS index saved as 'image_embeddings.index' and CSV saved as './faiss/embeddings_mapping.csv'.





In [5]:
# retrieve the embeddings
faiss_index = faiss.read_index("./faiss/image_embeddings.index")

# Retrieve all embeddings from the FAISS index
all_embeddings = []
for embedding_id in range(faiss_index.ntotal):
    embedding = faiss_index.reconstruct(embedding_id)
    all_embeddings.append(embedding)

all_embeddings = np.array(all_embeddings)

# Reduce the dimensionality of the embeddings for plotting
tsne = TSNE(n_components=2)
embeddings_2d = tsne.fit_transform(all_embeddings)

# Plot the embeddings based on the class
csv_file = "./faiss/embeddings_mapping.csv"
class_labels = {}
with open(csv_file, mode='r') as f:
    reader = csv.reader(f)
    next(reader)  # Skip header
    for row in reader:
        embedding_id, class_label = row
        class_labels[int(embedding_id)] = class_label

# Create a DataFrame for plotting
import pandas as pd
df = pd.DataFrame(embeddings_2d, columns=["x", "y"])
df['Class'] = [class_labels[i] for i in range(len(all_embeddings))]
fig = px.scatter(df, x="x", y="y", color="Class", title="Image Embeddings")


In [6]:
fig.show()

In [8]:
# Load the FAISS index
faiss_index = faiss.read_index("./faiss/image_embeddings.index")

# Load the CSV mapping (EmbeddingID -> Class)
csv_file = "./faiss/embeddings_mapping.csv"
id_to_class = {}

with open(csv_file, mode='r') as f:
    reader = csv.reader(f)
    next(reader)  # Skip header
    for row in reader:
        embedding_id = int(row[0])
        class_name = row[1]
        id_to_class[embedding_id] = class_name

# Function to perform similarity search
def search_similar_images(input_image_path, k=5):
    # Load and preprocess input image
    img = Image.open(input_image_path)
    img_preprocessed = preprocess(img).unsqueeze(0).to(device)

    # Calculate the embedding for the input image
    with torch.no_grad():
        input_embedding = clip_model.encode_image(img_preprocessed).cpu().numpy().astype(float)

    # Perform similarity search in FAISS index
    distances, indices = faiss_index.search(input_embedding, k)

    # Retrieve the top-k most similar classes and distances
    results = []
    for i, index in enumerate(indices[0]):
        class_name = id_to_class[index]
        distance = distances[0][i]
        results.append((class_name, distance))

    return results

# Example usage:
input_image_path = "test_images/00000005.png"  # Path to the input image
top_k_results = search_similar_images(input_image_path, k=5)

# Display the results
for i, (class_name, distance) in enumerate(top_k_results):
    print(f"Rank {i+1}: Class: {class_name}, Distance: {distance}")

Rank 1: Class: pikachu, Distance: 5.546635150909424
Rank 2: Class: pikachu, Distance: 7.80823278427124
Rank 3: Class: pikachu, Distance: 8.26430606842041
Rank 4: Class: pikachu, Distance: 8.26492691040039
Rank 5: Class: pikachu, Distance: 8.570350646972656


In [9]:
from collections import Counter

id_to_class = {}

with open(csv_file, mode='r') as f:
    reader = csv.reader(f)
    next(reader)  # Skip header
    for row in reader:
        embedding_id = int(row[0])
        class_name = row[1]
        id_to_class[embedding_id] = class_name

# Function to perform k-NN classification
def knn_classify(input_image_path, k=5):
    # Load and preprocess input image
    img = Image.open(input_image_path)
    img_preprocessed = preprocess(img).unsqueeze(0).to(device)

    # Calculate the embedding for the input image
    with torch.no_grad():
        input_embedding = clip_model.encode_image(img_preprocessed).cpu().numpy().astype(float)

    # Perform similarity search in FAISS index
    distances, indices = faiss_index.search(input_embedding, k)

    # Retrieve the classes of the k nearest neighbors
    nearest_classes = [id_to_class[idx] for idx in indices[0]]

    # Perform majority voting among the top-k classes
    class_counter = Counter(nearest_classes)
    predicted_class = class_counter.most_common(1)[0][0]  # Get the most common class

    return predicted_class, nearest_classes, distances[0]


In [10]:
# Example usage:
input_image_path = "test_images/00000005.png"  # Path to the input image
predicted_class, nearest_classes, distances = knn_classify(input_image_path, k=5)

# Display the results
print(f"Predicted Class: {predicted_class}")
print(f"Nearest Neighbors Classes: {nearest_classes}")
print(f"Distances: {distances}")

Predicted Class: pikachu
Nearest Neighbors Classes: ['pikachu', 'pikachu', 'pikachu', 'pikachu', 'pikachu']
Distances: [5.546635 7.808233 8.264306 8.264927 8.570351]
