# Fine-tune BLIP using Hugging Face `transformers`, `datasets`, `peft` 🤗 and `bitsandbytes`

Let's leverage recent advances from Parameter Efficient Fine-Tuning methods to fine-tune a large image to text model! We will show through this tutorial that it is possible to fine-tune a 3B scale model (~6GB in half-precision)

Here we will use a dummy dataset of [football players](https://huggingface.co/datasets/ybelkada/football-dataset) ⚽ that is uploaded on the Hub. The images have been manually selected together with the captions. 
Check the 🤗 [documentation](https://huggingface.co/docs/datasets/image_dataset) on how to create and upload your own image-text dataset.

## Set-up environment

## Load the image captioning dataset

Let's load the image captioning dataset, you just need few lines of code for that.

In [20]:
from datasets import load_dataset 
from PIL import Image
import pandas as pd
import numpy as np
import torch
import random
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from time import time
import re

seed=777
random.seed(seed)


def convert_floats_to_two_decimals(text):
    # Expressão regular para capturar números float
    float_pattern = r'(\d+\.\d+)'  # Captura números decimais
    # Substituir os números float por floats com 2 casas decimais
    modified_text = re.sub(float_pattern, lambda x: f"{float(x.group(1)):.2f}", text)
    return modified_text

# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)

# # Garante que algumas operações sejam determinísticas
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

#dataset = load_dataset("ybelkada/football-dataset", split="train")
#embds_path = './nodules_luna_lidc_3d_metadas_embeddings'
embds_path = './nodules_luna_lidc_3d_metadas_embeddings_isomorphic'
df = pd.read_csv('annotations_isomorphic_with_legends.csv') 

new_static_legends = []

for idx, row in df.iterrows():
    legend = str(row.static_legends)
    if 'Highly Unlikely' in legend:
        legend = legend.replace('Highly Unlikely', 'Remote')
    elif 'Moderately Unlikely' in legend:
        legend = legend.replace('Moderately Unlikely', 'Possible')
    elif 'Moderately Suspicious' in legend:
        legend = legend.replace('Moderately Suspicious', 'Doubtful')
    elif 'Highly Suspicious' in legend:
        legend = legend.replace('Highly Suspicious', 'Critical')
        #print("aqui")
#     print(legend)
    legend = convert_floats_to_two_decimals(legend)
    new_static_legends.append(legend)

df['static_legends'] = new_static_legends

print(df.shape)
#df = df.query("malignancy != 'Indeterminate'").copy()
print(df.shape)

# Pegue 80% dos dados como a primeira amostra
df_train = df.sample(frac=0.8, random_state=seed)

# Pegue os 20% restantes como a segunda amostra
df_test = df.drop(df_train.index)
#df_test.to_csv('annotations_test_with_legends_seed_777.csv')
print(df_train.shape)
print(df_test.shape)
is_training = False
is_evaluating = True

(1182, 15)
(1182, 15)
(946, 15)
(236, 15)


Let's retrieve the caption of the first example:

## Create PyTorch Dataset

Let's define below the dataset as well as the data collator!

In [21]:
from torch.utils.data import Dataset, DataLoader


class ImageCaptioningDataset(Dataset):
    def __init__(self, df, embds_path, processor):
        self.dataset = df
        self.processor = processor
        self.embds_path = embds_path

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset.iloc[idx]
        #print(f'{self.embds_path}/{item["filename"]}')
        image = np.load(f'{self.embds_path}/{item["filename"]}')
        #print(image_rgb.min(), image_rgb.max())
        #encoding = self.processor(images=image, padding="max_length", return_tensors="pt")
        # remove batch dimension
        #encoding = {k: v.squeeze() for k, v in encoding.items()}
        #encoding["text"] = item["text"]
        #text_inputs = self.processor.tokenizer(item['report_text'] ,return_tensors="pt")
        return {'embds': torch.from_numpy(image).squeeze(0), 'text': item['static_legends']}

def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer([example["text"] for example in batch], padding=True, return_tensors="pt")
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
            # processed_batch["original_text"] = [example["text"] for example in batch]
            
    return processed_batch


In [22]:
class VisionModelOutput:
    def __init__(self, last_hidden_state):
        self.last_hidden_state = last_hidden_state

# Adaptador para ajustar a saída do ViT3D para o Q-Former
class ViT3DAdapter(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.training = True

    def forward(self, pixel_values=None, **kwargs):
        if self.training:
            return pixel_values.unsqueeze(0)
        else:
            return VisionModelOutput(last_hidden_state=pixel_values)

In [23]:
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

In [24]:

if is_training:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    
    #model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded", device_map="auto", load_in_8bit=True)
    #model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded", 
    # _map="auto", load_in_8bit=True)
    
    model = Blip2ForConditionalGeneration.from_pretrained(
        "ybelkada/blip2-opt-2.7b-fp16-sharded",
        device_map="auto",
        torch_dtype=torch.float32
    )
    
    dtype = next(model.parameters()).dtype
    print(dtype)
    
    adapted_vit3d = ViT3DAdapter().to(device)
    model.vision_model = adapted_vit3d

In [25]:
# Checar se o modelo está congelado
def check_frozen_layers(model):
    frozen_layers = []
    trainable_layers = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_layers.append(name)
        else:
            frozen_layers.append(name)
    
    print("Camadas congeladas:")
    for layer in frozen_layers:
        print(layer)
    
    print("\nCamadas treináveis:")
    for layer in trainable_layers:
        print(layer)

#check_frozen_layers(model)

Next we define our `LoraConfig` object. We explicitly tell 

In [26]:
if is_training:
    from peft import LoraConfig, get_peft_model
    import torch
    
    
    # Let's define the LoraConfig
    config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        target_modules=["q_proj", "k_proj"]
    )
    
    model = get_peft_model(model, config)
    model.print_trainable_parameters()
    
    # Congelar todas as camadas do modelo, exceto o Q-Former (apenas q_proj e k_proj)
    # for name, param in model.named_parameters():
    #     # Mantenha apenas as camadas do Q-Former como treináveis e cheque o tipo de dado
    #     if "qformer" in name:
    #         #if ('attention' in name or 'crossattention' in name) and torch.is_floating_point(param):
    #         param.requires_grad = True
    #     else:
    #         param.requires_grad = False
    
    # Congelar todas as camadas do modelo, exceto o Q-Former
    # for name, param in model.named_parameters():
    #     # Checar se o parâmetro faz parte do Q-Former e se é de ponto flutuante
    #     if "qformer" in name and torch.is_floating_point(param):
    #         param.requires_grad = True
    #     else:
    #         param.requires_grad = False
    
    
    # Configurar todos os parâmetros como treináveis, verificando se são de ponto flutuante
    # for param in model.parameters():
    #     if torch.is_floating_point(param):
    #         param.requires_grad = True
    
    # Exibir parâmetros treináveis para verificação
    print("Parâmetros treináveis no Q-Former:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            # print(name)
            pass
    #model.print_trainable_parameters()


Now that we have loaded the processor, let's load the dataset and the dataloader:

In [27]:
train_dataset = ImageCaptioningDataset(df_train, embds_path, processor)
test_dataset = ImageCaptioningDataset(df_test, embds_path, processor)
# train_dataset_ = train_dataset_[:100]
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=4, collate_fn=collate_fn)
#train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4)

In [28]:
for i in range(10):
    print(f"Sample {i+1}")    
    d = train_dataset[i]
    print("LIDC-LUNA Image embed: ", d['embds'])
    print("Text Report: ", d['text'])
    print("-"*150)

Sample 1
LIDC-LUNA Image embed:  tensor([[ 1.1043,  1.2375, -0.9351,  ..., -0.2086, -0.6986, -0.0975],
        [ 1.1425,  1.2203, -0.9613,  ..., -0.1956, -0.7696, -0.1241],
        [ 1.1953,  1.2541, -0.9522,  ..., -0.3007, -0.7236, -0.1164],
        ...,
        [ 1.1469,  1.2849, -1.0146,  ..., -0.2580, -0.7015, -0.0449],
        [ 1.1134,  1.2949, -0.9030,  ..., -0.2477, -0.7227, -0.0767],
        [ 1.1456,  1.3453, -0.9390,  ..., -0.2839, -0.7031, -0.0501]])
Text Report:  The lung nodule is characterized as follows: subtlety: Moderately Obvious, internalStructure: Soft Tissue, calcification: Absent, sphericity: Ovoid/Round, margin: Sharp, lobulation: Near Marked Lobulation, spiculation: No Spiculation, texture: Non-Solid/GGO, malignancy: Indeterminate, diameter: 5.70 mm, volume: 49.45 mm3.
------------------------------------------------------------------------------------------------------------------------------------------------------
Sample 2
LIDC-LUNA Image embed:  tensor([[ 1

## Train the model

Let's train the model! Run the simply the cell below for training the model

In [29]:
if is_training:
    import torch
    import os
    
    epochs = 20
    lr = 5e-4
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Diretório para salvar o melhor modelo
    checkpoint_dir = "./checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    best_model_path = os.path.join(checkpoint_dir, "best_model.pt")
    best_loss = float('inf')  # Define a melhor perda inicial como infinita
    
    for epoch in range(epochs):
        start_time = time()
        train_loss = 0
        total_train_samples = 0
        total_val_samples = 0
        validation_loss = 0
        model.train()
        model.vision_model.training = True
      
        for idx, batch in enumerate(train_dataloader):
            input_ids = batch.pop("input_ids").to(device)
            embds = batch.pop("embds").to(device, torch.float32)
            attention_mask = batch.pop("attention_mask").to(device)
    
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=embds, labels=input_ids)
            loss = outputs.loss
        
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
            train_loss += loss.item()
            total_train_samples += len(input_ids)
        
        model.eval()
        validation_loss = 0
        total_val_samples = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for idx, batch in enumerate(test_dataloader):
                input_ids = batch.pop("input_ids").to(device)
                embds = batch.pop("embds").to(device, torch.float32)
                attention_mask = batch.pop("attention_mask").to(device)
                model.vision_model.training = True
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=embds, labels=input_ids)
    
                loss = outputs.loss
                validation_loss += loss.item()
                total_val_samples += len(input_ids)
                model.vision_model.training = False
                
    
        avg_eval_loss = validation_loss / total_val_samples
        print(f"Validation Loss (Epoch {epoch}): {avg_eval_loss}")
    
        if avg_eval_loss < best_loss:
            best_loss = avg_eval_loss
            print(f"Saving best model with loss {best_loss}")
            model.save_pretrained(checkpoint_dir)  
            optimizer_state = {"optimizer": optimizer.state_dict(), "epoch": epoch}
            torch.save(optimizer_state, os.path.join(checkpoint_dir, "optimizer.pt")) 
    
    
        end_time = time()  
        epoch_duration = end_time - start_time  
    
        print(f'Epoch [{epoch+1}/{epochs}], \
                Train Loss: {(train_loss/total_train_samples):.4f}, \
                Val Loss: {(validation_loss/total_val_samples):.4f}, \
                Elapsed Time: {epoch_duration:.2f} sec'
             )

## Inference

Let's check the results on our train dataset

In [30]:
import re
import numpy as np

def get_features(text):
    # Expressão regular para capturar key:value
    pattern = r'(\w+):\s*([^,]+)'
    
    # Expressão regular para extrair números float
    float_pattern = r'\d+\.\d+'
    
    # Encontrar todas as correspondências
    key_value_pairs = re.findall(pattern, text)
    
    # Converter para um dicionário, ajustando diameter e volume para float
    key_value_dict = {}
    for key, value in key_value_pairs:
        if key == 'follows': continue
        value = value.strip()
        if key in ["diameter", "volume"]:
            # Extrair número float usando regex
            float_match = re.search(float_pattern, value)
            if float_match:
                value = float(float_match.group())
        key_value_dict[key] = value

    return key_value_dict

def clip_text(text):

    # Extrair a parte da string até "malignancy"
    pattern = r'(.*malignancy:\s[^,]+)'

    # Busca a parte que termina com "malignancy"
    match = re.search(pattern, text)
    if match:
        text_until_malignancy = match.group(1)
        #print(text_until_malignancy)
        return text_until_malignancy
    else:
        print("Nenhum valor até 'malignancy' foi encontrado.")
        return text


def calculate_diameter(volume):
    return (6 * volume / np.pi) ** (1 / 3)


In [31]:
if is_evaluating:
    import os
    from bert_score import score
    from rouge_score import rouge_scorer

    malig_gt_labels = []
    malig_pred_labels = []
    malig_map = {'Remote': 1, 'Possible': 2, 'Indeterminate': 3, 'Doubtful': 4, 'Critical': 5}
    #malig_map = {'Remote': 1, 'Possible': 2, 'Doubtful': 3, 'Critical': 4}
    gt_diameters = []
    predicted_diameters = []
    predicted_diameters_2 = []

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Diretório para salvar o melhor modelo
    checkpoint_dir = "./checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    model = Blip2ForConditionalGeneration.from_pretrained(
        checkpoint_dir,
        device_map="auto",
        torch_dtype=torch.float32
    )
    
    adapted_vit3d = ViT3DAdapter().to(device)
    model.vision_model = adapted_vit3d
    # Colocar o modelo em modo de avaliação, se necessário
    
    model.eval()
    
    validation_loss = 0
    total_val_samples = 0
    correct = 0
    total = 0
    
    precisions = []
    recalls = []
    f1s = []

    rouge_1 = []
    rouge_2 = []
    rouge_l = []
    
    
    with torch.no_grad():
        for idx, batch in enumerate(test_dataloader):
            input_ids = batch.pop("input_ids").to(device)
            embds = batch.pop("embds").to(device, torch.float32)
            attention_mask = batch.pop("attention_mask").to(device)
            # original_text = batch.pop("original_text")
            model.vision_model.training = True
            # outputs = model(input_ids=input_ids, pixel_values=embds, labels=input_ids)
            outputs = model(input_ids=input_ids, pixel_values=embds, labels=input_ids)
    
            loss = outputs.loss
            validation_loss += loss.item()
            total_val_samples += len(input_ids)
            model.vision_model.training = False
            generated_ids = model.generate(pixel_values=embds, temperature=0.0, do_sample=False, max_length=100)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            original_text = processor.batch_decode(input_ids, skip_special_tokens=True)
            
            for original, generated in zip(original_text, generated_caption):

                
                print("Original Text -->", original)
                print("\n")
                print("Generated Text -->", generated)
                print("\n")

                
                original_features = get_features(original)
                predicted_features = get_features(generated)

                print(predicted_features)
                
                original_value = original_features['malignancy']
                predicted_value = predicted_features['malignancy']
                            
                malig_gt_labels.append(malig_map[original_value])
                malig_pred_labels.append(malig_map[predicted_value])


                gt_diameters.append(original_features['diameter'])
                predicted_diameters.append(predicted_features['diameter'])
                predicted_diameters_2.append(calculate_diameter(predicted_features['volume']))
                
                print("Original malignancy -->", original_value)
                print("\n")
                print("Generated Text -->", predicted_value, original_features['diameter'], predicted_features['diameter'], calculate_diameter(predicted_features['volume']))
                print("\n")

                # Inicializa o calculador de ROUGE com as métricas desejadas
                scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
                
                # Calcula as pontuações ROUGE
                scores = scorer.score(clip_text(original), clip_text(generated))
                
                # Exibe as pontuações
                print("ROUGE-1:", scores['rouge1'])
                print("ROUGE-2:", scores['rouge2'])
                print("ROUGE-L:", scores['rougeL'])
            
                print("\n")

                rouge_1.append(scores['rouge1'])
                rouge_2.append(scores['rouge2'])
                rouge_l.append(scores['rougeL'])
                
    # print("Precision mean: ", torch.mean(precisions))
    # print("Recall mean: ", torch.mean(recalls))
    # print("F1 mean: ", torch.mean(f1s))

Loading checkpoint shards: 100%|██████████| 8/8 [00:05<00:00,  1.50it/s]


Original Text --> The lung nodule is characterized as follows: subtlety: Obvious, internalStructure: Soft Tissue, calcification: Absent, sphericity: Ovoid/Round, margin: Near Sharp, lobulation: Nearly No Lobulation, spiculation: Nearly No Spiculation, texture: Solid, malignancy: Doubtful, diameter: 13.64 mm, volume: 1482.34 mm3.


Generated Text --> The lung nodule is characterized as follows: subtlety: Obvious, internalStructure: Soft Tissue, calcification: Absent, sphericity: Ovoid/Round, margin: Near Sharp, lobulation: Nearly No Lobulation, spiculation: Nearly No Spiculation, texture: Solid, malignancy: Critical, diameter: 16.05 mm, volume: 1751.99 mm3.


{'internalStructure': 'Soft Tissue', 'calcification': 'Absent', 'sphericity': 'Ovoid/Round', 'margin': 'Near Sharp', 'lobulation': 'Nearly No Lobulation', 'spiculation': 'Nearly No Spiculation', 'texture': 'Solid', 'malignancy': 'Critical', 'diameter': 16.05, 'volume': 1751.99}
Original malignancy --> Doubtful


Generated Text --> 

In [32]:
import  numpy as np
print(f"Rouge 1 - Mean  {np.mean([f1.fmeasure for f1 in rouge_1])}")
print(f"Rouge 2 - Mean  {np.mean([f1.fmeasure for f1 in rouge_2])}")
print(f"Rouge L - Mean  {np.mean([f1.fmeasure for f1 in rouge_l])}")

# Definir os pesos para cada distância
weights = {0: 1.0, 1: 0.75, 2: 0.5, 3: 0.25, 4:0.125}
#weights = {0: 1.0, 1: 0.75, 2: 0.5, 3: 0.25}

# Calcular a acurácia ponderada
weighted_accuracy = sum([weights.get(abs(true - pred), 0) for true, pred in zip(malig_gt_labels, malig_pred_labels)]) / len(malig_gt_labels)

print("Acurácia Ponderada:", weighted_accuracy)


Rouge 1 - Mean  0.9063923729300132
Rouge 2 - Mean  0.8105952820195067
Rouge L - Mean  0.9055739339040524
Acurácia Ponderada: 0.8066737288135594


In [33]:
from sklearn.metrics import classification_report
import pandas as pd


# Rótulos personalizados para as classes
target_names = ['Remote', 'Possible', 'Indeterminate', 'Doubtful', 'Critical']

# Gera o relatório de classificação com rótulos personalizados
report = classification_report(malig_gt_labels, malig_pred_labels, target_names=target_names, output_dict=True)

# Converte para um DataFrame e renomeia as colunas
report_df = pd.DataFrame(report).transpose()
report_df.columns = ['Precisão', 'Revocação', 'F1-Score', 'Suporte']

print(report_df)

               Precisão  Revocação  F1-Score    Suporte
Remote         0.000000   0.000000  0.000000   34.00000
Possible       0.000000   0.000000  0.000000   40.00000
Indeterminate  0.427746   0.902439  0.580392   82.00000
Doubtful       0.000000   0.000000  0.000000   42.00000
Critical       0.523810   0.868421  0.653465   38.00000
accuracy       0.453390   0.453390  0.453390    0.45339
macro avg      0.190311   0.354172  0.246772  236.00000
weighted avg   0.232966   0.453390  0.306881  236.00000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [34]:
import math

def rmse(y_true, y_pred):
    
    # Passo 1: Calcular o erro quadrado para cada par de valores
    squared_errors = [(y - y_hat) ** 2 for y, y_hat in zip(y_true, y_pred)]
    
    # Passo 2: Calcular a média dos erros quadrados (MSE)
    mse = sum(squared_errors) / len(y_true)
    
    # Passo 3: Tirar a raiz quadrada do MSE para obter o RMSE
    rmse = math.sqrt(mse)
    
    return rmse


In [35]:
# Calcula o RMSE médio
resultado = rmse(gt_diameters, predicted_diameters)
print(f"RMSE DIAMETER1: {resultado} mm")


RMSE DIAMETER1: 2.5817479461843518 mm


In [36]:
# Calcula o RMSE médio
resultado = rmse(gt_diameters, predicted_diameters_2)
print(f"RMSE DIAMETER2: {resultado} mm")

RMSE DIAMETER2: 2.1910452302727306 mm


In [37]:
len(gt_diameters)

236

In [38]:
#print(gt_diameters)
print("\n")
#print(predicted_diameters)

for idx, (gt, pre) in enumerate(zip(gt_diameters, predicted_diameters_2)):
    print(f"{idx+1}:           \tGt: {gt:.2f},           \tPred: {pre:.2f},          \t Error: {abs(gt-pre):.2f} mm")



1:           	Gt: 13.64,           	Pred: 14.96,          	 Error: 1.32 mm
2:           	Gt: 5.46,           	Pred: 5.06,          	 Error: 0.40 mm
3:           	Gt: 8.21,           	Pred: 7.35,          	 Error: 0.86 mm
4:           	Gt: 6.10,           	Pred: 5.04,          	 Error: 1.06 mm
5:           	Gt: 5.51,           	Pred: 5.82,          	 Error: 0.31 mm
6:           	Gt: 17.69,           	Pred: 20.62,          	 Error: 2.93 mm
7:           	Gt: 4.34,           	Pred: 5.06,          	 Error: 0.72 mm
8:           	Gt: 6.44,           	Pred: 5.06,          	 Error: 1.38 mm
9:           	Gt: 17.16,           	Pred: 9.77,          	 Error: 7.39 mm
10:           	Gt: 8.85,           	Pred: 9.63,          	 Error: 0.78 mm
11:           	Gt: 10.91,           	Pred: 14.96,          	 Error: 4.05 mm
12:           	Gt: 7.07,           	Pred: 5.82,          	 Error: 1.25 mm
13:           	Gt: 4.21,           	Pred: 5.19,          	 Error: 0.98 mm
14:           	Gt: 6.97,           	Pr