In [1]:
import sglang as sgl
from sglang.utils import wait_for_server
from sglang.srt.constrained import build_regex_from_object
from PIL import Image
from enum import Enum
from pydantic import BaseModel
import json

In [2]:
# launch vlm server in separate terminal via:
# python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava --host=0.0.0.0
# after starting server somewhere else

# !curl https://localhost:30000
wait_for_server("http://localhost:30000")
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

In [3]:
# define structured output
class color(str, Enum):
    brown = "brown"
    blue = "blue"
    yellow = "yellow"
    gray = "gray"
    green = "green"
    cyan = "cyan"
    red = "red"
    purple = "purple"


class object_type(str, Enum):
    tellerfeder = "Tellerfeder"
    spannzapfen = "Spannzapfen"
    sechskantschraube_m16x70_zlü = "Sechskantschraube_M16x70_ZLÜ"


class object_properties(BaseModel):
    object_type: object_type
    color: color

In [21]:
@sgl.function
def object_classification(s, image: str | Image.Image):
    s += sgl.user(
        sgl.image(image)
        + "Klassifiziere das Bauteil im Bild egmäß dem JSON Format.\n"
    )
    s += sgl.gen(
        "object description",
        max_tokens=128,
        temperature=0.2,
        regex=build_regex_from_object(object_properties),  # Requires pydantic >= 2.0
    )


def object_classification_gen(image: str | Image.Image):
    state = object_classification.run(image)
    response = state.text()
    start_index = response.find("{")
    return json.loads(response[start_index:])

In [None]:
# idiot VLM literally fails every answer
img = Image.open("/app/vlm_test_imgs/sechskantschraube_M16x70_ZLÜ_001.jpg")
structured_output = object_classification_gen(img)
print(structured_output)

{'object_type': 'Spannzapfen', 'color': 'gray'}
