In [1]:
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
from PIL import Image
import requests
import math
import os

# --- 제공해주신 전처리 함수들 (수정 없이 사용) ---
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image_pixels(image_file, input_size=448, max_num=6):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# --- 실험 설정 ---
if __name__ == '__main__':
    if not torch.cuda.is_available():
        print("CUDA is not available. This script requires a GPU.")
        exit()

    # 모델 로드
    model_path = "OpenGVLab/InternVL3-2B"
    print("Loading model...")
    model = AutoModel.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    ).eval().cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    print("Model loaded.")

    # 테스트용 이미지 다운로드
    image_url = "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg"
    image_path = "test_image.jpg"
    if not os.path.exists(image_path):
        print(f"Downloading test image from {image_url}...")
        response = requests.get(image_url, stream=True)
        response.raise_for_status()
        with open(image_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

    # 이미지 전처리 (한 번만 수행)
    single_image_pixels = load_image_pixels(image_path, max_num=6).to(torch.bfloat16).cuda()

    # 실험 파라미터
    question = '<image>\nDescribe the image in detail.'
    generation_config = dict(max_new_tokens=50, do_sample=False)
    max_batch_size = 10

    print("\n--- Starting GPU Memory Measurement (Transformers) ---")
    
    # 모델 로드 후 기본 메모리 사용량 측정
    torch.cuda.empty_cache()
    base_memory = torch.cuda.memory_allocated() / 1024**2
    print(f"Base model memory usage: {base_memory:.2f} MB")

    for batch_size in range(1, max_batch_size + 1):
        # 배치 데이터 준비
        pixel_values = torch.cat([single_image_pixels] * batch_size, dim=0)
        num_patches_list = [single_image_pixels.size(0)] * batch_size
        questions = [question] * batch_size
        
        # 메모리 측정 시작
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        # 추론 실행
        with torch.no_grad():
            responses = model.batch_chat(
                tokenizer,
                pixel_values=pixel_values,
                num_patches_list=num_patches_list,
                questions=questions,
                generation_config=generation_config
            )
        
        # 피크 메모리 측정 (MB 단위)
        peak_memory = torch.cuda.max_memory_allocated() / 1024**2
        
        print(f"Batch Size: {batch_size:2d} -> Peak GPU Memory: {peak_memory:.2f} MB")
        
        # 다음 배치를 위해 변수 삭제 및 캐시 정리
        del pixel_values, num_patches_list, questions, responses
        torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


Loading model...


A new version of the following files was downloaded from https://huggingface.co/OpenGVLab/InternVL3-2B:
- configuration_intern_vit.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/OpenGVLab/InternVL3-2B:
- configuration_internvl_chat.py
- configuration_intern_vit.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/OpenGVLab/InternVL3-2B:
- conversation.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/OpenGVLab/InternVL3-2B:
- modeling_intern_vit.p

Model loaded.
Downloading test image from https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg...

--- Starting GPU Memory Measurement (Transformers) ---
Base model memory usage: 4441.27 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  1 -> Peak GPU Memory: 6150.02 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  2 -> Peak GPU Memory: 7832.28 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  3 -> Peak GPU Memory: 9512.52 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  4 -> Peak GPU Memory: 11190.04 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  5 -> Peak GPU Memory: 12857.67 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  6 -> Peak GPU Memory: 14534.15 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  7 -> Peak GPU Memory: 16210.63 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size:  8 -> Peak GPU Memory: 17887.47 MB
Batch Size:  9 -> Peak GPU Memory: 19563.96 MB


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Batch Size: 10 -> Peak GPU Memory: 21240.70 MB


In [2]:
import json
import re

# 이전 버전 (v2) - 비교를 위해 포함
def parse_prediction_v2(pred_str: str) -> str:
    if not isinstance(pred_str, str): return 'parsing_failed'
    try:
        clean_str = pred_str
        if '```json' in clean_str: clean_str = clean_str.split('```json')[1].split('```')[0]
        elif '```' in clean_str: clean_str = clean_str.split('```')[1].split('```')[0]
        clean_str = clean_str.strip()
        start_brace, end_brace = clean_str.find('{'), clean_str.rfind('}')
        if start_brace != -1 and end_brace != -1 and start_brace < end_brace:
            json_part = clean_str[start_brace : end_brace + 1]
            try:
                data = json.loads(json_part)
                category = data.get('category')
                if category in ['violence', 'normal']: return category
            except json.JSONDecodeError: pass
        cat_match = re.search(r'["\']category["\']\s*:\s*["\'](violence|normal)["\']', clean_str)
        if cat_match: return cat_match.group(1)
        return 'no_json_found'
    except Exception: return 'parsing_failed'

# 대소문자 무시 버전 (v3)
def parse_prediction_v3(pred_str: str) -> str:
    if not isinstance(pred_str, str): return 'parsing_failed'
    try:
        clean_str = pred_str
        if '```json' in clean_str: clean_str = clean_str.split('```json')[1].split('```')[0]
        elif '```' in clean_str: clean_str = clean_str.split('```')[1].split('```')[0]
        clean_str = clean_str.strip()
        start_brace, end_brace = clean_str.find('{'), clean_str.rfind('}')
        if start_brace != -1 and end_brace != -1 and start_brace < end_brace:
            json_part = clean_str[start_brace : end_brace + 1]
            try:
                data = json.loads(json_part)
                category = data.get('category')
                if isinstance(category, str) and category.lower() in ['violence', 'normal']:
                    return category.lower()
            except json.JSONDecodeError: pass
        cat_match = re.search(r'["\']category["\']\s*:\s*["\'](violence|normal)["\']', clean_str, re.IGNORECASE)
        if cat_match: return cat_match.group(1).lower()
        return 'no_json_found'
    except Exception: return 'parsing_failed'


# ----------------- 테스트 케이스 정의 (대소문자 케이스 추가) -----------------
test_cases = [
    ("정상 (소문자)", '{"category": "normal", "description": "..."}'),
    ("정상 (소문자)", '{"category": "violence", "description": "..."}'),
    ("대소문자: Title Case", '{"category": "Violence", "description": "..."}'),
    ("대소문자: ALL CAPS", '{"category": "NORMAL", "description": "..."}'),
    ("대소문자: MiXeD CaSe", '{"category": "vIoLeNcE", "description": "..."}'),
    ("대소문자: Regex Fallback", "'category': 'nOrMaL', 'description': '...'"),
    ("마크다운 포함", '```json\n{"category": "normal", "description": "..."}\n```'),
    ("앞뒤 텍스트 포함", 'Answer: {"category": "violence", "description": "..."}'),
    ("JSON 형식 깨짐", '{"category": "normal" "description": "..."}'),
    ("category 키 없음", '{"action": "running", "description": "..."}'),
    ("완전한 쓰레기값", "A cat playing with a ball."),
    ("빈 문자열", ""),
]

# ----------------- 테스트 실행 및 결과 비교 -----------------
print(f"{'테스트 설명':<25} | {'V2 결과 (기존)':<15} | {'V3 결과 (수정)':<15}")
print("-" * 65)

for desc, case in test_cases:
    res2 = parse_prediction_v2(case)
    res3 = parse_prediction_v3(case)
    print(f"{desc:<25} | {res2:<15} | {res3:<15}")

테스트 설명                    | V2 결과 (기존)      | V3 결과 (수정)     
-----------------------------------------------------------------
정상 (소문자)                  | normal          | normal         
정상 (소문자)                  | violence        | violence       
대소문자: Title Case          | no_json_found   | violence       
대소문자: ALL CAPS            | no_json_found   | normal         
대소문자: MiXeD CaSe          | no_json_found   | violence       
대소문자: Regex Fallback      | no_json_found   | normal         
마크다운 포함                   | normal          | normal         
앞뒤 텍스트 포함                 | violence        | violence       
JSON 형식 깨짐                | normal          | normal         
category 키 없음             | no_json_found   | no_json_found  
완전한 쓰레기값                  | no_json_found   | no_json_found  
빈 문자열                     | no_json_found   | no_json_found  


In [None]:
import os
import torch
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from decord import VideoReader, cpu
import numpy as np
import json
import time

# 공통 설정
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size: int = 448):
    return T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        tgt_ar = ratio[0] / ratio[1]
        diff = abs(aspect_ratio - tgt_ar)
        if diff < best_ratio_diff or (diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1]):
            best_ratio_diff = diff
            best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    ow, oh = image.size
    aspect_ratio = ow / oh
    target_ratios = sorted(
        {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num},
        key=lambda x: x[0] * x[1],
    )
    ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size)
    tw, th = image_size * ratio[0], image_size * ratio[1]
    blocks = ratio[0] * ratio[1]
    resized = image.resize((tw, th))
    tiles = [
        resized.crop((
            (idx % (tw // image_size)) * image_size,
            (idx // (tw // image_size)) * image_size,
            ((idx % (tw // image_size)) + 1) * image_size,
            ((idx // (tw // image_size)) + 1) * image_size,
        ))
        for idx in range(blocks)
    ]
    if use_thumbnail and blocks != 1:
        tiles.append(image.resize((image_size, image_size)))
    return tiles

# V1 프레임 샘플링 방식 (첫 번째 코드)
def get_indices_by_frame_range_v1(start_idx: int, end_idx: int, num_segments: int) -> np.ndarray:
    start = int(start_idx)
    end = int(end_idx)
    if end < start:
        end = start
    length = end - start + 1
    num = max(1, min(num_segments, length))
    step = length / float(num)
    idxs = [start + int(step * i + step / 2) for i in range(num)]
    idxs = [min(max(start, x), end) for x in idxs]
    return np.array(idxs, dtype=int)

# V2 프레임 샘플링 방식 (두 번째 코드)
def get_index_v2(bound, fps, max_frame, first_idx=0, num_segments=32):
    if bound:
        start, end = bound[0], bound[1]
    else:
        start, end = -100000, 100000
    start_idx = max(first_idx, start)
    end_idx = min(end, max_frame)
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([
        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
        for idx in range(num_segments)
    ])
    return frame_indices

def load_video_v1(video_path: str, start_frame: int, end_frame: int, 
                  input_size: int = 448, max_num: int = 1, num_segments: int = 12):
    """V1 방식으로 비디오 로드"""
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    max_frame_idx = len(vr) - 1
    s = max(0, min(start_frame, max_frame_idx))
    e = max(0, min(end_frame, max_frame_idx))
    indices = get_indices_by_frame_range_v1(s, e, num_segments=num_segments)
    
    pixel_values_list, num_patches_list = [], []
    transform = build_transform(input_size=input_size)
    for frame_index in indices:
        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
        tiles = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(tile) for tile in tiles]
        pixel_values = torch.stack(pixel_values)
        num_patches_list.append(pixel_values.shape[0])
        pixel_values_list.append(pixel_values)
    pixel_values = torch.cat(pixel_values_list)
    return pixel_values, num_patches_list, indices

def load_video_v2(video_path: str, bound=None, input_size=448, max_num=1, num_segments=32):
    """V2 방식으로 비디오 로드"""
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    max_frame = len(vr) - 1
    fps = float(vr.get_avg_fps())
    
    pixel_values_list, num_patches_list = [], []
    transform = build_transform(input_size=input_size)
    frame_indices = get_index_v2(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
    
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
        img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(tile) for tile in img]
        pixel_values = torch.stack(pixel_values)
        num_patches_list.append(pixel_values.shape[0])
        pixel_values_list.append(pixel_values)
    pixel_values = torch.cat(pixel_values_list)
    return pixel_values, num_patches_list, frame_indices

class InternVL3Inferencer:
    def __init__(self, model_path="OpenGVLab/InternVL3-2B", device="cuda:0"):
        print(f"[INFO] InternVL 모델 로딩 중... device={device}")
        self.model = AutoModel.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=False,
            trust_remote_code=True
        ).eval().to(device)
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
        self.device = device
        self.generation_config = dict(max_new_tokens=1024, do_sample=False)
        print(f"[INFO] InternVL 모델 로딩 완료.")

    def infer_v1(self, video_path: str, prompt: str, start_frame: int, end_frame: int, num_segments: int = 12):
        """V1 방식 추론"""
        pixel_values, num_patches_list, indices = load_video_v1(
            video_path, start_frame, end_frame, input_size=448, max_num=1, num_segments=num_segments
        )
        pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
        video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
        question = video_prefix + prompt
        response = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config)
        return response, indices

    def infer_v2(self, video_path: str, template: str, num_segments: int = 12, bound=None):
        """V2 방식 추론"""
        pixel_values, num_patches_list, indices = load_video_v2(
            video_path, bound=bound, num_segments=num_segments, max_num=1
        )
        pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
        video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
        question = video_prefix + template
        response = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config)
        return response, indices

def compare_inference_results(video_path: str, model_path: str = "OpenGVLab/InternVL3-2B"):
    """두 버전의 추론 결과를 비교"""
    
    # 테스트 설정
    START_FRAME = 1224
    END_FRAME = 1235
    NUM_SEGMENTS = 12
    PROMPT    = """
    Watch this short video clip and respond with exactly one JSON object.\n\n[Rules]\n- The category must be either 'violence' or 'normal'.  \n- Classify as violence if any of the following actions are present:  \n  * Punching  \n  * Kicking  \n  * Weapon Threat\n  * Weapon Attack\n  * Falling/Takedown  \n  * Pushing/Shoving  \n  * Brawling/Group Fight  \n- If none of the above are observed, classify as normal.  \n- The following cases must always be classified as normal:  \n  * Affection (hugging, holding hands, light touches)  \n  * Helping (supporting, assisting)  \n  * Accidental (unintentional bumping)  \n  * Playful (non-aggressive playful contact)  \n\n[Output Format]\n- Output exactly one JSON object.  \n- The object must contain only two keys: \"category\" and \"description\".  \n- The description should briefly and objectively describe the scene.  \n\nExample (violence):  \n{\"category\":\"violence\",\"description\":\"A man in a black jacket punches another man, who stumbles backward.\"}\n\nExample (normal):  \n{\"category\":\"normal\",\"description\":\"Two people are hugging inside an elevator
    """  
    PROMPT = """
    Watch this short video clip (1–2 seconds) and respond with exactly one JSON object.\n\n[Rules]\n- The category must be either 'violence' or 'normal'.  \n- Classify as violence if any of the following actions are present:  \n  * Punching  \n  * Kicking  \n  * Weapon Threat  \n  * Falling/Takedown  \n  * Pushing/Shoving  \n  * Brawling/Group Fight  \n- If none of the above are observed, classify as normal.  \n- The following cases must always be classified as normal:  \n  * Affection (hugging, holding hands, light touches)  \n  * Helping (supporting, assisting)  \n  * Accidental (unintentional bumping)  \n  * Playful (non-aggressive playful contact)  \n  * Sports (contact within sports rules)  \n\n[Output Format]\n- Output exactly one JSON object.  \n- The object must contain only two keys: \"category\" and \"description\".  \n- The description should briefly and objectively describe the scene.  \n\nExample (violence):  \n{\"category\":\"violence\",\"description\":\"A man in a black jacket punches another man, who stumbles backward.\"}\n\nExample (normal):  \n{\"category\":\"normal\",\"description\":\"Two people are hugging inside an elevator.
    """
    PROMPT    = """
    Watch this short video clip and respond with exactly one JSON object.\n\n[Rules]\n- The category must be either 'violence' or 'normal'.  \n- Classify as violence if any of the following actions are present:  \n  * Punching  \n  * Kicking  \n  * Weapon Threat\n  * Weapon Attack\n  * Falling/Takedown  \n  * Pushing/Shoving  \n  * Brawling/Group Fight  \n- If none of the above are observed, classify as normal.  \n- The following cases must always be classified as normal:  \n  * Affection (hugging, holding hands, light touches)  \n  * Helping (supporting, assisting)  \n  * Accidental (unintentional bumping)  \n  * Playful (non-aggressive playful contact)  \n\n[Output Format]\n- Output exactly one JSON object.  \n- The object must contain only two keys: \"category\" and \"description\".  \n- The description should briefly and objectively describe the scene.  \n\nExample (violence):  \n{\"category\":\"violence\",\"description\":\"A man in a black jacket punches another man, who stumbles backward.\"}\n\nExample (normal):  \n{\"category\":\"normal\",\"description\":\"Two people are hugging inside an elevator"}
    """  
#     PROMPT = """Watch this short video clip and respond with exactly one JSON object.

# [Rules]
# - The category must be either 'violence' or 'normal'.
# - Classify as violence if any of the following actions are present:
#   * Punching
#   * Kicking
#   * Weapon Threat
#   * Weapon Attack
#   * Falling/Takedown
#   * Pushing/Shoving
#   * Brawling/Group Fight
# - If none of the above are observed, classify as normal.
# - The following cases must always be classified as normal:
#   * Affection (hugging, holding hands, light touches)
#   * Helping (supporting, assisting)
#   * Accidental (unintentional bumping)
#   * Playful (non-aggressive playful contact)

# [Output Format]
# - Output exactly one JSON object.
# - The object must contain only two keys: "category" and "description".
# - The description should briefly and objectively describe the scene.

# Example (violence):
# {"category":"violence","description":"A man in a black jacket punches another man, who stumbles backward."}

# Example (normal):
# {"category":"normal","description":"Two people are hugging inside an elevator"}
# """

    print("="*80)
    print(f"비디오 추론 결과 비교 테스트")
    print(f"비디오: {video_path}")
    print(f"프레임 구간: {START_FRAME} - {END_FRAME}")
    print(f"NUM_SEGMENTS: {NUM_SEGMENTS}")
    print("="*80)
    
    # 모델 로드
    inferencer = InternVL3Inferencer(model_path)
    
    # V1 방식 테스트
    print("\n[V1 방식 테스트]")
    start_time = time.time()
    try:
        result_v1, indices_v1 = inferencer.infer_v1(
            video_path, PROMPT, START_FRAME, END_FRAME, NUM_SEGMENTS
        )
        v1_time = time.time() - start_time
        print(f"선택된 프레임 인덱스: {indices_v1.tolist()}")
        print(f"추론 시간: {v1_time:.2f}초")
        print(f"결과: {result_v1}")
    except Exception as e:
        print(f"V1 방식 에러: {e}")
        result_v1, indices_v1 = None, None
    
    print("\n" + "-"*60)
    
    # V2 방식 테스트
    print("\n[V2 방식 테스트]")
    bound = [START_FRAME, END_FRAME]
    start_time = time.time()
    try:
        result_v2, indices_v2 = inferencer.infer_v2(
            video_path, PROMPT, NUM_SEGMENTS, bound
        )
        v2_time = time.time() - start_time
        print(f"선택된 프레임 인덱스: {indices_v2.tolist()}")
        print(f"추론 시간: {v2_time:.2f}초")
        print(f"결과: {result_v2}")
    except Exception as e:
        print(f"V2 방식 에러: {e}")
        result_v2, indices_v2 = None, None
    
    print("\n" + "="*80)
    
    # 비교 분석
    print("\n[비교 분석]")
    
    if indices_v1 is not None and indices_v2 is not None:
        print(f"프레임 인덱스 동일성: {np.array_equal(indices_v1, indices_v2)}")
        if not np.array_equal(indices_v1, indices_v2):
            print(f"V1 인덱스: {indices_v1.tolist()}")
            print(f"V2 인덱스: {indices_v2.tolist()}")
            print(f"차이점: {set(indices_v1) - set(indices_v2)} (V1에만 있음)")
            print(f"차이점: {set(indices_v2) - set(indices_v1)} (V2에만 있음)")
    
    if result_v1 is not None and result_v2 is not None:
        print(f"\n결과 동일성: {result_v1 == result_v2}")
        if result_v1 != result_v2:
            print(f"\nV1 결과:\n{result_v1}")
            print(f"\nV2 결과:\n{result_v2}")
            
            # JSON 파싱해서 category 비교
            try:
                import re
                def extract_category(text):
                    match = re.search(r'"category"\s*:\s*"([^"]*)"', text)
                    return match.group(1) if match else "파싱 실패"
                
                cat_v1 = extract_category(result_v1)
                cat_v2 = extract_category(result_v2)
                print(f"\n분류 결과 - V1: {cat_v1}, V2: {cat_v2}")
                print(f"분류 일치: {cat_v1 == cat_v2}")
            except:
                print("분류 결과 파싱 실패")

if __name__ == "__main__":
    # 사용 예시
    VIDEO_PATH = "sample/fight_0162.mp4"  # 실제 경로로 변경
    MODEL_PATH = "ckpts/PIA_Violence"  # 또는 사용하는 모델 경로
    
    if not os.path.exists(VIDEO_PATH):
        print(f"비디오 파일을 찾을 수 없습니다: {VIDEO_PATH}")
        print("VIDEO_PATH를 실제 파일 경로로 수정하세요.")
    else:
        compare_inference_results(VIDEO_PATH, MODEL_PATH)