In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
import nltk
from PIL import Image
from nltk.translate.bleu_score import corpus_bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

# Download necessary NLTK data
nltk.download('punkt')

# Dataset Class
class RadiologyDataset(Dataset):
    def __init__(self, caption_file, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.data = []

        with open(caption_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    image_name, caption = parts
                    image_path = os.path.join(image_dir, image_name + ".jpg")
                    if os.path.exists(image_path):
                        self.data.append((image_path, caption))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, caption = self.data[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption

# Collate Function
def collate_fn(batch):
    images, captions = zip(*batch)
    return list(images), list(captions)

# Evaluation Function for BLEU, ROUGE, and CIDEr
def evaluate_metrics(references, candidates):
    # Compute BLEU score
    bleu_score = corpus_bleu(references, candidates)

    # Compute ROUGE score
    rouge = Rouge()
    rouge_score, _ = rouge.compute_score({i: [" ".join(ref)] for i, ref in enumerate(references)},
                                         {i: [" ".join(candidate)] for i, candidate in enumerate(candidates)})

    # Compute CIDEr score
    cider = Cider()
    cider_score, _ = cider.compute_score({i: [" ".join(ref)] for i, ref in enumerate(references)},
                                         {i: [" ".join(candidate)] for i, candidate in enumerate(candidates)})

    return bleu_score, rouge_score, cider_score

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create Datasets and DataLoaders
train_dataset = RadiologyDataset("data/train/radiology/captions.txt", "data/train/radiology/images", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)

val_dataset = RadiologyDataset("data/validation/radiology/captions.txt", "data/validation/radiology/images", transform=transform)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2, collate_fn=collate_fn)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/dvasic/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:

# Load BLIP Model for Conditional Generation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

# Set model to training mode
blip_model.train()

# Define the optimizer and criterion
optimizer = torch.optim.AdamW(blip_model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# Training Loop
num_epochs = 5

for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}]")

    # Training Phase
    blip_model.train()
    total_loss = 0
    for images, captions in tqdm(train_loader, desc="Training"):
        # Process images and captions
        inputs = blip_processor(images=images, text=captions, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = blip_model(**inputs, labels=inputs["input_ids"])

        # Compute loss
        loss = outputs.loss
        total_loss += loss.item()

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Training Loss: {total_loss / len(train_loader):.4f}")

    # Validation Phase
    blip_model.eval()
    references, candidates = [], []
    with torch.no_grad():
        for images, ground_truth_captions in tqdm(val_loader, desc="Validation"):
            inputs = blip_processor(images=images, return_tensors="pt").to(device)

            # Generate captions using BLIP
            outputs = blip_model.generate(**inputs)
            generated_captions = blip_processor.batch_decode(outputs, skip_special_tokens=True)

            # Preprocess for BLEU, ROUGE, and CIDEr evaluation
            references.extend([[nltk.word_tokenize(caption.lower())] for caption in ground_truth_captions])
            candidates.extend([nltk.word_tokenize(gen_caption.lower()) for gen_caption in generated_captions])

    # Calculate evaluation metrics
    bleu_score, rouge_score, cider_score = evaluate_metrics(references, candidates)
    print(f"Validation BLEU: {bleu_score:.4f}, ROUGE: {rouge_score:.4f}, CIDEr: {cider_score:.4f}")

# Save the fine-tuned model
blip_model.save_pretrained("blip_finetuned")
blip_processor.save_pretrained("blip_finetuned")
print("Fine-tuned BLIP model saved.")

Epoch [1/5]


Training:   0%|          | 0/32707 [00:00<?, ?it/s]It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


: 