*More details in this article: [Fine-tune a Multimodal Chat Model with Florence-2 on Your Computer](https://newsletter.kaitchup.com/p/fine-tune-a-multimodal-chat-model)*


This notebook shows how to fine-tune Florence 2 to be multimodal chat model. Using ScienceQA, the model learns to answer Science question using images.

It requires a 16 GB GPU.

*Note: This notebook reuses pieces of code proposed by Hugging Face in this blog post: [Fine-tuning Florence-2 - Microsoft's Cutting-edge Vision Language Models](https://huggingface.co/blog/finetune-florence2)*

First, install the following:

In [None]:
!pip install -q datasets flash_attn timm einops bitsandbytes accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m52.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m70.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.0/314.0 kB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m20.3 

In [None]:
import torch
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler
from tqdm import tqdm
import numpy as np
from bitsandbytes.optim import AdamW

Load the dataset, the model, and its processor.
Gradient checkpointing is enabled and the parameters of the image encoder are frozen.

In [None]:
data = load_dataset("derek-thomas/ScienceQA")

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True).to("cuda")
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})
for param in model.vision_tower.parameters():
  param.is_trainable = False

Downloading readme:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/377M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/126M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/122M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12726 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4241 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4241 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/2.44k [00:00<?, ?B/s]

configuration_florence2.py:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-large:
- configuration_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_florence2.py:   0%|          | 0.00/127k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-large:
- modeling_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin:   0%|          | 0.00/1.54G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

processing_florence2.py:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-large:
- processing_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json:   0%|          | 0.00/34.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.10M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Custom Dataset class formatting the prompt and pre-processing the images.

In [None]:
class MMInstructDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        question = "<MMCHAT>"+example['question']+"\n\nAnswer:\n"
        choices = example['choices']
        answer = example['choices'][int(example['answer'])]
        if example['image'] is not None:
          image = example['image'].convert("RGB")
        else:
          zz = np.random.rand(50,50)
          image = Image.fromarray(zz).convert("RGB")
        return question, answer, image

In [None]:
def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to("cuda")
    return inputs, answers

In [None]:
train_dataset = MMInstructDataset(data['train'])
val_dataset = MMInstructDataset(data['validation'])

batch_size = 1
gradient_accumulation_steps = 8
num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          collate_fn=collate_fn, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                          collate_fn=collate_fn, num_workers=num_workers)

Training loop with gradient accumulation steps:

In [None]:

optimizer = AdamW(model.parameters(), lr=1e-6, optim_bits = 8, is_paged=True)

epochs = 2
num_training_steps = epochs * len(train_loader)

lr_scheduler = get_scheduler(name="linear", optimizer=optimizer,
                              num_warmup_steps=0, num_training_steps=num_training_steps,)

for epoch in range(epochs):
    model.train()
    train_loss = 0
    i = -1
    for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        i += 1
        input_ids = inputs["input_ids"]
        pixel_values = inputs["pixel_values"]
        labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to("cuda")
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
        loss = outputs.loss/gradient_accumulation_steps
        loss.backward()
        if i % gradient_accumulation_steps == 0:
          optimizer.step()
          lr_scheduler.step()
          optimizer.zero_grad()
          train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss: {avg_train_loss}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
            inputs, answers = batch
            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to("cuda")
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

        print(val_loss / len(val_loader))


Training Epoch 1/2: 100%|██████████| 12726/12726 [3:04:26<00:00,  1.15it/s]


Average Training Loss: 0.03429693092450855


Validation Epoch 1/2: 100%|██████████| 4241/4241 [20:34<00:00,  3.44it/s]


1.8777813950052409


Training Epoch 2/2: 100%|██████████| 12726/12726 [3:04:24<00:00,  1.15it/s]


Average Training Loss: 0.02164359368673569


Validation Epoch 2/2: 100%|██████████| 4241/4241 [20:34<00:00,  3.44it/s]

1.5683350493919312





In [None]:
model.save_pretrained("./mmchat/")