In [1]:
import os
import torch
import pickle
from PIL import Image
from tqdm.auto import tqdm
import xml.etree.ElementTree as ET

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from transformers import (
    BlipProcessor, 
    BlipForConditionalGeneration, 
    AutoTokenizer, 
    AutoModelForCausalLM
)

In [2]:
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm

# Define dataset paths
dataset_root = "/kaggle/input/chestxray-test/data"
reports_path = os.path.join(dataset_root, "reports_subset", "ecgen-radiology")
images_path = os.path.join(dataset_root, "images_subset")

# Function to process dataset
def clean_and_process_reports(reports_path, images_path):
    # Get the list of XML and PNG files
    report_files = sorted([f for f in os.listdir(reports_path) if f.endswith(".xml")])
    image_files = set([f for f in os.listdir(images_path) if f.endswith(".png")])  # Use a set for quick lookup
    
    # Initialize dictionaries and statistics
    clean_images_captions = {}
    clean_reports_with_images = {}
    clean_text_of_reports = {}
    stats = {
        'total_reports': len(report_files),
        'reports_with_no_image': 0,
        'reports_with_empty_sections': 0,
        'reports_with_no_impression': 0,
        'reports_with_no_findings': 0,
        'reports_processed': 0,
        'images_processed': 0
    }
    
    print("Processing reports...")
    for report_file in tqdm(report_files):
        try:
            # Parse the XML file
            report_path = os.path.join(reports_path, report_file)
            tree = ET.parse(report_path)
            root = tree.getroot()
            
            # Find associated images in the report
            images_in_report = root.findall("parentImage")
            if not images_in_report:
                stats['reports_with_no_image'] += 1
                continue
            
            # Extract findings and impression
            sections = root.find("MedlineCitation").find("Article").find("Abstract").findall("AbstractText")
            findings, impression = None, None
            for section in sections:
                label = section.get("Label")
                if label == "FINDINGS":
                    findings = section.text
                elif label == "IMPRESSION":
                    impression = section.text
            
            # Skip reports with no meaningful sections
            if not findings and not impression:
                stats['reports_with_empty_sections'] += 1
                continue
            
            # Create a caption for the report
            caption = (impression or "") + " " + (findings or "")
            caption = caption.strip()
            if len(caption.split()) < 10:  # Skip short captions
                continue
            
            # Validate and process associated images
            valid_images = []
            for image in images_in_report:
                image_id = f"{image.get('id')}.png"
                if image_id in image_files:  # Ensure the image exists in the directory
                    clean_images_captions[image_id] = caption
                    valid_images.append(image_id)
                    stats['images_processed'] += 1
            
            # Add the report if it has valid images
            if valid_images:
                clean_reports_with_images[report_file] = valid_images
                clean_text_of_reports[report_file] = caption
                stats['reports_processed'] += 1
        
        except Exception as e:
            print(f"Error processing report {report_file}: {e}")
            continue
    
    # Print statistics
    print("\nDataset Cleaning Statistics:")
    for key, value in stats.items():
        print(f"{key.replace('_', ' ').capitalize()}: {value}")
    
    return clean_images_captions, clean_reports_with_images, clean_text_of_reports


In [3]:
from torch.utils.data import Dataset
class ChestXrayDataset(Dataset):
    """
    Custom PyTorch Dataset for Chest X-Ray images and captions
    """
    def __init__(self, images_path, images_captions, processor):
        self.images_path = images_path
        self.images_captions = images_captions
        self.processor = processor
        self.image_files = list(images_captions.keys())
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.images_path, image_file)
        
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        caption = self.images_captions[image_file]
        
        # Prepare inputs
        inputs = self.processor(
            images=image, 
            text=caption, 
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=128
        )
        
        # Remove batch dimension
        for k, v in inputs.items():
            inputs[k] = v.squeeze()
            
        return inputs

In [4]:
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForCausalLM
def main():
    reports_path = os.path.join(dataset_root, "reports_subset", "ecgen-radiology")
    images_path = os.path.join(dataset_root, "images_subset")
    
    # Ensure paths exist
    os.makedirs(images_path, exist_ok=True)
    os.makedirs(reports_path, exist_ok=True)
    
    # Clean and process reports
    clean_images_captions, clean_reports_with_images, clean_text_of_reports = clean_and_process_reports(
        reports_path, 
        images_path
    )
    
    # Save processed dataset (optional)
    with open('medical_dataset.pkl', 'wb') as f:
        pickle.dump({
            'images_captions': clean_images_captions,
            'reports_with_images': clean_reports_with_images,
            'text_of_reports': clean_text_of_reports
        }, f)
if __name__ == "__main__":
    main()


Processing reports...


100%|██████████| 3955/3955 [00:12<00:00, 310.35it/s]


Dataset Cleaning Statistics:
Total reports: 3955
Reports with no image: 104
Reports with empty sections: 25
Reports with no impression: 0
Reports with no findings: 0
Reports processed: 3772
Images processed: 7326



