# Fashion Search — Indexation ASOS depuis Colab (GPU)

Ce notebook indexe le dataset [ASOS e-commerce](https://huggingface.co/datasets/UniqueData/asos-e-commerce-dataset) dans Weaviate avec **3 named vectors** :
- **fashion_clip** (512d) — `patrickjohncyh/fashion-clip`
- **marqo_clip** (512d) — `Marqo/marqo-fashionCLIP`
- **siglip2** (1152d) — `google/siglip2-so400m-patch14-384`

**Features :**
- Combined image + text embeddings (0.7 img + 0.3 txt) pour des vecteurs plus riches
- Descriptions texte pour BM25F search
- Checkpoint / reprise automatique
- Keep-alive Weaviate (ping toutes les 2 min)

**Architecture :**
- **Colab (GPU)** : encode les images avec les 3 modèles → pousse les vecteurs dans Weaviate
- **GCP (CPU)** : Weaviate + app FastAPI pour servir les recherches

**Prérequis :**
- Runtime GPU activé (Runtime → Change runtime type → T4 GPU)
- Weaviate accessible sur ta VM GCP (port 8080 ouvert dans le firewall)

## 1. Installation des dépendances

In [None]:
# Install deps (sans toucher au torch pre-installe sur Colab)
!pip install -q weaviate-client>=4.0 transformers Pillow datasets requests tqdm timm ftfy regex
!pip install -q --no-deps open-clip-torch

## 2. Configuration

Renseigne l'IP externe de ta VM GCP où tourne Weaviate.

**Important** : le port `8080` (HTTP) et `50051` (gRPC) doivent être ouverts dans le firewall GCP.

In [None]:
# --- CONFIGURATION ---
GCP_EXTERNAL_IP = ""   # ex: "34.56.78.90"

WEAVIATE_HTTP_PORT = 8080
WEAVIATE_GRPC_PORT = 50051

COLLECTION_NAME = "FashionCollection"

# Modèles
MODEL_FASHION_CLIP = "patrickjohncyh/fashion-clip"
MODEL_MARQO_CLIP = "Marqo/marqo-fashionCLIP"
MODEL_SIGLIP2 = "google/siglip2-so400m-patch14-384"

# Named vector keys
VECTOR_FASHION_CLIP = "fashion_clip"
VECTOR_MARQO_CLIP = "marqo_clip"
VECTOR_SIGLIP2 = "siglip2"
ALL_VECTOR_NAMES = [VECTOR_FASHION_CLIP, VECTOR_MARQO_CLIP, VECTOR_SIGLIP2]

# --- Tuning A100 40GB + 83GB RAM ---
MAX_ITEMS = None              # tout le dataset (~45k produits)
MAX_IMAGES_PER_PRODUCT = 1    # images par produit
WEAVIATE_BATCH_SIZE = 200     # flush frequent = meilleur checkpoint
DOWNLOAD_WORKERS = 64         # téléchargements parallèles
PREFETCH_SIZE = 200           # petit = barre de progrès en temps réel
THUMBNAIL_SIZE = 150
IMAGE_WEIGHT = 0.7            # poids image dans le vecteur combiné
RECREATE = False              # False = reprise, True = tout recréer

# Checkpoint (reprise automatique)
CHECKPOINT_FILE = "/content/index_checkpoint.json"
KEEPALIVE_INTERVAL = 120      # secondes

## 3. Vérification GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    device = "cpu"
    print("Pas de GPU, l'indexation sera lente.")

print(f"Device: {device}")

## 4. Chargement des 3 modèles

In [None]:
from transformers import CLIPModel, CLIPProcessor, AutoModel, AutoProcessor
import open_clip

# Fashion CLIP (512d) — HuggingFace format
print(f"Chargement de {MODEL_FASHION_CLIP}...")
fc_model = CLIPModel.from_pretrained(MODEL_FASHION_CLIP).to(device).eval()
fc_processor = CLIPProcessor.from_pretrained(MODEL_FASHION_CLIP)

# Marqo FashionCLIP (512d) — OpenCLIP format
print(f"Chargement de {MODEL_MARQO_CLIP}...")
mq_model, _, mq_preprocess = open_clip.create_model_and_transforms(
    f"hf-hub:{MODEL_MARQO_CLIP}", device=device
)
mq_tokenizer = open_clip.get_tokenizer(f"hf-hub:{MODEL_MARQO_CLIP}")
mq_model = mq_model.eval()

# SigLIP2-SO400M (1152d) — HuggingFace format
print(f"Chargement de {MODEL_SIGLIP2}...")
sl_model = AutoModel.from_pretrained(MODEL_SIGLIP2).to(device).eval()
sl_processor = AutoProcessor.from_pretrained(MODEL_SIGLIP2)

# Dict unifie : (type, model, processor/preprocess, tokenizer, max_text_tokens)
MODELS = {
    VECTOR_FASHION_CLIP: ("hf", fc_model, fc_processor, None, 77),
    VECTOR_MARQO_CLIP: ("openclip", mq_model, mq_preprocess, mq_tokenizer, 77),
    VECTOR_SIGLIP2: ("hf", sl_model, sl_processor, None, 64),
}

# A100 40GB : ~5GB en FP32 pour les 3 modeles, pas besoin de FP16
if device == "cuda":
    vram_used = torch.cuda.memory_allocated() / 1024**3
    vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"\n3 modeles charges sur {device} — VRAM: {vram_used:.1f}/{vram_total:.0f} GB")
else:
    print(f"\n3 modeles charges sur {device}.")

## 5. Connexion a Weaviate sur GCP

In [None]:
import weaviate
import json
import os
import threading
from datetime import datetime, timezone
from weaviate.classes.config import Configure, DataType, Property, VectorDistances

assert GCP_EXTERNAL_IP, "Renseigne GCP_EXTERNAL_IP dans la cellule de configuration."

client = weaviate.connect_to_custom(
    http_host=GCP_EXTERNAL_IP,
    http_port=WEAVIATE_HTTP_PORT,
    http_secure=False,
    grpc_host=GCP_EXTERNAL_IP,
    grpc_port=WEAVIATE_GRPC_PORT,
    grpc_secure=False,
)

print(f"Connecte a Weaviate sur {GCP_EXTERNAL_IP}" if client.is_ready() else "Echec de connexion")

# --- Keep-alive (ping toutes les 2 min, auto-reconnect) ---
_keepalive_stop = threading.Event()

def start_keepalive(wv_client, interval=KEEPALIVE_INTERVAL):
    def _run():
        while not _keepalive_stop.wait(interval):
            try:
                if not wv_client.is_ready():
                    raise ConnectionError("not ready")
            except Exception as e:
                print(f"\n[keep-alive] Reconnexion... ({e})")
                try:
                    wv_client.close()
                except Exception:
                    pass
                wv_client.connect_to_custom(
                    http_host=GCP_EXTERNAL_IP,
                    http_port=WEAVIATE_HTTP_PORT,
                    http_secure=False,
                    grpc_host=GCP_EXTERNAL_IP,
                    grpc_port=WEAVIATE_GRPC_PORT,
                    grpc_secure=False,
                )
    t = threading.Thread(target=_run, daemon=True)
    t.start()
    return t

start_keepalive(client)
print("Keep-alive demarre (ping toutes les 2 min)")

# --- Checkpoint helpers ---
def load_checkpoint():
    try:
        if os.path.exists(CHECKPOINT_FILE):
            with open(CHECKPOINT_FILE) as f:
                return json.load(f)
    except Exception:
        pass
    return {"last_index": -1, "indexed_count": 0, "errors": 0}

def save_checkpoint(last_index, indexed_count, errors):
    with open(CHECKPOINT_FILE, "w") as f:
        json.dump({"last_index": last_index, "indexed_count": indexed_count,
                    "errors": errors, "timestamp": datetime.now(timezone.utc).isoformat()}, f)

def clear_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

## 6. Creation du schema

In [None]:
def create_schema(wv_client, collection_name, vector_names=None):
    """Crée la collection avec named vectors."""
    if vector_names is None:
        vector_names = ALL_VECTOR_NAMES

    named_vectors = [
        Configure.NamedVectors.none(
            name=name,
            vector_index_config=Configure.VectorIndex.hnsw(
                distance_metric=VectorDistances.COSINE
            ),
        )
        for name in vector_names
    ]

    wv_client.collections.create(
        name=collection_name,
        properties=[
            Property(name="filename", data_type=DataType.TEXT),
            Property(name="path", data_type=DataType.TEXT),
            Property(name="thumbnail_base64", data_type=DataType.TEXT),
            Property(name="width", data_type=DataType.INT),
            Property(name="height", data_type=DataType.INT),
            Property(name="indexed_at", data_type=DataType.TEXT),
            Property(name="product_id", data_type=DataType.TEXT),
            Property(name="product_name", data_type=DataType.TEXT),
            Property(name="category", data_type=DataType.TEXT),
            Property(name="color", data_type=DataType.TEXT),
            Property(name="size", data_type=DataType.TEXT),
            Property(name="price", data_type=DataType.NUMBER),
            Property(name="brand", data_type=DataType.TEXT),
            Property(name="product_url", data_type=DataType.TEXT),
            Property(name="description", data_type=DataType.TEXT),
            Property(name="image_index", data_type=DataType.INT),
            Property(name="gender", data_type=DataType.TEXT),
        ],
        vectorizer_config=named_vectors,
    )
    print(f"Collection '{collection_name}' creee avec vecteurs: {vector_names}")


if RECREATE:
    if client.collections.exists(COLLECTION_NAME):
        client.collections.delete(COLLECTION_NAME)
        print(f"Collection '{COLLECTION_NAME}' supprimee.")
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
elif not client.collections.exists(COLLECTION_NAME):
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
else:
    print(f"Collection '{COLLECTION_NAME}' existe deja, on garde les donnees.")

## 7. Chargement du dataset ASOS

In [None]:
from datasets import load_dataset

print("Chargement du dataset ASOS...")
dataset = load_dataset("UniqueData/asos-e-commerce-dataset", split="train")
print(f"Dataset charge: {len(dataset)} produits")

if MAX_ITEMS:
    dataset = dataset.select(range(min(MAX_ITEMS, len(dataset))))
    print(f"Limite a {len(dataset)} produits")

## 8. Fonctions utilitaires

In [None]:
import ast
import base64
import io
import re
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import requests
from PIL import Image


def extract_images(images_field):
    """Extract image URLs (field is a string repr of a list)."""
    if not images_field:
        return []
    if isinstance(images_field, str):
        try:
            parsed = ast.literal_eval(images_field)
            if isinstance(parsed, list):
                return [u for u in parsed if isinstance(u, str) and u.startswith("http")]
        except (ValueError, SyntaxError):
            pass
        if images_field.startswith("http"):
            return [images_field]
    if isinstance(images_field, list):
        return [u for u in images_field if isinstance(u, str) and u.startswith("http")]
    return []


def extract_description(desc_field):
    """Extract brand + text from description (string repr of list of dicts)."""
    if not desc_field:
        return None, ""
    data = desc_field
    if isinstance(data, str):
        try:
            data = ast.literal_eval(data)
        except (ValueError, SyntaxError):
            return None, str(desc_field)
    if isinstance(data, list):
        brand = None
        texts = []
        for entry in data:
            if isinstance(entry, dict):
                for key, val in entry.items():
                    if "brand" in key.lower():
                        brand = str(val) if val else None
                    else:
                        texts.append(str(val))
        return brand, " ".join(texts)
    return None, str(desc_field)


def extract_price(price_field):
    """Extract price (string in the dataset, e.g. '49.99')."""
    if price_field is None:
        return None
    try:
        val = float(price_field)
        return val if val > 0 else None
    except (ValueError, TypeError):
        return None


def detect_gender(product_name, category):
    text = f"{product_name} {category}".lower()
    for kw in ["women's", "womens", "female", "femme", " woman ", "for women", "ladies", "maternity"]:
        if kw in text:
            return "women"
    for kw in ["men's", "mens", "male", "homme", " man ", "for men"]:
        if kw in text:
            return "men"
    return None


def build_description(item):
    """Build rich text from ASOS product fields for BM25 + text encoding."""
    parts = []
    name = item.get("name", "") or ""
    if name:
        parts.append(name)
    color = item.get("color", "") or ""
    if color:
        parts.append(f"Color: {color}")
    category = item.get("category", "") or ""
    if category and category != name:
        parts.append(f"Category: {category}")
    price = item.get("price", "") or ""
    if price:
        parts.append(f"Price: {price}")
    size = item.get("size", "") or ""
    if size:
        parts.append(f"Sizes: {size}")

    raw_desc = item.get("description", "")
    if raw_desc:
        try:
            if isinstance(raw_desc, str):
                desc_list = json.loads(raw_desc.replace("'", '"'))
            else:
                desc_list = raw_desc
            if isinstance(desc_list, list):
                for entry in desc_list:
                    if isinstance(entry, dict):
                        for key, val in entry.items():
                            if key == "Brand":
                                continue
                            parts.append(str(val))
        except (json.JSONDecodeError, ValueError):
            cleaned = re.sub(r"<[^>]+>", " ", str(raw_desc))
            cleaned = re.sub(r"\s+", " ", cleaned).strip()
            if cleaned:
                parts.append(cleaned)

    return " | ".join(parts)


def download_image(url, timeout=10):
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception:
        return None


def download_images_parallel(urls, max_workers=8, timeout=10):
    """Download multiple images in parallel. Returns list of (index, url, image) tuples."""
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_info = {
            executor.submit(download_image, url, timeout): (idx, url)
            for idx, url in enumerate(urls)
        }
        for future in as_completed(future_to_info):
            idx, url = future_to_info[future]
            img = future.result()
            if img is not None:
                results.append((idx, url, img))
    results.sort(key=lambda x: x[0])
    return results


def make_thumbnail_b64(image, size=THUMBNAIL_SIZE):
    img = image.copy()
    img.thumbnail((size, size))
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


# --- Encode functions (HuggingFace + OpenCLIP) ---

def _to_tensor(feat):
    """Extract tensor from model output (handles both raw tensors and BaseModelOutput)."""
    if isinstance(feat, torch.Tensor):
        return feat
    if hasattr(feat, "pooler_output") and feat.pooler_output is not None:
        return feat.pooler_output
    if hasattr(feat, "last_hidden_state"):
        return feat.last_hidden_state[:, 0, :]
    raise ValueError(f"Cannot extract tensor from {type(feat)}")


def _encode_image(model_type, model, processor, img):
    """Encode image -> normalized numpy vector."""
    if model_type == "openclip":
        image_tensor = processor(img).unsqueeze(0).to(device)
        feat = model.encode_image(image_tensor)
    else:
        inputs = processor(images=img, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_image_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy().flatten()


def _encode_text(model_type, model, processor, tokenizer, text, max_len=77):
    """Encode text -> normalized numpy vector. Truncates to max_len tokens."""
    if model_type == "openclip":
        # open_clip tokenizer auto-truncates to 77
        tokens = tokenizer([text]).to(device)
        feat = model.encode_text(tokens)
    else:
        # Use the raw tokenizer directly to guarantee max_length is respected
        # (AutoProcessor doesn't always forward max_length properly)
        tok = getattr(processor, "tokenizer", processor)
        inputs = tok(
            [text], return_tensors="pt", padding=True,
            truncation=True, max_length=max_len,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_text_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy().flatten()


# Pre-create one CUDA stream per model for true GPU parallelism
_model_streams = {}
if device == "cuda":
    for name in MODELS:
        _model_streams[name] = torch.cuda.Stream()
    print(f"Created {len(_model_streams)} CUDA streams for parallel encoding")


def encode_combined_all(pil_img, description, image_weight=IMAGE_WEIGHT):
    """Encode image+text avec les 3 modeles en parallele (ThreadPoolExecutor + CUDA streams).

    Combined = normalize(w * img_vec + (1-w) * txt_vec)
    Falls back to image-only si description vide.

    Chaque thread utilise son propre CUDA stream pour que les kernels GPU
    s'executent vraiment en parallele (sinon le stream par defaut serialise tout).
    """
    img = pil_img.convert("RGB")
    has_text = bool(description and description.strip())

    def encode_one(name, model_tuple):
        model_type, model, processor, tokenizer, max_len = model_tuple
        stream = _model_streams.get(name)
        stream_ctx = torch.cuda.stream(stream) if stream else nullcontext()
        with torch.no_grad(), stream_ctx:
            img_vec = _encode_image(model_type, model, processor, img)
            if has_text:
                txt_vec = _encode_text(model_type, model, processor, tokenizer, description, max_len)
                combined = image_weight * img_vec + (1 - image_weight) * txt_vec
                norm = np.linalg.norm(combined)
                if norm > 0:
                    combined = combined / norm
                result = combined
            else:
                result = img_vec
        # Wait for this stream's GPU work to finish before returning
        if stream:
            stream.synchronize()
        return name, result

    with ThreadPoolExecutor(max_workers=3) as executor:
        futures = [executor.submit(encode_one, n, t) for n, t in MODELS.items()]
        return {f.result()[0]: f.result()[1] for f in futures}


# Quick sanity check
_test_img = Image.new("RGB", (64, 64), "red")
_test_vecs = encode_combined_all(_test_img, "red test shirt")
for k, v in _test_vecs.items():
    print(f"  {k}: shape={v.shape}, dtype={v.dtype}")
print("OK")

In [None]:
## Benchmark : sequentiel vs parallele (CUDA streams)
import time

_bench_img = Image.new("RGB", (384, 384), "blue")
_bench_desc = "blue denim jacket with silver buttons, casual style"
N_RUNS = 20

# Warmup GPU
for _ in range(3):
    encode_combined_all(_bench_img, _bench_desc)
torch.cuda.synchronize() if device == "cuda" else None

# --- Sequentiel (boucle for, pas de threads) ---
def encode_sequential(pil_img, description, image_weight=IMAGE_WEIGHT):
    img = pil_img.convert("RGB")
    has_text = bool(description and description.strip())
    vectors = {}
    with torch.no_grad():
        for name, (model_type, model, processor, tokenizer, max_len) in MODELS.items():
            img_vec = _encode_image(model_type, model, processor, img)
            if has_text:
                txt_vec = _encode_text(model_type, model, processor, tokenizer, description, max_len)
                combined = image_weight * img_vec + (1 - image_weight) * txt_vec
                norm = np.linalg.norm(combined)
                if norm > 0:
                    combined = combined / norm
                vectors[name] = combined
            else:
                vectors[name] = img_vec
    return vectors

if device == "cuda":
    torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(N_RUNS):
    encode_sequential(_bench_img, _bench_desc)
if device == "cuda":
    torch.cuda.synchronize()
seq_time = time.perf_counter() - t0

# --- Parallele (ThreadPoolExecutor + CUDA streams) ---
if device == "cuda":
    torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(N_RUNS):
    encode_combined_all(_bench_img, _bench_desc)
if device == "cuda":
    torch.cuda.synchronize()
par_time = time.perf_counter() - t0

print(f"Benchmark sur {N_RUNS} images ({device}):")
print(f"  Sequentiel : {seq_time:.2f}s ({N_RUNS/seq_time:.1f} img/s) — {seq_time/N_RUNS*1000:.0f}ms/img")
print(f"  Parallele  : {par_time:.2f}s ({N_RUNS/par_time:.1f} img/s) — {par_time/N_RUNS*1000:.0f}ms/img")
print(f"  Speedup    : {seq_time/par_time:.2f}x")

## 9. Indexation (GPU batch encoding → Weaviate sur GCP)

In [None]:
import time
from tqdm.auto import tqdm

collection = client.collections.get(COLLECTION_NAME)

# --- Checkpoint & reprise ---
checkpoint = load_checkpoint()
start_product = 0
if RECREATE:
    start_product = 0
elif checkpoint["last_index"] >= 0:
    start_product = checkpoint["last_index"] + 1
    print(f"Reprise depuis le produit {start_product} ({checkpoint['indexed_count']} deja indexes)")

# --- Recupere les product_ids deja indexes (dedup) ---
existing_ids = set()
for obj in collection.iterator(return_properties=["product_id"]):
    pid = obj.properties.get("product_id", "")
    if pid:
        existing_ids.add(pid)
print(f"Deja indexes: {len(existing_ids)} produits — on skip ceux-la")

indexed = checkpoint["indexed_count"] if not RECREATE else 0
errors = checkpoint["errors"] if not RECREATE else 0
skipped = 0
already = 0
start_time = time.time()

# --- Collect download tasks (skip already indexed) ---
download_tasks = []

pbar_prep = tqdm(total=len(dataset), desc="Preparation", initial=start_product)

for product_idx in range(start_product, len(dataset)):
    item = dataset[product_idx]
    product_id = str(int(item.get("sku", 0))) if item.get("sku") else str(product_idx)

    if product_id in existing_ids:
        already += 1
        pbar_prep.update(1)
        continue

    product_name = item.get("name", "")
    category = item.get("category", "")
    color = item.get("color", "")
    price = extract_price(item.get("price"))
    product_url = item.get("url", "")

    brand, desc_text = extract_description(item.get("description"))
    image_urls = extract_images(item.get("images"))
    gender = detect_gender(product_name or "", category or "")
    description = build_description(item)

    if not image_urls:
        skipped += 1
        pbar_prep.update(1)
        continue

    for idx, url in enumerate(image_urls[:MAX_IMAGES_PER_PRODUCT]):
        meta = {
            "filename": f"{product_id}_{idx}.jpg",
            "path": url,
            "product_id": product_id,
            "product_name": product_name or "",
            "category": category or "",
            "color": color or "",
            "price": price,
            "brand": brand or "",
            "product_url": product_url or "",
            "description": description,
            "image_index": idx,
            "gender": gender,
            "_product_idx": product_idx,
        }
        download_tasks.append((url, meta))

    pbar_prep.update(1)

pbar_prep.close()
print(f"{already} produits deja indexes (skip)")
print(f"{len(download_tasks)} images restantes a telecharger")
print(f"Poids image/texte: {IMAGE_WEIGHT:.1f} / {1 - IMAGE_WEIGHT:.1f}")
print(f"Config: WEAVIATE_BATCH={WEAVIATE_BATCH_SIZE}, WORKERS={DOWNLOAD_WORKERS}, PREFETCH={PREFETCH_SIZE}")

# --- Download + encode (3 modeles) + push ---
weaviate_queue = []
last_product_idx = start_product

def flush_weaviate():
    global weaviate_queue, indexed, last_product_idx
    if not weaviate_queue:
        return
    with collection.batch.dynamic() as batch:
        for doc in weaviate_queue:
            vecs = doc.pop("_vectors")
            pidx = doc.pop("_product_idx", 0)
            batch.add_object(properties=doc, vector=vecs)
            last_product_idx = max(last_product_idx, pidx)
    indexed += len(weaviate_queue)
    save_checkpoint(last_product_idx, indexed, errors)
    weaviate_queue = []


pbar = tqdm(total=len(download_tasks), desc="Download + Encode (3 modeles)")

for batch_start in range(0, len(download_tasks), PREFETCH_SIZE):
    batch_slice = download_tasks[batch_start:batch_start + PREFETCH_SIZE]
    urls = [t[0] for t in batch_slice]

    downloaded = download_images_parallel(urls, max_workers=DOWNLOAD_WORKERS)

    for local_idx, url, img in downloaded:
        try:
            meta = batch_slice[local_idx][1].copy()
            w, h = img.size
            meta["thumbnail_base64"] = make_thumbnail_b64(img)
            meta["width"] = w
            meta["height"] = h
            meta["indexed_at"] = datetime.now(timezone.utc).isoformat()

            # Encode avec les 3 modeles (image + text combined)
            vectors = encode_combined_all(img, meta.get("description", ""))

            meta["_vectors"] = {k: v.tolist() for k, v in vectors.items()}
            weaviate_queue.append(meta)

            if len(weaviate_queue) >= WEAVIATE_BATCH_SIZE:
                flush_weaviate()

        except Exception as e:
            errors += 1
            if errors <= 10:
                tqdm.write(f"  Erreur: {e}")

    pbar.update(len(batch_slice))

# Flush remaining
flush_weaviate()

pbar.close()
elapsed = time.time() - start_time

print(f"\n{'='*50}")
print(f"Indexation terminee !")
print(f"Deja presents: {already}")
print(f"Nouvelles images indexees: {indexed}")
print(f"Sans image: {skipped}")
print(f"Erreurs: {errors}")
print(f"Temps: {elapsed:.0f}s ({indexed / max(elapsed, 1):.1f} images/sec)")
print(f"{'='*50}")
clear_checkpoint()

## 10. Verification

In [None]:
collection = client.collections.get(COLLECTION_NAME)
stats = collection.aggregate.over_all(total_count=True)
print(f"Collection '{COLLECTION_NAME}': {stats.total_count} documents")

## 11. Test rapide

In [None]:
from weaviate.classes.query import MetadataQuery
from IPython.display import display, HTML

query = "black leather jacket"

# Test avec chaque vecteur
for vec_name, (model_type, model, processor, tokenizer, max_len) in MODELS.items():
    with torch.no_grad():
        if model_type == "openclip":
            tokens = tokenizer([query]).to(device)
            features = model.encode_text(tokens)
        else:
            tok = getattr(processor, "tokenizer", processor)
            inputs = tok(
                [query], return_tensors="pt", padding=True,
                truncation=True, max_length=max_len,
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            features = model.get_text_features(**inputs)
        features = _to_tensor(features)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
        query_vector = features.cpu().float().numpy().flatten().tolist()

    results = collection.query.near_vector(
        near_vector=query_vector,
        target_vector=vec_name,
        limit=5,
        return_metadata=MetadataQuery(distance=True),
    )

    print(f"\n--- {vec_name} : '{query}' ---")
    html = '<div style="display:flex; gap:10px; flex-wrap:wrap;">'
    for obj in results.objects:
        p = obj.properties
        score = f"{1 - (obj.metadata.distance or 0):.3f}"
        thumb = p.get("thumbnail_base64", "")
        name = p.get("product_name", "")
        price = p.get("price")
        price_str = f"\u00a3{price:.2f}" if price else ""
        html += f'''
        <div style="text-align:center; width:160px;">
            <img src="data:image/jpeg;base64,{thumb}" style="max-width:150px; max-height:150px;"/>
            <div style="font-size:11px;">{name[:40]}</div>
            <div style="font-size:11px; color:gray;">{price_str} - score: {score}</div>
        </div>'''
    html += '</div>'
    display(HTML(html))

In [None]:
_keepalive_stop.set()  # Stop le thread keep-alive
client.close()
print("Connexion Weaviate fermee.")