# Few-Shot Medical Image Classification with RAG and VLMs

This lab introduces you to the medical applications of retrieval-augmented generation (RAG). We will be working on a brain tumor classifier. For a classifier, we are going to use a vision-language model (VLM). We will make the model classify the images through a prompt enhanced with 5 examples of most similar images with labels to the analyzed image. To retieve similar images, we will setup a Qdrant vector datasets of brain tumor MRI images. Each image will be encoded in the database with the medical encoder MedImageInsight released by Microsoft.

This approach does not require any training of the classifier. The positive effect of supplying VLM with similar images is most apparent for smaller models.

### Key Concepts

* **Zero-Shot Learning**

A machine learning paradigm where models can recognize or classify objects they haven't been explicitly trained on. Particularly valuable in medical contexts where labeled data may be scarce. Relies on transferring knowledge from related tasks and understanding semantic relationships.


* **RAG (Retrieval-Augmented Generation)**

A hybrid approach that combines information retrieval from a knowledge base (in our case, a database of medical images) and large language model generation capabilities.

Enhances model performance by providing relevant context before making decisions. Improves reliability by grounding predictions in similar historical cases


* **VLM (Vision Language Models)**

Advanced AI models that can process both images and text. Can understand and describe medical images in natural language. Enable more interpretable and explainable AI decisions in medical contexts.


* **Vector Databases**

Specialized databases that store high-dimensional vector representations of images. Enable efficient similarity search using distance metrics
Critical for retrieving relevant historical cases to support decision-making



### Laboratory Objectives
In this lab, you will:

1. Set up a RAG framework for medical image analysis
2. Create and manage a vector database of brain MRI scans
3. Implement zero-shot classification using state-of-the-art VLMs
4. Evaluate system performance using various metrics (homework).

### Prerequisites

* Basic understanding of Python programming
* Familiarity with deep learning concepts
* Knowledge of medical imaging basics

### Tools and Libraries

* Python 3.10+
* UMIE datasets for the images
* Transformers library for VLMs (they recently released a module supporting)
* Qdrant for vector database management
* MedImageInsight model for medical image encoding of images in the vector database

## 1. Setup

### 1.1 Install required libraries

In [None]:
!pip install qdrant_client
!pip install einops transformers_stream_generator
!pip install datasets
!pip install transformers accelerate -U

### 1.2 Load Brain Tumor Classification dataset from UMIE datasets repo on HuggingFace

In [None]:
from datasets import load_dataset

dataset = load_dataset("lion-ai/umie_datasets", "brain_tumor_classification", split='train')

### 1.3 Select a single image from each study. The selected image should have the most tumor visible.

In [None]:
from collections import defaultdict
from typing import Dict, List, Optional, Union
import re

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from sklearn.model_selection import train_test_split


def group_by_study_id(dataset: Dataset) -> Dict[str, List[Dict]]:
    """Group images by study_id."""
    study_groups = defaultdict(list)
    for item in dataset:
        study_groups[item["study_id"]].append(item)
    return dict(study_groups)


def get_study_label(study_group: List[Dict]) -> str:
    """Get the dominant source label for a study group."""
    label_counts: dict = defaultdict(int)
    for item in study_group:
        label = item["labels"]
        pattern = r'[a-zA-Z]+'
        label = ''.join(re.findall(pattern, label))
        label_counts[label] += 1
    return max(label_counts.items(), key=lambda x: x[1])[0]


def calculate_val_size(study_df: pd.DataFrame, val_size: Union[float, int], total_images: int) -> float:
    """
    Calculate validation set size as a proportion based on input specification.

    Args:
        study_df: DataFrame containing study information
        val_size: Either a float (proportion) or int (absolute number)
        total_images: Total number of images in the dataset

    Returns:
        Float representing the proportion for validation set
    """
    if isinstance(val_size, float):
        if not 0 < val_size < 1:
            raise ValueError("Validation proportion must be between 0 and 1")
        return val_size
    elif isinstance(val_size, int):
        if not 0 < val_size < total_images:
            raise ValueError(f"Validation count must be between 0 and {total_images}")
        return val_size / total_images
    else:
        raise ValueError("val_size must be either float (proportion) or int (count)")


def create_split(
    dataset: Dataset, 
    val_size: Union[float, int] = 0.2, 
    random_state: Optional[int] = 42
) -> DatasetDict:
    """
    Create train-val split while keeping same study_id together and stratifying based on source_labels.

    Args:
        dataset: Hugging Face dataset
        val_size: Either a float between 0-1 (proportion) or int (exact number of examples)
        random_state: Random seed for reproducibility

    Returns:
        DatasetDict containing train and validation splits
    """
    # Group data by study_id
    study_groups = group_by_study_id(dataset)

    # Create a list of (study_id, dominant_label, group_size) tuples
    study_info = [
        (study_id, get_study_label(group), len(group)) 
        for study_id, group in study_groups.items()
    ]

    # Convert to DataFrame for easier manipulation
    study_df = pd.DataFrame(study_info, columns=["study_id", "label", "group_size"])

    total_images = len(dataset)
    val_proportion = calculate_val_size(study_df, val_size, total_images)

    # Perform stratified split on study level
    train_studies, val_studies = train_test_split(
        study_df["study_id"],
        test_size=val_proportion,
        random_state=random_state,
        stratify=study_df["label"]
    )

    # Convert to sets for faster lookup
    train_studies_set = set(train_studies)

    # Create train and validation splits
    train_indices = [
        i for i, item in enumerate(dataset) 
        if item["study_id"] in train_studies_set
    ]
    val_indices = [
        i for i, item in enumerate(dataset) 
        if item["study_id"] not in train_studies_set
    ]

    # Create the splits using Dataset.select()
    train_dataset = dataset.select(train_indices)
    val_dataset = dataset.select(val_indices)

    # Print split statistics
    print("\nSplit Statistics:")
    print(f"Total studies: {len(study_groups)}")
    print(f"Total images: {total_images}")
    print(f"Train studies: {len(train_studies)}")
    print(f"Val studies: {len(val_studies)}")
    print(f"Train images: {len(train_dataset)}")
    print(f"Val images: {len(val_dataset)}")

    # Validate if using exact number
    if isinstance(val_size, int):
        print(f"\nRequested validation size: {val_size}")
        print(f"Actual validation size: {len(val_dataset)}")
        print("Note: Actual size might differ slightly from requested size due to study grouping")

    return DatasetDict({
        "train": train_dataset,
        "validation": val_dataset
    })


splits = create_split(dataset=dataset, val_size=100, random_state=42)


### 1.3 Copy MedImageInsight model directory to current directory

In [None]:
!cp -a /kaggle/input/medimageinsight/pytorch/default/1/. /kaggle/working/

### 1. 4 Install model requirements

In [None]:
!pip install -r requirements.txt

## 2. Setup RAG Framework

### 2.1 Create a medical encoder that will be used to create embedding for the vector database

In [None]:
from medimageinsightmodel import MedImageInsight

In [None]:
vision_encoder = MedImageInsight(
    model_dir="2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)
vision_encoder.load_model()

### 2.2 Setup Vector Database Client

In [None]:
from qdrant_client import QdrantClient

client = QdrantClient(":memory:")

### 2.3 Labels used by the Brain Tumor Classification dataset (multiclass classification problem)

In [None]:
labels = [
    "Normal",
    "Glioma",
    "Meningioma",
    "Pituitary"
]

### 2.4 Helper functions for creating image embeddings 

In [None]:
import base64
from io import BytesIO
from PIL import Image
import numpy as np


def convert_to_base64(image, resize=None, max_size=512):
    # Check if the image is a numpy array
    if isinstance(image, np.ndarray):
        # Convert the numpy array to a PIL Image
        image = Image.fromarray(image)

    # Resize the image if a new size is specified
    if resize:
        if isinstance(resize, tuple) and len(resize) == 2:
            image = image.resize(resize, Image.LANCZOS)
        else:
            raise ValueError("Resize parameter must be a tuple of (width, height)")

    # Scale the image to fit within max_size while preserving aspect ratio
    if max_size:
        if isinstance(max_size, int) and max_size > 0:
            image.thumbnail((max_size, max_size), Image.LANCZOS)
        else:
            raise ValueError("Max size must be a positive integer")

    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue())
    return img_str.decode("utf-8")

def resize_base64(base64_string, resize=None, max_size=512):
    # Decode the base64 string to image
    image_data = base64.b64decode(base64_string)
    image = Image.open(BytesIO(image_data))

    # Resize the image if a new size is specified
    if resize:
        if isinstance(resize, tuple) and len(resize) == 2:
            image = image.resize(resize, Image.LANCZOS)
        else:
            raise ValueError("Resize parameter must be a tuple of (width, height)")

    # Scale the image to fit within max_size while preserving aspect ratio
    if max_size:
        if isinstance(max_size, int) and max_size > 0:
            image.thumbnail((max_size, max_size), Image.LANCZOS)
        else:
            raise ValueError("Max size must be a positive integer")

    # Convert the resized image back to base64
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    resized_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return resized_base64



### 2.5 Create a class for creating and managing the vector database of images end labels

In [None]:
import cv2
from qdrant_client.models import PointStruct, VectorParams, Distance
import os


class Retriever:
    def __init__(self, encoder, collection_name: str, n_images: int = 5):
        self.encoder = encoder
        self.client = client
        self.collection_name = collection_name
        self.n_images = n_images

    def get_similar_imgs(self, img_str: list[str]):
        embedings = self.encoder.encode(img_str)[0]

        search_result = self.client.query_points(
            collection_name=self.collection_name,
            query=embedings,
            with_payload=True,
            limit=self.n_images,
        ).points

        results = []
        for p in search_result:
            results.append(
                {
                    "id": p.id,
                    "data": p.payload,
                }
            )

        return results

    def _create_id(self, umie_path):
        return int(os.path.basename(umie_path)[0:-4].replace("_", "0").lstrip("0"))

    def _upload_batch(self, batch):
        img_paths = batch["umie_id"]
        labels = batch["labels"]

        images = []

        for i in range(len(batch['image'])):
            base_img = convert_to_base64(batch["image"][i])
            images.append(base_img)
        img_embeddings = self.encoder.encode(images)

        points = []
        ids = []

        for img_path, img_embedding, img, label in zip(
            img_paths, img_embeddings, images, labels
        ):
            try:
                id = self._create_id(img_path)
                if id in ids:
                    print(f"error: {id}")
                    print(f"error: {img_path}")
                    break
                ids.append(id)
                label = label[0]
                points.append(
                    PointStruct(
                        id=id,
                        vector=img_embedding,
                        payload={"label": label, "id": id, "img": img},
                    )
                )
            except Exception as e:
                print(f"error: {img_path} {e}")

        operation_info = self.client.upsert(
            collection_name=self.collection_name, wait=True, points=points
        )

    def upload_dataset(self, dataset, batch_size=8, replace_collection=False):
        if client.collection_exists(self.collection_name):
            if replace_collection:
                client.delete_collection(self.collection_name)

                client.create_collection(
                    collection_name=self.collection_name,
                    vectors_config=VectorParams(
                        size=1024, distance=Distance.COSINE
                    ),
                )
        else:
            client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(
                    size=1024, distance=Distance.COSINE
                ),
            )

        dataset = dataset.select(range(20))
        dataset.map(self._upload_batch, batch_size=batch_size, batched=True)

### 2.6 Store brain tumor classification dataset in the vector database. Embed the images with MedImageInsight Encoder

*We selected validation split for demonstration purposes, not to use too much memory on Kaggle. You should normally use "train" split.

In [None]:
retriever = Retriever(encoder=vision_encoder, collection_name= "brain_tumor_classification", n_images= 5)
retriever.upload_dataset(splits["validation"], batch_size=8, replace_collection=False)

## 3. Create Medical Image Analyzer

### 3.1. Create prompts. We want our Vision Langauge Model to classify brain tumor type. Before performing classification, we supply VLM with 5 most similar images retrieved from the database to ground it before making a prediction.

In [None]:
system_message = """You are a medical expert.
Analize the MRI image and classify if there is a tumor present.
Select the appropriate class from {labels}."""

final_rag_message = """
Now, please analyze the new image and provide your classification.
You always need to classify.
Return only the result in the json format: {'y_pred': y_pred, 'explanation': explanation}.
The explanation should be brief."""

### 3.2 Set up VLM. We are going to use Qwen 2.5 B

In [None]:
from transformers import pipeline
pipe = pipeline("image-text-to-text", model="Qwen/Qwen2-VL-7B-Instruct")

Alternatively, you can try VLMs hosted on Kaggle.

In [None]:
# from transformers import pipeline
# pipe = pipeline("text-generation","/kaggle/input/qwen/pytorch/vl/1", trust_remote_code=True)

### 3.3 Create a class for analyzing images, creating a stuctured prompt and combining the retriever with VLM.

In [None]:
import numpy as np
import requests
from litellm import completion
import re

class ImageAnalyzer:
    def __init__(self, dataset, model):
        self.dataset = dataset
        self.model = model
        self.add_classes = add_classes

    def _create_base_prompt(self, modality: str, labels: list) -> str:
        prompt = system_message.format(modality=modality, labels=labels)

        return prompt

    def _create_few_shot_message(self, similar_images):
        messages = []

        for example in similar_images:
            image = (
                example["data"]["img"]
                if "img" in example["data"]
                else example["data"]["image"]
            )
            img = resize_base64(image)

                label = example["data"]["label"]

                prompt = f"""Example:
                                class: {label}
                                    """
            message = {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{img}"},
                    },
                ],
            }
            messages.append(message)

        return messages

    def _description_message(self, modality, classes, few_shot=False):
        prompt = self._create_base_prompt(modality, classes)

        if few_shot:
            prompt += "\n Here are some examples to learn from:"

        message = [
            {
                "role": "system",
                "content": prompt,
            }
        ]
        return message

    def _create_rag_final_message(self, img_str):
        message = {
            "role": "user",
            "content": [
                {"type": "text", "text": final_rag_message},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{img_str}"},
                },
            ],
        }
        return message

    def _get_message(self, img_str, similar_images=[]):
        modality = self.dataset.modality
        classes = self.dataset.classes
        few_shot = True if len(similar_images) > 0 else False
        messages = self._description_message(modality, classes, few_shot)
        if len(similar_images) > 0:
            messages.extend(self._create_few_shot_message(similar_images))
        messages.append(self._create_rag_final_message(img_str))

        return messages

    def _get_aux_imgs_info(self, similar_imgs):
        aux_images = {}
        for idx, example in enumerate(similar_imgs):
            aux_images[f"aux_{idx}_id"] = (example["data"]["id"],)
            if self.add_classes:
                aux_images[f"aux_{idx}_label"] = example["data"]["label"]
        return aux_images

    def analyze_img(self, img_str, similar_imgs, verbose=False):
        messages = self._get_message(img_str, similar_imgs)

        response = completion(
            model=self.model.name, api_key=self.model.api_key, messages=messages
        )
        
        outputs = pipe(text=messages, max_new_tokens=40, return_full_text=False)
        result = outputs[0]["generated_text"]
        if verbose:
            print(result)
        
        return result



In [None]:
image_analyzer = ImageAnalyzer(dataset, model)

# Homework
Test your system on the validation set. You can also try using different VLMs and different datasets from UMIE.
Compare the following metrics:

- Accuracy
- Confidence scores
- Classification speed
- Quality of explanations


Create a summary report of your findings:
```python def evaluate_system(analyzer, test_data):
    results = {
        'accuracy': [],
        'confidence': [],
        'speed': [],
        'explanation_quality': []
    }
    return results
```