## Install

In [1]:
# !nvidia-smi

In [None]:
!pip install --upgrade transformers==4.37.2

In [None]:
# !pip install git+https://github.com/huggingface/transformers.git

In [None]:
# import transformers
# print(transformers.__version__)

## Download model

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="microsoft/llava-med-v1.5-mistral-7b",
                  cache_dir="",
                  resume_download=True,)

In [2]:
# !cd /content/models--liuhaotian--llava-v1.5-7b

## Initiate the model

In [3]:
# !pip install shortuuid

In [None]:
import torch
import os
import json
from tqdm import tqdm
# import shortuuid
import numpy as np

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import sys
sys.path.append('/content/drive/MyDrive/LLaVA-Med')

In [7]:
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images

In [8]:
from PIL import Image
import math
from transformers import set_seed, logging

In [9]:
logging.set_verbosity_error()

In [None]:
def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

In [11]:
set_seed(0)
# Model
disable_torch_init()
model_path = 'microsoft/llava-med-v1.5-mistral-7b'
model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)
model_base = None

In [None]:
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)

## Data Preparation

In [None]:
# Colab
dir = 'Data/'
json_path = dir + 'qa_dataset_test.json'
with open(json_path, "r") as f:
    qa_dataset_test = json.load(f)

In [14]:
def generate_text_image_pair(qa_dataset, outcome, image_dir):
    prompts = []
    image_pathes = []

    nfiles = len(qa_dataset)
    # Generate user prompts for all QA pairs
    with tqdm(total=nfiles) as pbar:
        for i in range(nfiles):
            qa_pair = qa_dataset[i]
            # Text
            user_text = (
                f"Based on the information collected during current ICU stay, \n{qa_pair['context']}\n<image>"
                f"{qa_pair['question'][outcome]}\n"
                f"Provide a confident and definitive answer."
                f"Answer the question using only yes or no without any additional explanation"
                )

            prompts.append(user_text)

            # Images
            if qa_pair['image'] != None:
              image_path = os.path.join(image_dir, os.path.basename(qa_pair['image']))
              image_pathes.append(image_path)
              # print(image_path)
              # image = Image.open(image_path)
              # image = image.resize((512, 512))
            else:
              # image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
              image_pathes.append(None)
            # # Prompt
            # prompt = [
            #     {
            #         "role": "user",
            #         "content": [
            #             {"type": "text", "text": user_text},
            #              {"type": "image","image": image}
            #             ]
            #         }
            #     ]
            # Append
            # prompts.append(prompt)

            # Update
            pbar.update(1)
    return prompts, image_pathes

In [15]:
image_dir = dir + 'jpg_test/'
outcome = 'sepsis3'

In [16]:
prompts, images = generate_text_image_pair(qa_dataset_test, outcome, image_dir)

100%|██████████| 14637/14637 [00:00<00:00, 550606.08it/s]


## Run

In [18]:
def run_llava_med(prompts,images,output_dir,outcome,start_idx=0,end_idx=None,
                  conv_mode="mistral_instruct",temperature=0,top_p = None,num_beams = 1):

  prompts_to_run = prompts[start_idx:end_idx]
  images_to_run = images[start_idx:end_idx]
  nfiles = len(prompts_to_run)

  outputs = []

  with tqdm(total=nfiles) as pbar:
    for idx in range(nfiles):
      image_file = images_to_run[idx]
      # Prompt
      qs = prompts_to_run[idx].replace(DEFAULT_IMAGE_TOKEN, '').strip()
      cur_prompt = qs
      if model.config.mm_use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
      else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

      conv = conv_templates[conv_mode].copy()
      conv.append_message(conv.roles[0], qs)
      conv.append_message(conv.roles[1], None)
      prompt = conv.get_prompt()

      input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

      # Image
      if image_file is not None:
        image = Image.open(image_file)
        image = image.resize((512, 512))
      else:
        image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
      image_tensor = process_images([image], image_processor, model.config)[0]

      stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
      keywords = [stop_str]
      stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

      # Inference
      with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor.unsqueeze(0).half().cuda(),
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            # no_repeat_ngram_size=3,
            max_new_tokens=20,
            use_cache=True)

      output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
      # Append
      outputs.append(output)
      # Update
      pbar.update(1)

      # Save checkpoint
      if pbar.n % 1000 == 0 or pbar.n == nfiles:
        output_path = os.path.join(output_dir, f"llavamed_{outcome}_output_{start_idx}_{start_idx+pbar.n}.npz")
        np.savez(output_path, array=np.array(outputs))
        print(f"Saved checkpoint: {output_path}")

In [None]:
output_dir = 'Outputs/'

In [None]:
run_llava_med(prompts,images,output_dir,outcome)