In [None]:
import torch
import sys

try:
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        
        # Check minimum memory requirement
        total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        if total_memory_gb < 16:
            print(f"ERROR: Insufficient GPU memory. Minimum 16GB required, found {total_memory_gb:.1f}GB")
            sys.exit(1)
    else:
        print("ERROR: No GPU detected! This notebook requires GPU acceleration.")
        sys.exit(1)
        
    # Check bfloat16 support
    if torch.cuda.get_device_capability()[0] < 8:
        print("WARNING: GPU does not support bfloat16. Performance may be affected.")
        
except Exception as e:
    print(f"ERROR during GPU check: {str(e)}")
    sys.exit(1)

In [None]:
try:
    !pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm evaluate trl scikit-learn
    print("✓ All packages installed successfully")
except Exception as e:
    print(f"✗ Package installation failed: {str(e)}")
    raise

In [None]:
SYS_PROMPT = """You are an experienced emergency radiologist analyzing imaging studies.

OUTPUT FORMAT:
You must respond with ONLY detailed findings as a string.

ANALYSIS APPROACH:
- Systematically examine the entire image for all abnormalities
- Report all identified lesions and pathological findings
- Be factual - if uncertain, describe what you observe without assuming
- Use professional radiological terminology
- Review the image multiple times if findings are ambiguous

REPORT CONTENT:
The "report" field should contain a complete radiological description including:
- Primary findings related to the clinical question
- Additional incidental findings or lesions
- Relevant negative findings if clinically important

EXAMPLES:

Example 1 - Chest X-ray with pneumonia:
Input: Chest X-ray, patient with cough and fever
Output: Consolidation in the right lower lobe consistent with pneumonia. No pleural effusion or pneumothorax. Heart size normal.

Example 2 - Normal chest X-ray:
Input: Chest X-ray, routine screening
Output: Clear lung fields bilaterally. No consolidation, pleural effusion, or pneumothorax. Cardiac silhouette within normal limits. No acute bony abnormalities.
"""

USR_PROMPT = """Generate a radiology report for this {anatomy} X-ray of {subject}."""

## Create and prepare the fine-tuning dataset

When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model.

This notebook focuses on fine-tuning a MedGemma model to generate radiology reports for pediatric chest X-rays. We use the `costinstroie/xray-chest-ped-test` dataset which contains pediatric chest X-ray images with corresponding radiology reports.

### Loading the dataset
Load the dataset and display its structure

In [None]:
from datasets import load_dataset, Image

try:
    print("Loading dataset...")
    data = load_dataset("costinstroie/xray-chest-ped-test")
    
    # Verify dataset structure
    if "train" not in data or "validation" not in data or "test" not in data:
        raise ValueError("Dataset missing required splits (train/validation/test)")
    
    # Cast image column to proper type
    data = data.cast_column("image", Image())
    
    # Verify image data
    if len(data["train"]) == 0:
        raise ValueError("Training dataset is empty")
    
    print(f"✓ Dataset loaded successfully")
    print(f"  - Train samples: {len(data['train'])}")
    print(f"  - Validation samples: {len(data['validation'])}")
    print(f"  - Test samples: {len(data['test'])}")
    
except Exception as e:
    print(f"✗ Error loading dataset: {str(e)}")
    raise

### Processing the dataset

We create a custom prompt that will be used to guide the model during fine-tuning. The prompt includes patient demographics and anatomy information. To prepare the dataset for fine-tuning, we will create a new column called "messages". This column will contain structured data representing a system message, user query (the prompt), and assistant response (the report).

Hugging Face TRL supports multimodal conversations. The important piece is the "image" role, which tells the processing class that it should load the image. The structure should follow:

```json
{
  "messages": [
    {"role": "system", "content": [{"type": "text", "text": "You are..."}]},
    {"role": "user", "content": [
      {"type": "text", "text": "..."},
      {"type": "image"}
    ]},
    {"role": "assistant", "content": [{"type": "text", "text": "..."}]}
  ]
}
```

In [None]:
def format_data(example: dict[str, any]) -> dict[str, any]:
    """Format dataset example into conversation format with error handling."""
    try:
        # Validate required fields
        required_fields = ['age_group', 'gender', 'report', 'image']
        for field in required_fields:
            if field not in example:
                raise ValueError(f"Missing required field: {field}")
            if example[field] is None or (isinstance(example[field], str) and example[field].strip() == ""):
                raise ValueError(f"Empty field: {field}")
        
        prompt = USR_PROMPT.format(
            anatomy="chest",
            subject=f"{example['age_group']} {example['gender']}"
        )

        example["messages"] = [
            {
                "role": "system",
                "content": [{"type": "text", "text": SYS_PROMPT}]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                    },
                    {
                        "type": "text",
                        "text": prompt,
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": example["report"],
                    },
                ],
            },
        ]
        return example
        
    except Exception as e:
        print(f"✗ Error formatting example: {str(e)}")
        print(f"  Problematic example: {example.get('id', 'unknown')}")
        raise

# Apply the formatting to the dataset with error handling
try:
    print("Formatting dataset...")
    formatted_data = data.map(format_data)
    print(f"✓ Dataset formatted successfully")
    
    # Verify formatted data
    sample = formatted_data["train"][0]['messages']
    if not sample or len(sample) != 3:
        raise ValueError("Formatted messages structure is incorrect")
    
    print("Sample formatted data:")
    print(sample)
    
except Exception as e:
    print(f"✗ Error during dataset formatting: {str(e)}")
    raise

## Fine-tune MedGemma using TRL and the SFTTrainer

You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:

- Dataset formatting, including conversational and instruction formats
- Training on completions only, ignoring prompts
- Packing datasets for more efficient training
- Parameter-efficient fine-tuning (PEFT) support including QLoRA
- Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)

### Loading the model and tokenizer

We use the Transformers library to load the MedGemma 4B Instruct model and its processor. The model is configured to use bfloat16 precision for efficient computation on GPUs.

**Note:** This guide requires a GPU which supports bfloat16 data type such as NVIDIA L4 or NVIDIA A100 and more than 16GB of memory.

We use the Transformers library to load the MedGemma 4B Instruct model and its processor. The model is configured to use bfloat16 precision for efficient computation on GPUs.

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "google/medgemma-4b-it"

try:
    print("Loading model and processor...")
    
    ## Check if GPU supports bfloat16
    if torch.cuda.get_device_capability()[0] < 8:
        raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
    
    # Check available GPU memory
    free_memory = torch.cuda.mem_get_info()[0] / 1e9  # in GB
    if free_memory < 20:
        print(f"WARNING: Low GPU memory available ({free_memory:.1f}GB). Training may fail.")
    
    model_kwargs = dict(
        attn_implementation="eager",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    print("Loading model...")
    model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
    
    print("Loading processor...")
    processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
    
    # Use right padding to avoid issues during training
    processor.tokenizer.padding_side = "right"
    
    print("✓ Model and processor loaded successfully")
    
except Exception as e:
    print(f"✗ Error loading model: {str(e)}")
    if "401" in str(e) or "403" in str(e):
        print("ERROR: Authentication failed. Make sure you have accepted the model license and have proper Hugging Face credentials.")
    raise

### Setting up the model

To fine-tune the MedGemma 4B Instruct model efficiently, we will use Low-Rank Adaptation (LoRA), a parameter-efficient fine-tuning method.

LoRA allows us to adapt large models by training only a small number of additional parameters, significantly reducing computational costs while maintaining performance.

The `SFTTrainer` supports a built-in integration with `peft`, which makes it straightforward to efficiently tune LLMs using LoRA. You only need to create a `LoraConfig` and provide it to the trainer.

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

To handle both image and text inputs during training, we define a custom collation function. This function processes the dataset examples into a format suitable for the model, including tokenizing text and preparing image data.

Before you can start your training, you need to define the hyperparameter you want to use in a `SFTConfig` and a custom `collate_fn` to handle the vision processing. The `collate_fn` converts the messages with text and images into a format that the model can understand.

In [None]:
def collate_fn(examples: list[dict[str, any]]):
    """Collate function with robust error handling for batch processing."""
    try:
        texts = []
        images = []
        
        for example in examples:
            # Validate example structure
            if "image" not in example or "messages" not in example:
                raise ValueError("Example missing required fields (image or messages)")
            
            if not isinstance(example["messages"], list) or len(example["messages"]) == 0:
                raise ValueError("Messages field is empty or not a list")
            
            images.append([example["image"]])
            texts.append(
                processor.apply_chat_template(
                    example["messages"], add_generation_prompt=False, tokenize=False
                ).strip()
            )

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        # Create labels from input_ids
        labels = batch["input_ids"].clone()

        # Mask special tokens that should not contribute to loss
        try:
            image_token_id = processor.tokenizer.convert_tokens_to_ids(
                processor.tokenizer.special_tokens_map["boi_token"]
            )
        except KeyError:
            # Fallback if boi_token not found
            image_token_id = 262144  # Default MedGemma image token ID
        
        # Mask padding, image, and other special tokens
        labels[labels == processor.tokenizer.pad_token_id] = -100
        labels[labels == image_token_id] = -100
        labels[labels == 262144] = -100  # MedGemma specific image token ID

        batch["labels"] = labels
        return batch
        
    except Exception as e:
        print(f"✗ Error in collate function: {str(e)}")
        print(f"  Processing batch of {len(examples)} examples")
        raise

We use the SFTConfig class from the trl library to define the training arguments. These arguments control the fine-tuning process, including batch size, learning rate, and gradient accumulation steps.

The training configuration includes:
- 3 training epochs
- Batch size of 1 per device with gradient accumulation steps of 16 (effective batch size of 16)
- Gradient checkpointing to save memory
- AdamW optimizer with fused implementation for better performance
- Learning rate of 2e-4 based on QLoRA paper
- bfloat16 precision for efficient training
- Linear learning rate scheduler with 3% warmup

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-4b-it-ped",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=0.1,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=0.1,
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    push_to_hub=True,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns = False,
    label_names=["labels"],
)

The SFTTrainer simplifies the fine-tuning process by combining the model, dataset, data collator, training arguments, and LoRA configuration into a single workflow. This makes the process streamlined and user-friendly.

You now have every building block you need to create your `SFTTrainer` to start the training of your model.

In [None]:
from trl import SFTTrainer

try:
    print("Initializing SFTTrainer...")
    
    # Validate datasets
    if len(formatted_data["train"]) == 0:
        raise ValueError("Training dataset is empty")
    
    if "validation" in formatted_data and len(formatted_data["validation"]) == 0:
        print("WARNING: Validation dataset is empty")
    
    trainer = SFTTrainer(
        model=model,
        args=args,
        train_dataset=formatted_data["train"],
        eval_dataset=formatted_data["validation"],
        peft_config=peft_config,
        processing_class=processor,
        data_collator=collate_fn,
    )
    
    print("✓ Trainer initialized successfully")
    print(f"  - Training samples: {len(formatted_data['train'])}")
    if "validation" in formatted_data:
        print(f"  - Validation samples: {len(formatted_data['validation'])}")
    
except Exception as e:
    print(f"✗ Error initializing trainer: {str(e)}")
    raise

### Model training

Before starting training, we perform a final memory availability check to ensure we have sufficient resources. The training process requires significant GPU memory, especially when using gradient accumulation.

Once the model, dataset, and training configurations are set up, we can begin the fine-tuning process. The SFTTrainer simplifies this step, allowing us to train the model with just a single command:

Start training by calling the `train()` method.

In [None]:
try:
    print("Performing final memory check before training...")
    
    # Check available GPU memory
    free_memory_gb = torch.cuda.mem_get_info()[0] / 1e9
    total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"Available GPU memory: {free_memory_gb:.1f}GB / {total_memory_gb:.1f}GB")
    
    # Minimum memory requirement for training
    min_required_gb = 20.0  # Minimum free memory required
    if free_memory_gb < min_required_gb:
        raise MemoryError(f"Insufficient GPU memory. Minimum {min_required_gb}GB required, found {free_memory_gb:.1f}GB")
    
    print("Starting training...")
    print("This may take several hours depending on your GPU.")
    
    # Set up training progress monitoring
    import time
    start_time = time.time()
    
    trainer.train()
    
    # Calculate training duration
    duration = time.time() - start_time
    print(f"✓ Training completed in {duration/3600:.2f} hours")
    
except MemoryError as e:
    print(f"✗ Memory check failed: {str(e)}")
    print("SUGGESTIONS:")
    print("  - Reduce batch size (per_device_train_batch_size)")
    print("  - Increase gradient accumulation steps")
    print("  - Use a smaller model variant")
    print("  - Free up GPU memory by closing other applications")
    raise
except KeyboardInterrupt:
    print("\n✗ Training interrupted by user")
    raise
except Exception as e:
    print(f"✗ Training failed: {str(e)}")
    if "CUDA out of memory" in str(e):
        print("ERROR: GPU memory exhausted. Try reducing batch size or using gradient accumulation.")
    raise

After the training is complete, the fine-tuned model can be saved locally and pushed to the Hugging Face Hub using the save_model() method.

**Note:** When using LoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model.

In [None]:
try:
    print("Saving model...")
    trainer.save_model()
    print(f"✓ Model saved to: {args.output_dir}")
    
    # Verify saved files
    import os
    output_path = args.output_dir
    if os.path.exists(output_path) and os.listdir(output_path):
        print(f"✓ Verified model files exist in {output_path}")
    else:
        print(f"WARNING: No files found in output directory {output_path}")
    
except Exception as e:
    print(f"✗ Error saving model: {str(e)}")
    if "403" in str(e) or "401" in str(e):
        print("ERROR: Permission denied. Check your Hugging Face write access.")
    raise

## Test Model Inference and generate radiology reports

After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.

**Note:** Evaluating Generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks.

### Model Evaluation

Before starting the evaluation, we remove the training setup to free up GPU memory and ensure a clean environment for testing

In [None]:
try:
    print("Cleaning up memory...")
    
    # Clean up model and trainer
    if 'model' in locals():
        del model
    if 'trainer' in locals():
        del trainer
    
    # Clear CUDA cache
    torch.cuda.empty_cache()
    
    # Verify memory cleanup
    free_memory = torch.cuda.mem_get_info()[0] / 1e9
    print(f"✓ Memory cleanup complete. Available GPU memory: {free_memory:.1f}GB")
    
except Exception as e:
    print(f"✗ Error during memory cleanup: {str(e)}")
    raise

### Setting up for model testing

We format the validation dataset to match the input structure required by the model. This involves creating a "messages" column that contains the system message and user prompt for each example.

In [None]:
def format_test_data(example: dict[str, any]) -> dict[str, any]:
    """Format test data with error handling."""
    try:
        # Validate required fields
        required_fields = ['age_group', 'gender', 'image']
        for field in required_fields:
            if field not in example:
                raise ValueError(f"Missing required field: {field}")
        
        prompt = USR_PROMPT.format(
            anatomy="chest",
            subject=f"{example['age_group']} {example['gender']}"
        )

        example["messages"] = [
            {
                "role": "system",
                "content": [{"type": "text", "text": SYS_PROMPT}]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                    },
                    {
                        "type": "text",
                        "text": prompt,
                    },
                ],
            },
        ]
        return example
        
    except Exception as e:
        print(f"✗ Error formatting test example: {str(e)}")
        raise

try:
    print("Preparing test data...")
    test_data = data["test"]
    
    if len(test_data) == 0:
        raise ValueError("Test dataset is empty")
    
    test_data = test_data.map(format_test_data)
    print(f"✓ Test data prepared: {len(test_data)} samples")
    
except Exception as e:
    print(f"✗ Error preparing test data: {str(e)}")
    raise

### Model performance on the fine-tuned model

To evaluate the fine-tuned model's performance, we load the model with PEFT adapter and processor, configure the generation settings, and prepare the prompts and images for testing.

We will use both qualitative and quantitative metrics to evaluate the model:
- **Qualitative**: Manual comparison of ground truth vs generated reports
- **Quantitative**: Automated metrics including BLEU, ROUGE, and BERTScore

In [None]:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor

try:
    print("Loading fine-tuned model for evaluation...")
    
    # Check if output directory exists
    import os
    if not os.path.exists(args.output_dir):
        raise FileNotFoundError(f"Output directory not found: {args.output_dir}")
    
    model_kwargs = dict(
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    # Load the fine-tuned model
    print(f"Loading model from: {args.output_dir}")
    model = AutoModelForImageTextToText.from_pretrained(
        args.output_dir, **model_kwargs
    )
    
    from transformers import GenerationConfig
    gen_cfg = GenerationConfig.from_pretrained(model_id)
    gen_cfg.update(
        do_sample          = False,
        top_k              = None,
        top_p              = None,
        cache_implementation = "dynamic"
    )
    model.generation_config = gen_cfg
    
    # Load processor
    processor = AutoProcessor.from_pretrained(args.output_dir)
    tok = processor.tokenizer
    
    # Configure tokenization
    model.config.pad_token_id = tok.pad_token_id
    model.generation_config.pad_token_id = tok.pad_token_id
    
    def chat_to_prompt(chat_turns):
        """Convert chat turns to prompt with error handling."""
        try:
            return processor.apply_chat_template(
                chat_turns,
                add_generation_prompt=True,
                tokenize=False
            )
        except Exception as e:
            print(f"✗ Error in chat_to_prompt: {str(e)}")
            raise
    
    # Prepare prompts and images
    print("Preparing evaluation data...")
    prompts = []
    images = []
    
    for i, chat in enumerate(test_data["messages"]):
        try:
            prompts.append(chat_to_prompt(chat))
            images.append(test_data["image"][i])
        except Exception as e:
            print(f"✗ Error processing sample {i}: {str(e)}")
            raise
    
    if len(prompts) != len(images):
        raise ValueError(f"Mismatch between prompts ({len(prompts)}) and images ({len(images)})")
    
    print(f"✓ Evaluation data prepared: {len(prompts)} samples")
    
    # Install evaluation metrics if not already available
    try:
        import evaluate
        print("✓ Evaluation metrics library already installed")
    except ImportError:
        print("Installing evaluation metrics...")
        !pip install -q evaluate bert-score
        import evaluate
        print("✓ Evaluation metrics installed")
    
    # Load evaluation metrics
    bleu = evaluate.load("bleu")
    rouge = evaluate.load("rouge")
    bertscore = evaluate.load("bertscore")
    
    print("✓ Evaluation metrics loaded successfully")
    
except Exception as e:
    print(f"✗ Error loading model for evaluation: {str(e)}")
    raise

The predict_one function takes a prompt and an image as input, processes them using the model's processor, and generates a response. The function ensures that the model's output is decoded into human-readable text.

We will use the predict_one to generate responses for the entire test dataset and then compute quantitative metrics to evaluate performance.

In [None]:
import torch
from typing import Union, Dict, Any, List
from transformers import AutoModelForImageTextToText, AutoProcessor

def predict_one(
    prompt,
    image,
    model,
    processor,
    *,
    device="cuda",
    dtype=torch.bfloat16,
    disable_compile=True,
    **gen_kwargs
) -> str:
    """Generate prediction for single sample with comprehensive error handling."""
    try:
        # Validate inputs
        if not prompt or not isinstance(prompt, str):
            raise ValueError("Prompt must be a non-empty string")
        
        if image is None:
            raise ValueError("Image cannot be None")
        
        # Process inputs
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(
            device, dtype=dtype
        )
        
        plen = inputs["input_ids"].shape[-1]
        
        # Generate prediction
        with torch.inference_mode():
            ids = model.generate(
                **inputs,
                disable_compile=disable_compile,
                **gen_kwargs
            )
        
        # Decode and return result
        result = processor.decode(ids[0, plen:], skip_special_tokens=True)
        
        if not result or result.isspace():
            raise ValueError("Model generated empty response")
        
        return result
        
    except Exception as e:
        print(f"✗ Prediction failed: {str(e)}")
        if "CUDA out of memory" in str(e):
            print("ERROR: GPU memory exhausted during inference")
        raise

We will use the predict_one to generate a response for a sample from the dataset. This involves preparing the prompt and running the prediction function.

In [None]:
def evaluate_quantitative(prompts, images, references, num_samples=5):
    """Evaluate model quantitatively on multiple samples."""
    try:
        print(f"Running quantitative evaluation on {num_samples} samples...")
        
        predictions = []
        sample_references = []
        
        # Generate predictions for selected samples
        for i in range(min(num_samples, len(prompts))):
            try:
                print(f"Processing sample {i+1}/{num_samples}...")
                
                pred = predict_one(
                    prompt=prompts[i],
                    image=images[i],
                    model=model,
                    processor=processor,
                    max_new_tokens=500
                )
                
                predictions.append(pred)
                sample_references.append(references[i])
                
            except Exception as e:
                print(f"✗ Error processing sample {i}: {str(e)}")
                continue
        
        if len(predictions) == 0:
            raise ValueError("No successful predictions generated")
        
        # Compute BLEU score
        bleu_result = bleu.compute(
            predictions=predictions,
            references=[[ref] for ref in sample_references]
        )
        
        # Compute ROUGE scores
        rouge_result = rouge.compute(
            predictions=predictions,
            references=sample_references
        )
        
        # Compute BERTScore
        bertscore_result = bertscore.compute(
            predictions=predictions,
            references=sample_references,
            lang="en"
        )
        
        # Display results
        print("\n" + "="*80)
        print("QUANTITATIVE EVALUATION RESULTS")
        print("="*80)
        
        print(f"BLEU Score: {bleu_result['bleu']:.4f}")
        print(f"ROUGE-1: {rouge_result['rouge1']:.4f}")
        print(f"ROUGE-2: {rouge_result['rouge2']:.4f}")
        print(f"ROUGE-L: {rouge_result['rougeL']:.4f}")
        print(f"BERTScore Precision: {sum(bertscore_result['precision']) / len(bertscore_result['precision']):.4f}")
        print(f"BERTScore Recall: {sum(bertscore_result['recall']) / len(bertscore_result['recall']):.4f}")
        print(f"BERTScore F1: {sum(bertscore_result['f1']) / len(bertscore_result['f1']):.4f}")
        
        return {
            'bleu': bleu_result['bleu'],
            'rouge': rouge_result,
            'bertscore': bertscore_result
        }
        
    except Exception as e:
        print(f"✗ Quantitative evaluation failed: {str(e)}")
        raise

try:
    print("Running comprehensive evaluation...")
    
    # Get references (ground truth reports)
    references = test_data["report"][:len(prompts)]
    
    # Run quantitative evaluation
    metrics = evaluate_quantitative(prompts, images, references, num_samples=5)
    
    # Run qualitative evaluation on specific sample
    print("\n" + "="*80)
    print("QUALITATIVE EVALUATION")
    print("="*80)
    
    # Select sample for detailed comparison
    idx = 3
    
    # Validate sample
    if idx >= len(test_data["messages"]):
        raise IndexError(f"Sample index {idx} out of range. Dataset has {len(test_data['messages'])} samples.")
    
    chat = test_data["messages"][idx]
    
    # Create prompt
    prompt = processor.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=False
    )
    
    # Run prediction
    print(f"Generating response for sample {idx}...")
    answer = predict_one(
        prompt   = prompt,
        image    = test_data["image"][idx],
        model    = model,
        processor= processor,
        max_new_tokens = 500
    )
    
    # Display results
    import textwrap
    print("\nGROUND TRUTH REPORT:")
    print(textwrap.fill(test_data["report"][idx], 80))
    print("\nMODEL GENERATED REPORT:")
    print(textwrap.fill(answer, 80))
    
    # Show image
    print("\nDisplaying X-ray image:")
    test_data["image"][idx]
    
    print("\n✓ Comprehensive evaluation completed successfully")
    
except Exception as e:
    print(f"✗ Evaluation failed: {str(e)}")
    raise