In [None]:
import os
import re
from typing import Optional

import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.image_utils import load_image
from qwen_vl_utils import process_vision_info

# ---------- CONFIG -----------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct" # or "HuggingFaceM4/Idefics3-8B-Llama3"
INPUT_DIR = "output_images"
ADAPTER_PATH: Optional[str] = 'qwen' # or idefics
# -----------------------------------------------------------------------------

model = None  # will be lazily initialised once
processor = None

# ---------- PREDEFINED QUESTIONS ------------------------------------------------
TASKS = ["Object Detection", "Classification"]
PREDEF_QS = {
    "Object Detection": [
        "How is this part identified despite the strong metallic background?",
        "Which details enable the model to identify this part?",
        "Do the other metallic parts in the scene affect the prediction?",
        "Which grooves, ridges, or holes does the model highlight most prominently?",
    ],
    "Classification": [
        "Which texture cues guide the model's prediction?",
        "What part shape influences the class label most?",
        "Does background affect the prediction?",
        "How robust is the classification to lighting?",
    ],
}
CUSTOM_OPT = "Custom question..."

# ---------- MODEL LOADING HELPERS ------------------------------------------------
def _lazy_load_model():
    global model, processor
    if model is None:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            MODEL_NAME, torch_dtype="auto", device_map="auto"
        )
        if ADAPTER_PATH and ADAPTER_PATH.strip():
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, ADAPTER_PATH)
        processor = AutoProcessor.from_pretrained(MODEL_NAME)
    return model, processor

def _build_prompt(pil_img, img_name, class_id, question):
    system_prompt = f"""
    Context:
    We are analyzing a single saliency map visualization for the image named '{img_name}', focusing on the component called '{class_id}'. 
    This image helps us understand how the model perceives and identifies the object. The visualization includes:
    • Bounding Box: It encases the object, showing its location and providing insight into the model spatial awareness
    • Saliency Map: It highlights the most influential pixels in the model decision-making, with brighter areas indicating higher impact on detecting '{class_id}'.

    Your Task:
    Analyze step by step how the saliency map visualization contributes to understanding the detection of '{class_id}' and provide a detailed answer to the question below:
    Instructions for the Response:
    Provide a clear and simple explanation suitable for non-technical workers.
    Use straightforward language to describe how the saliency map (with its bounding box) shows the detection of '{class_id}'.
    Limit your final response to 10 to 20 words.
    Avoid technical jargon and complex terminology.
    Example: “The model focuses on the straight edges, sharp corners, and the unique rectangular shape of the object.”
    """
    return [
        {"role": "user", "content": [{"type": "text", "text": system_prompt.strip()}]},
        {"role": "user", "content": [{"type": "image", "image": pil_img}, {"type": "text", "text": question}]},
    ]


@torch.inference_mode()
def ask_model_eval(pil_img, img_name, class_id, question):
    model, processor = _lazy_load_model()
    msgs = _build_prompt(pil_img, img_name, class_id, question)
    text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(msgs)
    inputs = processor(
        text=[text], images=image_inputs, videos=video_inputs,
        padding=True, return_tensors="pt"
    ).to(model.device)
    gen_ids = model.generate(**inputs, max_new_tokens=300)
    trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, gen_ids)]
    return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()

# ---------- UI HELPERS ------------------------------------------------------------
def _sample_images():
    """
    Return a naturally sorted list of image filenames from INPUT_DIR.
    """
    files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
    # natural sort by numeric components
    def numeric_key(fn: str):
        nums = re.findall(r"(\d+)", fn)
        return [int(n) for n in nums] or [fn.lower()]
    return sorted(files, key=numeric_key)

@torch.inference_mode()
def _preview_image(filename):
    if not filename:
        return None
    return load_image(os.path.join(INPUT_DIR, filename))

@torch.inference_mode()
def _infer(task, mode, filename, webcam_img, q_choice, custom_q):
    # determine final question
    question = custom_q.strip() if q_choice == CUSTOM_OPT else q_choice
    # validate inputs
    if mode == "Select from list" and not filename:
        return None, "", "**Please select an image.**"
    if mode == "Live Inference" and webcam_img is None:
        return None, "", "**Please capture or upload an image.**"
    if not question:
        return None, "", "**Please enter a question.**"
    # load image
    if mode == "Select from list":
        pil_img = load_image(os.path.join(INPUT_DIR, filename))
        img_name = filename
    else:
        pil_img = webcam_img
        img_name = "live_image"
    # infer
    answer = ask_model_eval(pil_img, img_name, task, question)
    return pil_img, f"**Answer:** {answer}", ""

# ---------- BUILD DEMO -----------------------------------------------------------
with gr.Blocks(title="Factory Visual Inspection Demo", theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Visual Inspection Assistant")
    gr.Markdown("*Choose task, image, and question before submitting.*")

    with gr.Row():
        with gr.Column(scale=1):
            task_dd = gr.Dropdown(TASKS, label="Task", value=TASKS[0])
            mode_rg = gr.Radio(["Select from list", "Live Inference"], label="Image Source", value="Select from list")

            file_dd = gr.Dropdown(_sample_images(), label="Sample Images")
            webcam_in = gr.Image(label="Live / Upload Image", type="pil", visible=False)

            # question selector with custom option
            q_dd = gr.Dropdown(PREDEF_QS[TASKS[0]] + [CUSTOM_OPT], label="Question")
            q_txt = gr.Textbox(label="Custom Question", placeholder="Type your question here...", visible=False)

            run_btn = gr.Button("Submit")
            msg_out = gr.Markdown("")

        with gr.Column(scale=1):
            img_out = gr.Image(label="Selected Image", height=300)
            answer_md = gr.Markdown("")

    # ---------- EVENTS ---------------------------------------------------------
    task_dd.change(
        fn=lambda t: (gr.update(choices=PREDEF_QS[t] + [CUSTOM_OPT], value=PREDEF_QS[t][0]), gr.update(visible=False, value="")),
        inputs=task_dd,
        outputs=[q_dd, q_txt]
    )
    mode_rg.change(
        fn=lambda m: (gr.update(visible=m == "Select from list"), gr.update(visible=m == "Live Inference")),
        inputs=mode_rg,
        outputs=[file_dd, webcam_in]
    )
    file_dd.change(fn=_preview_image, inputs=file_dd, outputs=img_out)
    q_dd.change(
        fn=lambda q: gr.update(visible=q == CUSTOM_OPT),
        inputs=q_dd,
        outputs=q_txt
    )
    
    run_btn.click(
        fn=_infer,
        inputs=[task_dd, mode_rg, file_dd, webcam_in, q_dd, q_txt],
        outputs=[img_out, answer_md, msg_out]
    )
##11_box, #9_rod
if __name__ == "__main__":
    demo.launch(inline=True, share=False)
