# Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the HuggingFace Ecosystem (TRL)

In this example, we will finetune a Vision Language Model (VLM) using the Transformer Reinforcement Leanring library (`trl`).

We will finetune the [`Qwen2-VL-7B`](https://qwenlm.github.io/blog/qwen2-vl/) model on the [`ChatQA`](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) dataset. This dataset includes images of various chart types paired with question-answer pairs - ideal for enchancing the model's visual question-answering capabilities.

## Setups

In [None]:
!pip install  -U -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git datasets bitsandbytes peft qwen-vl-utils wandb accelerate
# Tested with transformers==4.47.0.dev0, trl==0.12.0.dev0, datasets==3.0.2, bitsandbytes==0.44.1, peft==0.13.2, qwen-vl-utils==0.0.8, wandb==0.18.5, accelerate==1.0.1

In [None]:
!pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

## Load dataset

Before loading the dataset, we will generate a system message for the VLM. In this case, we want to create a system that acts as an expert in analyzing chart images and providing concise answers to questions based on them.

In [None]:
system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

We will format the dataset into a chatbot structure for interaction. Each interaction will consist of a system message, followed by the image and the user's query, and the answer to the query.

In [None]:
def format_data(sample):
    return [
        {
            'role': 'system',
            'content': [{'type': 'text', 'text': system_message}]
        },
        {
            'role': 'user',
            'content': [
                {
                    'type': 'image',
                    'image': sample['image']
                },
                {
                    'type': 'text',
                    'text': sample['query']
                }
            ]
        },
        {
            'role': 'assistant',
            'content': [{'type': 'text', 'text': sample['label'][0]}]
        }
    ]

In [None]:
from datasets import load_dataset

dataset_id = 'HuggingFaceM4/ChartQA'
train_dataset, eval_dataset, test_dataset = load_dataset(
    dataset_id,
    split=['train[:10%]', 'val[:10%]', 'test[:10%]']
)

In [None]:
train_dataset

In [None]:
train_dataset[0]

Now we need to format the data using the our chatbot template, which will allow us to set up the interactions appropriately for our model.

In [None]:
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]

In [None]:
train_dataset[0]

## Load model and check performance

We can always check the [WildVision Arena](https://huggingface.co/WildVision) or the [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_leaderboard) to find the best performing VLMs.

In this example, we will use [`Qwen/Qwen2-VL-7B-Instruct`](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)

In [None]:
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor

model_id = 'Qwen/Qwen2-VL-7B-Instruct'

processor = Qwen2VLProcessor.from_pretrained(model_id)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16
)

To evaluate the model's performance, we will use a sample from the dataset.

In [None]:
sample = train_dataset[0]
sample

We need to use the sample without the system message to assess the VLM's raw understanding. Hence, the input we will use:

In [None]:
sample[1:2]

In [None]:
sample[1]['content'][0]['image']

We need a function to take the model, processor, and sample as inputs to generate the model's answer.

In [None]:
from qwen_vl_utils import process_vision_info

def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device='cuda'):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[1:2], # use the sample without the system message
        tokenize=False,
        add_generation_prompt=True
    )

    # Process the inputs from the sample
    image_inputs, _ = process_vision_info(sample)

    # Prepare the inputs for the model
    model_inputs = processor(
        text=[text_input],
        images=image_inputs,
        return_tensors='pt'
    ).to(device)

    # Generate text with the model
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip (model_inputs['input_ids'], generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text[0]

In [None]:
output = generate_text_from_sample(model, processor, sample)
output

### Remove model and clean GPU

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()

## Finetune the model using TRL

### Load the quantized model for training

In [None]:
from transformers import BitsAndBytesConfig

# bitsandbytes int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

processor = Qwen2VLProcessor.from_pretrained(model_id)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

### Set up Q-LoRA and SFTConfig

Q-LoRA enables efficient fine-tuning of large language models while significantly reducing the memory footprint compared to traditional methods. Unlike standard LoRA, which reduces memory usage by applying a low-rank approximation, Q-LoRA takes it a step further by quantizing the weights of the LoRA adapters. This leads to even lower memory requirements and improved training efficiency, making it an excellent choice for optimizing our model's performance without sacrificing quality.

In [None]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias='none',
    target_module=['q_proj', 'v_proj'],
    task_type='CAUSAL_LM'
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

SFT allows us to provide labeled data, helping the model learn to generate more accurate responses based on the input it receives. This approach ensures that the model is tailored to our specific use case, leading to better performance in understanding and responding to visual queries.

In [None]:
from trl import SFTConfig

training_args = SFTConfig(
    output_dir='qwen2-7b-instruct-trl-sft-ChartQA',
    num_train_epochs=3,
    per_device_train_batch_size=4, # batch size for training
    per_device_eval_batch_size=4, # batch size for evaluation
    gradient_accumulation_steps=8, # steps to accumulate gradients
    gradient_checkpointing=True,

    # Optimizer
    optim='adamw_torch_fused',
    learning_rate=2e-4,
    lr_scheduler_type='constant',

    # Logging
    logging_steps=10,
    eval_steps=10,
    eval_strategy='steps',
    save_strategy='steps',
    save_steps=20,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    load_best_model_at_end=True,

    # Mixed precision
    bf16=True,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,

    push_to_hub=False,
    report_to='wandb',
    gradient_checkpointing_kwargs={'use_reentrant': False},
    dataset_text_field='',
    dataset_kwargs={'skip_prepare_dataset': True}
    remove_unused_columns=False
)

### Train the model

In [None]:
import wandb

wandb.init(
    project='qwen2-7b-instruct-trl-sft-ChartQA',
    name='qwen2-7b-instruct-trl-sft-ChartQA',
    config=training_args
)

We also need a collator function to properly retrieve and batch the data during the training procedure. This function will handle the formatting of our dataset inputs, ensuring they are correctly structured for the model.

In [None]:
def collate_fn(examples):
    texts = [
        processor.apply_chat_template(example, tokenize=False)
        for example in examples
    ]
    image_inputs = [
        process_vision_info(example)[0] for example in examples
    ]

    # Tokenize the texts and process the images
    batch = processor(
        text=texts,
        images=image_inputs,
        padding=True,
        return_tensors='pt'
    )

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch['input_ids'].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_token:
        labels[labels == image_token_id] = -100

    batch['labels'] = labels # Add labels to the batch

    return batch

We now define the `SFTTrainer`.

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)

## Test the finetuned model

In [None]:
clear_memory()

In [None]:
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

processor = Qwen2VLProcessor.from_pretrained(model_id)

In [None]:
adapter_path = "sergiopaniego/qwen2-7b-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)

In [None]:
sample = train_dataset[0]
sample[:2]

In [None]:
sample[1]['content'][0]['image']

In [None]:
output = generate_text_from_sample(model, processor, sample)
output

In [None]:
smaple = test_dataset[0]
sample[:2]

In [None]:
sample[1]['content'][0]['image']

In [None]:
output = generate_text_from_sample(model, processor, sample)
output

## Compare finetuned model with base model + prompting

In [None]:
clear_memory()

In [None]:
# base model
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

processor = Qwen2VLProcessor.from_pretrained(model_id)

In [None]:
sample = train_dataset[0]
sample[:2]

In [None]:
text = processor.apply_chat_template(
    sample[:2], # we pass the system message to the base model this time
    tokenize=False,
    add_generation_prompt=True
)

image_inputs, _ = process_vision_info(train_dataset[0])

inputs = processor(
    text=[text],
    images=image_inputs,
    return_tensors="pt",
)

inputs = inputs.to("cuda")

generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
    out_ids[len(in_ids) :]
    for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]

output_text = processor.batch_decode(
    generated_ids_trimmed,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

output_text[0]