# Semantic Segmentation Contract

Goal: demonstrate a clean, explicit contract for semantic segmentation.

- Load a pretrained model from torchvision.
- Run inference on a demo image.
- Show both lazy and explicit model loading.

## 1. Workspace information

In [None]:
!uname -a

!free -h || true

!cat /proc/cpuinfo | head -n 20 || true

## 2. Workspace setup (Colab-friendly)

In [None]:
try:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')
except Exception:
    pass

import os
import glob

my_name = '04_semantic_segmentation_contract.ipynb'
my_path = glob.glob(os.getcwd() + '/**/' + my_name, recursive=True)

if my_path:
    nb_dir = os.path.dirname(my_path[0])
    os.chdir(os.path.abspath(os.path.join(nb_dir, '..')))

print("Current dir:", os.getcwd())
!ls


## 3. Load demo images

In [None]:
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

img_dir = Path("data/input/images")
img1 = Image.open(img_dir / "image3.png").convert("RGB")

def show(img, title: str) -> None:
    plt.figure()
    plt.imshow(img)
    plt.axis("off")
    plt.title(title)
    plt.show()

show(img1, "image3.png")


## 4. Install dependencies

In [None]:
!pip -q install -r requirements/requirements-nlp.txt
!pytest -q

## 5. Run Semantic Segmentation

In [None]:
from src.vision.segmentation import segment_semantic, load_pretrained_segmentation_model

# Explicit model loading
loaded = load_pretrained_segmentation_model()
class_map = segment_semantic(img1, loaded=loaded)

# Lazy model loading
class_map_lazy = segment_semantic(img1)

print(f"Class map shape: {class_map.shape}")
print(f"Lazy class map shape: {class_map_lazy.shape}")

# Display the result
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(img1)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(class_map)
plt.title("Segmentation Map")
plt.axis("off")

plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# garante consistência de dimensões para plot
h, w = class_map.shape
img_plot = img1
if img1.size != (w, h):
    img_plot = img1.resize((w, h))

unique_ids = np.unique(class_map)
fg_ids = [int(i) for i in unique_ids if int(i) != 0]

print("unique class ids:", unique_ids)
print("labels:", [loaded.categories[i] for i in unique_ids if i < len(loaded.categories)])

if not fg_ids:
    plt.figure(figsize=(6, 6))
    plt.imshow(img_plot)
    plt.title("Original (only background predicted)")
    plt.axis("off")
    plt.show()
else:
    n = len(fg_ids)
    plt.figure(figsize=(18, 4 * n))

    for r, cid in enumerate(fg_ids):
        label = loaded.categories[cid] if cid < len(loaded.categories) else f"class_{cid}"
        mask = (class_map == cid).astype(np.uint8)

        # 1) Original (completa, consistente com o map)
        ax = plt.subplot(n, 3, 3 * r + 1)
        ax.imshow(img_plot)
        ax.set_title("Original")
        ax.axis("off")

        # 2) Overlay
        ax = plt.subplot(n, 3, 3 * r + 2)
        ax.imshow(img_plot)
        ax.imshow(mask, cmap="Reds", alpha=0.5)
        ax.set_title(f"Overlay: {label} (id={cid})")
        ax.axis("off")

        # 3) Mask
        ax = plt.subplot(n, 3, 3 * r + 3)
        ax.imshow(mask, cmap="gray")
        ax.set_title(f"Mask: {label} (id={cid})")
        ax.axis("off")

    plt.tight_layout(pad=0.2, w_pad=0.1, h_pad=0.1)
    plt.subplots_adjust(wspace=0.0005, hspace=0.0005)
    plt.show()