In [1]:
import os
import torch
import pickle
from PIL import Image
from tqdm.auto import tqdm
import xml.etree.ElementTree as ET
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from transformers import (
    BlipProcessor, 
    BlipForConditionalGeneration, 
    AutoTokenizer, 
    AutoModelForCausalLM
)

In [2]:
import os
from PIL import Image
from torch.utils.data import Dataset

class ChestXrayDataset(Dataset):
    """
    Custom PyTorch Dataset for Chest X-Ray images and captions
    """
    def __init__(self, images_captions, processor, images_folder="/kaggle/input/chestxray-test/data/images_subset"):
        """
        Args:
            images_captions (dict): Dictionary mapping image filenames to captions.
            processor: Preprocessor that processes images and text (e.g., Hugging Face processor).
            images_folder (str): Path to the folder containing image files. Defaults to the Kaggle input path.
        """
        self.images_folder = images_folder
        self.images_captions = images_captions
        self.processor = processor
        self.image_files = list(images_captions.keys())
        
    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Retrieves the processed image and caption pair at the specified index.
        Args:
            idx (int): Index of the sample to retrieve.
        Returns:
            dict: A dictionary containing processed image and caption tensors.
        """
        image_file = self.image_files[idx]
        image_path = os.path.join(self.images_folder, image_file)  # Use correct image path
        
        # Check if the image exists
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image {image_file} not found at {image_path}")
        
        # Load and preprocess the image
        image = Image.open(image_path).convert('RGB')
        caption = self.images_captions[image_file]
        
        # Prepare inputs using the processor
        inputs = self.processor(
            images=image, 
            text=caption, 
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=128
        )

        input_ids = inputs['input_ids']
        labels = input_ids.clone()  # Copy input_ids to use as labels
        inputs['labels'] = labels
        # Remove the batch dimension
        for key, value in inputs.items():
            inputs[key] = value.squeeze(0)  # Remove the batch dimension
            
        return inputs


In [3]:
def save_final_model(blip_model, blip_processor, output_dir):
    # Save the model weights and configuration
    model_path = f"{output_dir}/blip_model"
    blip_model.save_pretrained(model_path)
    
    # Save the processor (used for preprocessing)
    processor_path = f"{output_dir}/blip_processor"
    blip_processor.save_pretrained(processor_path)

    print(f"Model and processor saved at {model_path} and {processor_path}")

In [4]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from transformers import get_scheduler

def train_blip_model(blip_model, blip_processor, images_captions, num_epochs, batch_size, learning_rate, device, output_dir):
    # Step 1: Prepare data (convert to dataset and dataloader)
    dataset = ChestXrayDataset(images_captions, blip_processor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Step 2: Initialize optimizer and scheduler
    optimizer = AdamW(blip_model.parameters(), lr=learning_rate)

    # Step 3: Initialize gradient scaler for mixed precision
    scaler = GradScaler()

    # Step 4: Start training loop for specified epochs
    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}')

        # Step 5: Iterate over batches in the dataloader
        for batch_idx, batch in enumerate(progress_bar):
            # Step 6: Move batch to device (GPU or CPU)
            batch = {k: v.to(device) for k, v in batch.items()}

            # Step 7: Forward pass through the model
            with autocast(dtype=torch.float16):
                outputs = blip_model(input_ids=batch['input_ids'], pixel_values=batch['pixel_values'], labels=batch['labels'])
                loss = outputs.loss  # Extract the loss from the model output

            # Step 8: Backward pass with gradient scaling
            scaler.scale(loss).backward()  # Scale the gradients before backward pass
            scaler.step(optimizer)  # Step the optimizer (update the parameters)
            scaler.update()  # Update the gradient scaler

            total_loss += loss.item()  # Accumulate loss for logging

            # Optional: Print and update progress bar
            progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1)})

    save_final_model(blip_model, blip_processor, output_dir)
    print("Training complete!")


In [5]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load the pickle file
    #dataset_root = "/kaggle/input/chestxray-processed/"
    #pickle_path = os.path.join(dataset_root, "captions.pkl")
    pickle_path = "/kaggle/input/chestxray-processed/medical_dataset.pkl"
    with open(pickle_path, "rb") as file:
        data = pickle.load(file)

    # Extract data from pickle file
    images_captions = data['images_captions']
    reports_with_images = data['reports_with_images']
    text_of_reports = data['text_of_reports']
    # Check the structure of the data
    print(f"Number of image-caption pairs: {len(images_captions)}")
    # Load BLIP models
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

    # Train the model
    train_blip_model(
        blip_model,
        blip_processor,
        images_captions,
        num_epochs=5,
        batch_size=8,
        learning_rate=2e-5,
        device='cuda',
        output_dir='./medical_blip_checkpoints'
    )
    
    # Optional: Save final models
    print("Saving final models...")
    output_dir = "./blip"
    os.makedirs(output_dir, exist_ok=True)
    save_final_model(blip_model, blip_processor, output_dir)
    print("Training complete!")

if __name__ == "__main__":
    main()

Number of image-caption pairs: 7326


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

  scaler = GradScaler()
  self.pid = os.fork()
  with autocast(dtype=torch.float16):
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
  self.pid = os.fork()
Epoch 1/5: 100%|██████████| 916/916 [07:56<00:00,  1.92it/s, loss=11.6]
Epoch 2/5: 100%|██████████| 916/916 [07:59<00:00,  1.91it/s, loss=11.6]
Epoch 3/5: 100%|██████████| 916/916 [07:59<00:00,  1.91it/s, loss=11.6]
Epoch 4/5: 100%|██████████| 916/916 [07:59<00:00,  1.91it/s, loss=11.6]
Epoch 5/5: 100%|██████████| 916/916 [07:59<00:00,  1.91it/s, loss=11.6]


Model and processor saved at ./medical_blip_checkpoints/blip_model and ./medical_blip_checkpoints/blip_processor
Training complete!
Saving final models...
Model and processor saved at ./blip/blip_model and ./blip/blip_processor
Training complete!
