# Evaluation

> Functions for evaluating LLaVA models, including prediction generation and metric calculation.

In [1]:
#| default_exp evaluation

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os
import gc
import json
import torch
import time
from tqdm.auto import tqdm
from typing import Union, Optional, Dict, Any, List # Added Dict, Any, List
import wandb # Added wandb import

from fastai.learner import Learner
from fastai.data.load import DataLoader
from transformers import AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast

# Add dependency for Levenshtein distance calculation
try:
    import Levenshtein
except ImportError:
    print("Warning: `pip install python-Levenshtein` is required for ANLS evaluation.")
    Levenshtein = None

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root
project_root = Path(os.getcwd())
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())
if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")
    
# Import necessary llava components
try:
    from llava.utils import load_config, init_wandb # Added init_wandb
    from llava.data.loading import get_test_dataloader, LLaVADataBlockStage2 # Assumes this function exists in 10_data_loading
    from llava.model.baseline import BaselineLLaVAModel # Or AdaptiveLLaVAModel later
    from llava.data.preprocessing import tokenizer, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX_PLACEHOLDER # Import tokenizer & constants
    # Import PEFT related things if needed for model loading check
    try:
        from peft import PeftModel
        _peft_available = True
    except ImportError:
        PeftModel = None 
        _peft_available = False
    # from llava.training.stage2 import get_stage2_learner # Maybe needed for loading full learner state
except ImportError as e:
    print(f"Error importing llava modules: {e}. Make sure nbdev_export was run.")
    # Define placeholders if running standalone or during initial setup
    def load_config(path): return {}
    def init_wandb(*args, **kwargs): pass
    def get_test_dataloader(config, dataset_name, dblock=None): raise NotImplementedError
    class BaselineLLaVAModel(torch.nn.Module): 
        def __init__(self, *args, **kwargs): 
            super().__init__()
            self.dummy = torch.nn.Linear(1,1)
            self.image_token_index_marker = -200 # Define necessary attributes
            self.projector = torch.nn.Linear(10,10)
            self.language_model = self # Make model act like LLM for generate
            self.vision_tower = self # Dummy vision tower
            self.config = {} # Dummy config
        def encode_image(self, *args, **kwargs): return torch.randn(1,5,10) # Dummy image features B, P, D
        def get_input_embeddings(self): return torch.nn.Embedding(100, 10)
        def forward(self, *args, **kwargs): return CausalLMOutputWithPast(logits=torch.randn(1, 10, 100))
        def generate(self, *args, **kwargs): return torch.randint(0, 100, (1, kwargs.get('max_new_tokens', 10)))
        def to(self, *args, **kwargs): return self # Avoid device errors in dummy
        def eval(self): pass
    class DummyTokenizer:
         def __init__(self): self.eos_token_id=1; self.pad_token_id=0
         def decode(self, *args, **kwargs): return "dummy decoded text"
    tokenizer = DummyTokenizer()
    DEFAULT_IMAGE_TOKEN = "<image>"
    IGNORE_INDEX = -100
    IMAGE_TOKEN_INDEX_PLACEHOLDER = -200
    LLaVADataBlockStage2 = None # Placeholder
    PeftModel = None
    _peft_available = False
    # def get_stage2_learner(config): raise NotImplementedError

Project root already in sys.path: /workspace/llava


## Step 5.2: Prediction Generation

This function takes a trained fastai `Learner` and a test `DataLoader`, generates predictions, decodes them into text, and saves them to a JSON file. It now also includes efficiency logging.

In [4]:
#| export
def generate_predictions(learner: Learner, 
                         test_dl: DataLoader, 
                         output_file: Union[str, Path], 
                         max_len: int = 200, # Max generation length
                         temperature: float = 0.2, # Generation temperature
                         top_p: Optional[float] = None, # Nucleus sampling top_p
                        ):
    """Generates predictions for a test dataloader and saves them to a JSON Lines file.
    Includes efficiency logging (latency, peak VRAM).

    Uses the HF generate method on the underlying language model component,
    manually preparing the combined image and text embeddings.

    Args:
        learner: Trained fastai Learner object containing the model.
                 Expected to have `learner.model`, `learner.dls.device`,
                 and potentially `learner.tokenizer` or `learner.dls.tokenizer`.
        test_dl: DataLoader for the test set.
        output_file: Path to save the JSON Lines prediction file.
        max_len: Maximum number of new tokens to generate.
        temperature: Sampling temperature for generation. Set to 0 for greedy decoding.
        top_p: If set (and temperature > 0), use nucleus sampling with this top_p value.
    """
    output_file = Path(output_file)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    model = learner.model
    # Get tokenizer - check learner, then dls, then global scope
    if hasattr(learner, 'tokenizer') and learner.tokenizer is not None:
        tok = learner.tokenizer
    elif hasattr(learner.dls, 'tokenizer') and learner.dls.tokenizer is not None:
        tok = learner.dls.tokenizer
    else:
        global tokenizer # Use the global tokenizer imported/defined earlier
        if tokenizer is None:
             raise AttributeError("Tokenizer not found in learner, dls, or globally. Cannot decode predictions.")
        print("Warning: Using global tokenizer instance for generation.")
        tok = tokenizer
        
    # Ensure model components exist
    if not all(hasattr(model, attr) and getattr(model, attr) is not None 
               for attr in ['vision_tower', 'projector', 'language_model', 'get_input_embeddings', 'encode_image']):
        raise AttributeError("Model is missing required components (vision_tower, projector, language_model, etc.)")
        
    model.eval() # Set model to evaluation mode
    results = []
    total_time = 0
    num_samples = 0
    device = learner.dls.device # Get device from learner

    print(f"Generating predictions for {len(test_dl.dataset)} samples...")
    print(f"Saving predictions to: {output_file}")
    
    # Reset CUDA memory stats before generation loop for accurate peak measurement
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device=device) 
        print("Reset CUDA peak memory stats before generation.")

    # Use context manager for file writing
    with torch.no_grad(), open(output_file, 'w') as f_out:
        # Iterate through batches
        for batch in tqdm(test_dl, desc="Generating Predictions"):
            start_time = time.time() # Start timing for this batch
            
            # Move batch items to appropriate device
            if not isinstance(batch, dict):
                 print(f"Warning: Expected batch to be a dict, got {type(batch)}. Attempting to proceed assuming basic structure.")
                 pixel_values = batch[0].to(device)
                 input_ids = batch[1].to(device)
                 batch_size = pixel_values.shape[0]
                 # Need a way to get IDs if batch isn't dict
                 batch_ids = [f"unknown_{num_samples+i}" for i in range(batch_size)]
            else:
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device) 
                batch_size = pixel_values.shape[0]
                # Assuming the dataloader added the original sample ID to the batch dict
                # This depends on modifying the DataBlock/Dataset/Collate
                # For now, let's try getting it from the dataloader's dataset if possible
                batch_ids = []
                start_ds_idx = test_dl.num_workers * test_dl.offs if hasattr(test_dl, 'offs') else num_samples
                for i in range(batch_size):
                    current_item_idx = start_ds_idx + i
                    item_id = f"sample_{current_item_idx}" # Default ID
                    try:
                        if hasattr(test_dl, 'dataset') and current_item_idx < len(test_dl.dataset):
                            item_data = test_dl.dataset[current_item_idx]
                            # Adapt based on what dataset[idx] returns
                            if hasattr(item_data, 'sample_id'): # If LLaVASample
                                item_id = item_data.sample_id
                            elif isinstance(item_data, dict) and 'id' in item_data: # If dict from JSONL
                                item_id = item_data['id']
                            # Add other checks if dataset format is different
                    except Exception:
                        pass # Use default ID
                    batch_ids.append(item_id)

            # --- Generation using underlying HF LLM --- 
            # 1. Encode images and project features
            image_features = model.encode_image(pixel_values) # (B, P, D_clip)
            if image_features is None: 
                print("Warning: Image encoding failed for batch. Skipping.")
                num_samples += batch_size # Account for skipped samples in latency calculation
                total_time += (time.time() - start_time) # Add time spent
                continue # Skip this batch
            projected_features = model.projector(image_features) # (B, P, D_llm)
            
            outputs_list = []
            # Process each sample in the batch individually for embedding preparation
            for i in range(batch_size):
                current_input_ids = input_ids[i:i+1] 
                current_proj_features = projected_features[i:i+1]
                
                marker = getattr(model, 'image_token_index_marker', IMAGE_TOKEN_INDEX_PLACEHOLDER)
                image_token_indices = torch.where(current_input_ids[0] == marker)[0]
                
                if len(image_token_indices) == 0:
                    print(f"Warning: Image token marker {marker} not found in sample {num_samples + i}. Skipping generation.")
                    outputs_list.append(torch.tensor([tok.eos_token_id], device=device)) 
                    continue
                
                image_token_start_index = image_token_indices[0].item()

                # Prepare prompt embeddings
                input_ids_no_marker = current_input_ids.clone()
                input_ids_no_marker[input_ids_no_marker == marker] = 0 
                text_embeddings = model.get_input_embeddings()(input_ids_no_marker)
                
                text_emb_before = text_embeddings[:, :image_token_start_index]
                text_emb_after = text_embeddings[:, image_token_start_index + 1:]
                
                prompt_embeds = torch.cat([
                    text_emb_before,
                    current_proj_features.to(text_embeddings.device, dtype=text_embeddings.dtype),
                    text_emb_after
                ], dim=1)
                
                # Generate using the underlying LLM
                llm_component = model.language_model
                if _peft_available and isinstance(llm_component, PeftModel):
                     llm_component = llm_component.base_model.model 
                
                prompt_attention_mask = torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
                
                try:
                    gen_params = {
                        "inputs_embeds": prompt_embeds,
                        "attention_mask": prompt_attention_mask,
                        "max_new_tokens": max_len,
                        "eos_token_id": tok.eos_token_id,
                        "pad_token_id": tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id,
                        "do_sample": (temperature > 0),
                        "num_beams": 1
                    }
                    if temperature > 0:
                        gen_params["temperature"] = temperature
                        if top_p is not None:
                            gen_params["top_p"] = top_p
                    else: 
                        gen_params["temperature"] = 1.0 
                        gen_params["do_sample"] = False

                    output_ids_gen = llm_component.generate(**gen_params)
                    
                    output_ids_gen = output_ids_gen[:, prompt_embeds.shape[1]:]
                    outputs_list.append(output_ids_gen[0]) 
                except Exception as gen_e:
                     print(f"Error during generation for sample {num_samples + i}: {gen_e}")
                     outputs_list.append(torch.tensor([tok.eos_token_id], device=device)) # Fallback
            # --- End Generation for Batch --- #

            end_time = time.time() # End timing for this batch
            total_time += (end_time - start_time)
            num_samples_in_batch = batch_size

            # --- Decode and Save Results for the Batch --- 
            # Use the batch_ids retrieved earlier
            item_ids = batch_ids
                 
            # Decode and write each result in the batch
            for i, gen_ids in enumerate(outputs_list):
                decoded_text = tok.decode(gen_ids.cpu(), skip_special_tokens=True).strip()
                result_entry = {
                    "id": item_ids[i], # Use the ID retrieved from the dataset/batch
                    "prediction": decoded_text,
                }
                f_out.write(json.dumps(result_entry) + '\n')
                results.append(result_entry)
            
            num_samples += num_samples_in_batch # Update total sample count

    # --- Log Efficiency Metrics (Step 5.4) --- 
    avg_latency = (total_time / num_samples) * 1000 if num_samples > 0 else 0
    print(f"Finished generation. Saved {len(results)} predictions to {output_file}")
    print(f"Average inference latency: {avg_latency:.2f} ms/sample")
    
    if wandb.run is not None:
        wandb.log({"eval/avg_inference_latency_ms": avg_latency})

    if torch.cuda.is_available():
        peak_vram_gb = torch.cuda.max_memory_allocated(device=device) / (1024**3) 
        print(f"Peak Inference VRAM used: {peak_vram_gb:.2f} GB")
        if wandb.run is not None:
            wandb.log({"eval/peak_inference_vram_gb": peak_vram_gb})
        # Reset peak stats after logging for this run
        torch.cuda.reset_peak_memory_stats(device=device) 
        
    return results # Return list of predictions

## Step 5.3 & 5.5: Integrate External Evaluation Scripts (Implement VQAv2)

Define functions for calling external evaluation scripts. The `evaluate_vqa` function is now implemented using the official VQA tools.

In [5]:
#| export
def evaluate_vqa(preds_file: Union[str, Path], 
                 gt_file: Union[str, Path], 
                 ques_file: Optional[Union[str, Path]] = None, # Path to questions file if needed by API
                 **kwargs):
    """Evaluates VQAv2 predictions using the official VQA evaluation tools.
    
    Args:
        preds_file: Path to the JSON Lines prediction file generated by `generate_predictions`
                    (expected format: {'id': question_id, 'prediction': answer}).
        gt_file: Path to the ground truth annotation file (e.g., v2_mscoco_val2014_annotations.json).
        ques_file: Path to the questions file (e.g., v2_OpenEnded_mscoco_val2014_questions.json). 
                   Required by the standard VQA eval tools.
        **kwargs: Additional arguments (unused currently).
    
    Returns:
        Dictionary containing VQA scores (e.g., {'overall': score, 'yes/no': ..., 'number': ..., 'other': ...}), 
        or {'overall': 0.0} if evaluation fails.
    """
    print(f"--- VQA Evaluation --- ")
    print(f"Predictions file: {preds_file}")
    print(f"Ground Truth Annotation file: {gt_file}")
    print(f"Ground Truth Questions file: {ques_file}")
    results = {"overall": 0.0} # Default return value
    
    if ques_file is None:
        print("Error: Path to questions file (ques_file) is required for VQA evaluation.")
        return results

    try:
        # --- Load VQA Tools --- 
        # Assumes the VQA evaluation tools directory ('PythonHelperTools') is in the Python path.
        # User needs to download from https://visualqa.org/download.html
        # and add the PythonHelperTools directory to their PYTHONPATH or sys.path
        try:
            from vqaTools.vqa import VQA
            from vqaEvaluation.vqaEval import VQAEval
        except ImportError:
            print("Error: VQA evaluation tools (vqaTools/vqaEvaluation) not found in PYTHONPATH.")
            print("Please download from https://visualqa.org/download.html and add PythonHelperTools to your path.")
            return results
        
        # --- Load and Reformat Predictions --- 
        vqa_preds_formatted = []
        print(f"Loading and reformatting predictions from {preds_file}...")
        with open(preds_file, 'r') as f_pred:
            for line in f_pred:
                try:
                    pred_item = json.loads(line.strip())
                    # VQA API expects integer question_id
                    question_id = int(pred_item['id']) 
                    answer = pred_item['prediction']
                    vqa_preds_formatted.append({
                        'question_id': question_id,
                        'answer': answer
                    })
                except (json.JSONDecodeError, KeyError, ValueError) as e:
                    print(f"Warning: Skipping invalid prediction line: {line.strip()}. Error: {e}")
        print(f"Loaded {len(vqa_preds_formatted)} predictions for evaluation.")
        if not vqa_preds_formatted:
            print("Error: No valid predictions found in the file.")
            return results
        
        # --- Run VQA Evaluation --- 
        print("Initializing VQA evaluation...")
        vqa_ann = VQA(gt_file, ques_file) # Load annotations and questions
        vqa_pred = vqa_ann.loadRes(vqa_preds_formatted, ques_file) # Load predictions
        
        vqa_eval = VQAEval(vqa_ann, vqa_pred, n=2) # n=2 is standard for VQA
        
        print("Running evaluation...")
        vqa_eval.evaluate() # Perform evaluation
        
        print("Evaluation complete. Results:")
        vqa_eval.showEvals() # Print detailed results
        
        # Extract results into a dictionary
        results = {
            "overall": vqa_eval.accuracy['overall'],
            "yes/no": vqa_eval.accuracy['perAnswerType']['yes/no'],
            "number": vqa_eval.accuracy['perAnswerType']['number'],
            "other": vqa_eval.accuracy['perAnswerType']['other'],
        }
        print(f"Parsed VQA Accuracy: {results}")

    except FileNotFoundError as e:
         print(f"Error: Required file not found: {e}")
    except Exception as e:
        print(f"Error during VQA evaluation: {e}")
        import traceback
        traceback.print_exc()

    # Log metric to W&B if active
    if wandb.run is not None:
        wandb.log({"eval/vqa_score_overall": results.get("overall", 0.0)})
        wandb.log({"eval/vqa_score_yes_no": results.get("yes/no", 0.0)})
        wandb.log({"eval/vqa_score_number": results.get("number", 0.0)})
        wandb.log({"eval/vqa_score_other": results.get("other", 0.0)})
        
    return results

## Step 5.6: Implement ANLS Calculation for TextVQA/DocVQA

In [6]:
#| export
def calculate_single_anls(prediction: str, ground_truths: List[str], threshold=0.5) -> float:
    """Calculates the Average Normalized Levenshtein Similarity (ANLS) for a single prediction.
    
    Based on the ANLS definition for ICDAR 2019 Robust Reading Challenge on Scene Text VQA.
    https://rrc.cvc.uab.es/?ch=11&com=tasks -> Task 3 -> Evaluation

    Args:
        prediction: The predicted answer string.
        ground_truths: A list of ground truth answer strings.
        threshold: The Normalized Levenshtein Distance threshold (default: 0.5).
                   If NLD > threshold, the similarity is 0.

    Returns:
        The ANLS score for this prediction (float between 0 and 1).
    """
    if Levenshtein is None:
        print("Warning: Levenshtein library not available. Returning 0 for ANLS.")
        return 0.0
    
    if not ground_truths: # Handle case with no ground truths
        return 0.0
    if not prediction: # Handle case with empty prediction
        # Similarity is 1 if any GT is also empty, 0 otherwise
        return 1.0 if any(not gt for gt in ground_truths) else 0.0
        
    max_similarity = 0.0
    prediction_lower = prediction.lower() # Typically case-insensitive

    for gt in ground_truths:
        gt_lower = gt.lower()
        if not gt_lower:
            # Similarity is 0 if prediction is non-empty and GT is empty
            similarity = 0.0
        else:
            distance = Levenshtein.distance(prediction_lower, gt_lower)
            max_len = max(len(prediction_lower), len(gt_lower))
            nld = distance / max_len # Normalized Levenshtein Distance

            if nld <= threshold:
                similarity = 1.0 - nld
            else:
                similarity = 0.0
        
        max_similarity = max(max_similarity, similarity)
        
    return max_similarity

#| export
def evaluate_textvqa(preds_file: Union[str, Path], 
                     gt_file: Union[str, Path], 
                     anls_threshold=0.5,
                     **kwargs):
    """Evaluates TextVQA/DocVQA predictions using the ANLS metric.

    Assumes standard TextVQA JSON format for ground truth and
    a JSON Lines file for predictions with 'id' (question_id) and 'prediction'.

    Args:
        preds_file: Path to the JSON Lines prediction file.
        gt_file: Path to the ground truth annotation file (e.g., TextVQA_0.5.1_val.json).
        anls_threshold: NLD threshold for ANLS calculation (default: 0.5).
        **kwargs: Additional arguments (unused).

    Returns:
        Dictionary containing ANLS score (e.g., {'anls': score}).
    """
    print(f"--- TextVQA/DocVQA ANLS Evaluation --- ")
    print(f"Predictions file: {preds_file}")
    print(f"Ground Truth file: {gt_file}")
    
    if Levenshtein is None:
        print("Error: `pip install python-Levenshtein` is required for ANLS evaluation.")
        return {"anls": 0.0}
        
    total_anls = 0.0
    count = 0
    results = {"anls": 0.0} # Default return value
    
    try:
        # Load ground truth
        with open(gt_file, 'r') as f_gt:
            ground_truth_data = json.load(f_gt)
            
        # Format ground truth into a dictionary {question_id: [answer1, answer2, ...]}
        # Adapting for TextVQA format which uses 'question_id' and 'answers'
        if 'data' in ground_truth_data:
            gt_dict = {item['question_id']: item['answers'] for item in ground_truth_data['data']}
        else:
             print(f"Warning: Unexpected GT file format in {gt_file}. Expected a 'data' key.")
             gt_dict = {}
             
        # Load predictions
        predictions = []
        with open(preds_file, 'r') as f_pred:
            for line in f_pred:
                try:
                    predictions.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    print(f"Warning: Skipping invalid JSON line in predictions file: {line.strip()}")
                    
        print(f"Loaded {len(gt_dict)} ground truth items and {len(predictions)} predictions.")
        
        # Calculate ANLS for each prediction
        for pred_item in tqdm(predictions, desc="Calculating ANLS"):
            pred_id = pred_item.get('id')
            prediction = pred_item.get('prediction', '')
            
            if pred_id is None:
                print(f"Warning: Prediction item missing 'id': {pred_item}")
                continue
                
            # Ensure ID matching (GT keys are strings in TextVQA JSON)
            gt_answers = gt_dict.get(str(pred_id))
            
            if gt_answers is None:
                # print(f"Warning: No ground truth found for prediction ID: {pred_id}")
                # Some benchmarks might have predictions for samples not in the target GT split
                continue # Skip if no ground truth for this prediction
            
            sample_anls = calculate_single_anls(prediction, gt_answers, threshold=anls_threshold)
            total_anls += sample_anls
            count += 1
            
        if count > 0:
            final_anls = total_anls / count
            results = {"anls": final_anls}
            print(f"Evaluated {count} samples.")
            print(f"Average Normalized Levenshtein Similarity (ANLS): {final_anls:.4f}")
        else:
             print("No matching predictions found for ground truth IDs. ANLS is 0.")
             results = {"anls": 0.0}

    except FileNotFoundError as e:
        print(f"Error: Required file not found: {e}")
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
    except Exception as e:
        print(f"Error during ANLS calculation: {e}")
        import traceback
        traceback.print_exc()

    # Log metric to W&B if active
    if wandb.run is not None:
        wandb.log({"eval/anls": results.get("anls", 0.0)})
        
    return results

#### Test ANLS Calculation

In [7]:
#| test
if Levenshtein:
    # Test cases for calculate_single_anls
    assert abs(calculate_single_anls("apple", ["apple"]) - 1.0) < 1e-6, "Test 1 Failed (Exact Match)"
    assert abs(calculate_single_anls("apple", ["aple"]) - (1 - 1/5)) < 1e-6, "Test 2 Failed (One deletion)"
    assert abs(calculate_single_anls("apple", ["banana"]) - 0.0) < 1e-6, "Test 3 Failed (High distance)"
    assert abs(calculate_single_anls("", ["apple"]) - 0.0) < 1e-6, "Test 4 Failed (Empty prediction)"
    assert abs(calculate_single_anls("apple", [""]) - 0.0) < 1e-6, "Test 5 Failed (Empty GT)"
    assert abs(calculate_single_anls("apple", ["APPLE"]) - 1.0) < 1e-6, "Test 6 Failed (Case difference)"
    assert abs(calculate_single_anls("apple", ["banana", "apple", "orange"]) - 1.0) < 1e-6, "Test 7 Failed (Multiple GT, one match)"
    assert abs(calculate_single_anls("apple pie", ["apple"]) - 0.0) < 1e-6, "Test 8 Failed (NLD > threshold)" # distance=4, max_len=9, nld=4/9 < 0.5 -> should be non-zero
    # Re-test 8, NLD=4/9=0.44, Similarity=1-0.44 = 0.56
    assert abs(calculate_single_anls("apple pie", ["apple"]) - (1 - 4/9)) < 1e-6, "Test 8 Failed (NLD < threshold)"
    print("ANLS calculation tests passed.")
else:
    # Print expected vs got if Levenshtein is not installed
    print(f"Test 1 (Exact Match): Expected 1.0, Got {calculate_single_anls('apple', ['apple'])} - Levenshtein not installed.")
    print(f"Test 2 (One deletion): Expected >0.8, Got {calculate_single_anls('apple', ['aple'])} - Levenshtein not installed.")
    print(f"Test 3 (High distance): Expected 0.0, Got {calculate_single_anls('apple', ['banana'])} - Levenshtein not installed.")
    print(f"Test 4 (Empty prediction): Expected 0.0, Got {calculate_single_anls('', ['apple'])} - Levenshtein not installed.")
    print(f"Test 5 (Empty GT): Expected 0.0, Got {calculate_single_anls('apple', [''])} - Levenshtein not installed.")
    print(f"Test 6 (Case difference): Expected 1.0, Got {calculate_single_anls('apple', ['APPLE'])} - Levenshtein not installed.")
    print(f"Test 7 (Multiple GT, one match): Expected 1.0, Got {calculate_single_anls('apple', ['banana', 'apple', 'orange'])} - Levenshtein not installed.")

Test 1 (Exact Match): Expected 1.0, Got 0.0 - Levenshtein not installed.
Test 2 (One deletion): Expected >0.8, Got 0.0 - Levenshtein not installed.
Test 3 (High distance): Expected 0.0, Got 0.0 - Levenshtein not installed.
Test 4 (Empty prediction): Expected 0.0, Got 0.0 - Levenshtein not installed.
Test 5 (Empty GT): Expected 0.0, Got 0.0 - Levenshtein not installed.
Test 6 (Case difference): Expected 1.0, Got 0.0 - Levenshtein not installed.
Test 7 (Multiple GT, one match): Expected 1.0, Got 0.0 - Levenshtein not installed.


## Evaluation Script

In [8]:
#| export
def run_evaluation(config_path: Union[str, Path], 
                   model_checkpoint_path: Optional[Union[str, Path]] = None, 
                   dataset_name: str = 'vqav2_test', # Example default
                   output_filename: Optional[str] = None,
                   **kwargs # Allow passing generation args like temp, top_p
                   ):
    """Runs the evaluation pipeline: load model, generate predictions, evaluate.
    
    Args:
        config_path: Path to the main configuration YAML.
        model_checkpoint_path: Path to the specific model checkpoint base name/directory to load 
                               (e.g., '/path/to/output/models/stage2_llava_lora'). 
                               If None, uses the path derived from config['paths']['stage2_model_weights'].
                               Expects associated files like '_projector_final.pth' and '_lora_adapters'/'_full.pth'.
        dataset_name: Name of the dataset split to evaluate (e.g., 'vqav2_test', 'textvqa_val').
                      This key should exist in `config['paths']` pointing to dataset info.
        output_filename: Name for the prediction output file (defaults based on model/dataset).
        **kwargs: Additional arguments passed to `generate_predictions` (e.g., temperature, top_p).
    """
    print(f"--- Starting Evaluation Run --- ")
    print(f"Config: {config_path}")
    print(f"Dataset: {dataset_name}")
    
    config = load_config(config_path)
    output_dir = Path(config['paths']['output_dir'])
    eval_output_dir = output_dir / 'eval_results' / dataset_name
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    models_dir = output_dir / 'models' # Where models are saved
    
    learner = None
    run = None # Initialize wandb run object
    try:
        # --- Initialize W&B --- 
        if config.get('logging', {}).get('wandb', {}).get('enabled', False):
             model_id_for_run = Path(model_checkpoint_path or config['paths']['stage2_model_weights']).stem
             run_name = f"eval_{model_id_for_run}_{dataset_name}_{time.strftime('%Y%m%d_%H%M%S')}"
             run = init_wandb(config, job_type="evaluation", run_name=run_name)
        
        # 1. Determine Model Path Base Name
        if model_checkpoint_path is None:
            model_base_name = Path(config['paths']['stage2_model_weights']).stem
            model_load_base = models_dir / model_base_name
            print(f"Using model base name from config: {model_base_name}")
        else:
            model_checkpoint_path = Path(model_checkpoint_path)
            model_base_name = model_checkpoint_path.stem 
            model_load_base = model_checkpoint_path 
            print(f"Using provided model base path: {model_load_base}")
        
        # 2. Load Model Components Manually
        print("Loading model components for evaluation...")
        model = BaselineLLaVAModel(config=config) 
        
        # Load projector weights
        proj_weights_path = model_load_base.parent / (model_load_base.name + "_projector_final.pth")
        if proj_weights_path.exists():
            model.projector.load_state_dict(torch.load(proj_weights_path, map_location='cpu'))
            print(f"Loaded projector weights from {proj_weights_path}")
        else:
            print(f"Warning: Projector weights not found at {proj_weights_path}. Using initial weights.")
            
        # Load LoRA adapters or full LLM weights
        use_lora_config = config.get('model', {}).get('peft', {}).get('use_lora', False)
        lora_adapter_path = model_load_base.parent / (model_load_base.name + "_lora_adapters")

        if use_lora_config:
            if lora_adapter_path.exists() and _peft_available and isinstance(model.language_model, PeftModel):
                print(f"Loading LoRA adapters from {lora_adapter_path}")
                try:
                    if hasattr(model.language_model, 'load_adapter'):
                         model.language_model.load_adapter(str(lora_adapter_path), adapter_name="default")
                         print("LoRA adapters loaded successfully.")
                    else:
                         print("Warning: Model's language_model does not have 'load_adapter' method. PEFT setup might be incorrect.")
                except Exception as e:
                    print(f"Error loading LoRA adapters: {e}. Model LLM may not be correctly wrapped or adapters mismatch.")
            else:
                print(f"Warning: LoRA enabled in config but adapters not found at {lora_adapter_path} or PEFT issue.")
        else:
             print("LoRA not used. Assuming full LLM fine-tuning.")
             print("Warning: Loading full fine-tuned LLM state is not fully implemented here yet. Model may use base weights.")
             
        # 3. Create Learner Shell for Convenience
        dummy_dl = DataLoader(list(range(1)), bs=1)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        dummy_dls = DataLoaders(dummy_dl, dummy_dl) 
        dummy_dls.device = device 
        learner = Learner(dls=dummy_dls, model=model, loss_func=lambda x,y: 0)
        
        global tokenizer 
        if tokenizer is None:
             raise RuntimeError("Tokenizer not loaded, cannot proceed with evaluation.")
        learner.tokenizer = tokenizer 
        print(f"Minimal Learner shell created. Model moved to {device}.")
        
        # 4. Load Test Data
        print(f"Loading test data for {dataset_name}...")
        if LLaVADataBlockStage2 is None:
             raise RuntimeError("LLaVADataBlockStage2 is not defined. Cannot load test data.")
        test_dl = get_test_dataloader(config, dataset_name, dblock=LLaVADataBlockStage2)
        if test_dl is None:
            raise RuntimeError(f"Failed to load test dataloader for {dataset_name}")

        # 5. Generate Predictions
        if output_filename is None:
            output_filename = f"preds_{model_base_name}_{dataset_name}.jsonl"
        preds_file = eval_output_dir / output_filename
        
        gen_kwargs = {k: v for k, v in kwargs.items() if k in ['max_len', 'temperature', 'top_p']}
        
        generate_predictions(learner, test_dl, preds_file, **gen_kwargs)

        # 6. Run Evaluation Metrics
        print("Running evaluation metrics...")
        gt_config = config['paths'].get(dataset_name)
        if not gt_config or 'annotations' not in gt_config:
            print(f"Warning: Ground truth annotation path not found for '{dataset_name}' in config. Cannot run metrics.")
        else:
            gt_file_path = Path(config['paths']['data_base']) / gt_config['annotations']
            # Check for the questions file path, especially for VQA
            ques_file_path = None
            if 'questions' in gt_config:
                 ques_file_path = Path(config['paths']['data_base']) / gt_config['questions']
            elif 'vqav2' in dataset_name.lower(): # Infer default VQA questions file path if not specified
                 ann_path_parts = list(gt_file_path.parts)
                 if 'annotations' in ann_path_parts[-1]:
                      ques_filename = gt_file_path.name.replace('annotations', 'questions').replace('v2_mscoco', 'v2_OpenEnded_mscoco')
                      ques_file_path_default = gt_file_path.parent / ques_filename
                      if ques_file_path_default.exists():
                           ques_file_path = ques_file_path_default
                           print(f"Inferred VQA questions file path: {ques_file_path}")
            
            if not gt_file_path.exists():
                 print(f"Warning: Ground truth file not found at {gt_file_path}. Cannot run metrics.")
            else:
                 if 'vqav2' in dataset_name.lower():
                     if ques_file_path and ques_file_path.exists():
                         vqa_results = evaluate_vqa(preds_file, gt_file_path, ques_file=ques_file_path)
                         print(f"VQAv2 Results: {vqa_results}")
                     else:
                          print(f"Warning: VQA questions file not found or specified ({ques_file_path}). Cannot run VQA evaluation.")
                 elif 'textvqa' in dataset_name.lower():
                     anls_results = evaluate_textvqa(preds_file, gt_file_path)
                     print(f"TextVQA ANLS Results: {anls_results}")
                 # --- Add more evaluation cases here --- #
                 else:
                     print(f"No specific evaluation script configured for dataset: {dataset_name}")

    except Exception as e:
        print(f"An error occurred during evaluation: {e}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
    finally:
        # Clean up memory
        if learner is not None and hasattr(learner, 'model') and learner.model is not None:
            learner.model.to('cpu')
            del learner.model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("Cleaned up model memory.")
        # Finish W&B run if active and not already finished due to error
        if run and wandb.run and wandb.run.id == run.id:
             run.finish()
             print("Finished W&B run.")
            
    print(f"--- Evaluation Run Complete --- ")

In [9]:
#| hide
import nbdev; nbdev.nbdev_export()