# Evaluation

> Functions for evaluating LLaVA models, including prediction generation and metric calculation.

In [1]:
#| default_exp evaluation

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os
import gc
import json
import torch
import time
from tqdm.auto import tqdm
from typing import Union, Optional, Dict, Any # Added Dict, Any
import wandb # Added wandb import

from fastai.learner import Learner
from fastai.data.load import DataLoader
from transformers import AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root
project_root = Path(os.getcwd())
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())
if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")
    
# Import necessary llava components
try:
    from llava.utils import load_config, init_wandb # Added init_wandb
    from llava.data.loading import get_test_dataloader, LLaVADataBlockStage2 # Assumes this function exists in 10_data_loading
    from llava.model.baseline import BaselineLLaVAModel # Or AdaptiveLLaVAModel later
    from llava.data.preprocessing import tokenizer, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX_PLACEHOLDER # Import tokenizer & constants
    # Import PEFT related things if needed for model loading check
    try:
        from peft import PeftModel
        _peft_available = True
    except ImportError:
        PeftModel = None 
        _peft_available = False
    # from llava.training.stage2 import get_stage2_learner # Maybe needed for loading full learner state
except ImportError as e:
    print(f"Error importing llava modules: {e}. Make sure nbdev_export was run.")
    # Define placeholders if running standalone or during initial setup
    def load_config(path): return {}
    def init_wandb(*args, **kwargs): pass
    def get_test_dataloader(config, dataset_name, dblock=None): raise NotImplementedError
    class BaselineLLaVAModel(torch.nn.Module): 
        def __init__(self, *args, **kwargs): 
            super().__init__()
            self.dummy = torch.nn.Linear(1,1)
            self.image_token_index_marker = -200 # Define necessary attributes
            self.projector = torch.nn.Linear(10,10)
            self.language_model = self # Make model act like LLM for generate
            self.vision_tower = self # Dummy vision tower
            self.config = {} # Dummy config
        def encode_image(self, *args, **kwargs): return torch.randn(1,5,10) # Dummy image features B, P, D
        def get_input_embeddings(self): return torch.nn.Embedding(100, 10)
        def forward(self, *args, **kwargs): return CausalLMOutputWithPast(logits=torch.randn(1, 10, 100))
        def generate(self, *args, **kwargs): return torch.randint(0, 100, (1, kwargs.get('max_new_tokens', 10)))
        def to(self, *args, **kwargs): return self # Avoid device errors in dummy
        def eval(self): pass
    class DummyTokenizer:
         def __init__(self): self.eos_token_id=1; self.pad_token_id=0
         def decode(self, *args, **kwargs): return "dummy decoded text"
    tokenizer = DummyTokenizer()
    DEFAULT_IMAGE_TOKEN = "<image>"
    IGNORE_INDEX = -100
    IMAGE_TOKEN_INDEX_PLACEHOLDER = -200
    LLaVADataBlockStage2 = None # Placeholder
    PeftModel = None
    _peft_available = False
    # def get_stage2_learner(config): raise NotImplementedError

Adding project root to sys.path: /workspace/llava


## Step 5.2: Prediction Generation

This function takes a trained fastai `Learner` and a test `DataLoader`, generates predictions, decodes them into text, and saves them to a JSON file. It now also includes efficiency logging.

In [4]:
#| export
def generate_predictions(learner: Learner, 
                         test_dl: DataLoader, 
                         output_file: Union[str, Path], 
                         max_len: int = 200, # Max generation length
                         temperature: float = 0.2, # Generation temperature
                         top_p: Optional[float] = None, # Nucleus sampling top_p
                        ):
    """Generates predictions for a test dataloader and saves them to a JSON Lines file.
    Includes efficiency logging (latency, peak VRAM).

    Uses the HF generate method on the underlying language model component,
    manually preparing the combined image and text embeddings.

    Args:
        learner: Trained fastai Learner object containing the model.
                 Expected to have `learner.model`, `learner.dls.device`,
                 and potentially `learner.tokenizer` or `learner.dls.tokenizer`.
        test_dl: DataLoader for the test set.
        output_file: Path to save the JSON Lines prediction file.
        max_len: Maximum number of new tokens to generate.
        temperature: Sampling temperature for generation. Set to 0 for greedy decoding.
        top_p: If set (and temperature > 0), use nucleus sampling with this top_p value.
    """
    output_file = Path(output_file)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    model = learner.model
    # Get tokenizer - check learner, then dls, then global scope
    if hasattr(learner, 'tokenizer') and learner.tokenizer is not None:
        tok = learner.tokenizer
    elif hasattr(learner.dls, 'tokenizer') and learner.dls.tokenizer is not None:
        tok = learner.dls.tokenizer
    else:
        global tokenizer # Use the global tokenizer imported/defined earlier
        if tokenizer is None:
             raise AttributeError("Tokenizer not found in learner, dls, or globally. Cannot decode predictions.")
        print("Warning: Using global tokenizer instance for generation.")
        tok = tokenizer
        
    # Ensure model components exist
    if not all(hasattr(model, attr) and getattr(model, attr) is not None 
               for attr in ['vision_tower', 'projector', 'language_model', 'get_input_embeddings', 'encode_image']):
        raise AttributeError("Model is missing required components (vision_tower, projector, language_model, etc.)")
        
    model.eval() # Set model to evaluation mode
    results = []
    total_time = 0
    num_samples = 0
    device = learner.dls.device # Get device from learner

    print(f"Generating predictions for {len(test_dl.dataset)} samples...")
    print(f"Saving predictions to: {output_file}")
    
    # Reset CUDA memory stats before generation loop for accurate peak measurement
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device=device) 
        print("Reset CUDA peak memory stats before generation.")

    # Use context manager for file writing
    with torch.no_grad(), open(output_file, 'w') as f_out:
        # Iterate through batches
        for batch in tqdm(test_dl, desc="Generating Predictions"):
            start_time = time.time() # Start timing for this batch
            
            # Move batch items to appropriate device
            if not isinstance(batch, dict):
                 print(f"Warning: Expected batch to be a dict, got {type(batch)}. Attempting to proceed assuming basic structure.")
                 pixel_values = batch[0].to(device)
                 input_ids = batch[1].to(device)
                 batch_size = pixel_values.shape[0]
            else:
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device) 
                batch_size = pixel_values.shape[0]

            # --- Generation using underlying HF LLM --- 
            # 1. Encode images and project features
            image_features = model.encode_image(pixel_values) # (B, P, D_clip)
            if image_features is None: 
                print("Warning: Image encoding failed for batch. Skipping.")
                num_samples += batch_size # Account for skipped samples in latency calculation
                total_time += (time.time() - start_time) # Add time spent
                continue # Skip this batch
            projected_features = model.projector(image_features) # (B, P, D_llm)
            
            outputs_list = []
            # Process each sample in the batch individually for embedding preparation
            for i in range(batch_size):
                current_input_ids = input_ids[i:i+1] 
                current_proj_features = projected_features[i:i+1]
                
                marker = getattr(model, 'image_token_index_marker', IMAGE_TOKEN_INDEX_PLACEHOLDER)
                image_token_indices = torch.where(current_input_ids[0] == marker)[0]
                
                if len(image_token_indices) == 0:
                    print(f"Warning: Image token marker {marker} not found in sample {num_samples + i}. Skipping generation.")
                    outputs_list.append(torch.tensor([tok.eos_token_id], device=device)) 
                    continue
                
                image_token_start_index = image_token_indices[0].item()

                # Prepare prompt embeddings
                input_ids_no_marker = current_input_ids.clone()
                input_ids_no_marker[input_ids_no_marker == marker] = 0 
                text_embeddings = model.get_input_embeddings()(input_ids_no_marker)
                
                text_emb_before = text_embeddings[:, :image_token_start_index]
                text_emb_after = text_embeddings[:, image_token_start_index + 1:]
                
                prompt_embeds = torch.cat([
                    text_emb_before,
                    current_proj_features.to(text_embeddings.device, dtype=text_embeddings.dtype),
                    text_emb_after
                ], dim=1)
                
                # Generate using the underlying LLM
                llm_component = model.language_model
                if _peft_available and isinstance(llm_component, PeftModel):
                     llm_component = llm_component.base_model.model 
                
                prompt_attention_mask = torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
                
                try:
                    gen_params = {
                        "inputs_embeds": prompt_embeds,
                        "attention_mask": prompt_attention_mask,
                        "max_new_tokens": max_len,
                        "eos_token_id": tok.eos_token_id,
                        "pad_token_id": tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id,
                        "do_sample": (temperature > 0),
                        "num_beams": 1
                    }
                    if temperature > 0:
                        gen_params["temperature"] = temperature
                        if top_p is not None:
                            gen_params["top_p"] = top_p
                    else: 
                        gen_params["temperature"] = 1.0 
                        gen_params["do_sample"] = False

                    output_ids_gen = llm_component.generate(**gen_params)
                    
                    output_ids_gen = output_ids_gen[:, prompt_embeds.shape[1]:]
                    outputs_list.append(output_ids_gen[0]) 
                except Exception as gen_e:
                     print(f"Error during generation for sample {num_samples + i}: {gen_e}")
                     outputs_list.append(torch.tensor([tok.eos_token_id], device=device)) # Fallback
            # --- End Generation for Batch --- #

            end_time = time.time() # End timing for this batch
            total_time += (end_time - start_time)
            num_samples_in_batch = batch_size

            # --- Decode and Save Results for the Batch --- 
            # Attempt to get item IDs (might need adjustment based on Dataset class)
            start_ds_idx = test_dl.num_workers * test_dl.offs if hasattr(test_dl, 'offs') else num_samples # Estimate start index
            item_ids = []
            for i in range(num_samples_in_batch):
                current_item_idx = start_ds_idx + i
                item_id = f"sample_{current_item_idx}" # Default ID
                try:
                    # Try accessing dataset item directly - this is often the most reliable
                    if hasattr(test_dl, 'dataset') and current_item_idx < len(test_dl.dataset):
                        item_data = test_dl.dataset[current_item_idx] 
                        # Adapt based on what test_dl.dataset[idx] returns
                        # If it's the LLaVASample object:
                        if hasattr(item_data, 'sample_id'):
                             item_id = item_data.sample_id
                        # If it's a tuple from DataBlock (e.g., (image_path, conversations)):
                        # Need a way to link back to original ID, maybe store IDs in dataset?
                        # For now, stick to default if direct access fails.
                except Exception as id_err:
                    # print(f"Could not get sample ID for index {current_item_idx}: {id_err}")
                    pass # Use default ID
                item_ids.append(item_id)
                 
            # Decode and write each result in the batch
            for i, gen_ids in enumerate(outputs_list):
                decoded_text = tok.decode(gen_ids.cpu(), skip_special_tokens=True).strip()
                result_entry = {
                    "id": item_ids[i], 
                    "prediction": decoded_text,
                }
                f_out.write(json.dumps(result_entry) + '\n')
                results.append(result_entry)
            
            num_samples += num_samples_in_batch # Update total sample count

    # --- Log Efficiency Metrics (Step 5.4) --- 
    avg_latency = (total_time / num_samples) * 1000 if num_samples > 0 else 0
    print(f"Finished generation. Saved {len(results)} predictions to {output_file}")
    print(f"Average inference latency: {avg_latency:.2f} ms/sample")
    
    if wandb.run is not None:
        wandb.log({"eval/avg_inference_latency_ms": avg_latency})

    if torch.cuda.is_available():
        peak_vram_gb = torch.cuda.max_memory_allocated(device=device) / (1024**3) 
        print(f"Peak Inference VRAM used: {peak_vram_gb:.2f} GB")
        if wandb.run is not None:
            wandb.log({"eval/peak_inference_vram_gb": peak_vram_gb})
        # Reset peak stats after logging for this run
        torch.cuda.reset_peak_memory_stats(device=device) 
        
    return results # Return list of predictions

## Step 5.3: Integrate External Evaluation Scripts (Placeholders)

Define placeholder functions for calling external evaluation scripts. These will be implemented fully in later steps.

In [5]:
#| export
def evaluate_vqa(preds_file: Union[str, Path], gt_file: Union[str, Path], **kwargs):
    """Placeholder function to evaluate VQAv2 predictions.
    
    Args:
        preds_file: Path to the JSON prediction file (expected format: [{'question_id': id, 'answer': prediction}]).
        gt_file: Path to the ground truth annotation file (e.g., VQA v2 format).
        **kwargs: Additional arguments for the VQA eval API (e.g., version).
    
    Returns:
        Dictionary containing VQA scores (e.g., {'overall': score}).
    """
    print(f"--- VQA Evaluation (Placeholder) ---")
    print(f"Predictions file: {preds_file}")
    print(f"Ground Truth file: {gt_file}")
    
    # --- TODO: Implement actual VQA evaluation logic (Step 5.5) --- 
    # 1. Ensure predictions file is in the correct format for the official VQA eval tool.
    # 2. Import or call the official VQA evaluation script/library.
    #    (Requires downloading the VQA evaluation tools: https://visualqa.org/evaluation.html)
    # Example structure:
    # try:
    #     from vqa_eval_tools.vqa import VQA
    #     from vqa_eval_tools.vqaEval import VQAEval
    #     
    #     vqa_ann = VQA(gt_file, None) # Load annotations
    #     vqa_pred = vqa_ann.loadRes(preds_file, gt_file) # Load predictions
    #     
    #     vqa_eval = VQAEval(vqa_ann, vqa_pred, n=2) # n=2 is standard for VQA
    #     vqa_eval.evaluate()
    #     
    #     results = {
    #         "overall": vqa_eval.accuracy['overall'],
    #         "yes/no": vqa_eval.accuracy['perAnswerType']['yes/no'],
    #         "number": vqa_eval.accuracy['perAnswerType']['number'],
    #         "other": vqa_eval.accuracy['perAnswerType']['other'],
    #     }
    #     print(f"Calculated VQA Accuracy: {results}")
    # except ImportError:
    #     print("Error: VQA evaluation tools not found. Please install them.")
    #     results = {"overall": 0.0} # Return dummy score
    # except Exception as e:
    #     print(f"Error during VQA evaluation: {e}")
    #     results = {"overall": 0.0} # Return dummy score
    # -------------------------------------------------------------
    
    print("Actual VQA evaluation logic not yet implemented.")
    dummy_score = 0.5 # Placeholder score
    results = {"overall": dummy_score} # Placeholder result

    # Log metric to W&B if active
    if wandb.run is not None:
        wandb.log({"eval/vqa_score_overall": results.get("overall", 0.0)})
        # Log other scores if needed
        # wandb.log({"eval/vqa_score_yes_no": results.get("yes/no", 0.0)})
        # wandb.log({"eval/vqa_score_number": results.get("number", 0.0)})
        # wandb.log({"eval/vqa_score_other": results.get("other", 0.0)})
        
    return results

#| export
def evaluate_textvqa(preds_file: Union[str, Path], gt_file: Union[str, Path], **kwargs):
    """Placeholder function to evaluate TextVQA/DocVQA predictions using ANLS.

    Args:
        preds_file: Path to the JSON prediction file (format depends on benchmark, often list of dicts with id/pred).
        gt_file: Path to the ground truth annotation file (e.g., TextVQA JSON format).
        **kwargs: Additional arguments for ANLS calculation (e.g., case sensitivity).

    Returns:
        Dictionary containing ANLS score (e.g., {'anls': score}).
    """
    print(f"--- TextVQA/DocVQA ANLS Evaluation (Placeholder) ---")
    print(f"Predictions file: {preds_file}")
    print(f"Ground Truth file: {gt_file}")
    
    # --- TODO: Implement actual ANLS calculation (Step 5.6) --- 
    # 1. Load predictions and ground truth data.
    # 2. Implement the ANLS metric calculation.
    #    Reference: https://rrc.cvc.uab.es/?ch=17&com=tasks (DocVQA 2021 Task 3)
    #    ANLS (Average Normalized Levenshtein Similarity) requires calculating Levenshtein distance
    #    and normalizing it based on string lengths, averaged over the dataset.
    #    Consider using existing libraries or implementing the formula carefully.
    # Example structure:
    # try:
    #     with open(preds_file, 'r') as f_pred, open(gt_file, 'r') as f_gt:
    #         predictions = json.load(f_pred) # Or load line by line if jsonl
    #         ground_truth = json.load(f_gt)
    #     
    #     # Preprocess/align predictions and ground truth based on IDs
    #     gt_dict = {item['questionId']: item['answers'] for item in ground_truth['data']}
    #     pred_dict = {item['id']: item['prediction'] for item in predictions} # Match ID key
    #     
    #     total_anls = 0
    #     count = 0
    #     for qid, prediction in pred_dict.items():
    #         if str(qid) in gt_dict: # Ensure ID matching (might need type conversion)
    #             gt_answers = gt_dict[str(qid)]
    #             # Calculate ANLS for this sample (max over multiple GT answers)
    #             sample_anls = calculate_single_anls(prediction, gt_answers)
    #             total_anls += sample_anls
    #             count += 1
    #             
    #     final_anls = total_anls / count if count > 0 else 0
    #     results = {"anls": final_anls}
    #     print(f"Calculated ANLS: {final_anls:.4f}")
    # except Exception as e:
    #     print(f"Error during ANLS calculation: {e}")
    #     results = {"anls": 0.0} # Return dummy score
    # -------------------------------------------------------------

    print("Actual ANLS calculation logic not yet implemented.")
    dummy_anls = 0.4 # Placeholder score
    results = {"anls": dummy_anls} # Placeholder result

    # Log metric to W&B if active
    if wandb.run is not None:
        wandb.log({"eval/anls": results.get("anls", 0.0)})
        
    return results

# Helper function placeholder for ANLS calculation (to be implemented in Step 5.6)
# def calculate_single_anls(prediction: str, ground_truths: List[str]) -> float:
#     """Calculates the ANLS score for a single prediction against multiple ground truths."""
#     # TODO: Implement Levenshtein distance and ANLS formula
#     return 0.0 # Placeholder

## Evaluation Script (Placeholder Structure)

In [6]:
#| export
def run_evaluation(config_path: Union[str, Path], 
                   model_checkpoint_path: Optional[Union[str, Path]] = None, 
                   dataset_name: str = 'vqav2_test', # Example default
                   output_filename: Optional[str] = None,
                   **kwargs # Allow passing generation args like temp, top_p
                   ):
    """Runs the evaluation pipeline: load model, generate predictions, evaluate.
    
    Args:
        config_path: Path to the main configuration YAML.
        model_checkpoint_path: Path to the specific model checkpoint base name/directory to load 
                               (e.g., '/path/to/output/models/stage2_llava_lora'). 
                               If None, uses the path derived from config['paths']['stage2_model_weights'].
                               Expects associated files like '_projector_final.pth' and '_lora_adapters'/'_full.pth'.
        dataset_name: Name of the dataset split to evaluate (e.g., 'vqav2_test', 'textvqa_val').
                      This key should exist in `config['paths']` pointing to dataset info.
        output_filename: Name for the prediction output file (defaults based on model/dataset).
        **kwargs: Additional arguments passed to `generate_predictions` (e.g., temperature, top_p).
    """
    print(f"--- Starting Evaluation Run --- ")
    print(f"Config: {config_path}")
    print(f"Dataset: {dataset_name}")
    
    config = load_config(config_path)
    output_dir = Path(config['paths']['output_dir'])
    eval_output_dir = output_dir / 'eval_results' / dataset_name
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    models_dir = output_dir / 'models' # Where models are saved
    
    learner = None
    run = None # Initialize wandb run object
    try:
        # --- Initialize W&B --- 
        # Modified: Init W&B here to capture VRAM/Latency logs from generate_predictions
        if config.get('logging', {}).get('wandb', {}).get('enabled', False):
             # Create a unique run name for evaluation
             model_id_for_run = Path(model_checkpoint_path or config['paths']['stage2_model_weights']).stem
             run_name = f"eval_{model_id_for_run}_{dataset_name}_{time.strftime('%Y%m%d_%H%M%S')}"
             run = init_wandb(config, job_type="evaluation", run_name=run_name)
        
        # 1. Determine Model Path Base Name
        if model_checkpoint_path is None:
            model_base_name = Path(config['paths']['stage2_model_weights']).stem
            model_load_base = models_dir / model_base_name
            print(f"Using model base name from config: {model_base_name}")
        else:
            # If a specific path is given, use its stem as the base name
            model_checkpoint_path = Path(model_checkpoint_path)
            model_base_name = model_checkpoint_path.stem 
            # Assume the path points to the base (e.g., stage2_llava_lora), not the specific file
            model_load_base = model_checkpoint_path 
            print(f"Using provided model base path: {model_load_base}")
        
        # 2. Load Model Components Manually
        print("Loading model components for evaluation...")
        # Pass the config to the model constructor
        model = BaselineLLaVAModel(config=config) 
        
        # Load projector weights
        proj_weights_path = model_load_base.parent / (model_load_base.name + "_projector_final.pth")
        if proj_weights_path.exists():
            model.projector.load_state_dict(torch.load(proj_weights_path, map_location='cpu'))
            print(f"Loaded projector weights from {proj_weights_path}")
        else:
            print(f"Warning: Projector weights not found at {proj_weights_path}. Using initial weights.")
            
        # Load LoRA adapters or full LLM weights
        use_lora_config = config.get('model', {}).get('peft', {}).get('use_lora', False)
        lora_adapter_path = model_load_base.parent / (model_load_base.name + "_lora_adapters")
        # full_model_path = model_load_base.parent / (model_load_base.name + ".pth") # Example for full save

        if use_lora_config:
            if lora_adapter_path.exists() and _peft_available and isinstance(model.language_model, PeftModel):
                print(f"Loading LoRA adapters from {lora_adapter_path}")
                # Ensure the adapter is loaded correctly. `load_adapter` might add a new one.
                # If `get_peft_model` was called during init, we might need to ensure the base model is loaded correctly first.
                # Safest is often to load the base model, then apply PEFT wrapper, then load adapters.
                # Assuming BaselineLLaVAModel's init already called get_peft_model:
                try:
                    # Ensure the PeftModel is ready
                    if hasattr(model.language_model, 'load_adapter'):
                         model.language_model.load_adapter(str(lora_adapter_path), adapter_name="default")
                         print("LoRA adapters loaded successfully.")
                    else:
                         print("Warning: Model's language_model does not have 'load_adapter' method. PEFT setup might be incorrect.")
                except Exception as e:
                    print(f"Error loading LoRA adapters: {e}. Model LLM may not be correctly wrapped or adapters mismatch.")
            else:
                print(f"Warning: LoRA enabled in config but adapters not found at {lora_adapter_path} or PEFT issue.")
        else:
             print("LoRA not used. Assuming full LLM fine-tuning.")
             print("Warning: Loading full fine-tuned LLM state is not fully implemented here yet. Model may use base weights.")
             # TODO: Implement loading full model state if needed (e.g., from SaveModelCallback output)
             # if full_model_path.exists():
             #    learner_state = torch.load(full_model_path, map_location='cpu')
             #    model.load_state_dict(learner_state['model'])
             #    print(f"Loaded full model state from {full_model_path}")
             
        # 3. Create Learner Shell for Convenience
        # Use a dummy dataloader just to hold the device context etc.
        dummy_dl = DataLoader(list(range(1)), bs=1) # Minimal dataloader
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        # Create a dummy DataLoaders object to satisfy Learner init
        dummy_dls = DataLoaders(dummy_dl, dummy_dl) 
        dummy_dls.device = device # Set device on DataLoaders
        learner = Learner(dls=dummy_dls, model=model, loss_func=lambda x,y: 0) # Dummy loss
        
        # Manually assign tokenizer if generate_predictions needs it from learner
        global tokenizer 
        if tokenizer is None:
             raise RuntimeError("Tokenizer not loaded, cannot proceed with evaluation.")
        learner.tokenizer = tokenizer 
        print(f"Minimal Learner shell created. Model moved to {device}.")
        
        # 4. Load Test Data
        print(f"Loading test data for {dataset_name}...")
        # Need to instantiate the correct DataBlock (e.g., Stage 2 for VQA/TextVQA)
        # Assuming Stage 2 structure is suitable for standard eval datasets
        if LLaVADataBlockStage2 is None:
             raise RuntimeError("LLaVADataBlockStage2 is not defined. Cannot load test data.")
        test_dl = get_test_dataloader(config, dataset_name, dblock=LLaVADataBlockStage2)
        if test_dl is None:
            raise RuntimeError(f"Failed to load test dataloader for {dataset_name}")

        # 5. Generate Predictions
        if output_filename is None:
            # Use the resolved model base name for the output file
            output_filename = f"preds_{model_base_name}_{dataset_name}.jsonl"
        preds_file = eval_output_dir / output_filename
        
        # Get generation kwargs from **kwargs
        gen_kwargs = {k: v for k, v in kwargs.items() if k in ['max_len', 'temperature', 'top_p']}
        
        # generate_predictions now handles latency/VRAM logging
        generate_predictions(learner, test_dl, preds_file, **gen_kwargs)

        # 6. Run Evaluation Metrics
        print("Running evaluation metrics...")
        # Determine Ground Truth file path based on dataset_name
        gt_config = config['paths'].get(dataset_name)
        if not gt_config or 'annotations' not in gt_config:
            print(f"Warning: Ground truth annotation path not found for '{dataset_name}' in config. Cannot run metrics.")
        else:
            gt_file_path = Path(config['paths']['data_base']) / gt_config['annotations']
            if not gt_file_path.exists():
                 print(f"Warning: Ground truth file not found at {gt_file_path}. Cannot run metrics.")
            else:
                 # Call appropriate evaluation function based on dataset name convention
                 if 'vqav2' in dataset_name.lower():
                     vqa_results = evaluate_vqa(preds_file, gt_file_path)
                     print(f"VQAv2 Results: {vqa_results}")
                 elif 'textvqa' in dataset_name.lower():
                     anls_results = evaluate_textvqa(preds_file, gt_file_path)
                     print(f"TextVQA ANLS Results: {anls_results}")
                 # --- Add more evaluation cases here --- #
                 # elif 'docvqa' in dataset_name.lower():
                 #     docvqa_results = evaluate_textvqa(preds_file, gt_file_path) # Often uses ANLS too
                 #     print(f"DocVQA ANLS Results: {docvqa_results}")
                 # elif 'chartqa' in dataset_name.lower():
                 #     # chartqa_results = evaluate_chartqa(preds_file, gt_file_path)
                 #     print("ChartQA evaluation not implemented yet.")
                 # elif 'custom_eval' in dataset_name.lower():
                 #     # custom_results = evaluate_custom(preds_file, gt_file_path)
                 #     print("Custom eval set evaluation not implemented yet.")
                 else:
                     print(f"No specific evaluation script configured for dataset: {dataset_name}")

    except Exception as e:
        print(f"An error occurred during evaluation: {e}")
        import traceback
        traceback.print_exc()
        # Set exit code for W&B if error occurs
        if run: run.finish(exit_code=1)
    finally:
        # Clean up memory
        if learner is not None and hasattr(learner, 'model') and learner.model is not None:
            learner.model.to('cpu')
            del learner.model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("Cleaned up model memory.")
        # Finish W&B run if active and not already finished due to error
        if run and wandb.run and wandb.run.id == run.id:
             run.finish()
             print("Finished W&B run.")
            
    print(f"--- Evaluation Run Complete --- ")

In [7]:
#| hide
import nbdev; nbdev.nbdev_export()