In [1]:
from peft import AutoPeftModelForCausalLM, PeftModel
from transformers import AutoProcessor, AutoModelForCausalLM, PaliGemmaForConditionalGeneration
import PIL
import numpy as np
import cv2
import itertools

from training_toolkit.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

from training_toolkit.common.tokenization_utils.json import (
    JSONTokenizer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_ID = "google/paligemma-3b-mix-224"
SEG_CHECKPOINT_PATH = "paligemma_2024-08-06_09-05-06"
JSON_CHECKPOINT_PATH = "paligemma_2024-08-13_07-21-14/checkpoint-240"


base_model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)

model = PeftModel.from_pretrained(base_model, SEG_CHECKPOINT_PATH, adapter_name="segmentation")

# load different adapter
model.load_adapter(JSON_CHECKPOINT_PATH, adapter_name="json")

# set adapter as active
# model.set_adapter("json")

segmentation_tokenizer = SegmentationTokenizer()
json_tokenizer = JSONTokenizer(processor)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.74s/it]


In [3]:
image = PIL.Image.open("../assets/trash1.jpg")
# image

In [4]:
class_names = {
    "carton",
    "foam",
    "food",
    "general",
    "glass",
    "metal",
    "paper",
    "plastic",
    "special",
}

In [5]:
PROMPT = "segment " + " ; ".join(class_names)

inputs = processor(images=image, text=PROMPT)

model.set_adapter("segmentation")
generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=True)

# Next we turn each predicted token ID back into a string using the decode method
# We chop of the prompt, which consists of image tokens and our text prompt
image_token_index = model.config.image_token_index
num_image_tokens = len(generated_ids[generated_ids == image_token_index])
num_text_tokens = len(processor.tokenizer.encode(PROMPT))
num_prompt_tokens = num_image_tokens + num_text_tokens + 2
generated_text = processor.batch_decode(
    generated_ids[:, num_prompt_tokens:],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]

w, h = image.size

generated_segmentation = segmentation_tokenizer.decode(generated_text, w, h)

In [6]:
generated_segmentation

[{'content': '<loc0207><loc0536><loc0792><loc0898> <seg015><seg066><seg066><seg088><seg022><seg091><seg044><seg022><seg104><seg048><seg078><seg026><seg072><seg095><seg075><seg026> paper ; ',
  'xyxy': (1708, 495, 2862, 1893),
  'mask': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]),
  'name': 'paper '},
 {'content': '<loc0210><loc0364><loc0706><loc0529> <seg055><seg011><seg059><seg048><seg030><seg012><seg119><seg082><seg026><seg042><seg061><seg030><seg007><seg075><seg068><seg030> special ; ',
  'xyxy': (1160, 502, 1686, 1688),
  'mask': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]

In [7]:
COLORS = [
    (0, 255, 255),  # Cyan
    (255, 128, 128),  # Salmon
    (255, 0, 255),  # Magenta
    (255, 128, 0),  # Orange
    (128, 255, 0),  # Lime
    (0, 255, 128),  # Spring Green
    (255, 0, 128),  # Rose
    (128, 0, 255),  # Violet
    (0, 128, 255),  # Azure
    (128, 255, 128),  # Chartreuse
    (128, 128, 255),  # Cornflower Blue
    (255, 255, 128),  # Light Yellow
    (255, 128, 255),  # Orchid
    (128, 255, 255),  # Light Cyan
    (255, 165, 0),  # Also orange
    (0, 255, 255),  # Aqua
    (255, 0, 255),  # Fuchsia
    (128, 0, 0),  # Maroon
    (128, 128, 0),  # Olive
    (0, 128, 128),  # Teal
    (128, 0, 128),  # Purple
]

colors = itertools.cycle(COLORS)

In [8]:
def draw_contours(image, masks):
    for j, mask in enumerate(masks):
        mask = np.array(mask)

        contours, _ = cv2.findContours(
            mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        image = cv2.drawContours(image, contours, -1, next(colors), 3)

        if not contours:
            continue

        # Calculate the center of the outline
        M = cv2.moments(contours[0])

        if M["m00"] == 0:
            continue

        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])

        # Add a backdrop to the label for visibility
        # text_size, _ = cv2.getTextSize(str(0), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
        cv2.rectangle(
            image,
            (cX - 5, cY - 30),
            (cX + 30, cY + 5),
            (0, 0, 0),  # Black color for the backdrop
            -1,  # Fill the rectangle
        )

        # Draw the label with the index of the mask in the middle of the outline
        cv2.putText(
            image,
            str(j),  # Replace 0 with the index of the mask
            (cX, cY),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (255, 255, 255),  # White color for the text
            2,
            cv2.LINE_AA,
        )
    return image

In [9]:
for seg in generated_segmentation:
    print(seg["mask"].shape)

(2448, 3264)
(2448, 3264)
(2448, 3264)
(2448, 3264)


In [None]:
masks = [(seg["mask"] > 0.5).astype(np.uint8) * 255 for seg in generated_segmentation]

image_with_masks = draw_contours(np.array(image), masks)
image_with_masks = PIL.Image.fromarray(image_with_masks)
image_with_masks

In [13]:
from string import Template

with open("korea_summary.txt", "r") as f:
    rules = f.read()

items = [{"item_id": i, "class": seg["name"].strip()} for i, seg in enumerate(generated_segmentation)]

# dataset = load_from_disk("taco_trash_json")
# image = dataset[0]["image"]
# image

# rules = dataset[0]["prefix"]["rules"]
# items = dataset[0]["prefix"]["items"]

PREFIX_TEMPLATE = Template("For every object outlined in the image, here're their detected classes: $items. "
                           "For every outlined item, extract JSON with a more accurate label, "
                           "as well as disposal directions based on local rules. "
                           "The local rules are as follows: $rules.")

prompt = PREFIX_TEMPLATE.substitute(items=items, rules=rules)
prompt

'For every object outlined in the image, here\'re their detected classes: [{\'item_id\': 0, \'class\': \'paper\'}, {\'item_id\': 1, \'class\': \'special\'}, {\'item_id\': 2, \'class\': \'general\'}, {\'item_id\': 3, \'class\': \'special\'}]. For every outlined item, extract JSON with a more accurate label, as well as disposal directions based on local rules. The local rules are as follows: In this community, garbage sorting rules are as follows:\n\n### General Waste Bag (일반 쓰레기 봉투)\n- **Color:** Typically white or green (varies by district).\n- **Contents:** Everything not classified as recyclable or food waste. Examples include used tissues, used toilet paper (when not flushed), sanitary pads, old shoes, and clothes.\n\n### Food Waste Bag (음식물 쓰레기 봉투)\n- **Contents:** Edible waste such as fruit peels, vegetable peels, uneaten meat, and raw eggs (without the shell).\n- **Exceptions:** Egg shells, crustacean shells, clam shells, onion and garlic skin, animal bones, and tea bags/leaves a

In [14]:
model.set_adapter("json")

inputs = processor(images=image_with_masks, text=prompt, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True)

# Next we turn each predicted token ID back into a string using the decode method
# We chop of the prompt, which consists of image tokens and our text prompt

generated_text = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True,
)[0]

generated_json = json_tokenizer.decode(generated_text)

In [15]:
generated_json

[{'sorted_items': [{'refined_label': 'Silk',
    'item_id': '0',
    'item_class': 'Food Waste Bag',
    'direction': 'Recyclable Waste (재활용 쓰레기) (No special bags required)'},
   {'refined_label': 'Straws',
    'item_id': '2',
    'item_class': 'Food Waste Bag',
    'direction': 'Recyclable Waste (재활용 쓰레기) (No special bags required)',
    'additional_label': 'Recyclable Waste (재활용 쓰레기) (No special bags required)',
    'label': 'Special Gifts'},
   {'refined_label': 'Silk',
    'item_id': '4',
    'item_class': 'Food Waste Bag',
    'direction': 'Recyclable Waste (재활용 쓰레기) (No special bags required)',
    'additional_label': 'Recyclable Waste (재활용 쓰레기) (No special bags required)',
    'label': 'Special Gifts'},
   {'refined_label': 'Silk',
    'item_id': '5',
    'item_class': 'Food Waste Bag',
    'direction': 'Recyclable Waste (재활용 쓰레기) (No special bags required)'},
   {'refined_label': 'Silk',
    'item_id': '6',
    'item_class': 'Food Waste Bag',
    'direction': 'Recyclable Waste 