# Fashion Search — Indexation ASOS depuis Colab (GPU)

Ce notebook indexe le dataset [ASOS e-commerce](https://huggingface.co/datasets/UniqueData/asos-e-commerce-dataset) dans Weaviate en utilisant [Fashion CLIP](https://huggingface.co/patrickjohncyh/fashion-clip) sur GPU Colab.

**Architecture :**
- **Colab (GPU)** : encode les images avec Fashion CLIP → pousse les vecteurs dans Weaviate
- **GCP (CPU)** : Weaviate + app FastAPI pour servir les recherches

**Prérequis :**
- Runtime GPU activé (Runtime → Change runtime type → T4 GPU)
- Weaviate accessible sur ta VM GCP (port 8080 ouvert dans le firewall)

## 1. Installation des dépendances

In [None]:
!pip install -q weaviate-client>=4.0 transformers torch torchvision Pillow datasets requests tqdm

## 2. Configuration

Renseigne l'IP externe de ta VM GCP où tourne Weaviate.

**Important** : le port `8080` (HTTP) et `50051` (gRPC) doivent être ouverts dans le firewall GCP.

In [None]:
# --- CONFIGURATION ---
GCP_EXTERNAL_IP = ""   # ex: "34.56.78.90"

WEAVIATE_HTTP_PORT = 8080
WEAVIATE_GRPC_PORT = 50051

COLLECTION_NAME = "FashionCollection"
MODEL_NAME = "patrickjohncyh/fashion-clip"

MAX_ITEMS = 2000              # produits à indexer (None = tout le dataset)
MAX_IMAGES_PER_PRODUCT = 1    # images par produit
BATCH_SIZE = 32               # taille des batches GPU
WEAVIATE_BATCH_SIZE = 100     # taille des batches Weaviate
THUMBNAIL_SIZE = 150

## 3. Vérification GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    device = "cpu"
    print("Pas de GPU, l'indexation sera lente.")

print(f"Device: {device}")

## 4. Chargement du modèle Fashion CLIP

In [None]:
from transformers import CLIPModel, CLIPProcessor

print(f"Chargement de {MODEL_NAME}...")
model = CLIPModel.from_pretrained(MODEL_NAME)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

if device == "cuda":
    model = model.half()  # FP16

model = model.to(device)
model.eval()

VECTOR_DIM = model.config.projection_dim
print(f"Modele charge - dimension vecteur: {VECTOR_DIM}")

## 5. Connexion a Weaviate sur GCP

In [None]:
import weaviate
from weaviate.classes.config import Configure, DataType, Property, VectorDistances

assert GCP_EXTERNAL_IP, "Renseigne GCP_EXTERNAL_IP dans la cellule de configuration."

client = weaviate.connect_to_custom(
    http_host=GCP_EXTERNAL_IP,
    http_port=WEAVIATE_HTTP_PORT,
    http_secure=False,
    grpc_host=GCP_EXTERNAL_IP,
    grpc_port=WEAVIATE_GRPC_PORT,
    grpc_secure=False,
)

print(f"Connecte a Weaviate sur {GCP_EXTERNAL_IP}" if client.is_ready() else "Echec de connexion")

## 6. Creation du schema

In [None]:
# Supprime et recree la collection
if client.collections.exists(COLLECTION_NAME):
    client.collections.delete(COLLECTION_NAME)
    print(f"Collection '{COLLECTION_NAME}' supprimee.")

client.collections.create(
    name=COLLECTION_NAME,
    properties=[
        Property(name="filename", data_type=DataType.TEXT),
        Property(name="path", data_type=DataType.TEXT),
        Property(name="thumbnail_base64", data_type=DataType.TEXT),
        Property(name="width", data_type=DataType.INT),
        Property(name="height", data_type=DataType.INT),
        Property(name="indexed_at", data_type=DataType.TEXT),
        Property(name="product_id", data_type=DataType.TEXT),
        Property(name="product_name", data_type=DataType.TEXT),
        Property(name="category", data_type=DataType.TEXT),
        Property(name="color", data_type=DataType.TEXT),
        Property(name="size", data_type=DataType.TEXT),
        Property(name="price", data_type=DataType.NUMBER),
        Property(name="brand", data_type=DataType.TEXT),
        Property(name="product_url", data_type=DataType.TEXT),
        Property(name="description", data_type=DataType.TEXT),
        Property(name="image_index", data_type=DataType.INT),
        Property(name="gender", data_type=DataType.TEXT),
    ],
    vectorizer_config=Configure.Vectorizer.none(),
    vector_index_config=Configure.VectorIndex.hnsw(
        distance_metric=VectorDistances.COSINE
    ),
)
print(f"Collection '{COLLECTION_NAME}' creee (dim={VECTOR_DIM}).")

## 7. Chargement du dataset ASOS

In [None]:
from datasets import load_dataset

print("Chargement du dataset ASOS...")
dataset = load_dataset("UniqueData/asos-e-commerce-dataset", split="train")
print(f"Dataset charge: {len(dataset)} produits")

if MAX_ITEMS:
    dataset = dataset.select(range(min(MAX_ITEMS, len(dataset))))
    print(f"Limite a {len(dataset)} produits")

## 8. Fonctions utilitaires

In [None]:
import base64
import io
from datetime import datetime

import requests
from PIL import Image


def extract_images(images_field):
    """Extract image URLs (dataset returns a Python list directly)."""
    if not images_field:
        return []
    if isinstance(images_field, list):
        return [u for u in images_field if isinstance(u, str) and u.startswith("http")]
    if isinstance(images_field, str) and images_field.startswith("http"):
        return [images_field]
    return []


def extract_description(desc_field):
    """Extract brand + text from description (list of dicts in the dataset)."""
    if not desc_field:
        return None, ""
    if isinstance(desc_field, list):
        brand = None
        texts = []
        for entry in desc_field:
            if isinstance(entry, dict):
                for key, val in entry.items():
                    if "brand" in key.lower():
                        brand = str(val) if val else None
                    else:
                        texts.append(str(val))
        return brand, " ".join(texts)
    return None, str(desc_field)


def extract_price(price_field):
    """Extract price (already a float in the dataset)."""
    if price_field is None:
        return None
    try:
        val = float(price_field)
        return val if val > 0 else None
    except (ValueError, TypeError):
        return None


def detect_gender(product_name, category):
    text = f"{product_name} {category}".lower()
    for kw in ["women's", "womens", "female", "femme", " woman ", "for women", "ladies", "maternity"]:
        if kw in text:
            return "women"
    for kw in ["men's", "mens", "male", "homme", " man ", "for men"]:
        if kw in text:
            return "men"
    return None


def download_image(url, timeout=10):
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception:
        return None


def make_thumbnail_b64(image, size=THUMBNAIL_SIZE):
    img = image.copy()
    img.thumbnail((size, size))
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


def encode_images_batch(images):
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        features = model.get_image_features(**inputs)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
    return features.cpu().float().numpy()


print("OK")

## 9. Indexation (GPU batch encoding → Weaviate sur GCP)

In [None]:
import time
from tqdm.auto import tqdm

collection = client.collections.get(COLLECTION_NAME)

image_batch = []
meta_batch = []
weaviate_queue = []

indexed = 0
skipped = 0
start_time = time.time()


def flush_gpu_batch():
    global image_batch, meta_batch
    if not image_batch:
        return
    vectors = encode_images_batch(image_batch)
    for vec, meta in zip(vectors, meta_batch):
        meta["vector"] = vec.tolist()
        weaviate_queue.append(meta)
    image_batch = []
    meta_batch = []


def flush_weaviate():
    global weaviate_queue, indexed
    if not weaviate_queue:
        return
    with collection.batch.dynamic() as batch:
        for doc in weaviate_queue:
            vec = doc.pop("vector")
            batch.add_object(properties=doc, vector=vec)
    indexed += len(weaviate_queue)
    weaviate_queue = []


pbar = tqdm(total=len(dataset), desc="Indexation")

for item in dataset:
    product_id = str(item.get("sku", ""))
    product_name = item.get("name", "")
    category = item.get("category", "")
    color = item.get("color", "")
    price = extract_price(item.get("price"))
    product_url = item.get("url", "")

    brand, description = extract_description(item.get("description"))
    image_urls = extract_images(item.get("images"))
    gender = detect_gender(product_name or "", category or "")

    product_has_image = False

    for idx, url in enumerate(image_urls[:MAX_IMAGES_PER_PRODUCT]):
        img = download_image(url)
        if img is None:
            continue

        product_has_image = True
        w, h = img.size

        meta = {
            "filename": f"{product_id}_{idx}.jpg",
            "path": url,
            "thumbnail_base64": make_thumbnail_b64(img),
            "width": w,
            "height": h,
            "indexed_at": datetime.utcnow().isoformat(),
            "product_id": product_id,
            "product_name": product_name or "",
            "category": category or "",
            "color": color or "",
            "price": price,
            "brand": brand or "",
            "product_url": product_url or "",
            "description": description,
            "image_index": idx,
            "gender": gender,
        }

        image_batch.append(img)
        meta_batch.append(meta)

        if len(image_batch) >= BATCH_SIZE:
            flush_gpu_batch()

        if len(weaviate_queue) >= WEAVIATE_BATCH_SIZE:
            flush_weaviate()

    if not product_has_image:
        skipped += 1

    pbar.update(1)

flush_gpu_batch()
flush_weaviate()

pbar.close()
elapsed = time.time() - start_time

print(f"\n{'='*50}")
print(f"Indexation terminee !")
print(f"Images indexees: {indexed}")
print(f"Produits sans image: {skipped}")
print(f"Temps: {elapsed:.0f}s ({indexed / max(elapsed, 1):.1f} images/sec)")
print(f"{'='*50}")

## 10. Verification

In [None]:
collection = client.collections.get(COLLECTION_NAME)
stats = collection.aggregate.over_all(total_count=True)
print(f"Collection '{COLLECTION_NAME}': {stats.total_count} documents")

## 11. Test rapide

In [None]:
from weaviate.classes.query import MetadataQuery
from IPython.display import display, HTML

query = "black leather jacket"

with torch.no_grad():
    inputs = processor(text=[query], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    features = model.get_text_features(**inputs)
    features = features / features.norm(p=2, dim=-1, keepdim=True)
    query_vector = features.cpu().float().numpy().flatten().tolist()

results = collection.query.near_vector(
    near_vector=query_vector,
    limit=5,
    return_metadata=MetadataQuery(distance=True),
)

print(f"Resultats pour: '{query}'\n")

html = '<div style="display:flex; gap:10px; flex-wrap:wrap;">'
for obj in results.objects:
    p = obj.properties
    score = f"{1 - (obj.metadata.distance or 0):.3f}"
    thumb = p.get("thumbnail_base64", "")
    name = p.get("product_name", "")
    price = p.get("price")
    price_str = f"\u00a3{price:.2f}" if price else ""
    html += f'''
    <div style="text-align:center; width:160px;">
        <img src="data:image/jpeg;base64,{thumb}" style="max-width:150px; max-height:150px;"/>
        <div style="font-size:11px;">{name[:40]}</div>
        <div style="font-size:11px; color:gray;">{price_str} - score: {score}</div>
    </div>'''
html += '</div>'
display(HTML(html))

In [None]:
client.close()
print("Connexion Weaviate fermee.")