In [1]:
from transformers import BlipForConditionalGeneration, BlipProcessor
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from transformers import get_scheduler

In [2]:

# Paths where the model and processor were saved
model_path = '/kaggle/input/blip-trained-on-medical-images/blip/blip_model'
processor_path = '/kaggle/input/blip-trained-on-medical-images/blip/blip_processor'

# Load the saved model
blip_model = BlipForConditionalGeneration.from_pretrained(model_path)

# Load the saved processor (for preprocessing inputs)
blip_processor = BlipProcessor.from_pretrained(processor_path)

# Move the model to the desired device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
blip_model = blip_model.to(device)

print("Model and processor loaded successfully!")


Model and processor loaded successfully!


In [3]:
from PIL import Image
import os
from tqdm import tqdm
import torch

def prepare_mistral_training_data(
    blip_model, 
    blip_processor, 
    images_captions,  # Dictionary with image filenames as keys and captions as values
    images_path,      # Base path where the images are stored
    device
):
    """
    Prepare training data for Mistral by generating BLIP captions
    using images dynamically loaded from disk based on filenames.
    """
    training_pairs = []
    blip_model.eval()

    for img_file, true_caption in tqdm(images_captions.items()):
        # Construct the full path to the image
        image_path = os.path.join(images_path, img_file)
        
        # Load the image
        try:
            image = Image.open(image_path).convert('RGB')
        except FileNotFoundError:
            print(f"Image not found: {image_path}")
            continue

        # Generate BLIP caption
        inputs = blip_processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            generated_ids = blip_model.generate(
                pixel_values=inputs["pixel_values"], 
                max_length=128, 
                num_beams=4, 
                early_stopping=True
            )
        blip_caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
        
        # Create training pair
        training_pairs.append({
            'blip_caption': blip_caption,
            'true_caption': true_caption
        })
    
    return training_pairs

In [4]:
import pickle
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pickle file
pickle_path = "/kaggle/input/chestxray-processed/medical_dataset.pkl"
with open(pickle_path, "rb") as file:
    data = pickle.load(file)
# Extract data from pickle file
images_captions = data['images_captions']
reports_with_images = data['reports_with_images']
text_of_reports = data['text_of_reports']
#Check the structure of the data
print(f"Number of image-caption pairs: {len(images_captions)}")
images_path= '/kaggle/input/chestxray-test/data/images_subset'

    
# Prepare training data
training_pairs = prepare_mistral_training_data(
blip_model=blip_model,
blip_processor=blip_processor,
images_captions=images_captions,
images_path=images_path,
device=device
)

Number of image-caption pairs: 7326


100%|██████████| 7326/7326 [30:24<00:00,  4.02it/s]


In [5]:
with open("training_pairs.pkl", "wb") as f:
    pickle.dump(training_pairs, f)

print("Training pairs saved successfully!")

Training pairs saved successfully!
