### installations, clones and imports

In [None]:
%%capture

!pip install -q transformers datasets peft bitsandbytes wandb trl

In [None]:
!git clone https://github.com/microsoft/LLaVA-Med.git LLaVA_Med

Cloning into 'LLaVA_Med'...
remote: Enumerating objects: 429, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 429 (delta 1), reused 31 (delta 1), pack-reused 388 (from 1)[K
Receiving objects: 100% (429/429), 77.09 MiB | 27.58 MiB/s, done.
Resolving deltas: 100% (122/122), done.


In [None]:
import os
os.chdir("/content/LLaVA_Med")

os.getcwd()

'/content/LLaVA_Med'

In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import requests
import json
import uuid

from transformers import Trainer, TrainingArguments
from peft import LoraConfig, LoraModel, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import Conversation
from llava.mm_utils import tokenizer_image_token, process_images
from llava.model.builder import load_pretrained_model
from llava.conversation import conv_templates

from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### data preparation and loading

In [None]:
def process_and_save(dataset, output_folder, subset_name):
  subset_folder = os.path.join(output_folder, subset_name)
  image_subfolder = os.path.join(subset_folder, "images")

  if not os.path.exists(image_subfolder):
    os.makedirs(image_subfolder, exist_ok=True)

  if not os.path.exists(subset_folder):
    os.makedirs(subset_folder)

  json_data_list = []

  for item in dataset:
    if isinstance(item["image"], str):
      response = requests.get(item["image"])
      image = Image.open(BytesIO(response.content))
    else:
      image = item["image"]

    unique_id = str(uuid.uuid4())

    image_path = os.path.join(image_subfolder, f"{unique_id}.jpg")
    image.save(image_path)

    answers = item["answer"]
    formatted_answers = "".join(answers)

    json_data = {
        "id": unique_id,
        "image": f"{unique_id}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": item["question"]
            },
            {
                "from": "gpt",
                "value": formatted_answers
            }
        ]
    }

    json_data_list.append(json_data)

  json_output_path = os.path.join(output_folder, subset_name, "dataset.json")
  with open(json_output_path, "w") as json_file:
    json.dump(json_data_list, json_file, indent=4)


def save_dataset(dataset_name, output_folder, subset_name):
  dataset = load_dataset(dataset_name)

  process_and_save(dataset[subset_name], output_folder, subset_name)

In [None]:
output_folder = "/content/drive/MyDrive/dataset"
save_dataset("flaviagiammarino/vqa-rad", output_folder, 'train')
save_dataset("flaviagiammarino/vqa-rad", output_folder, 'test')

README.md:   0%|          | 0.00/3.91k [00:00<?, ?B/s]

(…)-00000-of-00001-eb8844602202be60.parquet:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

(…)-00000-of-00001-e5bc3d208bb4deeb.parquet:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1793 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/451 [00:00<?, ? examples/s]

In [None]:
class VQARAD(Dataset):
  def __init__(self, split):
    super(VQARAD, self).__init__()
    self.split = split
    self.image_folder = f'/content/drive/MyDrive/dataset/{self.split}/images'
    self.paths = {
        'train': '/content/drive/MyDrive/dataset/train/dataset.json',
        'test': '/content/drive/MyDrive/dataset/test/dataset.json'
    }

    with open(self.paths[self.split], 'r') as f:
      self.dataset = json.load(f)

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    id = item['id']
    question = item['conversations'][0]['value']
    answer = item['conversations'][1]['value']
    image_path = item['image']
    image = Image.open(os.path.join(self.image_folder, image_path)).convert('RGB')

    return id, question, answer, image

In [None]:
class DataCollator:
  def __init__(self, tokenizer, split, conversation_template, pad_token_id, image_processor, model_config):
    self.tokenizer = tokenizer
    self.split = split
    self.conversation_template = conversation_template
    self.pad_token_id = pad_token_id
    self.image_processor = image_processor
    self.model_config = model_config

  def __call__(self, rows):
    if self.split == "train":
      return self._collate_train(rows)
    elif self.split == "test":
      return self._collate_test(rows)
    else:
      return ValueError(f"Invalid split: {self.split}")

  def _collate_train(self, rows):
    train_input_ids_list=[]
    train_labels_list=[]
    train_images=[]

    for row in rows:
      id, question, answer, image = row
      train_images.append(image)

      question = question.replace(DEFAULT_IMAGE_TOKEN, '').strip()
      question = DEFAULT_IMAGE_TOKEN + '\n' + question


      conv = self.conversation_template.copy()
      conv.append_message(conv.roles[0], question)
      conv.append_message(conv.roles[1], None)
      prefix = conv.get_prompt()

      conv = self.conversation_template.copy()
      conv.append_message(conv.roles[0], question)
      conv.append_message(conv.roles[1], answer)
      full = conv.get_prompt()

      prefix = tokenizer_image_token(prefix, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
      full = tokenizer_image_token(full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")

      prefix_length = prefix.size(0)
      full_length = full.size(0)

      train_input_ids = full
      train_labels = full.clone()
      train_labels[:prefix_length] = -100

      train_input_ids_list.append(train_input_ids)
      train_labels_list.append(train_labels)

    pad_value = -114514
    train_input_ids = pad_sequence(train_input_ids_list, batch_first=True, padding_value=pad_value)
    train_labels = pad_sequence(train_labels_list, batch_first=True, padding_value=pad_value)
    train_attention_mask = (train_input_ids != pad_value).long()

    train_input_ids[train_input_ids == pad_value] = self.pad_token_id
    train_labels[train_labels == pad_value] = self.pad_token_id

    train_images = process_images(train_images, self.image_processor, self.model_config).to(torch.bfloat16)

    return {
        "input_ids": train_input_ids,
        "labels": train_labels,
        "attention_mask": train_attention_mask,
        "images": train_images
    }

  def _collate_test(self, rows):
    pass

In [None]:
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="microsoft/llava-med-v1.5-mistral-7b",
    model_base="llava-med-v1.5-mistral-7b",
    model_name="llava-med-v1.5-mistral-7b"
    )

vqa_rad_dataset_train = VQARAD(split="train")
vqa_rad_dataset_test = VQARAD(split="test")

conv = conv_templates["mistral_instruct"]

collate_fn = DataCollator(
    tokenizer=tokenizer,
    split="train",
    conversation_template=conv,
    pad_token_id=tokenizer.pad_token_id,
    image_processor=image_processor,
    model_config=model.config
)

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/73.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/262M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.76k [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.la

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

In [None]:
model.train()

model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    target_modules=["k_proj", "q_proj", "v_proj", "out_proj"],
    lora_alpha=16
)

peft_model = get_peft_model(model, lora_config, "default")
peft_model.print_trainable_parameters()

trainable params: 6,291,456 || all params: 7,572,510,720 || trainable%: 0.0831


In [None]:
training_args = TrainingArguments(
    output_dir="trained_llava-med",
    report_to="wandb",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    logging_steps=5,
    learning_rate=2e-5,
    logging_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    warmup_ratio=0.03,
    weight_decay=0.01,
    remove_unused_columns=True,
    gradient_checkpointing=True,
    fp16=False,
    bf16=True,
    optim='paged_adamw_8bit')

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=vqa_rad_dataset_train,
    data_collator=collate_fn
)

peft_model.config.use_cache = False
trainer.train()

Step,Training Loss
1793,0.4476
3586,0.3456
5379,0.2196
7172,0.1411
8965,0.0777


TrainOutput(global_step=8965, training_loss=0.24632015877174254, metrics={'train_runtime': 3628.8625, 'train_samples_per_second': 2.47, 'train_steps_per_second': 2.47, 'total_flos': 1.058735921163264e+16, 'train_loss': 0.24632015877174254, 'epoch': 5.0})

In [None]:
from transformers import AutoModelForCausalLM
from peft import PeftModel

trainer.save_model('/content/drive/MyDrive/lora_trained_model')
base_model = AutoModelForCausalLM.from_pretrained('microsoft/llava-med-v1.5-mistral-7b')

trained_model = PeftModel.from_pretrained(base_model, '/content/drive/MyDrive/lora_trained_model')
merged_trained_model = trained_model.merge_and_unload()
merged_trained_model.save_pretrained('/content/drive/MyDrive/merged_trained_model', push_to_hub=True, repo_id="onyekaokonji/llava-med-v1.5-mistral-7b-oo")

tokenizer.save_pretrained('/content/drive/MyDrive/merged_trained_model', push_to_hub=True, repo_id='onyekaokonji/llava-med-v1.5-mistral-7b-oo')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.la

Upload 6 LFS files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.33G [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

('/content/drive/MyDrive/merged_trained_model/tokenizer_config.json',
 '/content/drive/MyDrive/merged_trained_model/special_tokens_map.json',
 '/content/drive/MyDrive/merged_trained_model/tokenizer.model',
 '/content/drive/MyDrive/merged_trained_model/added_tokens.json',
 '/content/drive/MyDrive/merged_trained_model/tokenizer.json')

In [None]:
# trainer.model.save_pretrained("/content/drive/MyDrive/trained_modelss")