# Fine-tuning SmolVLM with TRL on a consumer GPU

Notebook based on: https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl

Ressources provided by Scientific Computing University of Leipzig.

## 1. Install Dependencies

Let’s start by installing the essential libraries we’ll need for fine-tuning! 🚀


In [2]:
!pip install  -U -q transformers trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.46.3, trl==0.12.1, datasets==3.1.0, bitsandbytes==0.45.0, peft==0.13.2, accelerate==1.1.1

In [3]:
!pip install flash-attn --no-build-isolation

Defaulting to user installation because normal site-packages is not writeable
Collecting flash-attn
  Using cached flash_attn-2.7.4.post1-cp312-cp312-linux_x86_64.whl
Collecting einops (from flash-attn)
  Using cached einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Using cached einops-0.8.1-py3-none-any.whl (64 kB)
Installing collected packages: einops, flash-attn
Successfully installed einops-0.8.1 flash-attn-2.7.4.post1


Authenticate with your Hugging Face account to save and share your model directly from this notebook 🗝️.

In [4]:
import huggingface_hub
huggingface_hub.login()

## 2. Load Dataset 📁

We’ll create a system message to make the VLM act as a chart analysis expert, giving concise answers about historic alchemic images.

In [5]:
system_message = """You are a Vision Language Model specialized in interpreting alchemic objects in historic book illustrations.
Your task is to analyze the provided book page and respond in JSON format only. In total there are 12 classes of alchemic objects you should detect.
The classes are: ampullae, animal, cucurbitae, cucurbitae-ambix, ollae, cucurbitae-retorte, cucurbitae-rosenhut, furnace, human, mineral-metal, other-equipment, plant.
Respond in a well-structured JSON format in which you always name all of the 12 classes and then their number of occurences in the provided image.
An output should look like this, for example: 
{
  "ampullae": 0,
  "animal": 0,
  "cucurbitae": 0,
  "cucurbitae-ambix": 0,
  "cucurbitae-retorte": 0,
  "cucurbitae-rosenhut": 0,
  "furnace": 0,
  "human": 0,
  "mineral-metal": 0,
  "other-equipment": 0,
  "plant": 0,
  "ollae": 0
}
Focus on delivering accurate, succinct answers based on the visual information. 
Don't add any additional explanation."""

In [6]:
query = """Which alchemic objects are in the image?
Provide an answer in a well-structured JSON format with 12 classes (ampullae, animal, cucurbitae, cucurbitae-ambix,
ollae, cucurbitae-retorte, cucurbitae-rosenhut, furnace, human, mineral-metal,
other-equipment, plant) and the number of their occurences."""

We’ll format the dataset into a chatbot structure, with the system message, image, user query, and answer for each interaction.

💡For more tips on using this model, check out the [Model Card](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct).

In [7]:
def format_data(sample):
    return [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": system_message
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    "text": query,
                }
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": sample["labels"]
                }
            ],
        },
    ]

In [8]:
from datasets import load_dataset

dataset_id = "maychnix/AlchObj"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=['train', 'valid', 'test'])

Let’s take a look at the dataset structure. It includes an image, a query, a label (the answer), and a fourth feature that we’ll be discarding.

In [9]:
train_dataset

Dataset({
    features: ['image', 'labels', 'idx', 'source'],
    num_rows: 434
})

Now, let’s format the data using the chatbot structure. This will set up the interactions for the model.

In [10]:
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]

In [11]:
len(train_dataset)

434

In [12]:
len(eval_dataset)

93

In [13]:
train_dataset[12]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are a Vision Language Model specialized in interpreting alchemic objects in historic book illustrations.\nYour task is to analyze the provided book page and respond in JSON format only. In total there are 12 classes of alchemic objects you should detect.\nThe classes are: ampullae, animal, cucurbitae, cucurbitae-ambix, ollae, cucurbitae-retorte, cucurbitae-rosenhut, furnace, human, mineral-metal, other-equipment, plant.\nRespond in a well-structured JSON format in which you always name all of the 12 classes and then their number of occurences in the provided image.\nAn output should look like this, for example: \n{\n  "ampullae": 0,\n  "animal": 0,\n  "cucurbitae": 0,\n  "cucurbitae-ambix": 0,\n  "cucurbitae-retorte": 0,\n  "cucurbitae-rosenhut": 0,\n  "furnace": 0,\n  "human": 0,\n  "mineral-metal": 0,\n  "other-equipment": 0,\n  "plant": 0,\n  "ollae": 0\n}\nFocus on delivering accurate, succinct answers based on the

## 3. Load Model and Check Performance! 🤔


In [14]:
import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor

model_id = "HuggingFaceTB/SmolVLM-Instruct"

RuntimeError: Failed to import transformers.models.idefics3.modeling_idefics3 because of the following error (look up to see its traceback):
/home/sc.uni-leipzig.de/cm77xisu/.local/lib/python3.12/site-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1011StorageImpl27throw_data_ptr_access_errorEv

Next, we’ll load the model and the tokenizer to prepare for inference.

In [None]:
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2"
)

processor = AutoProcessor.from_pretrained(model_id)

Let’s create a method that takes the model, processor, and sample as inputs to generate the model's answer. This will allow us to streamline the inference process and easily evaluate the VLM's performance.

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=256, device="cuda"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[1:2],  # Use the sample without the system message
        add_generation_prompt=True
    )

    image_inputs = []
    image = sample[1]['content'][0]['image']
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image_inputs.append([image])

    # Prepare the inputs for the model
    model_inputs = processor(
        #text=[text_input],
        text=text_input,
        images=image_inputs,
        return_tensors="pt",
    ).to(device)  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text[0]  # Return the first decoded output text

In [None]:
output = generate_text_from_sample(model, processor, train_dataset[42])
output

It seems like the model is not able to output clean JSON nor accuratly understand the task, we can fine-tune the model with more relevant data to ensure it better understands the context and provides more accurate responses.

**Remove Model and Clean GPU**

Before we proceed with training the model in the next section, let's clear the current variables and clean the GPU to free up resources.



In [None]:
import gc
import time

def clear_memory():
    # Delete variables if they exist in the current global scope
    if 'inputs' in globals(): del globals()['inputs']
    if 'model' in globals(): del globals()['model']
    if 'processor' in globals(): del globals()['processor']
    if 'trainer' in globals(): del globals()['trainer']
    if 'peft_model' in globals(): del globals()['peft_model']
    if 'bnb_config' in globals(): del globals()['bnb_config']
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

clear_memory()

## 4. Fine-Tune the Model using TRL


### 4.1 Load the Quantized Model for Training ⚙️

Next, we’ll load the quantized model using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index). If you want to learn more about quantization, check out [this blog post](https://huggingface.co/blog/merve/quantization) or [this one](https://www.maartengrootendorst.com/blog/quantization/).


In [None]:
from transformers import BitsAndBytesConfig

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)

### 4.2 Set Up QLoRA and SFTConfig 🚀

Next, we’ll configure [QLoRA](https://github.com/artidoro/qlora) for our training setup. QLoRA allows efficient fine-tuning of large models by reducing the memory footprint. Unlike traditional LoRA, which uses low-rank approximation, QLoRA further quantizes the LoRA adapter weights, leading to even lower memory usage and faster training.

To boost efficiency, we can also leverage a **paged optimizer** or **8-bit optimizer** during QLoRA implementation. This approach enhances memory efficiency and speeds up computations, making it ideal for optimizing our model without sacrificing performance.

In [None]:
from peft import LoraConfig, get_peft_model

# Configure LoRA
peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
    use_dora=True,
    init_lora_weights="gaussian"
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

# Print trainable parameters
peft_model.print_trainable_parameters()

In [None]:
from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="smolvlm-instruct-trl-sft-AlchObj",
    overwrite_output_dir=True, # added (ran notebook multiple times)
    torch_empty_cache_steps=25, # added (oom error after ~70 steps)
    num_train_epochs=12,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    push_to_hub=True,
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

### 4.3 Training the Model 🏃

To ensure that the data is correctly structured for the model during training, we need to define a collator function. This function will handle the formatting and batching of our dataset inputs, ensuring the data is properly aligned for training.

👉 For more details, check out the official [TRL example scripts](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py).

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")]

def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
      image = example[1]['content'][0]['image']
      if image.mode != 'RGB':
          image = image.convert('RGB')
      image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

Now, we will define the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer), which is a wrapper around the [transformers.Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) class and inherits its attributes and methods. This class simplifies the fine-tuning process by properly initializing the [PeftModel](https://huggingface.co/docs/peft/v0.6.0/package_reference/peft_model) when a [PeftConfig](https://huggingface.co/docs/peft/v0.6.0/en/package_reference/config#peft.PeftConfig) object is provided. By using `SFTTrainer`, we can efficiently manage the training workflow and ensure a smooth fine-tuning experience for our Vision Language Model.



In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)

Time to Train the Model! 🎉

In [None]:
trainer.train()

Let's save the results 💾

In [None]:
trainer.save_model(training_args.output_dir)

## 5. Testing the Fine-Tuned Model 🔍

Now that our Vision Language Model (VLM) is fine-tuned, it's time to evaluate its performance! In this section, we'll test the model using examples from the ChartQA dataset to assess how accurately it answers questions based on chart images. Let's dive into the results and see how well it performs! 🚀

Let's clean up the GPU memory to ensure optimal performance 🧹

In [None]:
clear_memory()

We will reload the base model using the same pipeline as before.

In [None]:
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)

We will attach the trained adapter to the pretrained model. This adapter contains the fine-tuning adjustments made during training, enabling the base model to leverage the new knowledge while keeping its core parameters intact. By integrating the adapter, we enhance the model's capabilities without altering its original structure.

In [None]:
adapter_path = "maychnix/smolvlm-instruct-trl-sft-AlchObj"
model.load_adapter(adapter_path)

Let's evaluate the model on an unseen sample.


In [None]:
test_dataset[20][:2]

In [None]:
test_dataset[20][1]['content'][0]['image']

In [None]:
output = generate_text_from_sample(model, processor, test_dataset[20])
output

The model has successfully learned to respond to the queries as specified in the dataset. We've achieved our goal! 🎉✨