In [1]:
import torch
import os
import json
import nibabel as nib
import numpy as np
from tqdm import tqdm
from PIL import Image
from experts.expert_monai_brats import ExpertBrats
from experts.expert_monai_vista3d import ExpertVista3D
from experts.expert_torchxrayvision import ExpertTXRV
from llava.model.builder import load_pretrained_model
from llava.mm_utils import KeywordsStoppingCriteria, process_images, tokenizer_image_token
from llava.conversation import conv_templates, SeparatorStyle

`torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
`torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.


[2025-05-20 22:22:14,311] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: /home/hufsaim/cuda-12.8/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: /home/hufsaim/cuda-12.8/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: /home/hufsaim/cuda-12.8/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: /home/hufsaim/cuda-12.8/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/home/hufsaim/anaconda3/envs/vila/compiler_compat/ld: /home/hufsaim/cuda-12.8/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, cha

In [2]:
MODEL_PATH="MONAI/Llama3-VILA-M3-8B"
OUTPUT_PATH = "/home/hufsaim/VLM/VLM/m3/demo/result_new"
JSON_FILE = "/home/hufsaim/VLM/VLM/m3/demo/result_new/dataset_modalitycheck.json"
SLICE_SAVE_PATH = "/home/hufsaim/VLM/VLM/m3/demo/sliced_images"

MODEL_CARDS = """Here is a list of available expert models:\n
<BRATS(args)> 
Modality: MRI, 
Task: segmentation, 
Overview: A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data, 
Accuracy: Tumor core (TC): 0.8559 - Whole tumor (WT): 0.9026 - Enhancing tumor (ET): 0.7905 - Average: 0.8518, 
Valid args are: None\n
<VISTA3D(args)> 
Modality: CT, 
Task: segmentation, 
Overview: domain-specialized interactive foundation model developed for segmenting and annotating human anatomies with precision, 
Accuracy: 127 organs: 0.792 Dice on average, 
Valid args are: 'everything', 'hepatic tumor', 'pancreatic tumor', 'lung tumor', 'bone lesion', 'organs', 'cardiovascular', 'gastrointestinal', 'skeleton', or 'muscles'\n
<VISTA2D(args)> 
Modality: cell imaging, 
Task: segmentation, 
Overview: model for cell segmentation, which was trained on a variety of cell imaging outputs, including brightfield, phase-contrast, fluorescence, confocal, or electron microscopy, 
Accuracy: Good accuracy across several cell imaging datasets, 
Valid args are: None\n
<CXR(args)> 
Modality: chest x-ray (CXR), 
Task: classification, 
Overview: pre-trained model which are trained on large cohorts of data, 
Accuracy: Good accuracy across several diverse chest x-rays datasets, 
Valid args are: None\n
<HD-Glio(args)>
Modality: MRI, 
Task: segmentation, 
Overview: A deep learning-based model designed for high-grade glioma segmentation in brain MR images. HD-Glio leverages ensemble 3D U-Net architectures and robust preprocessing including brain extraction, intensity normalization, and co-registration. It is tailored to identify and delineate tumor subregions (enhancing tumor, tumor core, and whole tumor) with high accuracy.
Accuracy: Tumor core (TC): 0.860 - Whole tumor (WT): 0.910 - Enhancing tumor (ET): 0.800 - Average: 0.857
Valid args are: None\n
Give the model <NAME(args)> when selecting a suitable expert model.\n"""

In [3]:
def load_nifti_image(nifti_path, slice_axis=2, sample_id="default"):
    try:
        nifti_img = nib.load(nifti_path)
        volume = nifti_img.get_fdata()
    except Exception as e:
        raise RuntimeError(f"Failed to load NIfTI file: {nifti_path}, error: {e}")

    slice_idx = volume.shape[slice_axis] // 2

    if slice_axis == 0:
        slice_img = volume[slice_idx, :, :]
    elif slice_axis == 1:
        slice_img = volume[:, slice_idx, :]
    else:
        slice_img = volume[:, :, slice_idx]

    slice_norm = (slice_img - np.min(slice_img)) / (np.max(slice_img) - np.min(slice_img) + 1e-8)
    slice_image = Image.fromarray(np.uint8(slice_norm * 255)).convert('RGB')

    os.makedirs(SLICE_SAVE_PATH, exist_ok=True)
    save_path = os.path.join(SLICE_SAVE_PATH, f"{sample_id}.png")
    slice_image.save(save_path)

    return slice_image

In [4]:
def save_result_to_json(result, output_dir, base_filename="inference_result"):
    os.makedirs(output_dir, exist_ok=True)

    filepath = os.path.join(output_dir, f"{base_filename}.json")
    counter = 1
    while os.path.exists(filepath):
        filepath = os.path.join(output_dir, f"{base_filename}_{counter}.json")
        counter += 1

    formatted_results = []
    for sample_id, info_list in result.items():
        combined = {"id": sample_id}
        for entry in info_list:
            combined.update(entry)
        formatted_results.append(combined)

    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(formatted_results, f, ensure_ascii=False, indent=4)


def load_json_data(json_path):
    with open(json_path, "r") as f:
        return json.load(f)

In [5]:
class M3Inference:
    def __init__(self, model_path=MODEL_PATH, conv_mode="llama_3"):
        model_name = model_path.split("/")[-1]
        self.tokenizer, self.model, self.image_processor, _ = load_pretrained_model(model_path, model_name, device="cuda:0")
        self.conv_mode = conv_mode

    def inference(self, image_path, prompt, sample_id="default"):
        answer = []
        conv = conv_templates[self.conv_mode].copy()
        answer.append({"USER": prompt})
        answer.append({"image path": image_path})

        if image_path.endswith(('.nii', '.nii.gz')):
            image = load_nifti_image(image_path, sample_id=sample_id)
        else:
            image = Image.open(image_path).convert('RGB')

        images_tensor = process_images([image], self.image_processor, self.model.config).to(self.model.device, dtype=torch.float16)
        full_prompt = f"{MODEL_CARDS}\n<image>\n{prompt}"

        media_input = {"image": [img for img in images_tensor]}
        media_config = {"image": {}}

        conv.append_message(conv.roles[0], full_prompt)
        conv.append_message(conv.roles[1], "")
        prompt_text = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt_text, self.tokenizer, return_tensors="pt").unsqueeze(0).to(self.model.device)

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                media=media_input,
                media_config=media_config,
                max_new_tokens=1024,
                temperature=0.0,
                top_p=0.9,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                pad_token_id=self.tokenizer.eos_token_id,
            )

        output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        if output.endswith(stop_str):
            output = output[:-len(stop_str)].strip()

        answer.append({"VILA-M3": output})
        return answer

In [6]:
def run_batch_inference(pairs):
    inference_result = {}
    inference_model = M3Inference()
    for pair in tqdm(pairs, desc="Processing Cases"):
        sample_id = pair["id"]
        image_path = pair["image_path"]
        question = pair["question"]

        if isinstance(image_path, list):
            image_path = image_path[0]

        inference_result[sample_id] = inference_model.inference(
            image_path=image_path,
            prompt=question,
            sample_id=sample_id
        )

    return inference_result

In [7]:
DATA_ROOT_DIR = "/home/hufsaim/VLM/data/for_test"

if __name__ == "__main__":
    if JSON_FILE:
        data = load_json_data(JSON_FILE)
        pairs = []

        for sample in data:
            sample_id = sample["id"]
            image_path = sample["image_path"]
            question = sample["question"]

            if isinstance(image_path, list):
                image_path = [os.path.join(DATA_ROOT_DIR, p) if not os.path.isabs(p) else p for p in image_path]
            elif isinstance(image_path, str):
                if not os.path.isabs(image_path):
                    image_path = os.path.join(DATA_ROOT_DIR, image_path)
            else:
                raise TypeError(f"Unexpected image_path type: {type(image_path)} (id: {sample_id})")

            pairs.append({"id": sample_id, "image_path": image_path, "question": question})

        results = run_batch_inference(pairs=pairs)
        save_result_to_json(results, OUTPUT_PATH)
    else:
        raise ValueError("JSON_FILE path must be specified.")

Fetching 21 files:   0%|          | 0/21 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Processing Cases:   0%|          | 0/48 [00:00<?, ?it/s]