# 🦾 Training Toolkit: Segmentation

## 1. Preparations

### Download data and install packages

In [None]:
!wget -O taco.zip https://zenodo.org/records/3587843/files/TACO.zip?download=1
!unzip -o taco.zip -d taco_raw && rm taco.zip

In [None]:
!cd .. && pip3 install -r requirements.txt
!cd .. && git clone https://github.com/tensorsense/training_toolkit.git
!cd ../training_toolkit && pip3 install --upgrade -e .

### Convert the dataset into HF 🤗 format

In [None]:
from datasets import Dataset, Image
from collections import defaultdict
import PIL

from tqdm import tqdm
from pycocotools.coco import COCO
import numpy as np
import cv2
import albumentations as A

from pathlib import Path

In [None]:
dataset_path = Path("taco_raw/TACO/data/")
coco = COCO(dataset_path.joinpath("annotations.json").as_posix())

image_ids = coco.getImgIds()
categories = [coco.cats[cat_id]["name"] for cat_id in coco.getCatIds()]

In [None]:
# Let's map some of the classes to a more general class to make segmentation easier

plastic = [
    "Other plastic bottle",
    "Clear plastic bottle",
    "Plastic bottle cap",
    "Disposable plastic cup",
    "Other plastic cup",
    "Plastic lid",
    "Other plastic",
    "Plastic film",
    "Other plastic wrapper",
    "Other plastic container",
    "Plastic glooves",
    "Plastic utensils",
    "Plastic straw",
    "Disposable food container",
    "Polypropylene bag",
    "Single-use carrier bag",
    "Carded blister pack",
    "Crisp packet",
    "Garbage bag",
    "Six pack rings",
    "Spread tub",
    "Squeezable tube",
    "Tupperware",
]

glass = ["Glass bottle", "Broken glass", "Glass cup", "Glass jar"]

paper = [
    "Paper cup",
    "Magazine paper",
    "Wrapping paper",
    "Normal paper",
    "Paper bag",
    "Plastified paper bag",
    "Paper straw",
]

carton = [
    "Other carton",
    "Egg carton",
    "Drink carton",
    "Corrugated carton",
    "Meal carton",
    "Pizza box",
    "Toilet tube",
]

metal = [
    "Aluminium foil",
    "Aluminium blister pack",
    "Metal bottle cap",
    "Food Can",
    "Drink can",
    "Metal lid",
    "Scrap metal",
    "Pop tab",
]

foam = [
    "Foam cup",
    "Foam food container",
    "Styrofoam piece",
]

special = [
    "Aerosol",
    "Battery",
    "Rope & strings",
    "Shoe",
    "Cigarette",
]

food = [
    "Food waste",
]

general = [
    "Tissues",
    "Unlabeled litter",
]

class_map = (
    {item: "plastic" for item in plastic}
    | {item: "glass" for item in glass}
    | {item: "paper" for item in paper}
    | {item: "carton" for item in carton}
    | {item: "metal" for item in metal}
    | {item: "foam" for item in foam}
    | {item: "special" for item in special}
    | {item: "food" for item in food}
    | {item: "general" for item in general}
)

class_names = set(class_map.values())

In [None]:
class_names

In [None]:
IMAGE_SIZE = 512
LIMIT_SAMPLES = 500

# HF Datasets may choke on full images, so we'll resize them to a smaller size

transform = A.Compose(
    [
        A.SmallestMaxSize(max_size=IMAGE_SIZE, always_apply=True),
        A.CenterCrop(height=IMAGE_SIZE, width=IMAGE_SIZE, always_apply=True),
    ],
    bbox_params=A.BboxParams(
        format="pascal_voc", label_fields=["class_labels"], clip=True, min_area=1
    ),
)

dataset_dict = defaultdict(list)
prefix = "segment " + " ; ".join(class_names)


for image_id in tqdm(image_ids):

    # 1. Parse COCO annotations
    image_path = dataset_path.joinpath(coco.loadImgs(image_id)[0]["file_name"])
    annotations = coco.loadAnns(coco.getAnnIds(image_id))
    xywh_bboxes = [ann["bbox"] for ann in annotations]
    xyxy_bboxes = [[x, y, x + w, y + h] for x, y, w, h in xywh_bboxes]
    original_classes = [categories[ann["category_id"]] for ann in annotations]
    classes = [class_map[original_class] for original_class in original_classes]

    # 2. Load and resize the image and its annotations
    image = cv2.imread(image_path.as_posix())
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    masks = [coco.annToMask(ann) for ann in annotations]

    transformed = transform(
        image=image, masks=masks, bboxes=xyxy_bboxes, class_labels=classes
    )

    # 3. Prepare the sample for storage
    image = PIL.Image.fromarray(transformed["image"])
    masks = np.array(transformed["masks"], dtype=bool)
    xyxy_bboxes = np.array(transformed["bboxes"], dtype=int)
    classes = transformed["class_labels"]

    xyxy_bboxes = np.array(
        [
            [x1, y1, x2, y2]
            for x1, y1, x2, y2 in xyxy_bboxes
            if x2 - x1 > 0 and y2 - y1 > 0
        ]
    )

    if len(masks) == 0 or len(xyxy_bboxes) == 0 or len(classes) == 0:
        continue

    assert len(masks.shape) == 3
    assert (
        len(xyxy_bboxes.shape) == 2 and xyxy_bboxes.shape[1] == 4
    ), f"{xyxy_bboxes.shape}, {len(masks)}, {len(xyxy_bboxes)}"

    # 4. Store the sample
    dataset_dict["image"].append(image)
    dataset_dict["prompt"].append(prefix)
    dataset_dict["xyxy_bboxes"].append(xyxy_bboxes)
    dataset_dict["masks"].append(masks)
    dataset_dict["classes"].append(classes)
    dataset_dict["original_classes"].append(original_classes)


# Convert the dataset to HF format and save it to disk
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("image", Image())

dataset.info.dataset_name = "taco_trash"
dataset.info.description = f"class_names: {' ; '.join(class_names)}"

dataset.save_to_disk("taco_trash")

### Test segmentation tokenizer

`SegmentationTokenizer` is a utility that transforms segmentation masks into sequences of 20 tokens and back.

In [None]:
from datasets import Dataset
import PIL
import numpy as np
from training_toolkit.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

In [None]:
dataset = Dataset.load_from_disk("taco_trash")
dataset = dataset.with_format("torch")

segmentation_tokenizer = SegmentationTokenizer()

In [None]:
# 1. Let's take a look at the original image
example = dataset[0]
PIL.Image.fromarray(example["image"].permute(1, 2, 0).numpy())

In [None]:
# 2. ...and it's mask
PIL.Image.fromarray(example["masks"][0].numpy())

In [None]:
# 3. Now let's encode the mask and take a look at the resulting token

suffix = segmentation_tokenizer.encode(
    example["image"], example["xyxy_bboxes"], example["masks"], example["classes"]
)

suffix

In [None]:
# 4. Finally, let's decode the token sequence back into a pixel-level mask again
decoded = segmentation_tokenizer.decode(suffix, 512, 512)

PIL.Image.fromarray((decoded[0]["mask"] > 0.5).astype(np.uint8) * 255)

## 2. Train the model

In [None]:
# PaliGemma is in the gated repo, so we need to load the HF API token

from dotenv import load_dotenv
_ = load_dotenv()

In [None]:
# Load necessary bits from the toolkit
from training_toolkit import paligemma_image_preset, image_segmentation_preset, build_trainer

In [None]:
# Default setup results in OOM, so we need to set a smaller batch size
paligemma_image_preset.training_args["per_device_train_batch_size"] = 12
paligemma_image_preset.training_args["per_device_eval_batch_size"] = 12
paligemma_image_preset.training_args["num_train_epochs"] = 8

# Pass necessary arguments to the trainer (most of them pre-made in the presets)
trainer = build_trainer(
    **paligemma_image_preset.as_kwargs(),
    **image_segmentation_preset.with_path("taco_trash").as_kwargs()
)

In [None]:
# Train the model
trainer.train()

## 3. Load and run the model

In [None]:
from peft import AutoPeftModelForCausalLM
from transformers import AutoProcessor
import PIL
import numpy as np
import cv2
import supervision as sv

from training_toolkit.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

In [None]:
CHECKPOINT_PATH = "paligemma_2024-08-06_09-05-06"

model = AutoPeftModelForCausalLM.from_pretrained(CHECKPOINT_PATH)
processor = AutoProcessor.from_pretrained(CHECKPOINT_PATH)
segmentation_tokenizer = SegmentationTokenizer()

In [None]:
image = PIL.Image.open("../assets/trash1.jpg")
image

In [None]:
# 1. In order to be in line with pretraining, we need to pass class names as part of the prompt
PROMPT = "segment " + " ; ".join(class_names)

# 2. Let's generate some text using the standard HF way
inputs = processor(images=image, text=PROMPT)
generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=True)

# 3. Chop up the result to recover generated segmentation masks
image_token_index = model.config.image_token_index
num_image_tokens = len(generated_ids[generated_ids == image_token_index])
num_text_tokens = len(processor.tokenizer.encode(PROMPT))
num_prompt_tokens = num_image_tokens + num_text_tokens + 2

# 4. Decode and reconstruct the masks
generated_text = processor.batch_decode(
    generated_ids[:, num_prompt_tokens:],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]

w, h = image.size

generated_segmentation = segmentation_tokenizer.decode(generated_text, w, h)

In [None]:
PIL.Image.fromarray((generated_segmentation[0]["mask"] > 0.5).astype(np.uint8) * 255)

In [None]:
generated_segmentation

## 4. Visualize the result

In [None]:
xyxy = []
mask = []
class_id = []
class_name = []

for r in generated_segmentation:
    if "xyxy" not in r or "mask" not in r or r["mask"] is None:
        continue
    
    xyxy.append(r["xyxy"])
    _, m = cv2.threshold(r["mask"], 0.5, 1.0, cv2.THRESH_BINARY)
    mask.append(m)
    # class_id.append(ds.classes.index(r["name"].strip()))
    # class_id.append(classes.index(r['name'].strip()))
    class_id.append(list(class_names).index(r["name"].strip()))
    class_name.append(r["name"].strip() if r["name"] is not None else "trash")

detections = sv.Detections(
    xyxy=np.array(xyxy).astype(int),
    mask=np.array(mask).astype(bool),
    class_id=np.array(class_id).astype(int),
)

detections["class_name"] = class_name

In [None]:
image = sv.BoxAnnotator().annotate(image, detections)

image = sv.MaskAnnotator().annotate(image, detections)
image = sv.LabelAnnotator(text_scale=2, text_thickness=4, text_position=sv.Position.CENTER_OF_MASS, text_color=sv.Color.BLACK).annotate(image, detections)

# sv.plot_images_grid([image], (2, 2))
image