# FLARE25-PaliGemma2: Inference Notebook [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/medfm-flare/FLARE25-PaliGemma/blob/main/inference.ipynb) [![Hugging Face Model](https://img.shields.io/badge/HuggingFace-yws0322%2Fflare25--paligemma2-blue?logo=huggingface)](https://huggingface.co/yws0322/flare25-paligemma2)

This notebook adapts the provided python script for inference into a step-by-step Jupyter Notebook environment.

The script performs the following main tasks:
1.  Parses arguments for model, data, and inference configuration.
2.  Defines utility functions for file handling and path validation.
3.  Includes functions to parse model output based on different task types (classification, detection, counting, etc.).
4.  Loads the base PaliGemma2 model and optionally applies LoRA fine-tuned weights.
5.  Runs predictions on a set of medical images with associated questions from JSON files.
6.  Saves the predictions to a JSON output file.

We will break down the script into logical sections and execute them step by step in this notebook.

## 1. Initial Setup and Library Imports

In this section, we install and import the necessary libraries required for running inference with the PaliGemma2 model. We also log in to the Hugging Face Hub to access models and processors that require authentication (e.g., if they are gated or hosted under a private repository).

In [1]:
# Install required libraries
!pip install --upgrade transformers peft bitsandbytes accelerate

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.13.0->peft)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.13.0->peft)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.13.0-

In [4]:
# Import required libraries
import os
import json
import argparse
import glob
import re
from PIL import Image
from tqdm import tqdm
import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login

In [5]:
# Hugging Face Authentication
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## 2. Argument Configuration

The original script uses `argparse` to handle command-line arguments. In this notebook environment, we will define the arguments directly. You can modify the values in the following cell to configure the prediction process.

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [26]:
class Args:
    def __init__(self):
        self.base_model_id = "google/paligemma2-10b-pt-224"
        self.model_id = "yws0322/flare25-paligemma2"
        self.base_dataset_path = "/content/drive/MyDrive/flare/organized_dataset" # Update this path as needed
        self.validation_type = "public" # Choose 'hidden' or 'public'
        self.output_dir = "predictions" # Output directory for results
        self.output_filename = "predictions.json"
        self.max_new_tokens = 1024
        self.device = "cuda:0" # Or "cpu" if no GPU is available

args = Args()

print("Arguments configured:")
print(f"  Base Model ID: {args.base_model_id}")
print(f"  Model ID: {args.model_id}")
print(f"  Base Dataset Path: {args.base_dataset_path}")
print(f"  Validation Type: {args.validation_type}")
print(f"  Output Directory: {args.output_dir}")
print(f"  Output Filename: {args.output_filename}")
print(f"  Max New Tokens: {args.max_new_tokens}")
print(f"  Device: {args.device}")

Arguments configured:
  Base Model ID: google/paligemma2-10b-pt-224
  Model ID: yws0322/flare25-paligemma2
  Base Dataset Path: /content/drive/MyDrive/flare/organized_dataset
  Validation Type: public
  Output Directory: predictions
  Output Filename: predictions.json
  Max New Tokens: 1024
  Device: cuda:0


## 3. Utility Functions

This section includes utility functions from the original script to help with file handling and path validation.

In [27]:
def find_json_files(base_path):
    """Recursively find all JSON files in the specified directory."""
    json_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.json'):
                json_files.append(os.path.join(root, file))
    return json_files


def validate_paths(dataset_path):
    """Validate that required paths exist."""
    # Check dataset path
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")

    # Find JSON files
    json_files = find_json_files(dataset_path)
    if not json_files:
        raise FileNotFoundError(f"No JSON files found in {dataset_path}")

    return json_files

## 4. Answer Parsing Functions

This section includes functions to parse the model's output based on the specific task type (classification, detection, etc.) to extract the final answer in the required format.

In [28]:
def parse_answer(output, task_type=None):
    """Parse model output based on task type to extract the final answer."""
    output = output.strip()

    # Remove common prefixes
    if "Please provide a clear and concise answer." in output:
        try:
            output = output.split("Please provide a clear and concise answer.")[-1].strip()
        except:
            pass

    # Remove leading newlines
    if "\n" in output:
        output = output.split("\n", 1)[-1].strip()

    # Task-specific parsing
    task_type = (task_type or "").strip().lower()

    if task_type == "classification":
        return _parse_classification(output)
    elif task_type == "multi-label classification":
        return _parse_multi_label_classification(output)
    elif task_type in ["detection", "instance_detection"]:
        return _parse_detection(output)
    elif task_type in ["cell counting", "regression", "counting"]:
        return _parse_numeric(output)
    elif task_type == "report generation":
        return output
    else:
        return output


def _parse_classification(output):
    """Parse classification task output."""
    lines = output.splitlines()
    if len(lines) >= 1:
        last_line = lines[-1].strip()
        return last_line
    return output

def _parse_multi_label_classification(output):
    """Parse multi-label classification task output."""
    lines = output.splitlines()
    labels = []
    for line in lines:
        for part in re.split(r'[;]', line):
            label = part.strip()
            if label:
                labels.append(label)
    return "; ".join(labels)


def _parse_detection(output):
    """Parse detection task output (JSON format expected)."""
    match = re.search(r'\{.*\}|\[.*\]', output, re.DOTALL)
    if match:
        try:
            parsed = json.loads(match.group())
            return json.dumps(parsed)
        except:
            return match.group()
    return output


def _parse_numeric(output):
    """Parse numeric task output (counting, regression)."""
    match = re.search(r'[-+]?[0-9]*\.?[0-9]+', output)
    if match:
        return match.group()
    return "0"

## 5. Model Loading

This section defines the function to load the base PaliGemma2 model and optionally apply the LoRA fine-tuned weights.

In [33]:
def load_model_and_processor(base_model_id, model_id, device="cuda:0"):
    """Load PaliGemma2 model and processor from Hugging Face or with optional LoRA weights."""
    print(f"Loading model: {model_id}")

    # Configure quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    # Load model
    print(f"Loading base model: {base_model_id}")
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        base_model_id,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        device_map={"": 0},
        attn_implementation='eager',
    )
    # If model_id is provided, apply LoRA weights
    if model_id:
        model = PeftModel.from_pretrained(model, model_id)
        print(f"Loaded LoRA weights from {model_id}")

    # Load processor
    processor = PaliGemmaProcessor.from_pretrained(base_model_id)

    return model, processor

## 6. Prediction Function

This section defines the core function `predict_on_file` which processes a single JSON file containing samples (image paths and questions) and generates predictions using the loaded PaliGemma2 model.

In [30]:
def predict_on_file(input_file, model, processor, max_new_tokens=1024, device="cuda:0"):
    """Perform predictions on a single JSON file containing questions and images."""
    IMAGE_TOKEN = "<image>"

    # Load data
    with open(input_file) as f:
        val_data = json.load(f)

    print(f"Processing {len(val_data)} samples from {os.path.basename(input_file)}")

    # Process each sample
    for sample in tqdm(val_data, desc=f"Predicting {os.path.basename(input_file)}"):
        try:
            # Handle image loading
            img_field = sample["ImageName"]
            if isinstance(img_field, list):
                img_paths = img_field[:5]  # Limit to 5 images max
            else:
                img_paths = [img_field]

            # Load and validate images
            imgs = []
            for img_path in img_paths:
                full_path = os.path.join(os.path.dirname(input_file), img_path)
                try:
                    img = Image.open(full_path).convert("RGB")
                    imgs.append(img)
                except Exception as e:
                    print(f"Warning: Failed to load image {img_path}: {e}")
                    continue

            if not imgs:
                print(f"Warning: No valid images for sample, skipping")
                sample["Answer"] = "Error: No valid images"
                continue

            # Prepare input
            formatted_question = (
                "Analyze the given medical image and answer the following question:\n"
                f"Question: {sample['Question']}\n"
                "Please provide a clear and concise answer."
            )
            prefix = IMAGE_TOKEN * (processor.image_seq_length * len(imgs))
            input_text = f"{prefix}{processor.tokenizer.bos_token}{formatted_question}\n"

            # Process images and text
            pixel_values = processor.image_processor(imgs, return_tensors="pt")["pixel_values"].to(device)
            inputs = processor.tokenizer(
                input_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
            ).to(device)

            # Generate prediction
            with torch.no_grad():
                generated_ids = model.generate(
                    input_ids=inputs.input_ids,
                    pixel_values=pixel_values,
                    max_new_tokens=max_new_tokens,
                    do_sample=False
                )

            # Decode output
            output = processor.tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=True
            )[0]

            # Parse answer based on task type
            parsed_answer = parse_answer(output, sample.get("TaskType", ""))
            sample["Answer"] = parsed_answer

        except Exception as e:
            print(f"Error processing sample: {e}")
            sample["Answer"] = f"Error: {str(e)}"

    return val_data

## 7. Run Predictions

This section contains the main execution logic to run predictions on all found JSON files in the specified dataset directory using the configured arguments and defined functions.

In [34]:
def run_predictions(args):
    """
    Main function to run predictions on all JSON files in the dataset directory.
    """
    # Construct full dataset path
    dataset_path = os.path.join(args.base_dataset_path, f"validation-{args.validation_type}")

    # Validate paths and find JSON files
    print("Validating paths and discovering files...")
    try:
        input_files = validate_paths(dataset_path)
        print(f"Found {len(input_files)} JSON files in {dataset_path}:")
        for file in input_files:
            print(f"  - {os.path.relpath(file, dataset_path)}")
    except FileNotFoundError as e:
        print(f"Error during path validation: {e}")
        print("Please ensure the base_dataset_path and validation_type are correctly set and the directory contains JSON files.")
        return 0 # Indicate no predictions were made due to error


    # Load model and processor
    print("\nLoading model and processor...")
    try:
        model, processor = load_model_and_processor(
            args.base_model_id,
            args.model_id,
            args.device
        )
    except Exception as e:
        print(f"Error loading model or processor: {e}")
        print("Please check the model_id, and device settings.")
        return 0 # Indicate no predictions were made due to error


    # Run predictions on all files
    print(f"\nRunning predictions...")
    all_predictions = []
    total_predictions_made = 0

    for input_file in input_files:
        try:
            predictions = predict_on_file(
                input_file,
                model,
                processor,
                args.max_new_tokens,
                args.device
            )
            all_predictions.extend(predictions)
            total_predictions_made += len(predictions)
        except Exception as e:
            print(f"Error running predictions on {os.path.basename(input_file)}: {e}")
            # Continue to the next file even if one fails


    # Save results
    print(f"\nSaving results...")
    output_dir = args.output_dir if args.output_dir else "."
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, args.output_filename)

    try:
        with open(output_file, "w") as f:
            json.dump(all_predictions, f, indent=2)
        print(f"Predictions saved to {output_file}")
    except Exception as e:
        print(f"Error saving predictions to {output_file}: {e}")


    return total_predictions_made

# Run the main prediction function
print("Starting the prediction process...")
try:
    prediction_count = run_predictions(args)
    print(f"\nSuccessfully completed predictions for {prediction_count} samples.")
except Exception as e:
    print(f"\nAn unexpected error occurred during the prediction process: {e}")

Starting the prediction process...
Validating paths and discovering files...
Found 12 JSON files in /content/drive/MyDrive/flare/organized_dataset/validation-public:
  - Xray/IU_XRay/IU_XRay_questions_val.json
  - Xray/chestdr/chestdr_questions_val.json
  - Endoscopy/endo/endo_questions_val.json
  - Clinical/neojaundice/neojaundice_questions_val.json
  - Mammography/CMMD/CMMD_questions_val.json
  - Retinography/retino/retino_questions_val.json
  - Ultrasound/BUSI-det/BUSI-det_questions_val.json
  - Ultrasound/BUSI/BUSI_questions_val.json
  - Ultrasound/BUS-UCLM/BUS-UCLM_questions_val.json
  - Ultrasound/BUS-UCLM-det/BUS-UCLM-det_questions_val.json
  - Microscopy/neurips22cell/neurips22cell_questions_val.json
  - Dermatology/bcn20000/bcn20000_questions_val.json

Loading model and processor...
Loading model: yws0322/flare25-paligemma2
Loading base model: google/paligemma2-10b-pt-224


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded LoRA weights from yws0322/flare25-paligemma2


preprocessor_config.json:   0%|          | 0.00/424 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/243k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/733 [00:00<?, ?B/s]


Running predictions...
Processing 1945 samples from IU_XRay_questions_val.json


Predicting IU_XRay_questions_val.json:   0%|          | 0/1945 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Predicting IU_XRay_questions_val.json:   4%|▍         | 81/1945 [04:21<1:40:12,  3.23s/it]


KeyboardInterrupt: 