In [None]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import openslide

# Define the dataset class
class WSI_QADataset(Dataset):
    def __init__(self, image_dir, qa_pairs, transform=None):
        """
        Args:
            image_dir (str): Directory containing WSI images.
            qa_pairs (list): List of tuples (patient_id, question, answer).
            transform (callable, optional): Optional transform to be applied to the images.
        """
        self.image_dir = image_dir
        self.qa_pairs = qa_pairs
        self.transform = transform

    def __len__(self):
        return len(self.qa_pairs)

    def __getitem__(self, idx):
        # Retrieve the QA pair
        patient_id, question, answer = self.qa_pairs[idx]

        # Load the corresponding image
        image_path = os.path.join(self.image_dir, f"{patient_id}.svs")  # Adjust file extension as needed
        try:
            slide = openslide.OpenSlide(image_path)
            # Extract a region of interest or thumbnail
            image = slide.get_thumbnail((512, 512))  # Resize to 512x512
            image = image.convert("RGB")  # Ensure 3 channels
        except FileNotFoundError:
            raise FileNotFoundError(f"Image for patient_id {patient_id} not found at {image_path}")

        # Apply any transformations if specified
        if self.transform:
            image = self.transform(image)

        # Return the image and the corresponding question-answer pair
        return {
            "image": image,
            "question": question,
            "answer": answer
        }

# Preprocessing transformations
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize RGB channels
])

# Load QA pairs (example format: list of tuples)
qa_pairs = [
    ("patient_id_1", "What drug was prescribed to the patient?", "Drug A"),
    ("patient_id_2", "What therapy is prescribed to the patient?", "Therapy B"),
    # Add more QA pairs as needed
]

# Define the directory containing the images
image_dir = "/path/to/wsi/images"

# Create the dataset and dataloader
dataset = WSI_QADataset(image_dir, qa_pairs, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Example: Iterating through the dataloader
for batch in dataloader:
    images = batch["image"]  # Batch of images
    questions = batch["question"]  # Corresponding questions
    answers = batch["answer"]  # Corresponding answers
    print(questions, answers)
