In [None]:
import torch
import os
import gc
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import PeftModel
import sys
import re

# Assume utils is in current directory
sys.path.append('./') 
from utils.data_utils import vis_FLlabels

# ========== 1. Load Model ==========
def load_model(base_model_path: str, peft_model_path: str, device: torch.device):
    torch.cuda.empty_cache()
    gc.collect()
    print(f"üöÄ Loading base model...")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        attn_implementation="eager" 
    )
    print("üîÑ Loading LoRA weights...")
    model = PeftModel.from_pretrained(base_model, peft_model_path)
    processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
    model.to(device)
    model.eval()
    return model, processor

# ========== 2. Inference Function ==========
def infer_single_image(model, processor, image_path, task, text_input=None, bbox_str=None, device=None):
    """
    General inference function
    - text_input: For Grounding tasks, e.g., "pig"
    - bbox_str: For tasks requiring region input like REGION_TO_SEGMENTATION
    """
    image = Image.open(image_path).convert('RGB')
    
    # --- Construct Prompt ---
    if bbox_str:
        # Scenario A: Region task (Task + BBox)
        loc_start = bbox_str.find("<")
        clean_bbox = bbox_str[loc_start:] if loc_start != -1 else bbox_str
        question = task + clean_bbox
    elif text_input:
        # Scenario B: Text grounding task (Task + Text)
        question = task + text_input
    else:
        # Scenario C: Pure task
        question = task
    
    print(f"  üìù Prompt: {question[:100]}...")  # Print first 100 characters
        
    inputs = processor(text=question, images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                num_beams=3,
                do_sample=False
            )
    
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    return generated_text.replace('<s>', '').replace('</s>', '')

# ========== 3. Main Program ==========
def main():
    # --- Configuration ---
    target_device = 'cuda:1'
    base_path = '/sdb1_hdisk/pub_data/MODELS/Florence-2-large-ft/'
    peft_path = '/sdb1_hdisk/pub_data/chenhong/Florence2/results/KEYPOINT_ONLY/epoch_299/'
    img_path = '/sdb1_hdisk/pub_data/DATAS/BamaPig2D/images/000000.png'
    
    try:
        device = torch.device(target_device)
        torch.cuda.get_device_name(device)
    except:
        device = torch.device('cuda:0')
    print(f"üéØ Using device: {device}")

    # Load model
    model, processor = load_model(base_path, peft_path, device)

    # ==========================================
    # ‚ö†Ô∏è Modify task type here
    # ==========================================
    current_task = "<KEYPOINT>"
    
    # =====================================================
    # üîÄ Task Branch Logic
    # =====================================================
    
    # --- Branch 1: Tasks requiring detection boxes first (e.g., SEGMENTATION) ---
    if current_task == "<REGION_TO_SEGMENTATION>":
        print("\n" + "="*60)
        print("üîç [Step 1] Using base model to detect all 'pig' bounding boxes...")
        print("="*60)
        
        # üî• Disable LoRA, use base model for detection
        with model.disable_adapter():
            od_result = infer_single_image(
                model, processor, img_path, 
                task="<CAPTION_TO_PHRASE_GROUNDING>", 
                text_input="pig", 
                device=device
            )
            print(f"üì¶ Detection result: {od_result}")
            
            # Extract all bounding boxes
            pattern = r'(<loc_\d+>){4}'
            all_boxes = [m.group(0) for m in re.finditer(pattern, od_result)]
        
        print(f"‚úÖ Detected {len(all_boxes)} pigs")
        
        if len(all_boxes) == 0:
            print("‚ùå No pigs detected, exiting.")
            return
        
        print("\n" + "="*60)
        print("üîç [Step 2] Using LoRA model to segment each pig...")
        print("="*60)
        # LoRA automatically restored
        
        final_seg_str = ""
        
        for i, box in enumerate(all_boxes):
            print(f"\n  üëâ [{i+1}/{len(all_boxes)}] Processing box: {box}")
            
            # Segment current box
            seg_res = infer_single_image(
                model, processor, img_path, 
                task="<REGION_TO_SEGMENTATION>", 
                bbox_str=box, 
                device=device
            )
            
            print(f"  ‚úÖ Segmentation result length: {len(seg_res)} characters")
            print(f"  üìÑ First 100 characters: {seg_res[:100]}...")
            
            # Concatenate results (separate different pigs with <sep>)
            if i > 0:
                final_seg_str += "<sep>"
            final_seg_str += seg_res
        
        print(f"\n‚úÖ All segmentation completed! Total length: {len(final_seg_str)} characters")
        
        # Visualization parameters
        kwargs = {
            'img': img_path,
            'resize': 0.5,           # Resize for easier viewing
            'show': False,
            'FLlabel': final_seg_str  # ‚ö†Ô∏è Note: Segmentation results use FLlabel parameter
        }
    
    # --- Branch 2: Direct tasks (e.g., OD, KEYPOINT, etc.) ---
    elif current_task == "<OD>":
        print(f"\nüîç Executing task: {current_task}")
        res = infer_single_image(model, processor, img_path, current_task, device=device)
        print(f"\n‚ú® Raw output: {res}")
        
        kwargs = {
            'img': img_path,
            'resize': 0.5,
            'show': False,
            'FLbbox': res
        }
    
    elif current_task == "<POINT>":
        print(f"\nüîç Executing task: {current_task}")
        res = infer_single_image(model, processor, img_path, current_task, device=device)
        print(f"\n‚ú® Raw output: {res}")
        
        kwargs = {
            'img': img_path,
            'resize': 0.5,
            'show': False,
            'FLpoint': res
        }
    
    elif current_task == "<KEYPOINT>":
        print(f"\nüîç Executing task: {current_task}")
        
        # Keypoint task also requires two steps
        print("\nüîç [Step 1] Detecting all pigs...")
        with model.disable_adapter():
            od_result = infer_single_image(
                model, processor, img_path, 
                task="<CAPTION_TO_PHRASE_GROUNDING>", 
                text_input="pig", 
                device=device
            )
            pattern = r'(<loc_\d+>){4}'
            all_boxes = [m.group(0) for m in re.finditer(pattern, od_result)]
        
        print(f"‚úÖ Detected {len(all_boxes)} pigs")
        
        if len(all_boxes) == 0:
            print("‚ùå No pigs detected")
            return
        
        print("\nüîç [Step 2] Detecting keypoints...")
        final_kp_str = ""
        
        for i, box in enumerate(all_boxes):
            print(f"  üëâ [{i+1}/{len(all_boxes)}] Processing box: {box}")
            kp_res = infer_single_image(
                model, processor, img_path, 
                task="<KEYPOINT>", 
                bbox_str=box, 
                device=device
            )
            final_kp_str += kp_res
        
        kwargs = {
            'img': img_path,
            'resize': 0.5,
            'show': False,
            'FLkeypoint': final_kp_str
        }
    
    else:
        print(f"‚ùå Unsupported task type: {current_task}")
        return
    
    # =====================================================
    # üé® Unified Visualization
    # =====================================================
    print("\n" + "="*60)
    print("üé® Calling vis_FLlabels for visualization...")
    print("="*60)
    
    try:
        result_bgr = vis_FLlabels(**kwargs)
        
        if result_bgr is not None:
            result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
            result_pil = Image.fromarray(result_rgb)
            display(result_pil)
            
            # Save result
            save_name = f"result_{current_task.strip('<>').lower()}.png"
            result_pil.save(save_name)
            print(f"üíæ Result saved: {save_name}")
        else:
            print("‚ùå Visualization result is empty (None)")
            
    except Exception as e:
        print(f"‚ùå Error during visualization: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
