<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Holo_1_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch Pillow requests pydantic -q

In [2]:
!nvidia-smi

Wed Jun  4 12:37:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   41C    P8             11W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import json
import requests
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
import re # Import regex for more robust parsing

# --- 1. Load the Holo-1 model and processor ---
model_name = "Hcompany/Holo1-7B" # Or "Hcompany/Holo1-3B"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_name)

# --- 2. Define a helper function for inference ---
def run_inference(messages: list[dict[str, Any]]) -> str:
    """
    Runs inference on the Holo-1 model with a given set of messages.
    Extracts only the assistant's response.
    """
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=text, images=messages[0]["content"][0]["image"], return_tensors="pt").to(model.device)

    generated_ids = model.generate(**inputs, max_new_tokens=100)
    decoded_output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # --- FIX START ---
    # Extract only the assistant's response part
    # The output format is typically: <system_prompt><user_prompt><assistant_response>
    # We are looking for the last part after the "assistant\n" token.
    assistant_prefix = "assistant\n"
    if assistant_prefix in decoded_output:
        assistant_response = decoded_output.split(assistant_prefix, 1)[1].strip()
        return assistant_response
    else:
        # Fallback if the expected prefix is not found, return full output but indicate an issue
        print("Warning: 'assistant\\n' prefix not found in model output. Returning full output.")
        return decoded_output
    # --- FIX END ---

# --- 3. Prepare the image and instruction for UI localization ---
image_url = "https://huggingface.co/Hcompany/Holo1-7B/resolve/main/calendar_example.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)

instruction = "Click on the '3' on the calendar."
guidelines = "Localize an element on the GUI image according to my instructions and output a click position as Click(x, y) with x num pixels from the left edge and y num pixels from the top edge."

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": f"{guidelines}\n{instruction}"},
        ],
    }
]

In [4]:
# --- 4. Run inference and parse the output ---
coordinates_str = run_inference(messages)
print(f"Holo-1 output (extracted): {coordinates_str}")

# Optional: Parse the structured output using Pydantic (or regex/string parsing)
class ClickAction(BaseModel):
    action: Literal["click"] = "click"
    x: int
    y: int

    model_config = ConfigDict(
        extra="forbid",
        json_schema_serialization_defaults_required=True,
        json_schema_mode_override="serialization",
        use_attribute_docstrings=True,
    )

try:
    # Use regex to extract x and y more reliably
    match = re.match(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
    if match:
        x = int(match.group(1))
        y = int(match.group(2))
        click_action = ClickAction(action="click", x=x, y=y)
        print(f"Parsed Click Action: x={click_action.x}, y={click_action.y}")
    else:
        print("Output not in expected 'Click(x, y)' format.")

except Exception as e:
    print(f"Error parsing output: {e}")

Holo-1 output (extracted): Click(426, 278)
Parsed Click Action: x=426, y=278
