# Fashion Search — Indexation ASOS depuis Colab (GPU)

Ce notebook indexe le dataset [ASOS e-commerce](https://huggingface.co/datasets/UniqueData/asos-e-commerce-dataset) dans Weaviate avec **3 named vectors** :
- **fashion_clip** (512d) — `patrickjohncyh/fashion-clip`
- **marqo_clip** (512d) — `Marqo/marqo-fashionCLIP`
- **siglip2** (1152d) — `google/siglip2-so400m-patch14-384`

**Features :**
- **Batch GPU encoding** : N images encodées d'un coup par modèle (vs 1 par 1)
- 3 modèles en parallèle via CUDA streams + ThreadPoolExecutor
- Combined image + text embeddings (0.5 img + 0.5 txt: category, color, occasion)
- Champ `occasion` auto-généré (quand porter ce vêtement : sport, office, party...)
- Descriptions texte pour BM25 search (color, category, description, occasion)
- Checkpoint / reprise automatique
- Keep-alive Weaviate (ping toutes les 2 min)

**Architecture :**
- **Colab (GPU)** : encode les images en batch avec les 3 modèles → pousse les vecteurs dans Weaviate
- **GCP (CPU)** : Weaviate + app FastAPI pour servir les recherches

**Prérequis :**
- Runtime GPU activé (Runtime → Change runtime type → T4 GPU)
- Weaviate accessible sur ta VM GCP (port 8080 ouvert dans le firewall)

## 1. Installation des dépendances

In [None]:
# Install deps (sans toucher au torch pre-installe sur Colab)
!pip install -q weaviate-client>=4.0 transformers Pillow datasets requests tqdm timm ftfy regex
!pip install -q --no-deps open-clip-torch

## 2. Configuration

Renseigne l'IP externe de ta VM GCP où tourne Weaviate.

**Important** : le port `8080` (HTTP) et `50051` (gRPC) doivent être ouverts dans le firewall GCP.

In [None]:
# --- CONFIGURATION ---
GCP_EXTERNAL_IP = ""   # ex: "34.56.78.90"

WEAVIATE_HTTP_PORT = 8080
WEAVIATE_GRPC_PORT = 50051

COLLECTION_NAME = "FashionCollection"

# Modèles
MODEL_FASHION_CLIP = "patrickjohncyh/fashion-clip"
MODEL_MARQO_CLIP = "Marqo/marqo-fashionCLIP"
MODEL_SIGLIP2 = "google/siglip2-so400m-patch14-384"

# Named vector keys
VECTOR_FASHION_CLIP = "fashion_clip"
VECTOR_MARQO_CLIP = "marqo_clip"
VECTOR_SIGLIP2 = "siglip2"
ALL_VECTOR_NAMES = [VECTOR_FASHION_CLIP, VECTOR_MARQO_CLIP, VECTOR_SIGLIP2]

# --- Tuning A100 40GB + 83GB RAM ---
MAX_ITEMS = None              # tout le dataset (~45k produits)
MAX_IMAGES_PER_PRODUCT = 1    # images par produit
WEAVIATE_BATCH_SIZE = 500     # objets par flush vers Weaviate
DOWNLOAD_WORKERS = 128        # téléchargements parallèles (phase 1)
ENCODE_BATCH_SIZE = 128       # images par batch GPU (128 pour A100, 32 pour T4)
THUMBNAIL_SIZE = 150
IMAGE_WEIGHT = 0.5            # 50/50 image et texte (category + color + occasion)
RECREATE = False              # False = reprise, True = tout recréer

# Répertoire local pour les images pré-téléchargées
IMAGE_CACHE_DIR = "/content/asos_images"

# Checkpoint (reprise automatique)
CHECKPOINT_FILE = "/content/index_checkpoint.json"
KEEPALIVE_INTERVAL = 120      # secondes

## 3. Vérification GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    device = "cpu"
    print("Pas de GPU, l'indexation sera lente.")

print(f"Device: {device}")

## 4. Chargement des 3 modèles

In [None]:
from transformers import CLIPModel, CLIPProcessor, AutoModel, AutoProcessor
import open_clip

# Fashion CLIP (512d) — HuggingFace format
print(f"Chargement de {MODEL_FASHION_CLIP}...")
fc_model = CLIPModel.from_pretrained(MODEL_FASHION_CLIP).to(device).eval()
fc_processor = CLIPProcessor.from_pretrained(MODEL_FASHION_CLIP)

# Marqo FashionCLIP (512d) — OpenCLIP format
print(f"Chargement de {MODEL_MARQO_CLIP}...")
mq_model, _, mq_preprocess = open_clip.create_model_and_transforms(
    f"hf-hub:{MODEL_MARQO_CLIP}", device=device
)
mq_tokenizer = open_clip.get_tokenizer(f"hf-hub:{MODEL_MARQO_CLIP}")
mq_model = mq_model.eval()

# SigLIP2-SO400M (1152d) — HuggingFace format
print(f"Chargement de {MODEL_SIGLIP2}...")
sl_model = AutoModel.from_pretrained(MODEL_SIGLIP2).to(device).eval()
sl_processor = AutoProcessor.from_pretrained(MODEL_SIGLIP2)

# Dict unifie : (type, model, processor/preprocess, tokenizer, max_text_tokens)
MODELS = {
    VECTOR_FASHION_CLIP: ("hf", fc_model, fc_processor, None, 77),
    VECTOR_MARQO_CLIP: ("openclip", mq_model, mq_preprocess, mq_tokenizer, 77),
    VECTOR_SIGLIP2: ("hf", sl_model, sl_processor, None, 64),
}

# A100 40GB : ~5GB en FP32 pour les 3 modeles, pas besoin de FP16
if device == "cuda":
    vram_used = torch.cuda.memory_allocated() / 1024**3
    vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"\n3 modeles charges sur {device} — VRAM: {vram_used:.1f}/{vram_total:.0f} GB")
else:
    print(f"\n3 modeles charges sur {device}.")

## 5. Connexion a Weaviate sur GCP

In [None]:
import weaviate
import json
import os
import threading
from datetime import datetime, timezone
from weaviate.classes.config import Configure, DataType, Property, VectorDistances

assert GCP_EXTERNAL_IP, "Renseigne GCP_EXTERNAL_IP dans la cellule de configuration."

client = weaviate.connect_to_custom(
    http_host=GCP_EXTERNAL_IP,
    http_port=WEAVIATE_HTTP_PORT,
    http_secure=False,
    grpc_host=GCP_EXTERNAL_IP,
    grpc_port=WEAVIATE_GRPC_PORT,
    grpc_secure=False,
)

print(f"Connecte a Weaviate sur {GCP_EXTERNAL_IP}" if client.is_ready() else "Echec de connexion")

# --- Keep-alive (ping toutes les 2 min, auto-reconnect) ---
_keepalive_stop = threading.Event()

def start_keepalive(wv_client, interval=KEEPALIVE_INTERVAL):
    def _run():
        while not _keepalive_stop.wait(interval):
            try:
                if not wv_client.is_ready():
                    raise ConnectionError("not ready")
            except Exception as e:
                print(f"\n[keep-alive] Reconnexion... ({e})")
                try:
                    wv_client.close()
                except Exception:
                    pass
                wv_client.connect_to_custom(
                    http_host=GCP_EXTERNAL_IP,
                    http_port=WEAVIATE_HTTP_PORT,
                    http_secure=False,
                    grpc_host=GCP_EXTERNAL_IP,
                    grpc_port=WEAVIATE_GRPC_PORT,
                    grpc_secure=False,
                )
    t = threading.Thread(target=_run, daemon=True)
    t.start()
    return t

start_keepalive(client)
print("Keep-alive demarre (ping toutes les 2 min)")

# --- Checkpoint helpers ---
def load_checkpoint():
    try:
        if os.path.exists(CHECKPOINT_FILE):
            with open(CHECKPOINT_FILE) as f:
                return json.load(f)
    except Exception:
        pass
    return {"last_index": -1, "indexed_count": 0, "errors": 0}

def save_checkpoint(last_index, indexed_count, errors):
    with open(CHECKPOINT_FILE, "w") as f:
        json.dump({"last_index": last_index, "indexed_count": indexed_count,
                    "errors": errors, "timestamp": datetime.now(timezone.utc).isoformat()}, f)

def clear_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

## 6. Creation du schema

In [None]:
def create_schema(wv_client, collection_name, vector_names=None):
    """Crée la collection avec named vectors."""
    if vector_names is None:
        vector_names = ALL_VECTOR_NAMES

    named_vectors = [
        Configure.NamedVectors.none(
            name=name,
            vector_index_config=Configure.VectorIndex.hnsw(
                distance_metric=VectorDistances.COSINE
            ),
        )
        for name in vector_names
    ]

    wv_client.collections.create(
        name=collection_name,
        properties=[
            Property(name="filename", data_type=DataType.TEXT),
            Property(name="path", data_type=DataType.TEXT),
            Property(name="thumbnail_base64", data_type=DataType.TEXT),
            Property(name="width", data_type=DataType.INT),
            Property(name="height", data_type=DataType.INT),
            Property(name="indexed_at", data_type=DataType.TEXT),
            Property(name="product_id", data_type=DataType.TEXT),
            Property(name="product_name", data_type=DataType.TEXT),
            Property(name="category", data_type=DataType.TEXT),
            Property(name="color", data_type=DataType.TEXT),
            Property(name="size", data_type=DataType.TEXT),
            Property(name="price", data_type=DataType.NUMBER),
            Property(name="brand", data_type=DataType.TEXT),
            Property(name="product_url", data_type=DataType.TEXT),
            Property(name="description", data_type=DataType.TEXT),
            Property(name="image_index", data_type=DataType.INT),
            Property(name="gender", data_type=DataType.TEXT),
            Property(name="occasion", data_type=DataType.TEXT),
        ],
        vectorizer_config=named_vectors,
    )
    print(f"Collection '{collection_name}' creee avec vecteurs: {vector_names}")


if RECREATE:
    if client.collections.exists(COLLECTION_NAME):
        client.collections.delete(COLLECTION_NAME)
        print(f"Collection '{COLLECTION_NAME}' supprimee.")
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
elif not client.collections.exists(COLLECTION_NAME):
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
else:
    print(f"Collection '{COLLECTION_NAME}' existe deja, on garde les donnees.")

## 7. Chargement du dataset ASOS

In [None]:
from datasets import load_dataset

print("Chargement du dataset ASOS...")
dataset = load_dataset("UniqueData/asos-e-commerce-dataset", split="train")
print(f"Dataset charge: {len(dataset)} produits")

if MAX_ITEMS:
    dataset = dataset.select(range(min(MAX_ITEMS, len(dataset))))
    print(f"Limite a {len(dataset)} produits")

## 7b. Phase 1 : Pré-téléchargement de toutes les images sur disque

Télécharge toutes les images en parallèle AVANT l'encodage GPU.
Comme ça le GPU ne dort jamais en attendant le réseau.
~45k images × ~100-200KB = ~5-10GB sur disque, ça tient largement.

In [None]:
import ast
import io
import os
import time
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from tqdm.auto import tqdm

os.makedirs(IMAGE_CACHE_DIR, exist_ok=True)


def _download_and_save(args):
    """Download one image and save to disk. Returns (product_id, idx, success)."""
    product_id, idx, url, filepath = args
    if os.path.exists(filepath):
        return product_id, idx, True  # deja telecharge
    try:
        r = requests.get(url, timeout=15, stream=True)
        r.raise_for_status()
        with open(filepath, "wb") as f:
            for chunk in r.iter_content(8192):
                f.write(chunk)
        return product_id, idx, True
    except Exception:
        return product_id, idx, False


# --- Collect all image URLs ---
download_jobs = []
no_image_count = 0

for product_idx in range(len(dataset)):
    item = dataset[product_idx]
    product_id = str(int(item.get("sku", 0))) if item.get("sku") else str(product_idx)

    raw_images = item.get("images", "")
    if not raw_images:
        no_image_count += 1
        continue
    if isinstance(raw_images, str):
        try:
            parsed = ast.literal_eval(raw_images)
            image_urls = [u for u in parsed if isinstance(u, str) and u.startswith("http")]
        except (ValueError, SyntaxError):
            image_urls = [raw_images] if raw_images.startswith("http") else []
    elif isinstance(raw_images, list):
        image_urls = [u for u in raw_images if isinstance(u, str) and u.startswith("http")]
    else:
        image_urls = []

    if not image_urls:
        no_image_count += 1
        continue

    for idx, url in enumerate(image_urls[:MAX_IMAGES_PER_PRODUCT]):
        filepath = os.path.join(IMAGE_CACHE_DIR, f"{product_id}_{idx}.jpg")
        download_jobs.append((product_id, idx, url, filepath))

# Count already cached
already_cached = sum(1 for _, _, _, fp in download_jobs if os.path.exists(fp))
to_download = len(download_jobs) - already_cached

print(f"Total images: {len(download_jobs)}")
print(f"Deja sur disque: {already_cached}")
print(f"A telecharger: {to_download}")
print(f"Sans image: {no_image_count}")

if to_download > 0:
    t0 = time.time()
    success = 0
    failed = 0

    with ThreadPoolExecutor(max_workers=DOWNLOAD_WORKERS) as executor:
        futures = {executor.submit(_download_and_save, job): job for job in download_jobs}
        with tqdm(total=len(download_jobs), desc="Phase 1 : Telechargement") as pbar:
            for future in as_completed(futures):
                _, _, ok = future.result()
                if ok:
                    success += 1
                else:
                    failed += 1
                pbar.update(1)

    elapsed = time.time() - t0
    print(f"\nTelechargement termine: {success} OK, {failed} echecs en {elapsed:.0f}s")
else:
    print("Toutes les images sont deja sur disque !")

# Disk usage
total_size = sum(
    os.path.getsize(os.path.join(IMAGE_CACHE_DIR, f))
    for f in os.listdir(IMAGE_CACHE_DIR)
    if f.endswith(".jpg")
)
cached_count = len([f for f in os.listdir(IMAGE_CACHE_DIR) if f.endswith(".jpg")])
print(f"Cache: {cached_count} images, {total_size / 1024**3:.1f} GB")

## 8. Fonctions utilitaires

In [None]:
import ast
import base64
import io
import re
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import requests
from PIL import Image


def extract_images(images_field):
    """Extract image URLs (field is a string repr of a list)."""
    if not images_field:
        return []
    if isinstance(images_field, str):
        try:
            parsed = ast.literal_eval(images_field)
            if isinstance(parsed, list):
                return [u for u in parsed if isinstance(u, str) and u.startswith("http")]
        except (ValueError, SyntaxError):
            pass
        if images_field.startswith("http"):
            return [images_field]
    if isinstance(images_field, list):
        return [u for u in images_field if isinstance(u, str) and u.startswith("http")]
    return []


def extract_description(desc_field):
    """Extract brand + text from description (string repr of list of dicts)."""
    if not desc_field:
        return None, ""
    data = desc_field
    if isinstance(data, str):
        try:
            data = ast.literal_eval(data)
        except (ValueError, SyntaxError):
            return None, str(desc_field)
    if isinstance(data, list):
        brand = None
        texts = []
        for entry in data:
            if isinstance(entry, dict):
                for key, val in entry.items():
                    if "brand" in key.lower():
                        brand = str(val) if val else None
                    else:
                        texts.append(str(val))
        return brand, " ".join(texts)
    return None, str(desc_field)


def extract_price(price_field):
    """Extract price (string in the dataset, e.g. '49.99')."""
    if price_field is None:
        return None
    try:
        val = float(price_field)
        return val if val > 0 else None
    except (ValueError, TypeError):
        return None


def detect_gender(product_name, category):
    text = f"{product_name} {category}".lower()
    for kw in ["women's", "womens", "female", "femme", " woman ", "for women", "ladies", "maternity"]:
        if kw in text:
            return "women"
    for kw in ["men's", "mens", "male", "homme", " man ", "for men"]:
        if kw in text:
            return "men"
    return None


def detect_occasion(product_name, category):
    """Detect when/where to wear based on product name and category.

    Returns a string of occasion keywords for BM25 search + vector encoding.
    Keywords match the chatbot context_map in engine.py:
      casual, office/professional, formal/ceremony, sport/athletic, evening/party.
    """
    text = f"{product_name} {category}".lower()
    occasions = []

    # Sport / Athletic
    if any(kw in text for kw in [
        "running", "trainer", "sneaker", "jogger", "legging", "sports",
        "athletic", "gym", "yoga", "swim", "activewear", "workout",
        "fitness", "track", "cycling", "football", "basketball",
    ]):
        occasions.append("sport athletic activewear gym workout outdoor")

    # Formal / Ceremony
    if any(kw in text for kw in [
        "suit", "tuxedo", "gown", "evening dress", "formal", "bow tie",
        "cufflink", "waistcoat", "wedding",
    ]):
        occasions.append("formal ceremony wedding elegant evening")

    # Office / Professional
    if any(kw in text for kw in [
        "blazer", "shirt", "blouse", "trouser", "pencil", "chino",
        "oxford", "loafer", "smart", "tailored",
    ]):
        occasions.append("office work professional business classic")

    # Evening / Party
    if any(kw in text for kw in [
        "sequin", "glitter", "satin", "silk", "club", "party", "cocktail",
        "bodycon", "metallic", "sparkle", "mini dress", "evening",
    ]):
        occasions.append("party evening nightout glamorous")

    # Outdoor / Winter
    if any(kw in text for kw in [
        "coat", "puffer", "parka", "scarf", "glove", "beanie",
        "thermal", "fleece", "waterproof", "rain", "hiking", "down jacket",
    ]):
        occasions.append("outdoor winter cold weather")

    # Beach / Summer
    if any(kw in text for kw in [
        "bikini", "swimsuit", "swimwear", "sandal", "flip flop", "linen",
        "tank top", "vest top", "beach", "tropical", "resort",
    ]):
        occasions.append("beach summer vacation holiday")

    # Casual / Everyday (default)
    if any(kw in text for kw in [
        "t-shirt", "tee", "jeans", "denim", "hoodie", "sweatshirt",
        "cardigan", "pullover", "polo", "shorts", "canvas", "jersey",
        "crop top", "cargo",
    ]):
        occasions.append("casual everyday relaxed daily")

    if not occasions:
        occasions.append("casual everyday")

    return " | ".join(occasions)


def build_description(item):
    """Build rich text from ASOS product fields for BM25 + text encoding."""
    parts = []
    name = item.get("name", "") or ""
    if name:
        parts.append(name)
    color = item.get("color", "") or ""
    if color:
        parts.append(f"Color: {color}")
    category = item.get("category", "") or ""
    if category and category != name:
        parts.append(f"Category: {category}")
    price = item.get("price", "") or ""
    if price:
        parts.append(f"Price: {price}")
    size = item.get("size", "") or ""
    if size:
        parts.append(f"Sizes: {size}")

    raw_desc = item.get("description", "")
    if raw_desc:
        try:
            if isinstance(raw_desc, str):
                desc_list = json.loads(raw_desc.replace("'", '"'))
            else:
                desc_list = raw_desc
            if isinstance(desc_list, list):
                for entry in desc_list:
                    if isinstance(entry, dict):
                        for key, val in entry.items():
                            if key == "Brand":
                                continue
                            parts.append(str(val))
        except (json.JSONDecodeError, ValueError):
            cleaned = re.sub(r"<[^>]+>", " ", str(raw_desc))
            cleaned = re.sub(r"\s+", " ", cleaned).strip()
            if cleaned:
                parts.append(cleaned)

    return " | ".join(parts)


def download_image(url, timeout=10):
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception:
        return None


def download_images_parallel(urls, max_workers=8, timeout=10):
    """Download multiple images in parallel. Returns list of (index, url, image) tuples."""
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_info = {
            executor.submit(download_image, url, timeout): (idx, url)
            for idx, url in enumerate(urls)
        }
        for future in as_completed(future_to_info):
            idx, url = future_to_info[future]
            img = future.result()
            if img is not None:
                results.append((idx, url, img))
    results.sort(key=lambda x: x[0])
    return results


def make_thumbnail_b64(image, size=THUMBNAIL_SIZE):
    img = image.copy()
    img.thumbnail((size, size))
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


# --- Single-image encode functions ---

def _to_tensor(feat):
    """Extract tensor from model output (handles both raw tensors and BaseModelOutput)."""
    if isinstance(feat, torch.Tensor):
        return feat
    if hasattr(feat, "pooler_output") and feat.pooler_output is not None:
        return feat.pooler_output
    if hasattr(feat, "last_hidden_state"):
        return feat.last_hidden_state[:, 0, :]
    raise ValueError(f"Cannot extract tensor from {type(feat)}")


def _encode_image(model_type, model, processor, img):
    """Encode image -> normalized numpy vector."""
    if model_type == "openclip":
        image_tensor = processor(img).unsqueeze(0).to(device)
        feat = model.encode_image(image_tensor)
    else:
        inputs = processor(images=img, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_image_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy().flatten()


def _encode_text(model_type, model, processor, tokenizer, text, max_len=77):
    """Encode text -> normalized numpy vector. Truncates to max_len tokens."""
    if model_type == "openclip":
        tokens = tokenizer([text]).to(device)
        feat = model.encode_text(tokens)
    else:
        tok = getattr(processor, "tokenizer", processor)
        inputs = tok(
            [text], return_tensors="pt", padding=True,
            truncation=True, max_length=max_len,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_text_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy().flatten()


# --- Batch encode functions (N images at once through GPU) ---

def _encode_images_batch(model_type, model, processor, images):
    """Encode a batch of images -> normalized numpy vectors (N, D)."""
    if model_type == "openclip":
        batch = torch.stack([processor(img) for img in images]).to(device)
        feat = model.encode_image(batch)
    else:
        inputs = processor(images=images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_image_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy()


def _encode_texts_batch(model_type, model, processor, tokenizer, texts, max_len=77):
    """Encode a batch of texts -> normalized numpy vectors (N, D)."""
    if model_type == "openclip":
        tokens = tokenizer(texts).to(device)
        feat = model.encode_text(tokens)
    else:
        tok = getattr(processor, "tokenizer", processor)
        inputs = tok(
            texts, return_tensors="pt", padding=True,
            truncation=True, max_length=max_len,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_text_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.cpu().float().numpy()


# Pre-create one CUDA stream per model for true GPU parallelism
_model_streams = {}
if device == "cuda":
    for name in MODELS:
        _model_streams[name] = torch.cuda.Stream()
    print(f"Created {len(_model_streams)} CUDA streams for parallel encoding")


def encode_combined_all(pil_img, description, image_weight=IMAGE_WEIGHT):
    """Encode 1 image+text avec les 3 modeles en parallele (CUDA streams).
    Kept for backward compat / single-image use cases.
    """
    img = pil_img.convert("RGB")
    has_text = bool(description and description.strip())

    def encode_one(name, model_tuple):
        model_type, model, processor, tokenizer, max_len = model_tuple
        stream = _model_streams.get(name)
        stream_ctx = torch.cuda.stream(stream) if stream else nullcontext()
        with torch.no_grad(), stream_ctx:
            img_vec = _encode_image(model_type, model, processor, img)
            if has_text:
                txt_vec = _encode_text(model_type, model, processor, tokenizer, description, max_len)
                combined = image_weight * img_vec + (1 - image_weight) * txt_vec
                norm = np.linalg.norm(combined)
                if norm > 0:
                    combined = combined / norm
                result = combined
            else:
                result = img_vec
        if stream:
            stream.synchronize()
        return name, result

    with ThreadPoolExecutor(max_workers=3) as executor:
        futures = [executor.submit(encode_one, n, t) for n, t in MODELS.items()]
        return {f.result()[0]: f.result()[1] for f in futures}


def encode_combined_batch(images, texts, image_weight=IMAGE_WEIGHT):
    """Encode N images+texts avec les 3 modeles en parallele (batch GPU + CUDA streams).

    Chaque modele traite le batch entier d'images d'un coup (GPU saturation),
    et les 3 modeles tournent en parallele via CUDA streams + ThreadPoolExecutor.

    Args:
        images: list of PIL images
        texts: list of text strings (same length as images)
        image_weight: weight for image vectors (1-weight for text)

    Returns:
        list of dicts {model_name: numpy vector}
    """
    n = len(images)
    imgs = [img.convert("RGB") for img in images]
    has_texts = [bool(t and t.strip()) for t in texts]
    text_indices = [i for i in range(n) if has_texts[i]]
    batch_texts = [texts[i] for i in text_indices] if text_indices else []

    def encode_model(name, model_tuple):
        model_type, mdl, processor, tokenizer, max_len = model_tuple
        stream = _model_streams.get(name)
        stream_ctx = torch.cuda.stream(stream) if stream else nullcontext()
        with torch.no_grad(), stream_ctx:
            img_vecs = _encode_images_batch(model_type, mdl, processor, imgs)
            if text_indices:
                txt_vecs = _encode_texts_batch(model_type, mdl, processor, tokenizer, batch_texts, max_len)
                combined = img_vecs.copy()
                for j, idx in enumerate(text_indices):
                    c = image_weight * img_vecs[idx] + (1 - image_weight) * txt_vecs[j]
                    norm_val = np.linalg.norm(c)
                    if norm_val > 0:
                        c = c / norm_val
                    combined[idx] = c
                result = combined
            else:
                result = img_vecs
        if stream:
            stream.synchronize()
        return name, result

    with ThreadPoolExecutor(max_workers=3) as executor:
        futures = [executor.submit(encode_model, nm, t) for nm, t in MODELS.items()]
        all_vectors = {f.result()[0]: f.result()[1] for f in futures}

    return [{name: all_vectors[name][i] for name in all_vectors} for i in range(n)]


# Quick sanity check (single image)
_test_img = Image.new("RGB", (64, 64), "red")
_test_vecs = encode_combined_all(_test_img, "red test shirt")
for k, v in _test_vecs.items():
    print(f"  {k}: shape={v.shape}, dtype={v.dtype}")

# Sanity check (batch of 4)
_test_imgs = [Image.new("RGB", (64, 64), c) for c in ["red", "blue", "green", "black"]]
_test_texts = ["red shirt", "blue jeans", "green jacket", "black shoes"]
_test_batch = encode_combined_batch(_test_imgs, _test_texts)
print(f"  Batch: {len(_test_batch)} items, keys={list(_test_batch[0].keys())}")
for k in _test_batch[0]:
    print(f"    {k}: shape={_test_batch[0][k].shape}")

# Test detect_occasion
print("\nOccasion examples:")
print(f"  'running trainer' → {detect_occasion('Nike running trainer', 'shoes')}")
print(f"  'sequin dress' → {detect_occasion('ASOS sequin mini dress', 'dresses')}")
print(f"  'puffer jacket' → {detect_occasion('Puffer jacket in black', 'coats')}")
print(f"  'slim jeans' → {detect_occasion('Slim jeans in blue', 'jeans')}")
print(f"  'oxford shirt' → {detect_occasion('Oxford shirt white', 'shirts')}")
print("OK")

In [None]:
## Benchmark : single-image vs batch encoding
import time

N_BENCH = 32  # nombre d'images pour le benchmark
_bench_imgs = [Image.new("RGB", (384, 384), c) for c in ["blue", "red", "green", "black"] * (N_BENCH // 4)]
_bench_texts = [f"color {i} denim jacket casual style" for i in range(N_BENCH)]

# Warmup GPU
for _ in range(3):
    encode_combined_all(_bench_imgs[0], _bench_texts[0])
if device == "cuda":
    torch.cuda.synchronize()

# --- Single-image (1 par 1 avec CUDA streams) ---
if device == "cuda":
    torch.cuda.synchronize()
t0 = time.perf_counter()
for img, txt in zip(_bench_imgs, _bench_texts):
    encode_combined_all(img, txt)
if device == "cuda":
    torch.cuda.synchronize()
single_time = time.perf_counter() - t0

# --- Batch (toutes les images d'un coup) ---
if device == "cuda":
    torch.cuda.synchronize()
t0 = time.perf_counter()
encode_combined_batch(_bench_imgs, _bench_texts)
if device == "cuda":
    torch.cuda.synchronize()
batch_time = time.perf_counter() - t0

print(f"Benchmark sur {N_BENCH} images ({device}):")
print(f"  Single-image : {single_time:.2f}s ({N_BENCH/single_time:.1f} img/s) — {single_time/N_BENCH*1000:.0f}ms/img")
print(f"  Batch (x{N_BENCH})  : {batch_time:.2f}s ({N_BENCH/batch_time:.1f} img/s) — {batch_time/N_BENCH*1000:.0f}ms/img")
print(f"  Speedup      : {single_time/batch_time:.1f}x")

## 9. Indexation (GPU batch encoding → Weaviate sur GCP)

In [None]:
import time
from tqdm.auto import tqdm
from PIL import Image

collection = client.collections.get(COLLECTION_NAME)

# --- Checkpoint & reprise ---
checkpoint = load_checkpoint()
start_product = 0
if RECREATE:
    start_product = 0
elif checkpoint["last_index"] >= 0:
    start_product = checkpoint["last_index"] + 1
    print(f"Reprise depuis le produit {start_product} ({checkpoint['indexed_count']} deja indexes)")

# --- Recupere les product_ids deja indexes (dedup) ---
existing_ids = set()
for obj in collection.iterator(return_properties=["product_id"]):
    pid = obj.properties.get("product_id", "")
    if pid:
        existing_ids.add(pid)
print(f"Deja indexes: {len(existing_ids)} produits — on skip ceux-la")

indexed = checkpoint["indexed_count"] if not RECREATE else 0
errors = checkpoint["errors"] if not RECREATE else 0
skipped = 0
already = 0
start_time = time.time()

# --- Collect encode tasks from local disk (no network!) ---
encode_tasks = []  # list of (filepath, meta)

pbar_prep = tqdm(total=len(dataset), desc="Preparation", initial=start_product)

for product_idx in range(start_product, len(dataset)):
    item = dataset[product_idx]
    product_id = str(int(item.get("sku", 0))) if item.get("sku") else str(product_idx)

    if product_id in existing_ids:
        already += 1
        pbar_prep.update(1)
        continue

    product_name = item.get("name", "")
    category = item.get("category", "")
    color = item.get("color", "")
    price = extract_price(item.get("price"))
    product_url = item.get("url", "")

    brand, desc_text = extract_description(item.get("description"))
    image_urls = extract_images(item.get("images"))
    gender = detect_gender(product_name or "", category or "")
    occasion = detect_occasion(product_name or "", category or "")
    description = build_description(item)

    if not image_urls:
        skipped += 1
        pbar_prep.update(1)
        continue

    for idx, url in enumerate(image_urls[:MAX_IMAGES_PER_PRODUCT]):
        filename = f"{product_id}_{idx}.jpg"
        filepath = os.path.join(IMAGE_CACHE_DIR, filename)

        if not os.path.exists(filepath):
            skipped += 1
            continue

        meta = {
            "filename": filename,
            "path": url,
            "product_id": product_id,
            "product_name": product_name or "",
            "category": category or "",
            "color": color or "",
            "price": price,
            "brand": brand or "",
            "product_url": product_url or "",
            "description": description,
            "image_index": idx,
            "gender": gender,
            "occasion": occasion,
            "_product_idx": product_idx,
        }
        encode_tasks.append((filepath, meta))

    pbar_prep.update(1)

pbar_prep.close()
print(f"{already} produits deja indexes (skip)")
print(f"{len(encode_tasks)} images a encoder depuis le disque local")
print(f"Images manquantes (non telechargees): {skipped}")
print(f"Poids image/texte: {IMAGE_WEIGHT:.1f} / {1 - IMAGE_WEIGHT:.1f}")
print(f"Vecteur = image + (category, color, occasion)")
print(f"BM25 = color, category, description, occasion")
print(f"Config: ENCODE_BATCH={ENCODE_BATCH_SIZE}, WEAVIATE_BATCH={WEAVIATE_BATCH_SIZE}")

# --- Phase 2 : Batch encode (GPU pur) + push Weaviate ---
weaviate_queue = []
last_product_idx = start_product

def flush_weaviate():
    global weaviate_queue, indexed, last_product_idx
    if not weaviate_queue:
        return
    with collection.batch.dynamic() as batch:
        for doc in weaviate_queue:
            vecs = doc.pop("_vectors")
            pidx = doc.pop("_product_idx", 0)
            batch.add_object(properties=doc, vector=vecs)
            last_product_idx = max(last_product_idx, pidx)
    indexed += len(weaviate_queue)
    save_checkpoint(last_product_idx, indexed, errors)
    weaviate_queue = []


pbar = tqdm(total=len(encode_tasks), desc="Phase 2 : Batch Encode + Push (GPU)")

for batch_start in range(0, len(encode_tasks), ENCODE_BATCH_SIZE):
    batch_slice = encode_tasks[batch_start:batch_start + ENCODE_BATCH_SIZE]

    # Load images from local disk (instant, no network)
    batch_images = []
    batch_texts = []
    batch_metas = []

    for filepath, meta in batch_slice:
        try:
            img = Image.open(filepath).convert("RGB")
            w, h = img.size
            meta = meta.copy()
            meta["thumbnail_base64"] = make_thumbnail_b64(img)
            meta["width"] = w
            meta["height"] = h
            meta["indexed_at"] = datetime.now(timezone.utc).isoformat()

            # Vecteur = image + (category, color, occasion)
            vec_text_parts = [
                meta.get("category", ""),
                meta.get("color", ""),
                meta.get("occasion", ""),
            ]
            vec_text = " ".join(p for p in vec_text_parts if p).strip()

            batch_images.append(img)
            batch_texts.append(vec_text)
            batch_metas.append(meta)
        except Exception as e:
            errors += 1
            if errors <= 10:
                tqdm.write(f"  Load error ({filepath}): {e}")

    if not batch_images:
        pbar.update(len(batch_slice))
        continue

    # Encode batch on GPU (3 models in parallel)
    try:
        batch_vectors = encode_combined_batch(batch_images, batch_texts)

        for meta, vectors in zip(batch_metas, batch_vectors):
            meta["_vectors"] = {k: v.tolist() for k, v in vectors.items()}
            weaviate_queue.append(meta)

            if len(weaviate_queue) >= WEAVIATE_BATCH_SIZE:
                flush_weaviate()
    except Exception as e:
        errors += len(batch_images)
        if errors <= 10:
            tqdm.write(f"  Batch encode error: {e}")

    pbar.update(len(batch_slice))

# Flush remaining
flush_weaviate()

pbar.close()
elapsed = time.time() - start_time

print(f"\n{'='*50}")
print(f"Indexation terminee !")
print(f"Deja presents: {already}")
print(f"Nouvelles images indexees: {indexed}")
print(f"Sans image / non telechargees: {skipped}")
print(f"Erreurs: {errors}")
print(f"Temps: {elapsed:.0f}s ({indexed / max(elapsed, 1):.1f} images/sec)")
print(f"{'='*50}")
clear_checkpoint()

## 10. Verification

In [None]:
collection = client.collections.get(COLLECTION_NAME)
stats = collection.aggregate.over_all(total_count=True)
print(f"Collection '{COLLECTION_NAME}': {stats.total_count} documents")

## 11. Test rapide

In [None]:
from weaviate.classes.query import MetadataQuery
from IPython.display import display, HTML

query = "black leather jacket"

collection = client.collections.get(COLLECTION_NAME)

# --- 1. Encode query with all 3 models ---
query_vectors = {}
for vec_name, (model_type, model, processor, tokenizer, max_len) in MODELS.items():
    with torch.no_grad():
        if model_type == "openclip":
            tokens = tokenizer([query]).to(device)
            features = model.encode_text(tokens)
        else:
            tok = getattr(processor, "tokenizer", processor)
            inputs = tok(
                [query], return_tensors="pt", padding=True,
                truncation=True, max_length=max_len,
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            features = model.get_text_features(**inputs)
        features = _to_tensor(features)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
        query_vectors[vec_name] = features.cpu().float().numpy().flatten()

# --- 2. BM25 search ---
bm25_results = collection.query.bm25(
    query=query,
    query_properties=["color", "category", "description", "occasion"],
    limit=60,
    return_metadata=MetadataQuery(score=True),
)
print(f"BM25: {len(bm25_results.objects)} results")

# --- 3. near_vector per model ---
vector_results = {}
for vec_name, vec in query_vectors.items():
    results = collection.query.near_vector(
        near_vector=vec.tolist(),
        target_vector=vec_name,
        limit=60,
        return_metadata=MetadataQuery(distance=True),
    )
    vector_results[vec_name] = results
    print(f"{vec_name}: {len(results.objects)} results")

# --- 4. RRF Fusion (same as weaviate_client.py) ---
def rrf_fusion(result_lists, limit=20, k=60):
    scores = {}
    result_map = {}
    for results in result_lists:
        for rank, obj in enumerate(results):
            key = obj.properties.get("product_id") or obj.properties.get("filename", "")
            scores[key] = scores.get(key, 0) + 1 / (k + rank + 1)
            if key not in result_map:
                result_map[key] = obj
    sorted_keys = sorted(scores, key=lambda x: scores[x], reverse=True)
    return [(result_map[k], scores[k]) for k in sorted_keys[:limit]]

all_lists = [bm25_results.objects]
for vec_name in query_vectors:
    all_lists.append(vector_results[vec_name].objects)

fused = rrf_fusion(all_lists, limit=20)
print(f"\nRRF Fusion: {len(all_lists)} sources → {len(fused)} results")

# --- 5. Display results ---
html = f'<h3>Hybrid Search: "{query}" (BM25 + 3 vectors + RRF)</h3>'
html += '<div style="display:flex; gap:10px; flex-wrap:wrap;">'
for obj, rrf_score in fused[:20]:
    p = obj.properties
    thumb = p.get("thumbnail_base64", "")
    name = p.get("product_name", "")[:45]
    color = p.get("color", "")
    category = p.get("category", "")
    occasion = p.get("occasion", "")[:30]
    price = p.get("price")
    price_str = f"£{price:.2f}" if price else ""
    html += f'''
    <div style="text-align:center; width:170px; border:1px solid #eee; padding:5px; border-radius:8px;">
        <img src="data:image/jpeg;base64,{thumb}" style="max-width:150px; max-height:150px;"/>
        <div style="font-size:11px; font-weight:bold;">{name}</div>
        <div style="font-size:10px; color:#666;">{color} | {category}</div>
        <div style="font-size:10px; color:#999;">{occasion}</div>
        <div style="font-size:11px; color:#333;">{price_str} — RRF: {rrf_score:.4f}</div>
    </div>'''
html += '</div>'
display(HTML(html))

# --- 6. Also show per-source results for comparison ---
for source_name, source_results in [("BM25", bm25_results.objects[:5])] + [
    (vn, vector_results[vn].objects[:5]) for vn in query_vectors
]:
    html = f'<h4>{source_name} top 5</h4><div style="display:flex; gap:8px; flex-wrap:wrap;">'
    for obj in source_results:
        p = obj.properties
        thumb = p.get("thumbnail_base64", "")
        name = p.get("product_name", "")[:35]
        color = p.get("color", "")
        html += f'''
        <div style="text-align:center; width:130px;">
            <img src="data:image/jpeg;base64,{thumb}" style="max-width:120px; max-height:120px;"/>
            <div style="font-size:10px;">{name}</div>
            <div style="font-size:10px; color:#888;">{color}</div>
        </div>'''
    html += '</div>'
    display(HTML(html))

In [None]:
_keepalive_stop.set()  # Stop le thread keep-alive
client.close()
print("Connexion Weaviate fermee.")