In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Évaluation de la segmentation sémantique des vêtements\n",
    "Ce notebook automatise l'appel à l'API Hugging Face du modèle SegFormer, affiche les masques colorisés, modifie l'image selon les classes détectées, et propose de sauvegarder les résultats."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import base64\n",
    "import requests\n",
    "from dotenv import load_dotenv\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "\n",
    "# Charger les variables d'environnement (notamment HUGGINGFACE_API_TOKEN)\n",
    "load_dotenv()\n",
    "API_TOKEN = os.getenv(\"HUGGINGFACE_API_TOKEN\")\n",
    "MODEL_ID = \"sayeed99/segformer_b3_clothes\"\n",
    "\n",
    "# Répertoires et constantes\n",
    "IMAGES_DIR = \"assets/images\"\n",
    "MAX_IMAGE_SIZE = 512\n",
    "\n",
    "# Classes et palette couleur\n",
    "CLASSES = [\n",
    "    \"background\", \"hat\", \"hair\", \"sunglasses\", \"upper-clothes\", \"dress\",\n",
    "    \"coat\", \"socks\", \"pants\", \"gloves\", \"scarf\", \"skirt\", \"face\",\n",
    "    \"left-arm\", \"right-arm\", \"left-leg\", \"right-leg\", \"left-shoe\", \"right-shoe\"\n",
    "]\n",
    "\n",
    "PALETTE = [\n",
    "    [0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],\n",
    "    [255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85, 0], [0, 85, 85],\n",
    "    [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255], [85, 255, 170],\n",
    "    [170, 255, 85], [255, 255, 0], [255, 170, 0], [255, 0, 255]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_and_resize_image(path):\n",
    "    img = Image.open(path).convert(\"RGB\")\n",
    "    img.thumbnail((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE))\n",
    "    return img\n",
    "\n",
    "def call_hf_api(image: Image.Image):\n",
    "    buffered = io.BytesIO()\n",
    "    image.save(buffered, format=\"PNG\")\n",
    "    image_bytes = buffered.getvalue()\n",
    "\n",
    "    headers = {\n",
    "        \"Authorization\": f\"Bearer {API_TOKEN}\",\n",
    "        \"Content-Type\": \"application/octet-stream\"\n",
    "    }\n",
    "\n",
    "    url = f\"https://api-inference.huggingface.co/models/{MODEL_ID}\"\n",
    "    response = requests.post(url, headers=headers, data=image_bytes)\n",
    "\n",
    "    if response.status_code == 200:\n",
    "        return response.json()\n",
    "    else:\n",
    "        raise Exception(f\"Erreur API {response.status_code}: {response.text}\")\n",
    "\n",
    "def decode_mask(mask_base64):\n",
    "    mask_bytes = base64.b64decode(mask_base64)\n",
    "    mask_img = Image.open(io.BytesIO(mask_bytes)).convert(\"L\")\n",
    "    mask_arr = np.array(mask_img)\n",
    "    mask_bin = (mask_arr > 0).astype(np.uint8)\n",
    "    return mask_bin\n",
    "\n",
    "def build_color_mask(api_result, shape):\n",
    "    final_mask = np.zeros(shape, dtype=np.uint8)\n",
    "    for obj in api_result:\n",
    "        label_name = obj.get(\"label\", \"\").lower()\n",
    "        if label_name in [c.lower() for c in CLASSES]:\n",
    "            label_index = [c.lower() for c in CLASSES].index(label_name)\n",
    "            mask_bin = decode_mask(obj[\"mask\"])\n",
    "            final_mask[mask_bin == 1] = label_index\n",
    "    return final_mask\n",
    "\n",
    "def apply_palette(mask_array):\n",
    "    color_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 3), dtype=np.uint8)\n",
    "    for label_index in np.unique(mask_array):\n",
    "        if label_index < len(PALETTE):\n",
    "            color_mask[mask_array == label_index] = PALETTE[label_index]\n",
    "    return color_mask\n",
    "\n",
    "def apply_mask_effect_with_selection(img, mask_array, selected_classes):\n",
    "    img_np = np.array(img).copy()\n",
    "    modified = img_np.copy()\n",
    "\n",
    "    selected_indices = [CLASSES.index(c) for c in selected_classes]\n",
    "\n",
    "    for label_index in np.unique(mask_array):\n",
    "        if label_index == 0 or label_index not in selected_indices:\n",
    "            continue\n",
    "        class_mask = mask_array == label_index\n",
    "        if label_index == CLASSES.index(\"hair\"):\n",
    "            gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)\n",
    "            gray_rgb = np.stack((gray,) * 3, axis=-1)\n",
    "            modified[class_mask] = gray_rgb[class_mask]\n",
    "        elif label_index == CLASSES.index(\"coat\"):\n",
    "            modified[class_mask] = [255, 0, 0]\n",
    "        elif label_index == CLASSES.index(\"pants\"):\n",
    "            modified[class_mask] = [0, 255, 0]\n",
    "        elif label_index == CLASSES.index(\"upper-clothes\"):\n",
    "            upper = img_np[class_mask]\n",
    "            saturated = np.clip(upper * 1.5, 0, 255).astype(np.uint8)\n",
    "            modified[class_mask] = saturated\n",
    "    return modified\n",
    "\n",
    "def show_and_save_panel(img, color_mask, modified_img, detected_labels, image_name):\n",
    "    import matplotlib.pyplot as plt\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(18, 6))\n",
    "\n",
    "    axs[0].imshow(img)\n",
    "    axs[0].set_title(\"Image originale\")\n",
    "    axs[0].axis(\"off\")\n",
    "\n",
    "    axs[1].imshow(color_mask)\n",
    "    axs[1].set_title(\"Masque colorisé\")\n",
    "    axs[1].axis(\"off\")\n",
    "\n",
    "    axs[2].imshow(modified_img)\n",
    "    axs[2].set_title(\"Image modifiée (désaturation vêtements)\")\n",
    "    axs[2].axis(\"off\")\n",
    "\n",
    "    legend_text = \"\\n\".join(detected_labels)\n",
    "    plt.figtext(0.92, 0.5, legend_text, fontsize=10, va='center')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    user_input = input(\"💾 Souhaitez-vous sauvegarder ce panneau ? (o/n) : \").strip().lower()\n",
    "    if user_input in ['o', 'y', 'oui', 'yes']:\n",
    "        os.makedirs(\"outputs\", exist_ok=True)\n",
    "        save_path = os.path.join(\"outputs\", f\"panel_{image_name}\")\n",
    "        fig.savefig(save_path)\n",
    "        print(f\"✅ Panneau sauvegardé dans {save_path}\")\n",
    "    else:\n",
    "        print(\"❌ Panneau non sauvegardé.\")\n",
    "\n",
    "    plt.close()\n",
    "\n",
    "def main():\n",
    "    images = [f for f in os.listdir(IMAGES_DIR) if f.lower().endswith(\".png\")]\n",
    "    if not images:\n",
    "        print(\"❌ Aucune image PNG trouvée dans\", IMAGES_DIR)\n",
    "        return\n",
    "\n",
    "    selected_classes = [\"hair\", \"coat\", \"pants\", \"upper-clothes\"]\n",
    "\n",
    "    for image_name in images:\n",
    "        print(f\"\\n📸 Traitement de : {image_name}\")\n",
    "        img_path = os.path.join(IMAGES_DIR, image_name)\n",
    "        img = load_and_resize_image(img_path)\n",
    "\n",
    "        try:\n",
    "            result = call_hf_api(img)\n",
    "\n",
    "            if isinstance(result, list) and len(result) > 0 and \"mask\" in result[0]:\n",
    "                final_mask = build_color_mask(result, (img.height, img.width))\n",
    "                color_mask = apply_palette(final_mask)\n",
    "\n",
    "                detected_labels = [obj[\"label\"] for obj in result]\n",
    "                modified_img = apply_mask_effect_with_selection(img, final_mask, selected_classes)\n",
    "\n",
    "                show_and_save_panel(img, color_mask, modified_img, detected_labels, image_name)\n",
    "            else:\n",
    "                print(\"Réponse API inattendue :\", result)\n",
    "\n",
    "        except Exception as e:\n",
    "            print(\"❌ Erreur:\", e)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
