In [4]:
# Medical Image Captioning with BLIP - Fine-Tuning Notebook

"""
This notebook fine-tunes the pre-trained BLIP (Bootstrapped Language Image Pretraining) model
on a subset of the ROCOv2-radiology dataset for the task of medical image caption generation.

Main Steps:
- Load and preprocess a medical image-caption dataset.
- Load a pre-trained BLIP model and processor.
- Fine-tune the model on the medical dataset.
- Save the fine-tuned model for later use.

Fine-tuning helps adapt the general BLIP model to the specific domain of medical imaging,
improving its ability to generate accurate and medically relevant captions.
"""

'\nThis notebook fine-tunes the pre-trained BLIP (Bootstrapped Language Image Pretraining) model\non a subset of the ROCOv2-radiology dataset for the task of medical image caption generation.\n\nMain Steps:\n- Load and preprocess a medical image-caption dataset.\n- Load a pre-trained BLIP model and processor.\n- Fine-tune the model on the medical dataset.\n- Save the fine-tuned model for later use.\n\nFine-tuning helps adapt the general BLIP model to the specific domain of medical imaging,\nimproving its ability to generate accurate and medically relevant captions.\n'

In [8]:
# Step 1: Import Required Libraries
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
import matplotlib.pyplot as plt

In [9]:
# Step 2: Load Dataset (e.g., ROCOv2 Radiology Dataset)
# You can modify this part if you have your own dataset ready
print("Loading dataset...")
dataset = load_dataset("eltorio/ROCOv2-radiology")

# Use only a small subset for faster testing
train_data = dataset["train"].select(range(500))  # Select first 500 samples for initial fine-tuning

# Step 3: Load BLIP Model and Processor
print("Loading BLIP model...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

Loading dataset...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading BLIP model...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [10]:
# Step 3: Load BLIP Model and Processor
print("Loading BLIP model...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Loading BLIP model...


BlipForConditionalGeneration(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0-11): 12 x BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-0

In [11]:
# Step 4: Data Preparation Function
def collate_fn(batch):
    images = [example["image"] for example in batch]
    captions = [example["caption"] for example in batch]
    inputs = processor(images=images, text=captions, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    return inputs

# Create DataLoader
train_loader = DataLoader(train_data, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [12]:
# Step 5: Fine-Tuning Setup
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 2


In [13]:
# Step 6: Fine-Tuning Loop
print("Starting fine-tuning...")
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in train_loader:
        outputs = model(**batch, labels=batch["input_ids"])
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")


Starting fine-tuning...
Epoch 1/2, Average Loss: 5.7754
Epoch 2/2, Average Loss: 2.1439


In [14]:
# Step 7: Save Fine-Tuned Model
model.save_pretrained("./blip-finetuned")
processor.save_pretrained("./blip-finetuned")

print("Fine-tuning completed and model saved.")


Fine-tuning completed and model saved.
