# Fashion Search — Indexation DeepFashion In-Shop depuis Colab (GPU)

Ce notebook indexe le dataset [DeepFashion In-Shop](https://huggingface.co/datasets/Marqo/deepfashion-inshop) dans Weaviate avec **2 named vectors** :
- **fashion_clip** (512d) — `patrickjohncyh/fashion-clip`
- **marqo_clip** (512d) — `Marqo/marqo-fashionCLIP`

**Features :**
- **Zero-shot tagging** : Fashion CLIP tague chaque image (type, matière, occasion, vibe) pour enrichir les descriptions BM25
- **Batch GPU encoding** : N images encodées d'un coup par modèle (vs 1 par 1)
- 2 modèles en parallèle via CUDA streams + ThreadPoolExecutor
- Combined image + text embeddings (0.5 img + 0.5 txt enrichi par les tags)
- BM25 search avec boosts (product_name^3, category^2, color^2, description, occasion)
- Checkpoint / reprise automatique
- Keep-alive Weaviate (ping toutes les 2 min)
- **Pas de phase download** — images embarquées dans le dataset HuggingFace (PIL directes)

**Architecture :**
- **Colab (GPU)** : encode les images en batch avec les 2 modèles → pousse les vecteurs dans Weaviate
- **GCP (CPU)** : Weaviate + app FastAPI pour servir les recherches

**Dataset :** 52.6k images produit studio (224×224), ~7-8k produits uniques après dédoublonnage.

**Prérequis :**
- Runtime GPU activé (Runtime → Change runtime type → T4 GPU)
- Weaviate accessible sur ta VM GCP (port 8080 ouvert dans le firewall)

## 1. Installation des dépendances

In [None]:
# Install deps (sans toucher au torch pre-installe sur Colab)
!pip install -q weaviate-client>=4.0 transformers Pillow datasets requests tqdm timm ftfy regex
!pip install -q --no-deps open-clip-torch

## 2. Configuration

Renseigne l'IP externe de ta VM GCP où tourne Weaviate.

**Important** : le port `8080` (HTTP) et `50051` (gRPC) doivent être ouverts dans le firewall GCP.

In [None]:
# --- CONFIGURATION ---
GCP_EXTERNAL_IP = ""   # ex: "34.56.78.90"

WEAVIATE_HTTP_PORT = 8080
WEAVIATE_GRPC_PORT = 50051

COLLECTION_NAME = "FashionCollection"

# Dataset
DATASET_NAME = "Marqo/deepfashion-inshop"
MAX_VIEWS_PER_PRODUCT = 1          # 1 seule vue par produit (front de préférence)

# Modèles (2 modèles spécialisés fashion)
MODEL_FASHION_CLIP = "patrickjohncyh/fashion-clip"
MODEL_MARQO_CLIP = "Marqo/marqo-fashionCLIP"

# Named vector keys
VECTOR_FASHION_CLIP = "fashion_clip"
VECTOR_MARQO_CLIP = "marqo_clip"
ALL_VECTOR_NAMES = [VECTOR_FASHION_CLIP, VECTOR_MARQO_CLIP]

# --- Tuning ---
MAX_ITEMS = None              # tout le dataset
WEAVIATE_BATCH_SIZE = 500     # objets par flush vers Weaviate
ENCODE_BATCH_SIZE = 128       # images par batch GPU (128 pour A100, 32 pour T4)
THUMBNAIL_SIZE = 150
IMAGE_WEIGHT = 0.5            # 50/50 image et texte
RECREATE = False              # False = reprise, True = tout recréer

# Marques fictives
BRANDS = [
    # Sportswear & Streetwear
    "Nike", "Adidas", "Puma", "New Balance", "The North Face",
    # Premium & Designers
    "Calvin Klein", "Tommy Hilfiger", "Polo Ralph Lauren", "BOSS", "Armani Exchange",
    # Casual & Denim
    "Levi's", "Topshop", "Dr Martens", "Carhartt WIP", "Pull&Bear",
    # Accessible & Trendy
    "Mango", "Stradivarius", "Bershka", "Superdry", "AllSaints",
]

# Prix aléatoires par catégorie (min, max en £)
PRICE_RANGES = {
    "denim": (30, 90),
    "jackets": (50, 200),
    "tops": (15, 60),
    "sweaters": (30, 100),
    "shorts": (20, 60),
    "pants": (30, 90),
    "skirts": (25, 80),
    "dresses": (35, 150),
    "rompers": (30, 80),
    "swimwear": (20, 70),
    "shoes": (40, 180),
    "bags": (30, 250),
    "accessories": (10, 80),
    "intimates": (10, 50),
    "jumpsuits": (40, 120),
    "suiting": (80, 250),
}

# Checkpoint (reprise automatique)
CHECKPOINT_FILE = "/content/index_checkpoint.json"
KEEPALIVE_INTERVAL = 120      # secondes

## 3. Vérification GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    device = "cpu"
    print("Pas de GPU, l'indexation sera lente.")

print(f"Device: {device}")

## 4. Chargement des 2 modèles

In [None]:
from transformers import CLIPModel, CLIPProcessor
import open_clip

# Fashion CLIP (512d) — HuggingFace format
print(f"Chargement de {MODEL_FASHION_CLIP}...")
fc_model = CLIPModel.from_pretrained(MODEL_FASHION_CLIP).to(device).eval()
fc_processor = CLIPProcessor.from_pretrained(MODEL_FASHION_CLIP)

# Marqo FashionCLIP (512d) — OpenCLIP format
print(f"Chargement de {MODEL_MARQO_CLIP}...")
mq_model, _, mq_preprocess = open_clip.create_model_and_transforms(
    f"hf-hub:{MODEL_MARQO_CLIP}", device=device
)
mq_tokenizer = open_clip.get_tokenizer(f"hf-hub:{MODEL_MARQO_CLIP}")
mq_model = mq_model.eval()

# Dict unifie : (type, model, processor/preprocess, tokenizer, max_text_tokens)
MODELS = {
    VECTOR_FASHION_CLIP: ("hf", fc_model, fc_processor, None, 77),
    VECTOR_MARQO_CLIP: ("openclip", mq_model, mq_preprocess, mq_tokenizer, 77),
}

if device == "cuda":
    vram_used = torch.cuda.memory_allocated() / 1024**3
    vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"\n2 modeles charges sur {device} — VRAM: {vram_used:.1f}/{vram_total:.0f} GB")
else:
    print(f"\n2 modeles charges sur {device}.")

## 5. Connexion a Weaviate sur GCP

In [None]:
import weaviate
import json
import os
import threading
from datetime import datetime, timezone
from weaviate.classes.config import Configure, DataType, Property, VectorDistances

assert GCP_EXTERNAL_IP, "Renseigne GCP_EXTERNAL_IP dans la cellule de configuration."

client = weaviate.connect_to_custom(
    http_host=GCP_EXTERNAL_IP,
    http_port=WEAVIATE_HTTP_PORT,
    http_secure=False,
    grpc_host=GCP_EXTERNAL_IP,
    grpc_port=WEAVIATE_GRPC_PORT,
    grpc_secure=False,
)

print(f"Connecte a Weaviate sur {GCP_EXTERNAL_IP}" if client.is_ready() else "Echec de connexion")

# --- Keep-alive (ping toutes les 2 min, auto-reconnect) ---
_keepalive_stop = threading.Event()

def start_keepalive(wv_client, interval=KEEPALIVE_INTERVAL):
    def _run():
        while not _keepalive_stop.wait(interval):
            try:
                if not wv_client.is_ready():
                    raise ConnectionError("not ready")
            except Exception as e:
                print(f"\n[keep-alive] Reconnexion... ({e})")
                try:
                    wv_client.close()
                except Exception:
                    pass
                wv_client.connect_to_custom(
                    http_host=GCP_EXTERNAL_IP,
                    http_port=WEAVIATE_HTTP_PORT,
                    http_secure=False,
                    grpc_host=GCP_EXTERNAL_IP,
                    grpc_port=WEAVIATE_GRPC_PORT,
                    grpc_secure=False,
                )
    t = threading.Thread(target=_run, daemon=True)
    t.start()
    return t

start_keepalive(client)
print("Keep-alive demarre (ping toutes les 2 min)")

# --- Checkpoint helpers ---
def load_checkpoint():
    try:
        if os.path.exists(CHECKPOINT_FILE):
            with open(CHECKPOINT_FILE) as f:
                return json.load(f)
    except Exception:
        pass
    return {"last_index": -1, "indexed_count": 0, "errors": 0}

def save_checkpoint(last_index, indexed_count, errors):
    with open(CHECKPOINT_FILE, "w") as f:
        json.dump({"last_index": last_index, "indexed_count": indexed_count,
                    "errors": errors, "timestamp": datetime.now(timezone.utc).isoformat()}, f)

def clear_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

## 6. Creation du schema

In [None]:
def create_schema(wv_client, collection_name, vector_names=None):
    """Crée la collection avec named vectors."""
    if vector_names is None:
        vector_names = ALL_VECTOR_NAMES

    named_vectors = [
        Configure.NamedVectors.none(
            name=name,
            vector_index_config=Configure.VectorIndex.hnsw(
                distance_metric=VectorDistances.COSINE
            ),
        )
        for name in vector_names
    ]

    wv_client.collections.create(
        name=collection_name,
        properties=[
            Property(name="filename", data_type=DataType.TEXT),
            Property(name="path", data_type=DataType.TEXT),
            Property(name="thumbnail_base64", data_type=DataType.TEXT),
            Property(name="width", data_type=DataType.INT),
            Property(name="height", data_type=DataType.INT),
            Property(name="indexed_at", data_type=DataType.TEXT),
            Property(name="product_id", data_type=DataType.TEXT),
            Property(name="product_name", data_type=DataType.TEXT),
            Property(name="category", data_type=DataType.TEXT),
            Property(name="color", data_type=DataType.TEXT),
            Property(name="size", data_type=DataType.TEXT),
            Property(name="price", data_type=DataType.NUMBER),
            Property(name="brand", data_type=DataType.TEXT),
            Property(name="product_url", data_type=DataType.TEXT),
            Property(name="description", data_type=DataType.TEXT),
            Property(name="image_index", data_type=DataType.INT),
            Property(name="gender", data_type=DataType.TEXT),
            Property(name="occasion", data_type=DataType.TEXT),
        ],
        vectorizer_config=named_vectors,
    )
    print(f"Collection '{collection_name}' creee avec vecteurs: {vector_names}")


if RECREATE:
    if client.collections.exists(COLLECTION_NAME):
        client.collections.delete(COLLECTION_NAME)
        print(f"Collection '{COLLECTION_NAME}' supprimee.")
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
elif not client.collections.exists(COLLECTION_NAME):
    create_schema(client, COLLECTION_NAME)
    clear_checkpoint()
else:
    print(f"Collection '{COLLECTION_NAME}' existe deja, on garde les donnees.")

## 7. Chargement du dataset DeepFashion In-Shop

In [None]:
from datasets import load_dataset
from collections import Counter
import random

print(f"Chargement du dataset {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, split="data")
print(f"Dataset chargé: {len(ds)} images totales")

# --- Dédoublonnage par produit (garder 1 vue, front en priorité) ---
# item_ID format: "MEN_Denim_id_00000080_01_1_front"
# product_id = tout sauf les 2 derniers segments (view_number + view_type)
VIEW_PRIORITY = {"front": 0, "full": 1, "additional": 2, "side": 3, "back": 4}

product_views = {}  # product_id -> (priority, index_in_dataset)

for i in range(len(ds)):
    item_id = ds[i]["item_ID"]
    parts = item_id.rsplit("_", 2)  # split from right, max 2 splits
    product_id = parts[0]           # everything before last 2 segments
    view_type = parts[-1] if len(parts) > 1 else "unknown"
    priority = VIEW_PRIORITY.get(view_type, 99)

    if product_id not in product_views or priority < product_views[product_id][0]:
        product_views[product_id] = (priority, i)

# Build deduplicated list of dataset indices
selected_indices = [idx for _, idx in product_views.values()]
selected_indices.sort()

if MAX_ITEMS:
    selected_indices = selected_indices[:MAX_ITEMS]

print(f"Produits uniques: {len(selected_indices)} (sur {len(ds)} images)")
print(f"Vue sélectionnée par priorité: front > full > additional > side > back")

# Stats par catégorie
cat_counts = Counter(ds[i]["category2"] for i in selected_indices)
print(f"\nCatégories ({len(cat_counts)}):")
for cat, count in cat_counts.most_common():
    print(f"  {cat}: {count}")

## 8. Fonctions utilitaires

In [None]:
import base64
import io
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor

import numpy as np
from PIL import Image


# =====================================================================
# Encode functions (single image + batch)
# =====================================================================

def _to_tensor(feat):
    """Extract tensor from model output (handles both raw tensors and BaseModelOutput)."""
    if isinstance(feat, torch.Tensor):
        return feat
    if hasattr(feat, "pooler_output") and feat.pooler_output is not None:
        return feat.pooler_output
    if hasattr(feat, "last_hidden_state"):
        return feat.last_hidden_state[:, 0, :]
    raise ValueError(f"Cannot extract tensor from {type(feat)}")


def _encode_image(model_type, model, processor, img):
    """Encode image -> normalized numpy vector."""
    if model_type == "openclip":
        image_tensor = processor(img).unsqueeze(0).to(device)
        feat = model.encode_image(image_tensor)
    else:
        inputs = processor(images=img, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_image_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.detach().cpu().float().numpy().flatten()


def _encode_text(model_type, model, processor, tokenizer, text, max_len=77):
    """Encode text -> normalized numpy vector."""
    if model_type == "openclip":
        tokens = tokenizer([text]).to(device)
        feat = model.encode_text(tokens)
    else:
        tok = getattr(processor, "tokenizer", processor)
        inputs = tok(
            [text], return_tensors="pt", padding=True,
            truncation=True, max_length=max_len,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_text_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.detach().cpu().float().numpy().flatten()


def _encode_images_batch(model_type, model, processor, images):
    """Encode a batch of images -> normalized numpy vectors (N, D)."""
    if model_type == "openclip":
        batch = torch.stack([processor(img) for img in images]).to(device)
        feat = model.encode_image(batch)
    else:
        inputs = processor(images=images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_image_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.detach().cpu().float().numpy()


def _encode_texts_batch(model_type, model, processor, tokenizer, texts, max_len=77):
    """Encode a batch of texts -> normalized numpy vectors (N, D)."""
    if model_type == "openclip":
        tokens = tokenizer(texts).to(device)
        feat = model.encode_text(tokens)
    else:
        tok = getattr(processor, "tokenizer", processor)
        inputs = tok(
            texts, return_tensors="pt", padding=True,
            truncation=True, max_length=max_len,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        feat = model.get_text_features(**inputs)
    feat = _to_tensor(feat)
    feat = feat / feat.norm(p=2, dim=-1, keepdim=True)
    return feat.detach().cpu().float().numpy()


# =====================================================================
# Zero-shot tagging with Fashion CLIP
# =====================================================================

LABEL_TAXONOMY = {
    "garment": [
        "t-shirt", "blouse", "shirt", "crop top", "tank top", "camisole",
        "polo shirt", "henley", "tunic", "bodysuit",
        "jeans", "trousers", "shorts", "leggings", "chinos", "joggers",
        "wide leg pants", "cargo pants", "culottes",
        "mini dress", "midi dress", "maxi dress", "shirt dress", "wrap dress",
        "cocktail dress", "evening gown", "sundress",
        "jumpsuit", "romper", "playsuit",
        "blazer", "leather jacket", "denim jacket", "bomber jacket",
        "puffer jacket", "trench coat", "overcoat", "parka",
        "cardigan", "hoodie", "sweater", "pullover", "vest",
        "mini skirt", "midi skirt", "maxi skirt", "pleated skirt", "pencil skirt",
        "bikini", "swimsuit", "bra", "lingerie",
        "handbag", "backpack", "tote bag", "clutch bag", "crossbody bag",
        "sneakers", "boots", "ankle boots", "heels", "sandals", "loafers",
        "suit", "tuxedo",
    ],
    "description": [
        # Matières
        "leather", "faux leather", "suede", "denim", "cotton", "linen",
        "silk", "satin", "chiffon", "velvet", "wool", "cashmere",
        "knit", "mesh", "lace", "tulle", "sequin", "metallic",
        "corduroy", "tweed", "jersey", "fleece",
        # Coupes & détails
        "oversized", "slim fit", "fitted", "relaxed fit", "cropped",
        "high waisted", "flared", "skinny", "pleated", "ruffled",
        "embroidered", "belted", "hooded", "collared",
        "v-neck", "crew neck", "turtleneck", "off-shoulder", "strapless",
        "sleeveless", "long sleeve", "short sleeve",
        # Motifs
        "floral print", "striped", "plaid", "polka dot", "animal print",
        # Style
        "minimalist", "vintage", "bohemian", "elegant", "edgy", "sporty",
    ],
    "occasion": [
        # Formel / soirée
        "evening party", "cocktail party", "formal dinner", "gala event",
        "black tie event", "wedding guest", "prom night", "awards ceremony",
        "elegant night out", "fancy restaurant",
        # Bureau / professionnel
        "office work", "business meeting", "job interview", "conference",
        "professional setting",
        # Casual / quotidien
        "casual everyday", "weekend outing", "brunch with friends",
        "shopping trip", "coffee date", "running errands",
        # Sport / outdoor
        "gym workout", "outdoor hiking", "yoga session", "running jogging",
        "athleisure streetwear",
        # Vacances / loisirs
        "beach vacation", "summer festival", "travel", "resort wear",
        "garden party", "picnic outdoor",
        # Date / romantique
        "date night", "romantic dinner",
    ],
    "vibe": [
        "elegant sophisticated", "edgy bold", "romantic feminine",
        "sporty athletic", "classic timeless", "modern minimalist",
        "bohemian free-spirited", "glamorous luxurious",
        "casual relaxed", "professional polished", "playful fun",
        "streetwear cool", "chic refined", "grunge rebellious",
    ],
    "gender": [
        "male model",
        "female model",
    ],
}

TAG_TOP_K = {
    "garment": 3,
    "description": 4,
    "occasion": 5,
    "vibe": 2,
    "gender": 1,
}
TAG_MIN_SIMILARITY = 0.15

LABEL_PROMPTS = {
    "garment": "a photo of {}",
    "description": "a photo of {} clothing",
    "occasion": "outfit suitable for {}",
    "vibe": "a {} fashion style",
    "gender": "a photo of a {}",
}

GARMENT_OCCASION_MAP = {
    # Robes
    "cocktail dress": "evening party, cocktail party, formal dinner, date night, elegant night out",
    "evening gown": "gala event, black tie event, awards ceremony, prom night, elegant night out",
    "mini dress": "evening party, date night, brunch with friends, cocktail party, weekend outing",
    "midi dress": "office work, wedding guest, brunch with friends, date night, garden party",
    "maxi dress": "beach vacation, garden party, wedding guest, summer festival, resort wear",
    "shirt dress": "office work, brunch with friends, casual everyday, shopping trip, conference",
    "wrap dress": "office work, date night, wedding guest, brunch with friends, professional setting",
    "sundress": "beach vacation, summer festival, garden party, picnic outdoor, weekend outing",
    # Hauts
    "t-shirt": "casual everyday, weekend outing, running errands, shopping trip, travel",
    "blouse": "office work, business meeting, brunch with friends, date night, professional setting",
    "shirt": "office work, business meeting, casual everyday, conference, job interview",
    "crop top": "summer festival, beach vacation, casual everyday, weekend outing, date night",
    "tank top": "gym workout, beach vacation, casual everyday, summer festival, running jogging",
    "camisole": "date night, evening party, casual everyday, romantic dinner, elegant night out",
    "polo shirt": "casual everyday, weekend outing, office work, brunch with friends, travel",
    "henley": "casual everyday, weekend outing, coffee date, running errands, travel",
    "tunic": "casual everyday, beach vacation, brunch with friends, weekend outing, resort wear",
    "bodysuit": "evening party, date night, cocktail party, casual everyday, elegant night out",
    # Pantalons
    "jeans": "casual everyday, weekend outing, shopping trip, brunch with friends, coffee date",
    "trousers": "office work, business meeting, professional setting, conference, job interview",
    "shorts": "beach vacation, casual everyday, summer festival, weekend outing, gym workout",
    "leggings": "gym workout, yoga session, running jogging, athleisure streetwear, casual everyday",
    "chinos": "office work, casual everyday, brunch with friends, weekend outing, business meeting",
    "joggers": "casual everyday, athleisure streetwear, running errands, gym workout, travel",
    "wide leg pants": "office work, brunch with friends, casual everyday, weekend outing, travel",
    "cargo pants": "casual everyday, outdoor hiking, travel, weekend outing, running errands",
    "culottes": "office work, brunch with friends, casual everyday, weekend outing, shopping trip",
    # Combinaisons
    "jumpsuit": "evening party, date night, wedding guest, brunch with friends, cocktail party",
    "romper": "beach vacation, summer festival, casual everyday, weekend outing, garden party",
    "playsuit": "beach vacation, summer festival, weekend outing, garden party, casual everyday",
    # Vestes & manteaux
    "blazer": "office work, business meeting, job interview, conference, professional setting",
    "leather jacket": "date night, evening party, casual everyday, weekend outing, concert",
    "denim jacket": "casual everyday, weekend outing, shopping trip, brunch with friends, travel",
    "bomber jacket": "casual everyday, athleisure streetwear, weekend outing, travel, shopping trip",
    "puffer jacket": "outdoor hiking, casual everyday, travel, running errands, weekend outing",
    "trench coat": "office work, business meeting, casual everyday, professional setting, travel",
    "overcoat": "office work, formal dinner, business meeting, professional setting, elegant night out",
    "parka": "outdoor hiking, casual everyday, travel, running errands, weekend outing",
    # Maille
    "cardigan": "office work, casual everyday, weekend outing, coffee date, brunch with friends",
    "hoodie": "casual everyday, weekend outing, athleisure streetwear, running errands, gym workout",
    "sweater": "casual everyday, weekend outing, office work, coffee date, travel",
    "pullover": "casual everyday, weekend outing, office work, coffee date, travel",
    "vest": "casual everyday, outdoor hiking, office work, weekend outing, travel",
    # Jupes
    "mini skirt": "evening party, date night, weekend outing, brunch with friends, casual everyday",
    "midi skirt": "office work, brunch with friends, wedding guest, professional setting, date night",
    "maxi skirt": "beach vacation, casual everyday, wedding guest, garden party, resort wear",
    "pleated skirt": "office work, brunch with friends, wedding guest, casual everyday, date night",
    "pencil skirt": "office work, business meeting, professional setting, conference, job interview",
    # Swimwear & lingerie
    "bikini": "beach vacation, resort wear, summer festival, picnic outdoor",
    "swimsuit": "beach vacation, resort wear, summer festival",
    "bra": "casual everyday",
    "lingerie": "romantic dinner, date night",
    # Sacs
    "handbag": "office work, shopping trip, casual everyday, brunch with friends, date night",
    "backpack": "travel, casual everyday, outdoor hiking, running errands, weekend outing",
    "tote bag": "office work, shopping trip, beach vacation, casual everyday, brunch with friends",
    "clutch bag": "evening party, cocktail party, formal dinner, date night, gala event",
    "crossbody bag": "casual everyday, travel, shopping trip, weekend outing, running errands",
    # Chaussures
    "sneakers": "casual everyday, athleisure streetwear, weekend outing, travel, running errands",
    "boots": "casual everyday, weekend outing, outdoor hiking, date night, travel",
    "ankle boots": "casual everyday, date night, weekend outing, office work, brunch with friends",
    "heels": "evening party, cocktail party, formal dinner, date night, office work",
    "sandals": "beach vacation, casual everyday, summer festival, garden party, resort wear",
    "loafers": "office work, casual everyday, brunch with friends, professional setting, weekend outing",
    # Costumes
    "suit": "business meeting, job interview, formal dinner, conference, professional setting",
    "tuxedo": "gala event, black tie event, awards ceremony, formal dinner, prom night",
}


def occasions_from_garment_tags(garment_tags):
    """Derive occasions from garment tags using domain mapping."""
    scores = {}
    for tag in garment_tags:
        occs = GARMENT_OCCASION_MAP.get(tag, "")
        if occs:
            for i, occ in enumerate(occs.split(", ")):
                scores[occ] = scores.get(occ, 0) + 1.0 / (i + 1)
    return sorted(scores, key=scores.get, reverse=True)[:5]


def precompute_label_embeddings():
    """Pre-encode all taxonomy labels with Fashion CLIP (once)."""
    label_embeddings = {}
    fc_type, fc_mdl, fc_proc, fc_tok, fc_max = MODELS[VECTOR_FASHION_CLIP]

    for category, labels in LABEL_TAXONOMY.items():
        prompt_tpl = LABEL_PROMPTS.get(category, "a photo of {}")
        prompts = [prompt_tpl.format(label) for label in labels]
        vecs = _encode_texts_batch(fc_type, fc_mdl, fc_proc, fc_tok, prompts, fc_max)
        label_embeddings[category] = {
            "labels": labels,
            "vectors": vecs,  # (M, 512) normalized
        }
        print(f"  {category}: {len(labels)} labels encoded")

    return label_embeddings


def tag_images_batch(fc_image_vecs, label_embeddings):
    """Zero-shot tag a batch of images using pre-computed label embeddings.

    Args:
        fc_image_vecs: (N, 512) normalized Fashion CLIP image vectors
        label_embeddings: dict from precompute_label_embeddings()

    Returns:
        list of N dicts, each: {category: [top-K label strings]}
    """
    n = fc_image_vecs.shape[0]
    all_tags = [{} for _ in range(n)]

    for category, data in label_embeddings.items():
        lab_vecs = data["vectors"]  # (M, 512)
        labels = data["labels"]
        top_k = TAG_TOP_K.get(category, 3)

        # Cosine similarity: (N, 512) @ (512, M) -> (N, M)
        sims = fc_image_vecs @ lab_vecs.T

        for i in range(n):
            row = sims[i]
            top_indices = np.argsort(row)[::-1][:top_k]
            tags = [labels[j] for j in top_indices if row[j] >= TAG_MIN_SIMILARITY]
            all_tags[i][category] = tags

    return all_tags


def build_enriched_description(tags, original_desc=""):
    """Build enriched description from zero-shot tags + original description.

    Returns:
        (description, occasion, gender) tuple
    """
    # Garment mapping occasions (primary, precise)
    garment_occasions = occasions_from_garment_tags(tags.get("garment", []))
    # Zero-shot occasions (secondary, visual cues)
    clip_occasions = tags.get("occasion", [])
    # Merge: garment mapping first, then unique CLIP additions
    merged = list(garment_occasions)
    for occ in clip_occasions:
        if occ not in merged:
            merged.append(occ)
    final_occasions = merged[:5]

    # Build description
    parts = []
    for cat in ["garment", "description"]:
        if tags.get(cat):
            parts.append(", ".join(tags[cat]))
    if final_occasions:
        parts.append(", ".join(final_occasions))
    if tags.get("vibe"):
        parts.append(", ".join(tags["vibe"]))

    # Map gender tag to BM25-friendly text
    gender_tags = tags.get("gender", [])
    gender = None
    if gender_tags:
        g = gender_tags[0]
        if "male" in g and "female" not in g:
            gender = "men"
            parts.append("men's clothing")
        elif "female" in g:
            gender = "women"
            parts.append("women's clothing")

    tag_text = " | ".join(parts)
    description = f"{tag_text} | {original_desc.strip()}" if original_desc and original_desc.strip() else tag_text
    occasion = ", ".join(final_occasions)

    return description, occasion, gender


# Pre-compute label embeddings
print("Pre-computing label embeddings with Fashion CLIP...")
_label_embeddings = precompute_label_embeddings()
total_labels = sum(len(v["labels"]) for v in _label_embeddings.values())
print(f"Done: {total_labels} labels across {len(_label_embeddings)} categories")


# =====================================================================
# Thumbnail helper
# =====================================================================

def make_thumbnail_b64(image, size=THUMBNAIL_SIZE):
    img = image.copy()
    img.thumbnail((size, size))
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


# =====================================================================
# Combined encoding (image + text, multi-model)
# =====================================================================

# Pre-create one CUDA stream per model for true GPU parallelism
_model_streams = {}
if device == "cuda":
    for name in MODELS:
        _model_streams[name] = torch.cuda.Stream()
    print(f"Created {len(_model_streams)} CUDA streams for parallel encoding")


def encode_combined_batch(images, texts, image_weight=IMAGE_WEIGHT,
                          fc_image_vecs_precomputed=None):
    """Encode N images+texts avec les 2 modeles en parallele (batch GPU + CUDA streams).

    Args:
        images: list of PIL images
        texts: list of text descriptions
        image_weight: weight for image vs text (default 0.5)
        fc_image_vecs_precomputed: optional (N, 512) Fashion CLIP image vecs
            already computed (reused from tagging step to avoid double encoding)
    """
    n = len(images)
    imgs = [img.convert("RGB") for img in images]
    has_texts = [bool(t and t.strip()) for t in texts]
    text_indices = [i for i in range(n) if has_texts[i]]
    batch_texts = [texts[i] for i in text_indices] if text_indices else []

    def encode_model(name, model_tuple):
        model_type, mdl, processor, tokenizer, max_len = model_tuple
        stream = _model_streams.get(name)
        stream_ctx = torch.cuda.stream(stream) if stream else nullcontext()
        with torch.no_grad(), stream_ctx:
            # Reuse precomputed Fashion CLIP image vecs if available
            if name == VECTOR_FASHION_CLIP and fc_image_vecs_precomputed is not None:
                img_vecs = fc_image_vecs_precomputed
            else:
                img_vecs = _encode_images_batch(model_type, mdl, processor, imgs)
            if text_indices:
                txt_vecs = _encode_texts_batch(model_type, mdl, processor, tokenizer, batch_texts, max_len)
                combined = img_vecs.copy()
                for j, idx in enumerate(text_indices):
                    c = image_weight * img_vecs[idx] + (1 - image_weight) * txt_vecs[j]
                    norm_val = np.linalg.norm(c)
                    if norm_val > 0:
                        c = c / norm_val
                    combined[idx] = c
                result = combined
            else:
                result = img_vecs
        if stream:
            stream.synchronize()
        return name, result

    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = [executor.submit(encode_model, nm, t) for nm, t in MODELS.items()]
        all_vectors = {f.result()[0]: f.result()[1] for f in futures}

    return [{name: all_vectors[name][i] for name in all_vectors} for i in range(n)]


# Quick sanity check (single image)
_test_img = Image.new("RGB", (64, 64), "red")
_test_vecs = _encode_images_batch(*MODELS[VECTOR_FASHION_CLIP][:3], [_test_img])
print(f"  Fashion CLIP image vec: shape={_test_vecs.shape}")

# Sanity check: zero-shot tagging
_test_tags = tag_images_batch(_test_vecs, _label_embeddings)
print(f"  Tags for red 64x64 image: {_test_tags[0]}")

# Sanity check (batch of 4)
_test_imgs = [Image.new("RGB", (64, 64), c) for c in ["red", "blue", "green", "black"]]
_test_texts = ["red shirt", "blue jeans", "green jacket", "black shoes"]
_test_batch = encode_combined_batch(_test_imgs, _test_texts)
print(f"  Batch: {len(_test_batch)} items, keys={list(_test_batch[0].keys())}")
for k in _test_batch[0]:
    print(f"    {k}: shape={_test_batch[0][k].shape}")
print("OK")

In [None]:
## Benchmark : batch encoding speed
import time

N_BENCH = 32  # nombre d'images pour le benchmark
_bench_imgs = [Image.new("RGB", (384, 384), c) for c in ["blue", "red", "green", "black"] * (N_BENCH // 4)]
_bench_texts = [f"color {i} denim jacket casual style" for i in range(N_BENCH)]

# Warmup GPU
for _ in range(3):
    encode_combined_batch([_bench_imgs[0]], [_bench_texts[0]])
if device == "cuda":
    torch.cuda.synchronize()

# --- Batch (all images at once) ---
if device == "cuda":
    torch.cuda.synchronize()
t0 = time.perf_counter()
encode_combined_batch(_bench_imgs, _bench_texts)
if device == "cuda":
    torch.cuda.synchronize()
batch_time = time.perf_counter() - t0

print(f"Benchmark sur {N_BENCH} images ({device}):")
print(f"  Batch (x{N_BENCH})  : {batch_time:.2f}s ({N_BENCH/batch_time:.1f} img/s) — {batch_time/N_BENCH*1000:.0f}ms/img")

## 9. Indexation (GPU batch encoding + zero-shot tagging → Weaviate sur GCP)

Pipeline : images PIL directes → Fashion CLIP image encode → zero-shot tagging → enriched descriptions → batch encode (2 modèles) → push Weaviate.

**Optimisation** : Fashion CLIP image vecs calculés une seule fois, réutilisés pour le tagging ET la vectorisation.

In [None]:
import time
import random
from datetime import datetime, timezone
from tqdm.auto import tqdm
from PIL import Image

collection = client.collections.get(COLLECTION_NAME)

# --- Checkpoint & reprise ---
checkpoint = load_checkpoint()
start_idx = 0
if RECREATE:
    start_idx = 0
elif checkpoint["last_index"] >= 0:
    start_idx = checkpoint["last_index"] + 1
    print(f"Reprise depuis l'index {start_idx} ({checkpoint['indexed_count']} déjà indexés)")

# --- Récupère les product_ids déjà indexés (dedup) ---
existing_ids = set()
for obj in collection.iterator(return_properties=["product_id"]):
    pid = obj.properties.get("product_id", "")
    if pid:
        existing_ids.add(pid)
print(f"Déjà indexés: {len(existing_ids)} produits — on skip ceux-là")

indexed = checkpoint["indexed_count"] if not RECREATE else 0
errors = checkpoint["errors"] if not RECREATE else 0
already = 0
start_time = time.time()

# --- Build encode tasks from selected indices ---
random.seed(42)
encode_tasks = []  # list of (ds_idx, product_id, meta)

for task_idx, ds_idx in enumerate(selected_indices):
    if task_idx < start_idx:
        continue

    row = ds[ds_idx]
    item_id = row["item_ID"]
    product_id = item_id.rsplit("_", 2)[0]

    if product_id in existing_ids:
        already += 1
        continue

    category2 = row.get("category2", "") or ""
    color = row.get("color", "") or ""
    category1 = row.get("category1", "") or ""
    description = row.get("description", "") or ""

    # product_name = "Color Category" capitalisé
    name_parts = [p for p in [color, category2] if p]
    product_name = " ".join(name_parts).title() if name_parts else item_id

    # Prix et marque aléatoires
    price_range = PRICE_RANGES.get(category2, (20, 100))
    price = round(random.uniform(*price_range), 2)
    brand = random.choice(BRANDS)

    # Gender depuis category1
    gender_raw = (category1 or "").strip().upper()
    if gender_raw == "WOMEN":
        gender = "women"
    elif gender_raw == "MEN":
        gender = "men"
    else:
        gender = None

    meta = {
        "product_id": product_id,
        "product_name": product_name,
        "category": category2,
        "color": color,
        "price": price,
        "brand": brand,
        "product_url": "",
        "description": description,  # will be enriched by tagging
        "gender": gender,
        "occasion": "",  # will be set by tagging
        "size": "",
        "image_index": 0,
        "_task_idx": task_idx,
        "_ds_idx": ds_idx,
    }
    encode_tasks.append((ds_idx, product_id, meta))

print(f"{already} produits déjà indexés (skip)")
print(f"{len(encode_tasks)} produits à encoder")
print(f"Poids image/texte: {IMAGE_WEIGHT:.1f} / {1 - IMAGE_WEIGHT:.1f}")
print(f"Pipeline: FC image encode → zero-shot tag → enriched desc → 2-model encode → Weaviate")
print(f"Config: ENCODE_BATCH={ENCODE_BATCH_SIZE}, WEAVIATE_BATCH={WEAVIATE_BATCH_SIZE}")

# --- Batch encode (GPU) + tagging + push Weaviate ---
weaviate_queue = []
last_task_idx = start_idx
fc_type, fc_mdl, fc_proc, _, _ = MODELS[VECTOR_FASHION_CLIP]


def flush_weaviate():
    global weaviate_queue, indexed, last_task_idx
    if not weaviate_queue:
        return
    with collection.batch.dynamic() as batch:
        for doc in weaviate_queue:
            vecs = doc.pop("_vectors")
            tidx = doc.pop("_task_idx", 0)
            doc.pop("_ds_idx", None)
            batch.add_object(properties=doc, vector=vecs)
            last_task_idx = max(last_task_idx, tidx)
    indexed += len(weaviate_queue)
    save_checkpoint(last_task_idx, indexed, errors)
    weaviate_queue = []


pbar = tqdm(total=len(encode_tasks), desc="Tag + Encode + Push (GPU)")

for batch_start in range(0, len(encode_tasks), ENCODE_BATCH_SIZE):
    batch_slice = encode_tasks[batch_start:batch_start + ENCODE_BATCH_SIZE]

    batch_images = []
    batch_metas = []

    for ds_idx, product_id, meta in batch_slice:
        try:
            row = ds[ds_idx]
            img = row["image"]
            if not isinstance(img, Image.Image):
                img = Image.open(io.BytesIO(img)).convert("RGB")
            else:
                img = img.convert("RGB")

            w, h = img.size
            meta = meta.copy()
            meta["thumbnail_base64"] = make_thumbnail_b64(img)
            meta["width"] = w
            meta["height"] = h
            meta["filename"] = f"{product_id}.jpg"
            meta["path"] = ""
            meta["indexed_at"] = datetime.now(timezone.utc).isoformat()

            batch_images.append(img)
            batch_metas.append(meta)
        except Exception as e:
            errors += 1
            if errors <= 10:
                tqdm.write(f"  Load error (idx={ds_idx}): {e}")

    if not batch_images:
        pbar.update(len(batch_slice))
        continue

    try:
        # Step 1: Encode images with Fashion CLIP (once, reused for tagging + vectorization)
        with torch.no_grad():
            fc_image_vecs = _encode_images_batch(fc_type, fc_mdl, fc_proc, batch_images)

        # Step 2: Zero-shot tagging
        batch_tags = tag_images_batch(fc_image_vecs, _label_embeddings)

        # Step 3: Build enriched descriptions + occasion + gender from tags
        batch_texts = []
        for meta, tags in zip(batch_metas, batch_tags):
            original_desc = meta.get("description", "")
            gender = meta.get("gender")  # from category1
            enriched_desc, occasion, tagged_gender = build_enriched_description(tags, original_desc)
            meta["description"] = enriched_desc
            meta["occasion"] = occasion
            meta["gender"] = gender or tagged_gender  # category1 en priorité, modèle en fallback

            # Text for vector encoding = enriched description + product_name + category + color
            vec_text_parts = [
                meta.get("product_name", ""),
                meta.get("category", ""),
                meta.get("color", ""),
                enriched_desc,
            ]
            vec_text = " ".join(p for p in vec_text_parts if p).strip()
            batch_texts.append(vec_text)

        # Step 4: Encode with 2 models (FC image vecs reused from step 1)
        batch_vectors = encode_combined_batch(
            batch_images, batch_texts,
            fc_image_vecs_precomputed=fc_image_vecs,
        )

        for meta, vectors in zip(batch_metas, batch_vectors):
            meta["_vectors"] = {k: v.tolist() for k, v in vectors.items()}
            weaviate_queue.append(meta)

            if len(weaviate_queue) >= WEAVIATE_BATCH_SIZE:
                flush_weaviate()

        # Print first item of first batch as example
        if batch_start == 0 and batch_metas:
            tqdm.write(f"\n--- Example enriched item ---")
            tqdm.write(f"  product_name: {batch_metas[0].get('product_name', '')}")
            tqdm.write(f"  category: {batch_metas[0].get('category', '')}")
            tqdm.write(f"  gender: {batch_metas[0].get('gender', '')}")
            tqdm.write(f"  description: {batch_metas[0].get('description', '')[:200]}...")
            tqdm.write(f"  occasion: {batch_metas[0].get('occasion', '')}")
            tqdm.write(f"  tags: {batch_tags[0]}")
            tqdm.write(f"---\n")

    except Exception as e:
        errors += len(batch_images)
        if errors <= 10:
            tqdm.write(f"  Batch encode error: {e}")

    pbar.update(len(batch_slice))

# Flush remaining
flush_weaviate()

pbar.close()
elapsed = time.time() - start_time

print(f"\n{'='*50}")
print(f"Indexation terminée !")
print(f"Déjà présents: {already}")
print(f"Nouveaux produits indexés: {indexed}")
print(f"Erreurs: {errors}")
print(f"Temps: {elapsed:.0f}s ({indexed / max(elapsed, 1):.1f} produits/sec)")
print(f"{'='*50}")
clear_checkpoint()

## 10. Vérification

In [None]:
collection = client.collections.get(COLLECTION_NAME)
stats = collection.aggregate.over_all(total_count=True)
print(f"Collection '{COLLECTION_NAME}': {stats.total_count} documents")

## 11. Test rapide

In [None]:
from weaviate.classes.query import MetadataQuery
from IPython.display import display, HTML

query = "black leather jacket"

collection = client.collections.get(COLLECTION_NAME)

# --- 1. Encode query with all 2 models ---
query_vectors = {}
for vec_name, (model_type, model, processor, tokenizer, max_len) in MODELS.items():
    with torch.no_grad():
        if model_type == "openclip":
            tokens = tokenizer([query]).to(device)
            features = model.encode_text(tokens)
        else:
            tok = getattr(processor, "tokenizer", processor)
            inputs = tok(
                [query], return_tensors="pt", padding=True,
                truncation=True, max_length=max_len,
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            features = model.get_text_features(**inputs)
        features = _to_tensor(features)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
        query_vectors[vec_name] = features.cpu().float().numpy().flatten()

# --- 2. BM25 search (with boosts) ---
bm25_results = collection.query.bm25(
    query=query,
    query_properties=["product_name^3", "category^2", "color^2", "description", "occasion"],
    limit=60,
    return_metadata=MetadataQuery(score=True),
)
print(f"BM25: {len(bm25_results.objects)} results")

# Show BM25 top hits to debug
for i, obj in enumerate(bm25_results.objects[:5]):
    p = obj.properties
    print(f"  BM25 #{i+1}: score={obj.metadata.score:.2f} | "
          f"name={p.get('product_name') or ''} | "
          f"color={p.get('color') or ''} | "
          f"cat={p.get('category') or ''} | "
          f"desc={str(p.get('description') or '')[:80]}")

# --- 3. near_vector per model ---
vector_results = {}
for vec_name, vec in query_vectors.items():
    results = collection.query.near_vector(
        near_vector=vec.tolist(),
        target_vector=vec_name,
        limit=60,
        return_metadata=MetadataQuery(distance=True),
    )
    vector_results[vec_name] = results
    print(f"{vec_name}: {len(results.objects)} results")

# --- 4. RRF Fusion (BM25 counted 2x for more keyword weight) ---
def rrf_fusion(result_lists, limit=20, rrf_k=60):
    scores = {}
    result_map = {}
    for results in result_lists:
        for rank, obj in enumerate(results):
            key = (obj.properties.get("product_id") or obj.properties.get("filename") or str(rank))
            scores[key] = scores.get(key, 0) + 1 / (rrf_k + rank + 1)
            if key not in result_map:
                result_map[key] = obj
    sorted_keys = sorted(scores, key=lambda x: scores[x], reverse=True)
    return [(result_map[key], scores[key]) for key in sorted_keys[:limit]]

# BM25 counted twice (2/4 weight) vs 2 vectors (2/4 weight)
all_lists = [bm25_results.objects, bm25_results.objects]
for vec_name in query_vectors:
    all_lists.append(vector_results[vec_name].objects)

fused = rrf_fusion(all_lists, limit=20)
print(f"\nRRF Fusion: {len(all_lists)} sources (BM25 x2 + 2 vectors) -> {len(fused)} results")

# --- Helper to safely get string properties ---
def _p(props, key, max_len=0):
    val = props.get(key) or ""
    return val[:max_len] if max_len else val

# --- 5. Display hybrid results ---
html = f'<h3>Hybrid Search: "{query}" (BM25 x2 + 2 vectors + RRF)</h3>'
html += '<div style="display:flex; gap:10px; flex-wrap:wrap;">'
for obj, rrf_score in fused[:20]:
    p = obj.properties
    thumb = _p(p, "thumbnail_base64")
    name = _p(p, "product_name", 45)
    color = _p(p, "color")
    category = _p(p, "category")
    occasion = _p(p, "occasion", 50)
    desc = _p(p, "description", 80)
    price = p.get("price")
    price_str = f"£{price:.2f}" if price else ""
    html += f'''
    <div style="text-align:center; width:170px; border:1px solid #eee; padding:5px; border-radius:8px;">
        <img src="data:image/jpeg;base64,{thumb}" style="max-width:150px; max-height:150px;"/>
        <div style="font-size:11px; font-weight:bold;">{name}</div>
        <div style="font-size:10px; color:#666;">{color} | {category}</div>
        <div style="font-size:10px; color:#999;">{occasion}</div>
        <div style="font-size:9px; color:#aaa;">{desc}</div>
        <div style="font-size:11px; color:#333;">{price_str} — RRF: {rrf_score:.4f}</div>
    </div>'''
html += '</div>'
display(HTML(html))

# --- 6. Per-source comparison ---
sources = [("BM25", bm25_results.objects[:5])]
for vn in query_vectors:
    sources.append((vn, vector_results[vn].objects[:5]))

for source_name, source_results in sources:
    html = f'<h4>{source_name} top 5</h4><div style="display:flex; gap:8px; flex-wrap:wrap;">'
    for obj in source_results:
        p = obj.properties
        thumb = _p(p, "thumbnail_base64")
        name = _p(p, "product_name", 35)
        color = _p(p, "color")
        html += f'''
        <div style="text-align:center; width:130px;">
            <img src="data:image/jpeg;base64,{thumb}" style="max-width:120px; max-height:120px;"/>
            <div style="font-size:10px;">{name}</div>
            <div style="font-size:10px; color:#888;">{color}</div>
        </div>'''
    html += '</div>'
    display(HTML(html))

In [None]:
_keepalive_stop.set()  # Stop le thread keep-alive
client.close()
print("Connexion Weaviate fermee.")