# 🦾 Training Toolkit: Multi-adapter inference

It's time for us to load both of our adapters along with the base model and put together an inference pipeline.

## 1. Load adapters

In [None]:
from peft import PeftModel
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import PIL
import numpy as np
import cv2
import itertools

from training_toolkit.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

from training_toolkit.common.tokenization_utils.json import (
    JSONTokenizer,
)

In [None]:
MODEL_ID = "google/paligemma-3b-mix-224"
SEG_CHECKPOINT_PATH = "paligemma_trash_segm_adapter"
JSON_CHECKPOINT_PATH = "paligemma_trash_json_adapter"

# 1. Load the base model straight from the hub
base_model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)

# 2. Load the first adapter
model = PeftModel.from_pretrained(base_model, SEG_CHECKPOINT_PATH, adapter_name="segmentation")

# 3. Load the second adapter
model.load_adapter(JSON_CHECKPOINT_PATH, adapter_name="json")

# 4. Prepare utility classes to process inputs and outputs
segmentation_tokenizer = SegmentationTokenizer()
json_tokenizer = JSONTokenizer(processor)

In [None]:
image = PIL.Image.open("../assets/trash1.jpg")
# image

In [None]:
# Coming straight from the segmentation notebook

class_names = {
    "carton",
    "foam",
    "food",
    "general",
    "glass",
    "metal",
    "paper",
    "plastic",
    "special",
}

## 2. Do segmentation

In [None]:
# Prepare segmentation inputs

PROMPT = "segment " + " ; ".join(class_names)
inputs = processor(images=image, text=PROMPT)

# Enable segmentation adapter
model.set_adapter("segmentation")
generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=True)

# Post process segmentation outputs
image_token_index = model.config.image_token_index
num_image_tokens = len(generated_ids[generated_ids == image_token_index])
num_text_tokens = len(processor.tokenizer.encode(PROMPT))
num_prompt_tokens = num_image_tokens + num_text_tokens + 2

generated_text = processor.batch_decode(
    generated_ids[:, num_prompt_tokens:],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]

w, h = image.size

# Reconstruct the segmentation mask
generated_segmentation = segmentation_tokenizer.decode(generated_text, w, h)

In [None]:
generated_segmentation

### Post process segmentation mask to create JSON inputs

In [None]:
COLORS = [
    (0, 255, 255),  # Cyan
    (255, 128, 128),  # Salmon
    (255, 0, 255),  # Magenta
    (255, 128, 0),  # Orange
    (128, 255, 0),  # Lime
    (0, 255, 128),  # Spring Green
    (255, 0, 128),  # Rose
    (128, 0, 255),  # Violet
    (0, 128, 255),  # Azure
    (128, 255, 128),  # Chartreuse
    (128, 128, 255),  # Cornflower Blue
    (255, 255, 128),  # Light Yellow
    (255, 128, 255),  # Orchid
    (128, 255, 255),  # Light Cyan
    (255, 165, 0),  # Also orange
    (0, 255, 255),  # Aqua
    (255, 0, 255),  # Fuchsia
    (128, 0, 0),  # Maroon
    (128, 128, 0),  # Olive
    (0, 128, 128),  # Teal
    (128, 0, 128),  # Purple
]

colors = itertools.cycle(COLORS)

In [None]:
def draw_contours(image, masks):
    for j, mask in enumerate(masks):
        mask = np.array(mask)

        contours, _ = cv2.findContours(
            mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        image = cv2.drawContours(image, contours, -1, next(colors), 3)

        if not contours:
            continue

        # Calculate the center of the outline
        M = cv2.moments(contours[0])

        if M["m00"] == 0:
            continue

        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])

        # Add a backdrop to the label for visibility
        # text_size, _ = cv2.getTextSize(str(0), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
        cv2.rectangle(
            image,
            (cX - 5, cY - 30),
            (cX + 30, cY + 5),
            (0, 0, 0),  # Black color for the backdrop
            -1,  # Fill the rectangle
        )

        # Draw the label with the index of the mask in the middle of the outline
        cv2.putText(
            image,
            str(j),  # Replace 0 with the index of the mask
            (cX, cY),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (255, 255, 255),  # White color for the text
            2,
            cv2.LINE_AA,
        )
    return image

In [None]:
# Draw the contours on the image

masks = [(seg["mask"] > 0.5).astype(np.uint8) * 255 for seg in generated_segmentation]

image_with_masks = draw_contours(np.array(image), masks)
image_with_masks = PIL.Image.fromarray(image_with_masks)
image_with_masks

## 3. Do JSON extraction

In [None]:
[x['name'] for x in generated_segmentation]

In [None]:
from string import Template

with open("korea_summary.txt", "r") as f:
    rules = f.read()

items = [
    {"item_id": i, "class": seg["name"].strip()}
    for i, seg in enumerate(generated_segmentation)
    if seg["name"]
]

PREFIX_TEMPLATE = Template(
    "For every object outlined in the image, here're their detected classes: $items. "
    "For every outlined item, extract JSON with a more accurate label, "
    "as well as disposal directions based on local rules. "
    "The local rules are as follows: $rules."
)

prompt = PREFIX_TEMPLATE.substitute(items=items, rules=rules)
prompt

In [None]:
# Enable the JSON adapter
model.set_adapter("json")

inputs = processor(images=image_with_masks, text=prompt, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True)

generated_text = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True,
)[0]

generated_json = json_tokenizer.decode(generated_text)

In [None]:
generated_json

In [None]:
!git status