In [None]:
import torch
import os
import json
import nibabel as nib
import numpy as np
from tqdm import tqdm
from PIL import Image
from experts.expert_monai_brats import ExpertBrats
from experts.expert_monai_vista3d import ExpertVista3D
from experts.expert_torchxrayvision import ExpertTXRV
from llava.model.builder import load_pretrained_model
from llava.mm_utils import KeywordsStoppingCriteria, process_images, tokenizer_image_token
from llava.conversation import conv_templates, SeparatorStyle

In [2]:
MODEL_PATH="MONAI/Llama3-VILA-M3-8B"
OUTPUT_PATH = "/home/hufsaim/VLM/VLM/m3/demo/result_new"
JSON_FILE = "/home/hufsaim/VLM/VLM/m3/demo/result_new/dataset_betcheck.json"


MODEL_CARDS = """Here is a list of available expert models:\n
<BRATS(args)> 
Modality: MRI, 
Task: segmentation, 
Overview: A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data, 
Accuracy: Tumor core (TC): 0.8559 - Whole tumor (WT): 0.9026 - Enhancing tumor (ET): 0.7905 - Average: 0.8518, 
Valid args are: None\n
<VISTA3D(args)> 
Modality: CT, 
Task: segmentation, 
Overview: domain-specialized interactive foundation model developed for segmenting and annotating human anatomies with precision, 
Accuracy: 127 organs: 0.792 Dice on average, 
Valid args are: 'everything', 'hepatic tumor', 'pancreatic tumor', 'lung tumor', 'bone lesion', 'organs', 'cardiovascular', 'gastrointestinal', 'skeleton', or 'muscles'\n
<VISTA2D(args)> 
Modality: cell imaging, 
Task: segmentation, 
Overview: model for cell segmentation, which was trained on a variety of cell imaging outputs, including brightfield, phase-contrast, fluorescence, confocal, or electron microscopy, 
Accuracy: Good accuracy across several cell imaging datasets, 
Valid args are: None\n
<CXR(args)> 
Modality: chest x-ray (CXR), 
Task: classification, 
Overview: pre-trained model which are trained on large cohorts of data, 
Accuracy: Good accuracy across several diverse chest x-rays datasets, 
Valid args are: None\n
<HD-Glio(args)>
Modality: MRI, 
Task: segmentation, 
Overview: A deep learning-based model designed for high-grade glioma segmentation in brain MR images. HD-Glio leverages ensemble 3D U-Net architectures and robust preprocessing including brain extraction, intensity normalization, and co-registration. It is tailored to identify and delineate tumor subregions (enhancing tumor, tumor core, and whole tumor) with high accuracy.
Accuracy: Tumor core (TC): 0.860 - Whole tumor (WT): 0.910 - Enhancing tumor (ET): 0.800 - Average: 0.857
Valid args are: None\n
Give the model <NAME(args)> when selecting a suitable expert model.\n"""

In [3]:
def load_nifti_image(nifti_path, slice_axis=2, slice_idx=None):
    """
    Load NIfTI file and extract 2D slice image for VILA-M3 input.
    Automatically handles out-of-bounds slice indices.
    """
    try:
        nifti_img = nib.load(nifti_path)
        volume = nifti_img.get_fdata()
    except Exception as e:
        raise RuntimeError(f"Failed to load NIfTI file: {nifti_path}, error: {e}")

    max_idx = volume.shape[slice_axis] - 1

    # 슬라이스 인덱스 범위 검사 & 보정
    if slice_idx is None or not (0 <= slice_idx <= max_idx):
        print(f"[WARNING] slice_idx {slice_idx} is out of bounds (max: {max_idx}). Using middle slice instead.")
        slice_idx = max_idx // 2

    # 슬라이싱
    if slice_axis == 0:
        slice_img = volume[slice_idx, :, :]
    elif slice_axis == 1:
        slice_img = volume[:, slice_idx, :]
    else:
        slice_img = volume[:, :, slice_idx]

    # 정규화 및 RGB 변환
    slice_norm = (slice_img - np.min(slice_img)) / (np.max(slice_img) - np.min(slice_img) + 1e-8)
    slice_image = Image.fromarray(np.uint8(slice_norm * 255)).convert('RGB')

    return slice_image


In [4]:
def save_result_to_json(result, output_dir, base_filename="inference_result"):
    os.makedirs(output_dir, exist_ok=True)

    filename = f"{base_filename}.json"
    filepath = os.path.join(output_dir, filename)

    counter = 1
    while os.path.exists(filepath):
        filename = f"{base_filename}_{counter}.json"
        filepath = os.path.join(output_dir, filename)
        counter += 1

    formatted_results = []

    for sample_id, info_list in result.items():
        combined = {"id": sample_id}
        for entry in info_list:
            combined.update(entry)
        formatted_results.append(combined)

    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(formatted_results, f, ensure_ascii=False, indent=4)

def load_json_data(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    return data

In [6]:
class M3Inference:
    def __init__(self, model_path="MONAI/Llama3-VILA-M3-8B", conv_mode="llama_3"):
        model_name = model_path.split("/")[-1]
        self.tokenizer, self.model, self.image_processor, _ = load_pretrained_model(model_path, model_name, device="cuda:0")
        self.conv_mode = conv_mode

    def inference(self, image_path, prompt, slice_index=None, max_tokens=1024, temperature=0.0, top_p=0.9):
        answer = []
        conv = conv_templates[self.conv_mode].copy()

        answer.append({"USER": prompt})
        answer.append({"image path": image_path})

        if image_path.endswith(('.nii', '.nii.gz')):
            image = load_nifti_image(image_path, slice_axis=2, slice_idx=slice_index)
        else:
            image = Image.open(image_path).convert('RGB')

        images_tensor = process_images([image], self.image_processor, self.model.config).to(self.model.device, dtype=torch.float16)
        full_prompt = f"{MODEL_CARDS}\n<image>\n{prompt}"

        media_input = {"image": [img for img in images_tensor]}
        media_config = {"image": {}}

        conv.append_message(conv.roles[0], full_prompt)
        conv.append_message(conv.roles[1], "")

        prompt_text = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt_text, self.tokenizer, return_tensors="pt").unsqueeze(0).to(self.model.device)

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                media=media_input,
                media_config=media_config,
                do_sample=(temperature > 0),
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_tokens,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                pad_token_id=self.tokenizer.eos_token_id,
            )

        output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        if output.endswith(stop_str):
            output = output[:-len(stop_str)].strip()

        answer.append({"VILA-M3": output})

        expert_model = None
        for expert_cls in [ExpertBrats, ExpertVista3D, ExpertTXRV]:
            expert = expert_cls()
            if expert.mentioned_by(output):
                expert_model = expert
                break

        if expert_model:
            try:
                expert_img_file = [image_path]

                if slice_index is None and image_path.endswith(('.nii', '.nii.gz')):
                    try:
                        nib_img = nib.load(expert_img_file[0])
                        shape = nib_img.shape
                        if len(shape) == 3:
                            slice_index = shape[2] // 2
                            print(f"[DEBUG] Auto-calculated slice_index: {slice_index}")
                        else:
                            print(f"[WARNING] Unexpected shape {shape} for NIfTI file.")
                    except Exception as e:
                        print(f"[WARNING] Failed to load NIfTI file for slice index: {e}")
                        slice_index = 77  # fallback

                expert_response, expert_image_path, instruction = expert_model.run(
                    image_url=expert_img_file,
                    input=output,
                    output_dir="/home/hufsaim/VLM/VLM/m3/demo/expert_result",
                    img_file=expert_img_file,
                    slice_index=slice_index,
                    prompt=prompt,
                )

                answer.append({"Expert": expert_response})
                if expert_image_path:
                    answer.append({"Expert Image Path": expert_image_path})

                if instruction:
                    conv = conv_templates[self.conv_mode].copy()
                    image_tokens = "<image>"
                    updated_prompt = f"{expert_response}\n{instruction}\n{image_tokens}"
                    conv.append_message(conv.roles[0], updated_prompt)
                    conv.append_message(conv.roles[1], "")
                    updated_prompt_text = conv.get_prompt()

                    answer.append({"Expert": instruction})

                    input_ids = tokenizer_image_token(updated_prompt_text, self.tokenizer, return_tensors="pt").unsqueeze(0).to(self.model.device)

                    with torch.inference_mode():
                        updated_output_ids = self.model.generate(
                            input_ids,
                            media=media_input,
                            media_config=media_config,
                            do_sample=(temperature > 0),
                            temperature=temperature,
                            top_p=top_p,
                            max_new_tokens=max_tokens,
                            use_cache=True,
                            stopping_criteria=[stopping_criteria],
                            pad_token_id=self.tokenizer.eos_token_id,
                        )

                    output = self.tokenizer.batch_decode(updated_output_ids, skip_special_tokens=True)[0].strip()
                    if output.endswith(stop_str):
                        output = output[:-len(stop_str)].strip()

                    answer.append({"VILA-M3": output})

            except Exception as e:
                print(f"[ERROR] Expert model encountered an error: {e}")

        return answer


In [7]:
def run_batch_inference(pairs):
    """
    Run batch inference using VILA-M3 with JSON-based input pairs.
    If image_path is a list, only the first file will be used.

    Args:
        pairs (list of dict): List with keys 'id', 'image_path', 'question', optional 'slice_index'.
        model_path (str): Path to the pretrained model.

    Returns:
        dict: Inference results keyed by sample ID.
    """
    inference_result = {}
    inference_model = M3Inference(model_path=MODEL_PATH)  # VILA-M3 모델 로딩

    for pair in tqdm(pairs, desc="Processing Cases",leave=True):
        sample_id = pair["id"]  # 샘플 ID
        image_path = pair["image_path"]  # 이미지 경로 (list or str)
        question = pair["question"]  # 프롬프트
        slice_index = pair.get("slice_index", None)  # 선택적 slice index

        # image_path가 list면 첫 번째 파일만 사용
        if isinstance(image_path, list):
            image_path = image_path[0]

        # inference 호출 (slice_index 전달)
        inference_result[sample_id] = inference_model.inference(
            image_path=image_path,
            prompt=question,
            slice_index=slice_index
        )

    return inference_result

In [None]:
DATA_ROOT_DIR = "/home/hufsaim/VLM/data/for_test"

if __name__ == "__main__":
    results = []
    if JSON_FILE:
        data = load_json_data(JSON_FILE)  # JSON 로드
        pairs = []

        for sample in data:
            sample_id = sample["id"]
            image_path = sample["image_path"]
            question = sample["question"]
            slice_index = sample.get("slice_index", None)  # slice_index가 있으면 반영

            # image_path가 list인지 str인지에 따라 절대경로 변환
            if isinstance(image_path, list):
                image_path = [
                    os.path.join(DATA_ROOT_DIR, path) if not os.path.isabs(path) else path
                    for path in image_path
                ]
            elif isinstance(image_path, str):
                if not os.path.isabs(image_path):
                    image_path = os.path.join(DATA_ROOT_DIR, image_path)
            else:
                raise TypeError(f"Unexpected image_path type: {type(image_path)} (id: {sample_id})")

            pairs.append({
                "id": sample_id,
                "image_path": image_path,
                "question": question,
                "slice_index": slice_index  # 있으면 전달
            })

        # run_batch_inference는 이제 pairs만 받음
        results = run_batch_inference(pairs=pairs)

        # 결과 저장
        save_result_to_json(results, OUTPUT_PATH)
    else:
        raise ValueError("JSON_FILE path must be specified.")
