In [None]:
import os
import json
import torch
import pandas as pd
from pathlib import Path
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

# --------------------------
# Config
# --------------------------
MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
IMAGE_DIR = "inference_images"
OUTPUT_JSON_DIR = "outputs_hierarchy_json_25"
OUTPUT_CSV = "qwen25vl_component_hierarchy.csv"
MAX_NEW_TOKENS = 2500

PROMPT_HIERARCHY = (
    "You are an expert UI layout analyzer. "
    "Analyze this wireframe and output all visible components in a hierarchical JSON structure.\n\n"
    "Each component should be represented as an object with:\n"
    "- 'type': the component name (e.g., header, nav, hero, button, image, card, footer)\n"
    "- 'attributes': a dictionary with attributes like color, position, size, alignment, and text content if visible\n"
    "- 'children': a list of nested components inside it\n\n"
    "The root node should represent the full page as 'page'.\n"
    "Follow the visual hierarchy (top to bottom, left to right). Output valid JSON only‚Äîno text outside the JSON."
)

# --------------------------
# Load Model
# --------------------------
def load_model_and_processor():
    has_cuda = torch.cuda.is_available()
    quant = None
    dtype = "auto"
    device_map = "auto" if has_cuda else "cpu"

    if has_cuda:
        quant = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        quantization_config=quant,
        torch_dtype=dtype,
        device_map=device_map,
        low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    return model, processor


# --------------------------
# Inference Function
# --------------------------
def extract_hierarchy(image_path, model, processor):
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": PROMPT_HIERARCHY},
        ],
    }]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image_path], return_tensors="pt").to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)

    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return output_text.strip()


# --------------------------
# Main Function
# --------------------------
def main():
    model, processor = load_model_and_processor()
    os.makedirs(OUTPUT_JSON_DIR, exist_ok=True)

    results = []
    image_dir = Path(IMAGE_DIR)

    for fname in sorted(os.listdir(image_dir)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            continue

        img_path = str(image_dir / fname)
        stem = Path(fname).stem
        print(f"üîç Extracting hierarchy for {fname}...")

        row = {"image": fname, "raw_output": "", "status": ""}

        try:
            raw_json = extract_hierarchy(img_path, model, processor)
            row["raw_output"] = raw_json

            # Attempt to parse JSON
            try:
                parsed_json = json.loads(raw_json)
                json_path = os.path.join(OUTPUT_JSON_DIR, f"{stem}_hierarchy.json")
                with open(json_path, "w", encoding="utf-8") as f:
                    json.dump(parsed_json, f, indent=2)
                row["status"] = "parsed"
            except json.JSONDecodeError:
                # Save raw text if JSON invalid
                with open(os.path.join(OUTPUT_JSON_DIR, f"{stem}_raw.txt"), "w", encoding="utf-8") as f:
                    f.write(raw_json)
                row["status"] = "invalid_json"

        except Exception as e:
            row["raw_output"] = f"ERROR: {e}"
            row["status"] = "error"

        results.append(row)

    pd.DataFrame(results).to_csv(OUTPUT_CSV, index=False)
    print(f"\n‚úÖ Extraction complete! Results saved to {OUTPUT_CSV}\nüìÇ JSONs in: {OUTPUT_JSON_DIR}")


if __name__ == "__main__":
    main()
