# How to Fine-Tune Multimodal Models or VLMs with HuggingFace TRL

The Vision-Language Models (VLMs) can handle a variety of multimodal tasks, including image captioning, visual question answering, and image-text matching without additional training. However, to customize a model for our specific application, we may need to fine-tune it on our data to achieve higher quality results or to create a more efficient model for our use case.

## Define our multimodal use case

For most use case, fine-tuning may not be the first option. We should evaluate pretrained models or API-based solutions before committing to fine-tuning our own model.

In this example, we want to fine-tune a model that can generate detailed product descriptions based on product images and basic metadata. This model will be integrated into our e-commerce platform to help sellers create more compelling listings. The goal is to reduce the time it takes to create product descriptions and improve their quality and consistency.

Existing models may already be good for this use case, but we may want to tweak/tune it to our specific needs. This image-to-text generation task is well-suited for fine-tuning VLMs, as it requires understanding visual features and combining them with textual information to produce coherent and relevant descriptions.

We will use the [`amazon-product-descriptions-vlm`](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm) dataset for fine-tuning.

## Setups

In [None]:
# Install Pytorch & other libraries
%pip install "torch==2.4.0" tensorboard pillow

# Install Hugging Face libraries
%pip install  --upgrade \
  "transformers==4.45.1" \
  "datasets==3.0.1" \
  "accelerate==0.34.2" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.44.0" \
  "trl==0.11.1" \
  "peft==0.13.0" \
  "qwen-vl-utils"

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Create and prepare dataset

The [`amazon-product-description-vlm`](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm) dataset contains 1350 amazon products with title, images, and descriptions and metadata. We want to fine-tune our model to generate product descriptions based on the images, title, and metadata. Therefore, we need to create a prompt including the title, metadata, and image, and the completion is the description from the model.

A typical conversational dataset format would look like:
```python
messages = [
    {'role': 'system', 'content': [{'type': 'text', 'text': 'You are a helpful..'}]},
    {'role': 'user', 'content': [
        {'type': 'text', 'text': 'How many dogs are in this image?',
        {'type': 'image', 'image': <PIL.Image>}
    ]},
    {'role': 'assistant', 'content': [{'type': 'text', 'text': 'There are 3 dogs in the image.'}]}
]
```

In [None]:
# Note the image is nor provided in the prompt, its included as part of the `processor`
prompt = """Create a short product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

##PRODUCT NAME##: {product_name}
##CATEGORY##: {category}"""

system_message = "You are an expert product description writer for Amazon."

We also need to format our dataset to the conversational format:

In [None]:
from datasets import load_dataset

# convert dataset to OpenAI messages
def format_data(sample):
    return {
        "messages": [
            {
                'role': 'system',
                'content': [{'type': 'text', 'text': system_message}]
            },
            {
                'role': 'user',
                'content': [
                    {
                        'type': 'text',
                        'text': prompt.format(product_name=sample['Product Name'], category=sample['Category'])
                    },
                    {
                        'type': 'image',
                        'image': sample['image']
                    }
                ]
            },
            {
                'role': 'assistant',
                'content': [{'type': 'text', 'text': sample['description']}]
            }
        ]
    }


# load dataset
dataset_id = 'philschmid/amazon-product-descriptions-vlm'
dataset = load_dataset(
    dataset_id,
    split='train'
)

# Convert format
# Need to use list comprehension to keep PIL.Image type; `.map()` convert image to bytes
dataset = [format_data(sample) for sample in dataset]

In [None]:
dataset

In [None]:
dataset[0]

In [None]:
dataset[0]['messages']

## Fine-tune VLM using `trl` and the `SFTTrainer`

The `SFTTrainer` is straightforward to supervise fine-tune open LLMs and VLMs.

We will also use Q-LoRA to reduce the memory footprint during finetuning, without sacrificing performance by using quantization.

We will use `Qwen-2-7B` model for fine-tuning in this example, but other models are welcomed as well.

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig

model_id = 'Qwen/Qwen2-VL-7B-Instruct'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

# load model and tokenizer
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

In [None]:
# Prepare for inference
text = processor.apply_chat_template(
    dataset[0]['messages'],
    tokenize=False,
    add_generation_prompt=False
)
text

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias='none',
    target_modules=['q_proj', 'v_proj'],
    task_type='CAUSAL_LM'
)

Next, we need to create a custom `DataCollator` which formates the inputs correctly and include the image features. We will use the `process_vision_info` method from a utility package the Qwen2 team provides. If we want to use other models, we may need to check if this works.

In [None]:
from trl import SFTConfig
from transformers import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info


training_args = SFTConfig(
    output_dir='qwen2-7b-instruct-amazon-description',
    num_train_epochs=3,
    per_device_train_batch_size=4, # batch size per device
    gradient_accumulation_steps=8, # number of steps before performing a backward pass
    gradient_checkpointing=True
    optim='adamw_torch_fused',
    logging_steps=5,
    save_strategy='epoch',
    learning_rate=2e-4,
    bf16=True,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type='constant',
    push_to_hub=False,
    report_to='tensorboard',
    gradient_checkpointing_kwargs={'use_reentrant': False},  # use reentrant checkpointing
    dataset_text_field="",  # need a dummy field for collator
    dataset_kwargs={'skip_prepare_dataset': True}
    remove_unused_columns=False
)

In [None]:
# create a data collator to encode text and image pairs
def collate_fn(examples):
    # get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example['message'], tokenize=False)
        for example in examples
    ]
    image_inputs = [
        process_vision_info(example['messages'])[0]
        for example in examples
    ]

    # tokenize the texts and process the images
    batch = processor(
        text=texts,
        images=image_inputs,
        padding=True,
        return_tensors='pt'
    )

    # The labels are the input_ids,
    # and we need to mask the padding tokens in the loss computation
    labels = batch['input_ids'].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652,151653,151655]
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]

    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100

    batch['labels'] = labels

    return batch

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
    dataset_text_field="", # need dummy value
    peft_config=peft_config,
    tokenizer=processor.tokenizer
)

In [None]:
trainer.train()
trainer.save_model(training_args.output_dir)

In [None]:
del model
del trainer
torch.cuda.empty_cache()

## Test model and run inference

We will first load the base model and let it generate a description for a random amazon product, and then we will load our Q-LoRA adapted model and let it generate a description for the same product.

Finally we can merge the adapter into the base model to make it more efficient and run inference on the same product again.

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq

adapter_path = './qwen2-7b-instruct-amazon-description'

# load model
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.float16
)

Select a random product from amazon and prepared a `generated_description` function to generate a description for the product

In [None]:
from qwen_vl_utils import process_vision_info

# sample from amazon
sample = {
  "product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
  "catergory": "Toys & Games | Toy Figures & Playsets | Action Figures",
  "image": "https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg"
}

# prepare message
messages = [{
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": sample["image"],
            },
            {"type": "text", "text": prompt.format(product_name=sample["product_name"], category=sample["catergory"])},
        ],
    }
]

In [None]:
def generate_description(sample, model, processor):
    messages = [
        {'role': 'system', 'content': [{'type': 'text', 'text': system_message}]},
        {'role': 'user', 'content': [
            {'type': 'text', 'text': prompt.format(product_name=sample['product_name'], category=sample['catergory'])},
            {'type': 'image', 'image': sample['image']}
        ]}
    ]

    # prepare for inference
    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors='pt'
    ).to(model.device)

    # Inference: generation of the output
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=256,
        top_p=1.0,
        do_sample=True,
        temperature=0.8
    )
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )
    return output_text[0]

In [None]:
base_description = generate_description(sample, model, processor)
base_description

Now load our adapter

In [None]:
model.load_adapter(adapter_path)

ft_description = generate_description(sample, model, processor)
ft_description

In [None]:
import pandas as pd
from IPython.display import display, HTML

def compare_generations(base_gen, ft_gen):
    df = pd.DataFrame({
        'Base Generation': [base_gen],
        'Fine-tuned Generation': [ft_gen]
    })

    styled_df = df.style.set_properties(**{
        'text-align': 'left',
        'white-space': 'pre-wrap',
        'border': '1px solid black',
        'padding': '10px',
        'width': '250px',
        'overflow-wrap': 'break-word'
    })

    display(HTML(styled_df.to_html()))


compare_generations(base_description, ft_description)

## Merge LoRA adapter into the original model

When using Q-LoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If we want to save the full model, which makes it easier to use with text generation inference, we can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method.

In [None]:
from peft import PeftModel
from transformers import AutoProcessor, AutoModelForVision2Seq

adapter_path = "./qwen2-7b-instruct-amazon-description"
base_model_id = "Qwen/Qwen2-VL-7B-Instruct"
merged_path = "merged"

# load base model
processor = AutoProcessor.from_pretrained(base_model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    low_cpu_mem_usage=True
)

# merge lora
peft_model = PeftModel.from_pretrained(
    model,
    adapter_path
)
merged_model = peft_model.merge_and_unload()
# save merged model
merged_model.save_pretrained(
    merged_path,
    safe_serialization=True,
    max_shard_size='2GB'
)

# don't forget to save processor to the merged model path
processor.save_pretrained(merged_path)