# Finetuning Granite Vision 3.1 2B with TRL

In this example, we will finetune the IBM's [`Granite Vision 3.1 2B`](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview) model. It is lightweight and trained by finetuning a [`Granite Language`] model with both images and text modalities.

We will finetune and evaluate the [`Granite Vision`](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview) model using the [`Geometric Perception`](https://huggingface.co/datasets/euclid-multimodal/Geoperception) dataset, containing tasks that the model was not initially trained for. The **Geometric Perception** dataset provides images of various geometric diagrams, compiled from high-school textbooks, paired with question-answer pairs.

## Setups

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install  -U -q trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.49.0.dev0, trl==0.14.0, datasets==3.2.0, bitsandbytes==0.45.2, peft==0.14.0, accelerate==1.3.0

In [None]:
!pip install -q flash-attn --no-build-isolation

try:
    import flash_attn
    print("FlashAttention is installed")
    USE_FLASH_ATTENTION = True
except ImportError:
    print("FlashAttention is not installed")
    USE_FLASH_ATTENTION = False

## Load dataset

We will load the [`GeometricPerception`](https://huggingface.co/datasets/euclid-multimodal/Geoperception) dataset, which provides images of various geometric diagrams, compiled from popular high-school textbooks, paired with question-answer pairs.

We will use the original system prompt used during the model training.

In [None]:
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."

For demo purposes, we will only train and evaluate on the Line Length Comparison task.

In [None]:
from datasets import load_dataset

dataset_id = 'euclid-multimodal/Geoperception'
dataset = load_dataset(dataset_id)

dataset_LineComparison = dataset['train'].filter(
    lambda x: x['predicate'] == 'LineComparison'
)
train_test = dataset_LineComparison.train_test_split(test_size=0.5, seed111)

In [None]:
train_test

We will format the dataset into a chatbot structure, with the system message, image, user query, and answer for each interaction.

In [None]:
def format_data(sample):
    return [
        {
            "role": "system",
            "content": [{'type': 'text', 'text': system_message}]
        },
        {
            'role': 'user',
            'content': [
                {'type': 'image', 'image': sample['image']},
                {'type': 'text', 'text': sample['question']}
            ]
        },
        {
            'role': 'assistant',
            'content': [{'type': 'text', 'text': sample['answer']}]
        }
    ]

In [None]:
train_dataset = [foramt_data(x) for x in train_test['train']]
test_dataset = [foramt_data(x) for x in train_test['test']]

In [None]:
train_dataset[0]

## Load model and check performance

In [None]:
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch

model_id = 'ibm-granite/granite-vision-3.1-2b-preview'
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    _attn_implementation='flash_attention_2' if USE_FLASH_ATTENTION else None
)

In [None]:
# take a sample
test_idx = 0
sample = test_dataset[test_idx]
sample

In [None]:
sample[1]['content'][0]['image']

We will create a method that takes the model, processor, and sample as inputs to generate the model's answer.

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=100, device='cuda'):
    text_input = processor.apply_chat_template(
        sample[:2], # use without the assistant response
        add_generation_prompt=True,
    )

    image_inputs = []
    image = sample[1]['content'][0]['image']
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image_inputs.append(image)

    # Prepare the inputs for the model
    model_inputs = processor(
        text=text_input,
        images=image_inputs,
        return_tensors='pt',
    ).to(device)

    # Generate output ids
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text[0]

In [None]:
output = generate_text_from_sample(model, processor, sample)
output

### Remove model and clean GPU

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()

## Finetune the model using TRL

### Load the quantized model for training

In [None]:
from transformers import BitsAndBytesConfig

USE_QLORA = True
USE_LORA = True

if USE_QLORA:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16,
        llm_int8_skip_modules=['vision_tower', 'lm_head'], # skip problematic modules
        llm_int8_enable_fp32_cpu_offload=True
    )
else:
    bnb_config = None


processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation='flash_attention_2' if USE_FLASH_ATTENTION else None
)

### Set up Q-LoRA and SFTConfig

QLoRA allows efficient fine-tuning of large models by reducing the memory footprint. Unlike traditional LoRA, which uses low-rank approximation, QLoRA further quantizes the LoRA adapter weights, leading to even lower memory usage and faster training.

To boost efficiency, we can also leverage a paged optimizer or 8-bit optimizer during QLoRA implementation. This approach enhances memory efficiency and speeds up computations, making it ideal for optimizing our model without sacrificing performance.

In [None]:
if USE_LORA:
    from peft import LoraConfig, get_peft_model

    peft_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=[name for name, _ in model.named_modules() if 'language_model' in name and '_proj' in name],
        use_dora=True,
        init_lora_weights='gaussian'
    )

    # apply peft model
    model.add_adapter(peft_config)
    model.enable_adapters()
    model = get_peft_model(model, peft_config)

    model.print_trainable_parameters()
else:
    peft_config = None

In [None]:
from trl import SFTConfig

training_args = SFTConfig(
    output_dir='./checkpoints/geoperception',
    num_train_epochs=1,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    warmup_steps=10,
    learning_rate=1e-4,
    weight_decay=0.01,
    loggint_steps=10,
    save_strategy='steps',
    save_steps=20,
    save_total_limit=1,
    optim='adamw_torch_fused',
    bf16=True,
    push_to_hub=False,
    report_to='none',
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field='',
    dataset_kwargs={'skip_prepare_dataset': True}
)

### Train the model

We need a collator function to ensure that the data is correctly structured for the model during training. This function will handle the formatting and batching of our dataset inputs, ensuring the data is properly aligned for training.

In [None]:
def collate_fn(examples):
    texts = [
        processor.apply_chat_template(example, tokenize=False)
        for example in examples
    ]

    image_inputs = []
    for example in examples:
        image = example[1]['content'][0]['image']
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image_inputs.append(image)

    batch = processor(
        text=texts,
        images=image_inputs,
        padding=True,
        return_tensors='pt'
    )

    labels = batch['input_ids'].clone()
    assistant_token = processor.tokenizer(
        '<|assistant|>',
        return_tensors='pt'
    )['input_ids'][0]
    eos_token = processor.tokenizer(
        '<|end_of_text|>',
        return_tensors='pt'
    )['input_ids'][0]

    for i in range(batch['input_ids'].shape[0]):
        apply_loss = False
        for j in range(batch['input_ids'].shape[1]):
            if not apply_loss:
                labels[i][j] = -100
            if (j >= len(assistant_token) + 1) and torch.all(
                batch['input_ids'][i][j + 1 - len(assistant_token) : j + 1] == assistant_token
            ):
                apply_loss = True
            if batch['input_ids'][i][j] == eos_token:
                apply_loss = False

    batch['labels'] = labels
    return batch

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)

## Test and finetuned model

In [None]:
clear_memory()

In [None]:
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    training_args.output_dir,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    _attn_implementation='flash_attention_2' if USE_FLASH_ATTENTION else None
)

if USE_LORA:
    from peft import PeftModel
    model = PeftModel.from_pretrained(model, training_args.output_dir)

In [None]:
test_idx = 0
sample = test_dataset[test_idx]
sample

In [None]:
sample[1]['content'][0]['image']

In [None]:
output = generate_text_from_sample(model, processor, sample)
output