# Fashion Search — Indexation ASOS depuis Colab (GPU)

Ce notebook indexe le dataset [ASOS e-commerce](https://huggingface.co/datasets/UniqueData/asos-e-commerce-dataset) dans Weaviate en utilisant [Fashion CLIP](https://huggingface.co/patrickjohncyh/fashion-clip) sur GPU.

**Prérequis :**
- Runtime GPU activé (Runtime → Change runtime type → T4 GPU)
- Une instance Weaviate accessible (Weaviate Cloud ou tunnel ngrok)

## 1. Installation des dépendances

In [None]:
!pip install -q weaviate-client>=4.0 transformers torch torchvision Pillow datasets requests tqdm

## 2. Configuration

Renseigne l'URL et la clé API de ton instance Weaviate.

- **Weaviate Cloud** : URL type `https://xxx.weaviate.network`, clé API depuis la console
- **Weaviate local + ngrok** : `docker compose up -d weaviate` puis `ngrok http 8080`

In [None]:
# --- CONFIGURATION ---
WEAVIATE_URL = ""        # ex: "https://fashion-abc123.weaviate.network"
WEAVIATE_API_KEY = ""    # laisser vide si local/ngrok

COLLECTION_NAME = "FashionCollection"
MODEL_NAME = "patrickjohncyh/fashion-clip"

MAX_ITEMS = 2000              # nombre de produits à indexer (None = tout)
MAX_IMAGES_PER_PRODUCT = 1    # images par produit
BATCH_SIZE = 32               # taille des batches GPU
WEAVIATE_BATCH_SIZE = 100     # taille des batches Weaviate
THUMBNAIL_SIZE = 150

## 3. Vérification GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")
else:
    device = "cpu"
    print("⚠ Pas de GPU détecté, l'indexation sera lente.")

print(f"Device: {device}")

## 4. Chargement du modèle Fashion CLIP

In [None]:
from transformers import CLIPModel, CLIPProcessor

print(f"Chargement de {MODEL_NAME}...")
model = CLIPModel.from_pretrained(MODEL_NAME)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

if device == "cuda":
    model = model.half()  # FP16 pour économiser la VRAM

model = model.to(device)
model.eval()

VECTOR_DIM = model.config.projection_dim
print(f"Modèle chargé — dimension vecteur: {VECTOR_DIM}")

## 5. Connexion à Weaviate

In [None]:
import weaviate
from weaviate.classes.config import Configure, DataType, Property, VectorDistances

assert WEAVIATE_URL, "Renseigne WEAVIATE_URL dans la cellule de configuration."

is_cloud = WEAVIATE_URL.startswith("https://") and WEAVIATE_API_KEY

if is_cloud:
    client = weaviate.connect_to_weaviate_cloud(
        cluster_url=WEAVIATE_URL,
        auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
    )
else:
    # Connexion directe (ngrok, IP publique, etc.)
    client = weaviate.connect_to_custom(
        http_host=WEAVIATE_URL.replace("https://", "").replace("http://", "").split(":")[0],
        http_port=int(WEAVIATE_URL.split(":")[-1]) if ":" in WEAVIATE_URL.split("//")[-1] else 443 if WEAVIATE_URL.startswith("https") else 80,
        http_secure=WEAVIATE_URL.startswith("https"),
        grpc_host=WEAVIATE_URL.replace("https://", "").replace("http://", "").split(":")[0],
        grpc_port=50051,
        grpc_secure=WEAVIATE_URL.startswith("https"),
    )

print("Connecté !" if client.is_ready() else "Échec de connexion")

## 6. Création du schéma

In [None]:
# Supprime et recrée la collection (⚠ supprime les données existantes)
if client.collections.exists(COLLECTION_NAME):
    client.collections.delete(COLLECTION_NAME)
    print(f"Collection '{COLLECTION_NAME}' supprimée.")

client.collections.create(
    name=COLLECTION_NAME,
    properties=[
        Property(name="filename", data_type=DataType.TEXT),
        Property(name="path", data_type=DataType.TEXT),
        Property(name="thumbnail_base64", data_type=DataType.TEXT),
        Property(name="width", data_type=DataType.INT),
        Property(name="height", data_type=DataType.INT),
        Property(name="indexed_at", data_type=DataType.TEXT),
        Property(name="product_id", data_type=DataType.TEXT),
        Property(name="product_name", data_type=DataType.TEXT),
        Property(name="category", data_type=DataType.TEXT),
        Property(name="color", data_type=DataType.TEXT),
        Property(name="size", data_type=DataType.TEXT),
        Property(name="price", data_type=DataType.NUMBER),
        Property(name="brand", data_type=DataType.TEXT),
        Property(name="product_url", data_type=DataType.TEXT),
        Property(name="description", data_type=DataType.TEXT),
        Property(name="image_index", data_type=DataType.INT),
        Property(name="gender", data_type=DataType.TEXT),
    ],
    vectorizer_config=Configure.Vectorizer.none(),
    vector_index_config=Configure.VectorIndex.hnsw(
        distance_metric=VectorDistances.COSINE
    ),
)
print(f"Collection '{COLLECTION_NAME}' créée (dim={VECTOR_DIM}).")

## 7. Chargement du dataset ASOS

In [None]:
from datasets import load_dataset

print("Chargement du dataset ASOS...")
dataset = load_dataset("UniqueData/asos-e-commerce-dataset", split="train")
print(f"Dataset chargé: {len(dataset)} produits")

if MAX_ITEMS:
    dataset = dataset.select(range(min(MAX_ITEMS, len(dataset))))
    print(f"Limité à {len(dataset)} produits")

## 8. Fonctions utilitaires

In [None]:
import base64
import io
import json
from datetime import datetime

import requests
from PIL import Image


def parse_price(price_str):
    """Parse '£45.00' ou '$29.99' → float."""
    if not price_str:
        return None
    try:
        return float(str(price_str).strip().replace("£", "").replace("$", "").replace(",", ""))
    except (ValueError, AttributeError):
        return None


def parse_description(desc_str):
    """Extraire brand + texte depuis la description ASOS."""
    if not desc_str:
        return None, None
    try:
        data = json.loads(desc_str)
        if isinstance(data, dict):
            brand = data.get("brand") or data.get("Brand")
            text = data.get("description") or data.get("Description") or str(data)
            return brand, text
        return None, str(data)
    except (json.JSONDecodeError, TypeError):
        return None, str(desc_str)


def parse_images(images_str):
    """Extraire les URLs d'images."""
    if not images_str:
        return []
    try:
        data = json.loads(images_str)
        if isinstance(data, list):
            return [u for u in data if isinstance(u, str) and u.startswith("http")]
        if isinstance(data, str) and data.startswith("http"):
            return [data]
        return []
    except (json.JSONDecodeError, TypeError):
        if isinstance(images_str, str) and images_str.startswith("http"):
            return [images_str]
        return []


def detect_gender(product_name, category):
    """Détecter le genre depuis le nom/catégorie."""
    text = f"{product_name} {category}".lower()
    for kw in ["women's", "womens", "female", "femme", " woman ", "for women", "ladies", "maternity"]:
        if kw in text:
            return "women"
    for kw in ["men's", "mens", "male", "homme", " man ", "for men"]:
        if kw in text:
            return "men"
    return None


def download_image(url, timeout=10):
    """Télécharger une image → PIL Image ou None."""
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception:
        return None


def make_thumbnail_b64(image, size=THUMBNAIL_SIZE):
    """Créer une thumbnail base64 JPEG."""
    img = image.copy()
    img.thumbnail((size, size))
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


def encode_images_batch(images):
    """Encoder un batch d'images PIL avec Fashion CLIP sur GPU."""
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        features = model.get_image_features(**inputs)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
    return features.cpu().float().numpy()


print("Fonctions utilitaires chargées.")

## 9. Indexation

Le process :
1. Télécharge les images par batch
2. Encode le batch entier sur GPU en une passe
3. Insère dans Weaviate

In [None]:
import time
from tqdm.auto import tqdm

collection = client.collections.get(COLLECTION_NAME)

# Buffers pour le batching
image_batch = []    # PIL images en attente d'encoding
meta_batch = []     # métadonnées correspondantes
weaviate_queue = [] # documents prêts pour Weaviate

indexed = 0
skipped = 0
start_time = time.time()


def flush_gpu_batch():
    """Encode le batch GPU courant et ajoute les résultats à la queue Weaviate."""
    global image_batch, meta_batch
    if not image_batch:
        return
    vectors = encode_images_batch(image_batch)
    for vec, meta in zip(vectors, meta_batch):
        meta["vector"] = vec.tolist()
        weaviate_queue.append(meta)
    image_batch = []
    meta_batch = []


def flush_weaviate():
    """Insère la queue dans Weaviate."""
    global weaviate_queue, indexed
    if not weaviate_queue:
        return
    with collection.batch.dynamic() as batch:
        for doc in weaviate_queue:
            vec = doc.pop("vector")
            batch.add_object(properties=doc, vector=vec)
    indexed += len(weaviate_queue)
    weaviate_queue = []


pbar = tqdm(total=len(dataset), desc="Indexation")

for item in dataset:
    product_id = str(item.get("id", item.get("Unnamed: 0", "")))
    product_name = item.get("product_name", item.get("name", ""))
    category = item.get("category", item.get("product_type", ""))
    color = item.get("colour", item.get("color", ""))
    price = parse_price(item.get("price", item.get("current_price", "")))
    product_url = item.get("url", item.get("product_url", ""))

    brand_parsed, desc_text = parse_description(item.get("description", ""))
    brand = brand_parsed or item.get("brand", "")
    description = desc_text or ""

    image_urls = parse_images(str(item.get("images", item.get("image", ""))))
    gender = detect_gender(product_name or "", category or "")

    product_has_image = False

    for idx, url in enumerate(image_urls[:MAX_IMAGES_PER_PRODUCT]):
        img = download_image(url)
        if img is None:
            continue

        product_has_image = True
        w, h = img.size

        meta = {
            "filename": f"{product_id}_{idx}.jpg",
            "path": url,
            "thumbnail_base64": make_thumbnail_b64(img),
            "width": w,
            "height": h,
            "indexed_at": datetime.utcnow().isoformat(),
            "product_id": product_id,
            "product_name": product_name or "",
            "category": category or "",
            "color": color or "",
            "price": price,
            "brand": brand or "",
            "product_url": product_url or "",
            "description": description,
            "image_index": idx,
            "gender": gender,
        }

        image_batch.append(img)
        meta_batch.append(meta)

        # Flush GPU batch
        if len(image_batch) >= BATCH_SIZE:
            flush_gpu_batch()

        # Flush Weaviate batch
        if len(weaviate_queue) >= WEAVIATE_BATCH_SIZE:
            flush_weaviate()

    if not product_has_image:
        skipped += 1

    pbar.update(1)

# Flush les restes
flush_gpu_batch()
flush_weaviate()

pbar.close()
elapsed = time.time() - start_time

print(f"\n{'='*50}")
print(f"Indexation terminée !")
print(f"Images indexées: {indexed}")
print(f"Produits sans image: {skipped}")
print(f"Temps: {elapsed:.0f}s ({indexed / max(elapsed, 1):.1f} images/sec)")
print(f"{'='*50}")

## 10. Vérification

In [None]:
collection = client.collections.get(COLLECTION_NAME)
stats = collection.aggregate.over_all(total_count=True)
print(f"Collection '{COLLECTION_NAME}': {stats.total_count} documents")

## 11. Test rapide — recherche texte

In [None]:
from weaviate.classes.query import MetadataQuery
from IPython.display import display, HTML

query = "black leather jacket"

with torch.no_grad():
    inputs = processor(text=[query], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    features = model.get_text_features(**inputs)
    features = features / features.norm(p=2, dim=-1, keepdim=True)
    query_vector = features.cpu().float().numpy().flatten().tolist()

results = collection.query.near_vector(
    near_vector=query_vector,
    limit=5,
    return_metadata=MetadataQuery(distance=True),
)

print(f"Résultats pour: '{query}'\n")

html = '<div style="display:flex; gap:10px; flex-wrap:wrap;">'
for obj in results.objects:
    p = obj.properties
    score = f"{1 - (obj.metadata.distance or 0):.3f}"
    thumb = p.get("thumbnail_base64", "")
    name = p.get("product_name", "")
    price = p.get("price")
    price_str = f"£{price:.2f}" if price else ""
    html += f'''
    <div style="text-align:center; width:160px;">
        <img src="data:image/jpeg;base64,{thumb}" style="max-width:150px; max-height:150px;"/>
        <div style="font-size:11px;">{name[:40]}</div>
        <div style="font-size:11px; color:gray;">{price_str} — score: {score}</div>
    </div>'''
html += '</div>'
display(HTML(html))

In [None]:
client.close()
print("Connexion Weaviate fermée.")