https://github.com/microsoft/LLaVA-Med \
https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b

In [1]:
import os
import sys

In [2]:
# Append the LLaVA directory to the system path
# This allows us to import modules from the LLaVA repository
sys.path.append(os.path.join(os.getcwd(), "LLaVA-Med"))

In [3]:
# Set CUDA_VISIBLE_DEVICES to expose only device 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Set CUDA_VISIBLE_DEVICES to expose devices 0, 1, 2, and 3
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [4]:
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images

from PIL import Image
import math
from transformers import set_seed, logging
import json
import torch

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [5]:
def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

In [6]:
set_seed(0)
disable_torch_init()

In [7]:
model_path = "microsoft/llava-med-v1.5-mistral-7b"
model_path = os.path.expanduser(model_path)

In [8]:
model_base=None
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s]
Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.17.self_attn.v_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.16.self_attn.out_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.10.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.3.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.10.self_attn.k_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.8.self_attn.out_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.21.self_attn.q_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.21.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.3.self_attn.k_proj.bias', 'model.vision_tower.vis

In [10]:
question_file = "llava_med_questions.jsonl"
questions = [json.loads(q) for q in open(os.path.expanduser(question_file), "r")]

In [11]:
num_chunks=1
chunk_idx=0
questions = get_chunk(questions, num_chunks, chunk_idx)

In [12]:
answers_file = "./llava_med_answers.jsonl"
answers_file = os.path.expanduser(answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)

In [13]:
ans_file = open(answers_file, "w")

In [14]:
line = questions[0]

In [15]:
idx = line["question_id"]
image_file = line["image"]
qs = line["text"]
cur_prompt = qs

In [16]:
if model.config.mm_use_im_start_end:
    qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
    qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

In [17]:
qs

'<image>\nExplain the image in detail'

In [18]:
conv_mode = "mistral_instruct"
conv = conv_templates[conv_mode].copy()

In [19]:
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)

In [20]:
conv

Conversation(system='', roles=('USER', 'ASSISTANT'), messages=[['USER', '<image>\nExplain the image in detail'], ['ASSISTANT', None]], offset=0, sep_style=<SeparatorStyle.LLAMA_2: 5>, sep='', sep2='</s>', version='llama_v2', skip_next=False)

In [21]:
prompt = conv.get_prompt()
prompt

'[INST] <image>\nExplain the image in detail [/INST]'

In [22]:
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

In [23]:
input_ids

tensor([[    1,   733, 16289, 28793, 28705,  -200, 28705,    13,   966, 19457,
           272,  3469,   297,  8291,   733, 28748, 16289, 28793]],
       device='cuda:0')

In [24]:
image_folder="./pathology_images"
image = Image.open(os.path.join(image_folder, image_file))

In [25]:
image_tensor = process_images([image], image_processor, model.config)[0]

In [26]:
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

In [27]:
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor.unsqueeze(0).half().cuda(),
        do_sample=True if 0.2 > 0 else False,
        temperature=0.2,
        top_p=None,
        num_beams=1,
        # no_repeat_ngram_size=3,
        max_new_tokens=1024,
        use_cache=True)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [28]:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

In [29]:
outputs

'The image is a histopathological examination of a tissue sample, which is stained with hematoxylin and eosin (H&E). Histopathology is the study of diseased tissue under a microscope, and it helps in the diagnosis of various diseases, including cancer. In this case, the image is showing the presence of a tumor, which is an abnormal growth of cells that can invade surrounding tissues and potentially spread to other parts of the body.'

In [None]:
!export PYTHONPATH=/data/mn27889/path-open-data/LLaVA-Med:$PYTHONPATH

In [None]:
!CUDA_VISIBLE_DEVICES=0 python llava/eval/model_vqa.py \
--model-path microsoft/llava-med-v1.5-mistral-7b \
--conv-mode mistral_instruct \
--image-folder=/data/mn27889/path-open-data/pathology_images \
--question-file ./../llava_med_questions.jsonl \
--answers-file ./../llava_med_answers.jsonl