In [None]:
#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

In [None]:
from google.colab import drive
drive.mount("/content/drive")
import warnings
warnings.filterwarnings("ignore")

## 1.0 Install and import necessary libraries

In [None]:
!python -m pip install --upgrade pip -q
!pip install matplotlib -q -U

In [None]:
!pip install -q datasets
!pip install transformers -q -U
!pip install -q bitsandbytes sentencepiece accelerate loralab
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q

In [None]:
#%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os
if not os.path.exists("./CS45_2_S1_2024"): # modify the path
    !git clone https://github.com/theon1130/CS45_2_S1_2024.git
else:
    print("CS45_2_S1_2024 already exists")

In [None]:
import os
if not os.path.exists("/workspace/LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print("LLaVA already exists")

Replace the llava file in LLaVA with our llava in CS45_2_S1_2024

In [None]:
!pwd
%cd ./LLaVA
!pwd

In [None]:
!pip install -e . -q

In [None]:
!pip install protobuf -q -U
!pip install --upgrade Pillow -q
!pip install -e ".[train]" -q
!pip install flash-attn --no-build-isolation -q

## 2.0 Load the model

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoProcessor, Trainer, BitsAndBytesConfig, TrainingArguments, AutoTokenizer
import torchvision.transforms as transforms

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.utils import disable_torch_init
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "your path to the downloaded model from hugging face"
model_name = get_model_name_from_path(model_path)
print(model_name)

In [None]:
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path = model_path,
    model_base = None,
    model_name = model_name,
    #use_flash_attn=True,
#    cache_dir = ''
)

load lora for VQARADqformer

In [None]:
peft_model_id = "your path"
model = PeftModel.from_pretrained(model, peft_model_id)
model = model.merge_and_unload()

load lora for SLAKEqformer

In [None]:
peft_model_id = "your path"
model = PeftModel.from_pretrained(model, peft_model_id)
model = model.merge_and_unload()

In [None]:
print(model)
print('='*100)
print(image_processor)
print('='*100)
print(tokenizer)
print('='*100)
print(context_len)
print(tokenizer.model_max_length)
print('='*100)
print(model.config)

## 3.0 Inference

In [None]:
import re
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import Image
import requests
from io import BytesIO


from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
    IGNORE_INDEX,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

def creat_prompt(qs, model, model_name=model_name, caption=None):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"


    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    if caption:
        conv.append_message(conv.roles[1], caption)
    else:
        conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    return prompt

In [None]:
def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if isinstance(image_file, str):
      if image_file.startswith("http") or image_file.startswith("https"):
          response = requests.get(image_file)
          image = Image.open(BytesIO(response.content)).convert("RGB")
      else:
          image = Image.open(image_file).convert("RGB")
    elif isinstance(image_file, Image.Image):
        image = image_file
    else:
        raise ValueError(f"Unsupported image file type: {type(image_file)}")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

In [None]:
def process_and_prepare_image(image_files, model, image_processor, device):
    images = load_images(image_files)
    images_tensor = process_images(images, image_processor, model.config)

    images_tensor_to_device = [image_tensor.to(device, dtype=torch.bfloat16) for image_tensor in images_tensor]

    image_sizes = [image.size for image in images]
    return images_tensor_to_device, image_sizes

In [None]:
from transformers import (
    AutoTokenizer,
    CLIPImageProcessor,
    CLIPVisionConfig,
    CLIPVisionModel,
    InstructBlipQFormerConfig,
    InstructBlipQFormerModel,
    InstructBlipProcessor,
)

QformerProcessor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xxl")

def preprocess_text(strs):
    tokenized_text = QformerProcessor(text=strs, padding=True, return_tensors="pt")
    qformer_ids = tokenized_text["qformer_input_ids"].to(model.device)
    attention_mask = tokenized_text["qformer_attention_mask"].to(model.device)
    return qformer_ids, attention_mask


In [None]:
def eval_model(tokenizer, model, image_processor, context_len, image_files, qs, use_q,  sep=',', model_name=model_name, temperature=1.0, num_beams=1, max_new_tokens=512):
    disable_torch_init()

    qformer_ids, qformerattention_mask = preprocess_text(qs)

    prompt = creat_prompt(qs, model, model_name)
    print(f"Prompt: {prompt}")

    images_tensor, image_sizes = process_and_prepare_image(image_files, model, image_processor, model.device) # image_files should be a str list

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            qformer_ids,
            qformerattention_mask,
            images=images_tensor,
            image_sizes=image_sizes,
            use_q=use_q,
            do_sample=True if temperature != 1.0 else False,
            temperature=temperature,
            #top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    print(outputs)

    return outputs


## 4.0 Prepare the dataset

### 4.1 Prepare for ROCO

In [None]:
import pandas as pd
train_df = pd.read_csv('/train/radiologytraindata.csv')
train_prefix = "Your path"
train_df["name"] = train_df["name"].apply(lambda x: train_prefix + x)
train_df.head()

In [None]:
val_df = pd.read_csv('/validation/radiologytraindata.csv')
val_prefix = "Your paht"
val_df["name"] = val_df["name"].apply(lambda x: val_prefix + x)
val_df.head()

In [None]:
train_ds = Dataset.from_pandas(train_df)
eval_ds = Dataset.from_pandas(val_df)

In [None]:
print(train_ds)

In [None]:
print(eval_ds)

prepare the dataset that will be used for finetuning

In [None]:
concise_describe_instructions = [
    "Describe the image concisely.",
    "Provide a brief description of the given image.",
    "Offer a succinct explanation of the picture presented.",
    "Summarize the visual content of the image.",
    "Give a short and clear explanation of the given image.",
    "Share a concise interpretation of the image provided.",
    "Present a compact description of the photo's key features.",
    "Relay a brief, clear account of the picture shown.",
    "Render a clear and concise summary of the photo.",
    "Write a terse but informative summary of the provided picture.",
    "Briefly describe this image.",
]

In [None]:
import random
from torch.nn.utils.rnn import pad_sequence

def tokenize_and_create_label(example_batch, image_processor, tokenizer, model, model_name, device):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch["name"]

    images_tensor, image_sizes = process_and_prepare_image(image_files, model, image_processor, model.device)

    tokenized_conversation_with_caption = []
    tokenized_conversation_without_caption = []
    query_list = []
    for caption in example_batch["caption"]:
        query = random.choice(concise_describe_instructions)
        query_list.append(query)
        prompt_without_caption = creat_prompt(query, model, model_name, None)
        prompt_with_caption = creat_prompt(query, model, model_name, caption)

        tokenized_without_caption = tokenizer_image_token(prompt_without_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")
        tokenized_with_caption = tokenizer_image_token(prompt_with_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")

        tokenized_conversation_without_caption.append(tokenized_without_caption)
        tokenized_conversation_with_caption.append(tokenized_with_caption)

    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversation_with_caption], batch_first=True, padding_value=pad_token_id)
    attention_mask = (input_ids != pad_token_id).long().to(device)

    labels = torch.full_like(input_ids, fill_value=IGNORE_INDEX)
    for i, tcwc in enumerate(tokenized_conversation_without_caption):
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]

    qformer_ids_list, qformerattention_list = preprocess_text(query_list)

    inputs = {
        "input_ids": input_ids,
        "qformer_inputids": qformer_ids_list,
        "qfromer_attention_mask": qformerattention_list,
        "attention_mask": attention_mask,
        "labels": labels,
        "images": images_tensor,
        "image_sizes": image_sizes,
    }

    return inputs

def transform_batch(batch):
    return tokenize_and_create_label(batch, image_processor, tokenizer, model, model_name, device)


train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

### 4.2 Prepare for VQA-RAD

In [None]:
from datasets import load_dataset
cache_dir = "./VQA-RAD"
VQA_RADdataset = load_dataset("flaviagiammarino/vqa-rad", cache_dir=cache_dir)

In [None]:
print(VQA_RADdataset)

In [None]:
import json
path = "./VQA_RAD Dataset Public.json"
#path = "/workspace/VQA_RAD_Dataset_Public.json"
with open(path, 'r', encoding='utf-8') as file:
    vqa_data = json.load(file)


In [None]:
test_data = VQA_RADdataset['test']

In [None]:
import re

def normalize_text(text):
    if not isinstance(text, str):
        text = str(text)
    text = re.sub(r'\s+', '', text)
    return text.lower()

def add_image_name_and_answer_type(example):
    question = normalize_text(example['question'])
    answer = normalize_text(example['answer'])
    for item in vqa_data:
        if normalize_text(item['question']) == question and normalize_text(item['answer']) == answer:
            example['image_name'] = item['image_name']
            example['answer_type'] = item['answer_type']
            return example

    raise ValueError(f"No matching question-answer pair found for question: {example['question']}, answer: {example['answer']}")

train_datas = VQA_RADdataset['train'].map(add_image_name_and_answer_type)
train_ds, eval_ds = train_datas.train_test_split(test_size=0.2).values()

In [None]:
modified_test_data = test_data.map(add_image_name_and_answer_type)

In [None]:
import os

pre_prompt = {
    "short": "Based on the image, respond to this question with a word or phrase: ",
    "long": "Based on the image, respond to this question with a short answer: "
}

def prepare_data(example):

    if "CLOSED" in example["answer_type"]:
        example["question"] = pre_prompt["short"] + example["question"]
    elif example["answer_type"] == "OPEN":
        example["question"] = pre_prompt["long"] + example["question"]
    else:
        raise ValueError(f"Invalid answer type: {example['answer_type']}")

    return example


train_ds = train_ds.map(prepare_data)
eval_ds = eval_ds.map(prepare_data)

In [None]:
print(train_ds)

In [None]:
print(train_ds[0])

In [None]:
import random
from torch.nn.utils.rnn import pad_sequence

def tokenize_and_create_label(example_batch, image_processor, tokenizer, model, model_name, device):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch["image"]

    images_tensor, image_sizes = process_and_prepare_image(image_files, model, image_processor, model.device)

    tokenized_conversation_with_caption = []
    tokenized_conversation_without_caption = []
    query_list = []
    for query, answer in zip(example_batch["question"], example_batch["answer"]):

        query_list.append(query)
        prompt_without_caption = creat_prompt(query, model, model_name, None)
        prompt_with_caption = creat_prompt(query, model, model_name, answer)

        tokenized_without_caption = tokenizer_image_token(prompt_without_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")
        tokenized_with_caption = tokenizer_image_token(prompt_with_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")

        tokenized_conversation_without_caption.append(tokenized_without_caption)
        tokenized_conversation_with_caption.append(tokenized_with_caption)

    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversation_with_caption], batch_first=True, padding_value=pad_token_id)
    attention_mask = (input_ids != pad_token_id).long().to(device)

    labels = torch.full_like(input_ids, fill_value=IGNORE_INDEX)
    for i, tcwc in enumerate(tokenized_conversation_without_caption):
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]

    qformer_ids_list, qformerattention_list = preprocess_text(query_list)

    inputs = {
        "input_ids": input_ids,
        "qformer_inputids": qformer_ids_list,
        "qfromer_attention_mask": qformerattention_list,
        "attention_mask": attention_mask,
        "labels": labels,
        "images": images_tensor,
        "image_sizes": image_sizes,
    }

    return inputs

def transform_batch(batch):
    return tokenize_and_create_label(batch, image_processor, tokenizer, model, model_name, device)


In [None]:
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

### 4.3 Prepare for SLAKE

In [None]:
from datasets import load_dataset

cache_dir = "./SLAKE"
slake_dataset = load_dataset("BoKelvin/SLAKE", cache_dir=cache_dir)

In [None]:
print(slake_dataset)

In [None]:
train_ds = slake_dataset['train'].filter(lambda example: example['q_lang'] == 'en')
eval_ds = slake_dataset['validation'].filter(lambda example: example['q_lang'] == 'en')
test_ds = slake_dataset['test'].filter(lambda example: example['q_lang'] == 'en')

In [None]:
import os
path_pre = "./Slake1.0/imgs/"

pre_prompt = {
    "short": "Based on the image, respond to this question with a word or phrase: ",
    "long": "Based on the image, respond to this question with a short answer: "
}

def prepare_data(example):
    example['img_name'] = os.path.join(path_pre, example['img_name'])

    if example["answer_type"] == "CLOSED":
        example["question"] = pre_prompt["short"] + example["question"]
    elif example["answer_type"] == "OPEN":
        example["question"] = pre_prompt["long"] + example["question"]
    else:
        raise ValueError(f"Invalid answer type: {example['answer_type']}")

    return example


train_ds = train_ds.map(prepare_data)
eval_ds = eval_ds.map(prepare_data)
test_ds = test_ds.map(prepare_data)

In [None]:
import random
from torch.nn.utils.rnn import pad_sequence

def tokenize_and_create_label(example_batch, image_processor, tokenizer, model, model_name, device):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch["img_name"]

    images_tensor, image_sizes = process_and_prepare_image(image_files, model, image_processor, model.device)

    tokenized_conversation_with_caption = []
    tokenized_conversation_without_caption = []
    query_list = []
    for query, answer in zip(example_batch["question"], example_batch["answer"]):

        query_list.append(query)
        prompt_without_caption = creat_prompt(query, model, model_name, None)
        prompt_with_caption = creat_prompt(query, model, model_name, answer)

        tokenized_without_caption = tokenizer_image_token(prompt_without_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")
        tokenized_with_caption = tokenizer_image_token(prompt_with_caption, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt")

        tokenized_conversation_without_caption.append(tokenized_without_caption)
        tokenized_conversation_with_caption.append(tokenized_with_caption)

    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversation_with_caption], batch_first=True, padding_value=pad_token_id)
    attention_mask = (input_ids != pad_token_id).long().to(device)

    labels = torch.full_like(input_ids, fill_value=IGNORE_INDEX)
    for i, tcwc in enumerate(tokenized_conversation_without_caption):
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]

    qformer_ids_list, qformerattention_list = preprocess_text(query_list)

    inputs = {
        "input_ids": input_ids,
        "qformer_inputids": qformer_ids_list,
        "qfromer_attention_mask": qformerattention_list,
        "attention_mask": attention_mask,
        "labels": labels,
        "images": images_tensor,
        "image_sizes": image_sizes,
    }

    return inputs

def transform_batch(batch):
    return tokenize_and_create_label(batch, image_processor, tokenizer, model, model_name, device)


In [None]:
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

## 5.0 Evaluation on benchmark

### 5.1 Eval VQA-RAD

In [None]:
print(modified_test_data)

In [None]:
counter = 0
json_data = []

for row in modified_test_data:

    image = row['image']
    answer = row['answer']

    if row['answer_type'] == 'OPEN':
        qs = f"Based on the image, respond to this question with a short answer:{row['question']}"
    elif row['answer_type'] == 'CLOSED':
        qs = f"Based on the image, respond to this question with a word or phrase:{row['question']}"
    else:
        raise ValueError(f"Invalid answer_type: {row['answer_type']}")

    generate_answer = eval_model(tokenizer, model, image_processor, context_len, [image], qs, use_q=True)

    print(f"{counter}.Img: {image}\n Question: {row['question']}\n Answer: {generate_answer}\n GT: {answer}")

    counter += 1

    new_json_data = {
        "image_name": row['image_name'],
        "question": row['question'],
        "prompt":qs,
        "generated": generate_answer,
        "answer": answer,
        "mode": "test",
        "answer_type": row['answer_type']
    }
    json_data.append(new_json_data)


with open("./vqa_prediction_answer.json", "w") as json_file:
    json.dump(json_data, json_file, indent=4)


### 5.2 Eval SLAKE

eval test dataset

In [None]:
path_pre = "./Slake1.0/imgs/"

In [None]:
import os
import json

counter = 0
json_data = []

for row in test_ds:

    image_path = os.path.join(path_pre, row['img_name'])
    answer = row['answer']

    if row['answer_type'] == 'OPEN':
        qs = f"Based on the image, respond to this question with a short answer:{row['question']}"
    elif row['answer_type'] == 'CLOSED':
        qs = f"Based on the image, respond to this question with a word or phrase:{row['question']}"
    else:
        raise ValueError(f"Invalid answer_type: {row['answer_type']}")

    generate_answer = eval_model(tokenizer, model, image_processor, context_len, [image_path], qs, use_q=True)

    print(f"{counter}.Img: {image_path}\n Question: {row['question']}\n Answer: {generate_answer}\n GT: {answer}")

    counter += 1

    new_json_data = {
        "image_name": row['img_name'],
        "question": row['question'],
        "prompt":qs,
        "generated": generate_answer,
        "answer": answer,
        "mode": "test",
        "answer_type": row['answer_type'],
        'img_id' : row['img_id'],
        'qid' : row['qid']
    }
    json_data.append(new_json_data)


with open(",/slake_prediction_answer.json", "w") as json_file:
    json.dump(json_data, json_file, indent=4)

## 6.0 LoRA

### 6.1 lora for VQA-RAD

In [None]:
config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    bias="none",
    target_modules=[
        "q_proj", "v_proj", "k_proj",
        "up_proj", "down_proj", "gate_proj",
    ],
    modules_to_save=["mm_projector", "query_tokens", "post_projection", "projection"],
)

model = get_peft_model(model, config)

In [None]:
model.base_model.model.model.vision_tower.query_tokens.requires_grad = True

In [None]:
print(model)

In [None]:
model.print_trainable_parameters()

### 6.2 lora for SLAKE

In [None]:
print(model)

In [None]:
config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    bias="none",
    target_modules=[
        "q_proj", "v_proj", "k_proj",
        "up_proj", "down_proj", "gate_proj",
    ],
    modules_to_save=["mm_projector", "query_tokens", "post_projection", "projection"],
)

model = get_peft_model(model, config)

In [None]:
model.base_model.model.model.vision_tower.query_tokens.requires_grad = True

In [None]:
print(model)

In [None]:
model.print_trainable_parameters()

## 7.0 Training with hugging face Trainer

In [None]:
!pip install wandb

In [None]:
import wandb

In [None]:
import os
os.environ["WANDB_API_KEY"] = "your key"
os.environ["WANDB_PROJECT"] = "your project"

In [None]:
wandb.login()

In [None]:
!pip install huggingface_hub

In [None]:
!huggingface-cli whoami

### 7.1 Training proector with ROCO

In [None]:
for param in model.parameters():
    param.requires_grad = False

for param in model.model.vision_tower.projection.parameters():
    param.requires_grad = True

model.model.vision_tower.query_tokens.requires_grad = True

for param in model.model.vision_tower.post_projection.parameters():
    param.requires_grad = True

for param in model.model.mm_projector.parameters():
    param.requires_grad = True

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=4e-5, foreach=False)

In [None]:
output_model_name = f"ROCO_{model_name}"

training_args = TrainingArguments(
    output_dir= "/workspace/checkpoints/" + output_model_name,
    learning_rate=4e-5,
    bf16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    dataloader_pin_memory=False,
    save_total_limit=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=1,
    num_train_epochs=2,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    report_to="wandb",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    optimizers=(optimizer, None)
)


In [None]:
torch.set_default_dtype(torch.bfloat16)

In [None]:
trainer.train()

In [None]:
#trainer.push_to_hub()

In [None]:
new_model_dir = './ROCOQformer/'
trainer.save_model(new_model_dir)

In [None]:
querytoken = model.model.vision_tower.query_tokens

In [None]:
print(type(querytoken))

In [None]:
torch.save(querytoken, './ROCOquery_tokens.pth')

In [None]:
wandb.finish()

### 7.2 Finetune with VQA-RAD

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5, foreach=False)

In [None]:
output_model_name = f"vqarad{model_name}"

training_args = TrainingArguments(
    output_dir= "/workspace/checkpoints/" + output_model_name,
    learning_rate=2e-5,
    bf16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    dataloader_pin_memory=False,
    save_total_limit=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=25,
    save_steps=25,
    logging_steps=1,
    num_train_epochs=5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    report_to="wandb",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    optimizers=(optimizer, None)
)


In [None]:
torch.set_default_dtype(torch.bfloat16)

In [None]:
trainer.train()

In [None]:
new_model_dir = './VQARADQformer/'
trainer.save_model(new_model_dir)

In [None]:
querytoken = model.base_model.model.model.vision_tower.query_tokens

In [None]:
print(type(querytoken))

In [None]:
print(querytoken)

In [None]:
torch.save(querytoken, '.VQARADquery_tokens.pth')

In [None]:
wandb.finish()

### 7.3 Finetune with SLAKE

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5, foreach=False)

In [None]:
output_model_name = f"SLAKE{model_name}"

training_args = TrainingArguments(
    output_dir= "/workspace/checkpoints/" + output_model_name,
    learning_rate=2e-5,
    bf16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    dataloader_pin_memory=False,
    save_total_limit=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=30,
    save_steps=30,
    logging_steps=1,
    num_train_epochs=5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    report_to="wandb",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    optimizers=(optimizer, None)
)


In [None]:
torch.set_default_dtype(torch.bfloat16)

In [None]:
trainer.train()

In [None]:
new_model_dir = './SLAKEQformer/'
trainer.save_model(new_model_dir)

In [None]:
querytoken = model.base_model.model.model.vision_tower.query_tokens

In [None]:
print(type(querytoken))

In [None]:
print(querytoken)

In [None]:
torch.save(querytoken, './SLAKEquery_tokens.pth')

In [None]:
wandb.finish()