# Stage 2 Training: Instruction Fine-tuning

> Sets up and runs the second stage of LLaVA training: fine-tuning the LLM (using LoRA) and the projector on instruction-following data.
> Also includes setup for training the Adaptive model variant and handling ablation configurations.

In [None]:
#| default_exp training.stage2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os
import gc # For memory cleanup
import argparse # For command-line execution
import time # For timing
import traceback # For detailed error printing
import copy # For deep copying config

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

# If running as a script, the path might need adjustment relative to the script location
# This assumes the script is run from the project root or `scripts/` directory
if __name__ == "__main__" and "get_ipython" not in locals():
     # If script is in 'scripts/', go up one level to project root
     if project_root.name == 'scripts':
         project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose for script
    pass

Adding project root to sys.path: /workspace/llava


In [None]:
#| export
import torch
import torch.nn as nn
from fastai.learner import Learner
from fastai.optimizer import AdamW
from fastai.callback.wandb import WandbCallback
from fastai.callback.schedule import fit_one_cycle
from fastai.callback.save import SaveModelCallback
from fastai.callback.training import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision
from fastai.vision.all import params # For splitter
from fastai.text.all import Perplexity # Import Perplexity metric
from fastai.data.core import DataLoaders
from functools import partial
import wandb # Import wandb directly for cleanup
import json # For dummy data creation
import PIL.Image # For dummy data creation
from typing import List, Optional, Type # Added Type

# Attempt to import peft, set flag
try:
    from peft import PeftModel, save_adapter # Removed save_adapter (not used here)
    _peft_available = True
except ImportError:
    print("Warning: peft library not found. LoRA functionality will be disabled.", file=sys.stderr)
    PeftModel = None # Define as None if not available
    _peft_available = False

try:
    from llava.utils import load_config, init_wandb
    from llava.data.loading import get_stage2_dataloaders
    from llava.model.baseline import BaselineLLaVAModel
    from llava.model.adaptive import AdaptiveLLaVAModel # Added Adaptive model
    from llava.training.core import LLaVALoss
except ImportError as e:
     print(f"Error importing llava modules: {e}")
     print("Ensure that nbdev_export has been run and the llava library is installed/accessible.")
     # In a script context, it's better to exit if core modules are missing
     if __name__ == "__main__" and "get_ipython" not in locals():
          sys.exit(1)

Project root already in sys.path: /workspace/llava




Loaded config from ../configs/config.yaml


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336
CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- {'mean': tensor([[[[0.4815]],

         [[0.4578]],

         [[0.4082]]]]), 'std': tensor([[[[0.2686]],

         [[0.2613]],

         [[0.2758]]]]), 'axes': (0, 2, 3)}
(enc:1,dec:1)
Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)
V1 template assistant role tokens: [1792, 29889]
LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0, Template: v1
LLaVADataBlockStage2 defined.
LLaVALoss initialized, ignoring index: -100


## Stage 2 Splitter Functions

In Stage 2, we fine-tune the `projector` and the `language_model`. If LoRA is enabled (`use_lora: true` in config), only the LoRA adapter parameters within the language model are trained. If LoRA is disabled, the entire language model is trained. The `vision_tower` remains frozen by default.

For the adaptive model, we also include parameters from the `patcher` module if it is trainable.

In [None]:
#| export
def llava_stage2_splitter(model: nn.Module):
    """Splits the `BaselineLLaVAModel` parameters for Stage 2 training.

    Trains the `projector` and LoRA adapters (if enabled) or the full `language_model`.
    Keeps the `vision_tower` frozen by default.

    Args:
        model: An instance of `BaselineLLaVAModel`.

    Returns:
        A list containing parameter groups for trainable components.
    """
    projector_params = []
    llm_params = []
    frozen_params = []

    print("Applying Stage 2 splitter (Baseline)..." + (" (PEFT Available)" if _peft_available else ""))

    # Projector parameters are always trained in Stage 2
    if hasattr(model, 'projector') and model.projector is not None:
        print("  - Collecting projector parameters (trainable).")
        projector_params.extend(list(model.projector.parameters()))
        for p in model.projector.parameters():
             p.requires_grad = True
    else:
        print("Warning: Model has no projector attribute.")

    # Handle Language Model parameters based on LoRA configuration
    if hasattr(model, 'language_model') and model.language_model is not None:
        # Access config stored within the model instance
        # Handle potential nested structure in config access more safely
        use_lora = model.config.get('model', {}).get('peft', {}).get('use_lora', False)

        if _peft_available and use_lora and isinstance(model.language_model, PeftModel):
            print("  - LoRA enabled: Collecting LLM adapter parameters (trainable).")
            # PEFT model automatically handles requires_grad for adapters
            llm_params.extend([p for p in model.language_model.parameters() if p.requires_grad])
            # Base model parameters should already be frozen by PEFT
            # We can double-check frozen status for verification
            # frozen_params.extend([p for p in model.language_model.parameters() if not p.requires_grad])
        elif use_lora and not _peft_available:
             print("Warning: LoRA configured but PEFT library not found. Cannot train LoRA adapters. Freezing LLM.")
             for p in model.language_model.parameters():
                 p.requires_grad = False
             # frozen_params.extend(list(model.language_model.parameters()))
        else:
            print("  - LoRA disabled: Collecting all LLM parameters (trainable).")
            llm_params.extend(list(model.language_model.parameters()))
            for p in model.language_model.parameters():
                 p.requires_grad = True # Ensure all LLM params are trainable if not using LoRA
    else:
        print("Warning: Model has no language_model attribute.")

    # Vision Tower parameters are frozen by default
    if hasattr(model, 'vision_tower') and model.vision_tower is not None:
        print("  - Collecting vision tower parameters (frozen).")
        # frozen_params.extend(list(model.vision_tower.parameters())) # Collect if needed for verification
        for p in model.vision_tower.parameters():
            p.requires_grad = False
    else:
        print("Warning: Model has no vision_tower attribute.")

    # Combine trainable parameters into one group for the optimizer
    trainable_groups = projector_params + llm_params
    # Count frozen parameters for verification
    frozen_count = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    trainable_count = sum(p.numel() for p in trainable_groups)
    print(f"Splitter created groups: Trainable ({trainable_count} params), Frozen ({frozen_count} params)")
    
    if not trainable_groups:
         raise ValueError("Splitter function resulted in no trainable parameters. Check model structure and config.")
         
    return [trainable_groups]

In [None]:
#| export
def adaptive_llava_stage2_splitter(model: nn.Module):
    """Splits the `AdaptiveLLaVAModel` parameters for Stage 2 training.

    Trains the `projector`, LoRA adapters (if enabled) or the full `language_model`,
    and the `patcher` module (if it has trainable parameters).
    Keeps the `vision_tower` frozen by default.

    Args:
        model: An instance of `AdaptiveLLaVAModel`.

    Returns:
        A list containing parameter groups for trainable components.
    """
    projector_params = []
    llm_params = []
    patcher_params = []
    frozen_params = []

    print("Applying Stage 2 splitter (Adaptive)..." + (" (PEFT Available)" if _peft_available else ""))

    # Projector parameters
    if hasattr(model, 'projector') and model.projector is not None:
        print("  - Collecting projector parameters (trainable).")
        projector_params.extend(list(model.projector.parameters()))
        for p in model.projector.parameters(): p.requires_grad = True
    else: print("Warning: Model has no projector attribute.")

    # LLM parameters (LoRA or full)
    if hasattr(model, 'language_model') and model.language_model is not None:
        use_lora = model.config.get('model', {}).get('peft', {}).get('use_lora', False)
        if _peft_available and use_lora and isinstance(model.language_model, PeftModel):
            print("  - LoRA enabled: Collecting LLM adapter parameters (trainable).")
            llm_params.extend([p for p in model.language_model.parameters() if p.requires_grad])
        elif use_lora and not _peft_available:
            print("Warning: LoRA configured but PEFT library not found. Freezing LLM.")
            for p in model.language_model.parameters(): p.requires_grad = False
        else:
            print("  - LoRA disabled: Collecting all LLM parameters (trainable).")
            llm_params.extend(list(model.language_model.parameters()))
            for p in model.language_model.parameters(): p.requires_grad = True
    else: print("Warning: Model has no language_model attribute.")

    # Patcher parameters (if exists and has parameters)
    if hasattr(model, 'patcher') and model.patcher is not None:
        patcher_trainable_params = [p for p in model.patcher.parameters() if p.requires_grad]
        if patcher_trainable_params:
            print("  - Collecting patcher parameters (trainable).")
            patcher_params.extend(patcher_trainable_params)
            # Ensure they are set to trainable (might be redundant but safe)
            for p in patcher_params: p.requires_grad = True
        else:
            print("  - Patcher found, but has no trainable parameters.")
    else: print("  - No adaptive patcher found in model.")

    # Vision Tower (frozen)
    if hasattr(model, 'vision_tower') and model.vision_tower is not None:
        print("  - Collecting vision tower parameters (frozen).")
        for p in model.vision_tower.parameters(): p.requires_grad = False
    else: print("Warning: Model has no vision_tower attribute.")

    # Combine trainable groups
    trainable_groups = projector_params + llm_params + patcher_params
    frozen_count = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    trainable_count = sum(p.numel() for p in trainable_groups)
    print(f"Splitter created groups: Trainable ({trainable_count} params), Frozen ({frozen_count} params)")

    if not trainable_groups:
        raise ValueError("Splitter function resulted in no trainable parameters. Check model structure and config.")

    return [trainable_groups]

In [None]:
# show_doc(llava_stage2_splitter) # Omitted for script execution
# show_doc(adaptive_llava_stage2_splitter) # Omitted for script execution

## Step 4.3 & 7.1: Setup Learner Configuration (Stage 2 - Baseline & Adaptive)

These functions set up the `Learner` object for Stage 2 instruction tuning, handling both the baseline and adaptive model variants.

**Update for Step 8.1:** Modified `_get_stage2_learner_internal` to include ablation name in the W&B run name if an ablation is active.

In [None]:
#| export
def get_stage2_learner(config: dict) -> Learner:
    """Configures and returns a fastai Learner for Stage 2 Instruction Fine-tuning (Baseline Model).

    Loads Stage 1 projector weights, sets up the BaselineLLaVAModel (potentially with LoRA),
    uses the llava_stage2_splitter, and includes relevant callbacks and metrics.

    Args:
        config: The main configuration dictionary.

    Returns:
        A configured fastai Learner instance for Stage 2 Baseline Training.

    Raises:
        RuntimeError: If DataLoaders or Model instantiation fails.
        FileNotFoundError: If Stage 1 projector weights or data paths are invalid.
        AttributeError: If the model is missing expected components.
    """
    return _get_stage2_learner_internal(config, model_class=BaselineLLaVAModel, splitter=llava_stage2_splitter)

#| export
def get_adaptive_stage2_learner(config: dict) -> Learner:
    """Configures and returns a fastai Learner for Stage 2 Instruction Fine-tuning (Adaptive Model).

    Loads Stage 1 projector weights, sets up the AdaptiveLLaVAModel (potentially with LoRA),
    uses the adaptive_llava_stage2_splitter, and includes relevant callbacks and metrics.

    Args:
        config: The main configuration dictionary.

    Returns:
        A configured fastai Learner instance for Stage 2 Adaptive Training.

    Raises:
        RuntimeError: If DataLoaders or Model instantiation fails.
        FileNotFoundError: If Stage 1 projector weights or data paths are invalid.
        AttributeError: If the model is missing expected components.
    """
    return _get_stage2_learner_internal(config, model_class=AdaptiveLLaVAModel, splitter=adaptive_llava_stage2_splitter)

#| export
# Internal function to handle common learner setup logic
def _get_stage2_learner_internal(config: dict, model_class: Type[nn.Module], splitter: callable) -> Learner:
    """Internal function to set up Stage 2 Learner for baseline or adaptive models."""
    model_type_name = model_class.__name__
    print(f"--- Setting up Stage 2 Learner ({model_type_name}) ---")
    output_dir = Path(config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)
    ablation_config = config.get('ablation', {})
    ablation_name = ablation_config.get('force_patcher_strategy') # e.g., 'baseline' or None

    # 1. Load Stage 2 DataLoaders
    print("Loading Stage 2 DataLoaders...")
    try:
        dls = get_stage2_dataloaders(config)
    except (FileNotFoundError, Exception) as e:
        print(f"Error loading Stage 2 DataLoaders: {e}")
        raise RuntimeError("Failed to create Stage 2 DataLoaders.") from e
    if not dls:
        raise RuntimeError("Stage 2 DataLoaders object is None.")
    print(f"DataLoaders loaded. Train samples: {len(dls.train_ds)}, Valid samples: {len(dls.valid_ds)}")

    # 2. Instantiate Model
    print(f"Instantiating {model_type_name} for Stage 2..." + (f" (Ablation: {ablation_name})" if ablation_name else ""))
    try:
        # Pass the potentially modified config (e.g., from command line) to the model
        model = model_class(config)
        if model.vision_tower is None or model.language_model is None or model.projector is None:
            raise RuntimeError(f"{model_type_name} initialization incomplete.")
        # For adaptive model, check if patcher initialized if expected
        if model_type_name == 'AdaptiveLLaVAModel' and config.get('model', {}).get('adaptive_patcher', {}).get('enabled', False):
            if not hasattr(model, 'patcher') or model.patcher is None:
                print("Warning: Adaptive patcher was enabled in config but failed to initialize in the model.")
        print("Model instantiated successfully.")
    except Exception as e:
        print(f"Error instantiating {model_type_name}: {e}")
        raise RuntimeError(f"Failed to instantiate {model_type_name} for Stage 2.") from e

    # 3. Load Stage 1 Projector Weights
    stage1_weights_fname = config['paths'].get('stage1_projector_weights', 'stage1_projector.pth')
    stage1_weights_path = output_dir / 'models' / stage1_weights_fname
    print(f"Attempting to load Stage 1 projector weights from: {stage1_weights_path}")
    if stage1_weights_path.is_file():
        try:
            projector_state_dict = torch.load(stage1_weights_path, map_location='cpu')
            model.projector.load_state_dict(projector_state_dict)
            print(f"Successfully loaded Stage 1 projector weights from {stage1_weights_path}")
        except Exception as e:
            print(f"Error loading Stage 1 projector weights: {e}")
            raise RuntimeError(f"Failed to load expected Stage 1 projector weights from {stage1_weights_path}") from e
    else:
        print(f"Warning: Stage 1 projector weights not found at {stage1_weights_path}. Projector will use initial weights.")

    # 4. Define Loss Function
    loss_func = LLaVALoss()
    print(f"Loss function: {type(loss_func).__name__}")

    # 5. Define Optimizer
    lr = config.get('training', {}).get('learning_rate_stage2', 2e-5)
    wd = config.get('training', {}).get('weight_decay', 0.0)
    opt_func = partial(AdamW, lr=lr, wd=wd, eps=1e-8)
    print(f"Optimizer: AdamW (lr={lr}, wd={wd})")

    # 6. Define Splitter (Passed as argument)
    print(f"Parameter splitter: {splitter.__name__}")

    # --- Add Metrics (Step 5.1) --- #
    metrics = [Perplexity()]
    print(f"Metrics: {[m.name for m in metrics]}")

    # 7. Define Callbacks
    cbs = []
    if config.get('logging', {}).get('wandb', {}).get('enabled', False):
        wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
        if wandb_entity and 'your_wandb_entity' not in str(wandb_entity):
            project_name = config.get('logging', {}).get('wandb', {}).get('project', 'llava-adaptive-patching')
            run_name_prefix = config.get('logging', {}).get('wandb', {}).get('run_name_prefix', 'stage2')
            # Adapt run name based on model type and ablation status
            model_tag = 'adaptive' if model_type_name == 'AdaptiveLLaVAModel' else 'baseline'
            ablation_tag = f"_abl-{ablation_name}" if ablation_name else ""
            stage2_model_name = Path(config['paths']['stage2_model_weights']).stem
            run_name = f"{run_name_prefix}_{model_tag}{ablation_tag}_{stage2_model_name}_{wandb.util.generate_id()}"
            init_wandb(config, job_type="stage2-training", run_name=run_name)
            cbs.append(WandbCallback(log_preds=False, log_model=False))
            print("Added WandbCallback.")
        else:
            print("W&B enabled in config, but entity not set or default. Skipping W&B init and callback.")

    # SaveModelCallback setup (remains commented out, manual save preferred)
    stage2_model_fname = Path(config['paths']['stage2_model_weights']).stem
    ablation_fname_tag = f"_abl-{ablation_name}" if ablation_name else "" # Add ablation tag to saved model name
    save_cb = SaveModelCallback(
        monitor='valid_loss',
        min_delta=0.001,
        fname=f"{stage2_model_fname}{ablation_fname_tag}", # Include ablation tag in filename
        every_epoch=False,
        with_opt=True,
        reset_on_fit=True
    )
    print(f"SaveModelCallback is configured but commented out. Manual saving of adapters/projector is preferred.")

    # Optimization Callbacks
    grad_accum_steps = config.get('training', {}).get('gradient_accumulation_steps', 1)
    if grad_accum_steps > 1:
        cbs.append(GradientAccumulation(grad_accum_steps))
        print(f"Added GradientAccumulation callback with {grad_accum_steps} steps.")
    
    use_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    if use_mixed_precision:
        cbs.append(MixedPrecision())
        print("Added MixedPrecision callback.")

    # 8. Create Learner
    try:
        learner = Learner(
            dls=dls,
            model=model,
            loss_func=loss_func,
            opt_func=opt_func,
            splitter=splitter,
            cbs=cbs,
            metrics=metrics,
            path=output_dir,
            train_bn=False
        )
    except Exception as e:
        print(f"Error creating Stage 2 Learner ({model_type_name}): {e}")
        if wandb.run is not None: wandb.finish(exit_code=1)
        raise RuntimeError(f"Failed to create the Stage 2 Learner object ({model_type_name}).") from e

    print(f"--- Stage 2 Learner Setup Complete ({model_type_name}) ---")
    return learner

In [None]:
# show_doc(get_stage2_learner) # Omitted for script execution
# show_doc(get_adaptive_stage2_learner) # Omitted for script execution

#### Example Usage & Test (Learner Configuration - Stage 2)

In [None]:
#| test 
import gc
from fastai.callback.save import SaveModelCallback
from fastai.callback.wandb import WandbCallback
from fastai.callback.training import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision
from fastai.text.all import Perplexity # Import Perplexity for testing
from llava.model.baseline import LLaVAProjector, BaselineLLaVAModel # Import required classes
from llava.model.adaptive import AdaptiveLLaVAModel # Import adaptive model

baseline_learner = None
adaptive_learner = None

try:
    # Load config
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # --- Test Setup ---
    output_dir = Path(config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / 'models').mkdir(parents=True, exist_ok=True)
    
    # Create dummy Stage 1 weights if they don't exist
    stage1_weights_fname = config['paths'].get('stage1_projector_weights', 'stage1_projector.pth')
    stage1_weights_path = output_dir / 'models' / stage1_weights_fname
    if not stage1_weights_path.is_file():
        print(f"Creating dummy Stage 1 projector weights: {stage1_weights_path}")
        proj_input_dim = config['model']['projector']['input_dim']
        proj_output_dim = config['model']['projector']['output_dim']
        dummy_projector = LLaVAProjector(proj_input_dim, proj_output_dim)
        torch.save(dummy_projector.state_dict(), stage1_weights_path)
        del dummy_projector # Clean up
        
    # Create dummy Stage 2 data if needed
    data_base = Path(config['paths']['data_base'])
    stage2_json_rel = Path(config['paths']['stage2_data'])
    stage1_img_rel = Path(config['paths']['stage1_images']) # Needed for dummy image paths
    stage1_img_path = data_base / stage1_img_rel
    stage2_json_path = data_base / stage2_json_rel
    stage2_json_path.parent.mkdir(parents=True, exist_ok=True)
    stage1_img_path.mkdir(parents=True, exist_ok=True) # Ensure image dir exists

    img1_rel_path_str = str(stage1_img_rel.name + '/dummy_img1.jpg')
    img2_rel_path_str = str(stage1_img_rel.name + '/dummy_img2.png')
    if not (stage1_img_path / 'dummy_img1.jpg').exists():
        PIL.Image.new('RGB', (60, 30), color = 'red').save(stage1_img_path / 'dummy_img1.jpg')
    if not (stage1_img_path / 'dummy_img2.png').exists():
        PIL.Image.new('RGB', (60, 30), color = 'green').save(stage1_img_path / 'dummy_img2.png')

    if not stage2_json_path.exists() or stage2_json_path.stat().st_size < 10:
        print(f"Creating dummy Stage 2 JSONL: {stage2_json_path}")
        dummy_jsonl_content = [
            {"id": "s2_001", "image": img1_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nDescribe image."}, {"from": "gpt", "value": "It is a red object."}]}, 
            {"id": "s2_002", "image": img2_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nIs it green?"}, {"from": "gpt", "value": "Yes, it appears green."}]},
        ]
        with open(stage2_json_path, 'w') as f:
            for item in dummy_jsonl_content:
                f.write(json.dumps(item) + '\n')
    # -----------------
    
    # Modify config for test
    if 'model' not in config: config['model'] = {}
    if 'peft' not in config['model']: config['model']['peft'] = {}
    config['model']['peft']['use_lora'] = True # Enable LoRA for splitter test
    config['model']['use_activation_checkpointing'] = False # Keep disabled for test
    if 'adaptive_patcher' not in config['model']: config['model']['adaptive_patcher'] = {}
    config['model']['adaptive_patcher']['enabled'] = True # Enable adaptive for test
    config['model']['adaptive_patcher']['strategy'] = 'variable_resolution'
    # Ensure ablation is not forced for this test
    if 'ablation' not in config: config['ablation'] = {}
    config['ablation']['force_patcher_strategy'] = None 

    # Disable W&B for testing unless entity is properly set
    wandb_enabled = config.get('logging', {}).get('wandb', {}).get('enabled', False)
    wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
    if wandb_enabled and (wandb_entity is None or 'your_wandb_entity' in str(wandb_entity)):
        print("Warning: W&B is enabled but entity is not set or default. Disabling W&B for this test.")
        if 'logging' not in config: config['logging'] = {}
        if 'wandb' not in config['logging']: config['logging']['wandb'] = {}
        config['logging']['wandb']['enabled'] = False

    # --- Test Baseline Learner --- 
    baseline_learner = get_stage2_learner(config)
    print("Baseline Stage 2 Learner created successfully.")
    assert isinstance(baseline_learner, Learner)

    # --- Test Adaptive Learner --- 
    adaptive_learner = get_adaptive_stage2_learner(config)
    print("Adaptive Stage 2 Learner created successfully.")
    assert isinstance(adaptive_learner, Learner)
    assert isinstance(adaptive_learner.model, AdaptiveLLaVAModel)
    assert adaptive_learner.model.patcher is not None # Check patcher exists

    print("\nStage 2 Learner setup test passed.")

except FileNotFoundError as e:
    print(f"Skipping test: FileNotFoundError - {e}")
except ImportError as e:
    print(f"Skipping test: ImportError - {e}. (Likely `peft` is missing)")
except RuntimeError as e:
    print(f"Skipping test: RuntimeError - {e}. (Likely CUDA OOM or setup issue)")
except Exception as e:
    import traceback
    print(f"An error occurred during Stage 2 learner setup test: {e}")
    traceback.print_exc()
finally:
    # Clean up memory
    def cleanup_learner(learner: Optional[Learner], name: str):
        if learner is None: return
        try:
            if hasattr(learner, 'model') and learner.model is not None:
                if hasattr(learner.model, 'vision_tower') and learner.model.vision_tower is not None: learner.model.vision_tower.to('cpu')
                if hasattr(learner.model, 'language_model') and learner.model.language_model is not None: learner.model.language_model.to('cpu')
                if hasattr(learner.model, 'projector') and learner.model.projector is not None: learner.model.projector.to('cpu')
                if hasattr(learner.model, 'patcher') and learner.model.patcher is not None: learner.model.patcher.to('cpu')
                del learner.model
            learner.destroy()
            print(f"Cleaned up {name} learner and model memory.")
        except Exception as e:
             print(f"Error during {name} learner cleanup: {e}")
        finally:
            del learner
        
    cleanup_learner(baseline_learner, "baseline_learner")
    cleanup_learner(adaptive_learner, "adaptive_learner")
    
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    
    if wandb.run is not None:
        try:
            if wandb.run.id: wandb.finish()
        except Exception as e: print(f"Error finishing W&B run: {e}")

Loaded config from ../configs/config.yaml
Creating dummy Stage 1 projector weights: /workspace/llava/output/models/stage1_projector.pth
Initializing Projector: Input Dim=1024, Output Dim=4096
Creating dummy Stage 2 JSONL: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
--- Setting up Stage 2 Learner (BaselineLLaVAModel) ---
Loading Stage 2 DataLoaders...
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming images relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.
DataLoaders loaded. Train samples: 1, Valid samples: 1
Instantiating BaselineLLaVAModel for Stage 2...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.




LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Model instantiated successfully.
Attempting to load Stage 1 projector weights from: /workspace/llava/output/models/stage1_projector.pth
Successfully loaded Stage 1 projector weights from /workspace/llava/output/models/stage1_projector.pth
Loss function: LLaVALoss
Optimizer: AdamW (lr=2e-05, wd=0.0)
Parameter splitter: llava_stage2_splitter
Metrics: ['perplexity']
W&B enabled in config, but entity not set or default. Skipping W&B init and callback.
SaveModelCallback is configured but commented out. Manual saving of adapters/projector is preferred.
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
Applying Stage 2 splitter (Baseline)...
  - Collecting projector parameters (trainable).




  - Collecting vision tower parameters (frozen).
Splitter created groups: Trainable (33554432 params), Frozen (6891018240 params)
--- Stage 2 Learner Setup Complete (BaselineLLaVAModel) ---
Baseline Stage 2 Learner created successfully.

--- Setting up Stage 2 Learner (AdaptiveLLaVAModel) ---
Loading Stage 2 DataLoaders...
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming images relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.
DataLoaders loaded. Train samples: 1, Valid samples: 1
Instantiating AdaptiveLLaVAModel for Stage 2...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.




LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Adaptive Patcher enabled with strategy: 'variable_resolution' (VariableResolutionPatcher)
Initialized VariableResolutionPatcher with grid options: [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008], [336, 336]]
Model instantiated successfully.
Attempting to load Stage 1 projector weights from: /workspace/llava/output/models/stage1_projector.pth
Successfully loaded Stage 1 projector weights from /workspace/llava/output/models/stage1_projector.pth
Loss function: LLaVALoss
Optimizer: AdamW (lr=2e-05, wd=0.0)
Parameter splitter: adaptive_llava_stage2_splitter
Metrics: ['perplexity']
W&B enabled in config, but entity not set or default. Skipping W&B init and callback.
SaveModelCallback is configured but commented out. Manual saving of adapters/projector is preferred.
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
Apply



  - Patcher found, but has no trainable parameters.
  - Collecting vision tower parameters (frozen).
Splitter created groups: Trainable (33554432 params), Frozen (6891018240 params)
--- Stage 2 Learner Setup Complete (AdaptiveLLaVAModel) ---
Adaptive Stage 2 Learner created successfully.

Stage 2 Learner setup test passed.
Cleaned up baseline_learner learner and model memory.
Cleaned up adaptive_learner learner and model memory.


## Step 4.5 & 7.2: Implement Stage 2 Training Script (Baseline & Adaptive)

This function orchestrates the Stage 2 training loop, including saving the final weights (projector + LoRA adapters). It can now handle both baseline and adaptive model training.

**Update for Step 8.1:** Modified `_train_stage2_internal` to include ablation name in the saved filenames if an ablation is active.

In [None]:
#| export
def train_stage2(config_path: str | Path, ablation_mode: Optional[str] = None):
    """Loads config, sets up Stage 2 baseline learner, runs training, and saves weights.
    Includes logging of efficiency metrics (Peak VRAM, Training Time).
    Saves the projector weights and LoRA adapter weights (if used) separately.
    
    Args:
        config_path: Path to the YAML configuration file.
        ablation_mode: Optional string specifying the ablation mode (e.g., 'baseline'). 
                      Overrides config if provided.
    """
    _train_stage2_internal(config_path, get_learner_func=get_stage2_learner, ablation_mode=ablation_mode)

#| export
def train_adaptive_stage2(config_path: str | Path, ablation_mode: Optional[str] = None):
    """Loads config, sets up Stage 2 adaptive learner, runs training, and saves weights.
    Includes logging of efficiency metrics (Peak VRAM, Training Time).
    Saves the projector weights and LoRA adapter weights (if used) separately.
    
    Args:
        config_path: Path to the YAML configuration file.
        ablation_mode: Optional string specifying the ablation mode (e.g., 'baseline'). 
                       Overrides config if provided.
    """
    _train_stage2_internal(config_path, get_learner_func=get_adaptive_stage2_learner, ablation_mode=ablation_mode)


#| export
# Internal training function
def _train_stage2_internal(config_path: str | Path, get_learner_func: callable, ablation_mode: Optional[str] = None):
    """Internal function to handle the Stage 2 training loop."""
    model_type_name = 'Adaptive' if get_learner_func == get_adaptive_stage2_learner else 'Baseline'
    print(f"--- Starting Stage 2 Training ({model_type_name} Model) --- ")
    start_run_time = time.time()
    print(f"Loading configuration from: {config_path}")
    config = load_config(config_path)
    
    # --- Handle Ablation Override --- 
    ablation_name = ablation_mode # Use CLI override if provided
    if ablation_name is None: # Otherwise use config setting
         ablation_name = config.get('ablation', {}).get('force_patcher_strategy')
    # Update config dictionary if CLI override was used or config has it
    if ablation_name:
         if 'ablation' not in config: config['ablation'] = {}
         config['ablation']['force_patcher_strategy'] = ablation_name
         if ablation_mode:
              print(f"Overriding ablation mode from command line: {ablation_mode}")
         else:
              print(f"Using ablation mode from config: {ablation_name}")
    # ----------------------------- 

    output_dir = Path(config['paths']['output_dir'])
    models_dir = output_dir / 'models'
    models_dir.mkdir(parents=True, exist_ok=True)
    
    learner = None # Initialize learner to None for finally block
    run = None # Initialize wandb run object
    try:
        # --- Get Learner (using the passed function) --- 
        # Pass the potentially modified config
        learner = get_learner_func(config)
        run = wandb.run # Get the active run object if W&B was initialized

        # --- Reset CUDA memory stats before training ---
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats() 
            print("Reset CUDA peak memory stats before training.")
            
        # --- Start Training --- 
        epochs = config.get('training', {}).get('num_epochs_stage2', 3)
        lr = config.get('training', {}).get('learning_rate_stage2', 2e-5)
        print(f"Starting training for {epochs} epochs with max_lr={lr}...")
        start_train_time = time.time()

        learner.fit_one_cycle(epochs, lr_max=lr)
        
        end_train_time = time.time()
        total_train_time_sec = end_train_time - start_train_time
        print(f"Training finished in {total_train_time_sec:.2f} seconds.")
        
        # --- Log Efficiency Metrics (Step 5.4) --- 
        if run:
            wandb.log({"train/stage2_total_training_time_sec": total_train_time_sec})
        if torch.cuda.is_available():
            peak_vram_gb = torch.cuda.max_memory_allocated() / (1024**3)
            print(f"Peak Training VRAM used (Stage 2): {peak_vram_gb:.2f} GB")
            if run: wandb.log({"train/stage2_peak_vram_gb": peak_vram_gb})
        
        # --- Save final trained weights --- 
        # Define specific save names based on model type and ablation status
        model_base_name = Path(config['paths']['stage2_model_weights']).stem
        ablation_tag = f"_abl-{ablation_name}" if ablation_name else "" # Use determined ablation name
        save_prefix = f"{model_base_name}_{model_type_name.lower()}{ablation_tag}"
        
        # 1. Save Projector Weights
        projector_save_path = models_dir / f"{save_prefix}_projector_final.pth"
        print(f"Saving final projector weights to: {projector_save_path}")
        if hasattr(learner.model, 'projector') and learner.model.projector is not None:
             torch.save(learner.model.projector.state_dict(), projector_save_path)
             print("Projector weights saved.")
        else: print("Warning: Cannot save projector weights, model has no projector.")
             
        # 2. Save LoRA Adapters (if LoRA was used)
        use_lora_config = config.get('model', {}).get('peft', {}).get('use_lora', False)
        lora_applied = _peft_available and use_lora_config and hasattr(learner.model, 'language_model') and isinstance(learner.model.language_model, PeftModel)
        if lora_applied:
            lora_save_dir = models_dir / f"{save_prefix}_lora_adapters"
            print(f"Saving LoRA adapters to: {lora_save_dir}")
            try:
                lora_save_dir.mkdir(parents=True, exist_ok=True)
                # Use save_pretrained method from PeftModel
                learner.model.language_model.save_pretrained(str(lora_save_dir))
                print("LoRA adapters saved successfully.")
            except Exception as e: print(f"Error saving LoRA adapters: {e}"); traceback.print_exc()
        elif use_lora_config: print("LoRA configured but not applied. Cannot save adapters.")
        else: print("LoRA was not enabled in config.")
        
        # 3. Save Patcher Weights (if adaptive and trainable)
        if model_type_name == 'Adaptive' and hasattr(learner.model, 'patcher') and learner.model.patcher is not None:
            patcher_params = list(learner.model.patcher.parameters())
            if any(p.requires_grad for p in patcher_params):
                patcher_save_path = models_dir / f"{save_prefix}_patcher_final.pth"
                print(f"Saving final patcher weights to: {patcher_save_path}")
                torch.save(learner.model.patcher.state_dict(), patcher_save_path)
                print("Patcher weights saved.")
            else:
                print("Adaptive patcher exists but has no trainable parameters to save.")

    except Exception as e:
        print(f"An error occurred during Stage 2 training ({model_type_name}): {e}")
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        raise e
    finally:
        if learner is not None:
             cleanup_learner(learner, model_type_name.lower()) # Use helper cleanup
        if run and wandb.run and wandb.run.id == run.id: # Ensure we finish the correct run
            try: wandb.finish()
            except Exception as e: print(f"Error finishing W&B run: {e}")
            
    end_run_time = time.time()
    total_script_time = end_run_time - start_run_time
    print(f"Total Stage 2 script execution time ({model_type_name}): {total_script_time:.2f} seconds.")
    print(f"--- Stage 2 Training Complete ({model_type_name} Model) --- ")

# Helper cleanup function (to avoid repetition)
def cleanup_learner(learner: Optional[Learner], name: str):
    if learner is None: return
    try:
        if hasattr(learner, 'model') and learner.model is not None:
            if hasattr(learner.model, 'vision_tower') and learner.model.vision_tower is not None: learner.model.vision_tower.to('cpu')
            if hasattr(learner.model, 'language_model') and learner.model.language_model is not None: learner.model.language_model.to('cpu')
            if hasattr(learner.model, 'projector') and learner.model.projector is not None: learner.model.projector.to('cpu')
            if hasattr(learner.model, 'patcher') and learner.model.patcher is not None: learner.model.patcher.to('cpu')
            del learner.model
        learner.destroy()
        print(f"Cleaned up {name} learner and model memory.")
    except Exception as e:
         print(f"Error during {name} learner cleanup: {e}")
    finally:
        del learner
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

In [None]:
# show_doc(train_stage2) # Omitted for script execution
# show_doc(train_adaptive_stage2) # Omitted for script execution

#### Example Usage (Execution)

In [None]:
#| hide 
# Example of how to run this from within the notebook (for testing purposes)
# Requires dummy data and Stage 1 weights to be set up as in the learner test cell

try:
    # Reduce epochs for testing
    _test_config = load_config('../configs/config.yaml')
    _test_config['training']['num_epochs_stage2'] = 1 # Just 1 epoch for testing
    # Enable adaptive patcher and LoRA
    if 'model' not in _test_config: _test_config['model'] = {}
    if 'peft' not in _test_config['model']: _test_config['model']['peft'] = {}
    if 'adaptive_patcher' not in _test_config['model']: _test_config['model']['adaptive_patcher'] = {}
    _test_config['model']['peft']['use_lora'] = True
    _test_config['model']['adaptive_patcher']['enabled'] = True
    _test_config['model']['adaptive_patcher']['strategy'] = 'variable_resolution'
    # Ensure ablation is OFF for the baseline run
    if 'ablation' not in _test_config: _test_config['ablation'] = {}
    _test_config['ablation']['force_patcher_strategy'] = None 
    
    # Ensure dummy data exists (copied from learner test)
    output_dir = Path(_test_config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / 'models').mkdir(parents=True, exist_ok=True)
    stage1_weights_fname = _test_config['paths'].get('stage1_projector_weights', 'stage1_projector.pth')
    stage1_weights_path = output_dir / 'models' / stage1_weights_fname
    if not stage1_weights_path.is_file():
        proj_input_dim = _test_config['model']['projector']['input_dim']
        proj_output_dim = _test_config['model']['projector']['output_dim']
        from llava.model.baseline import LLaVAProjector
        dummy_projector = LLaVAProjector(proj_input_dim, proj_output_dim)
        torch.save(dummy_projector.state_dict(), stage1_weights_path)
        del dummy_projector
    data_base = Path(_test_config['paths']['data_base'])
    stage2_json_rel = Path(_test_config['paths']['stage2_data'])
    stage1_img_rel = Path(_test_config['paths']['stage1_images'])
    stage1_img_path = data_base / stage1_img_rel
    stage2_json_path = data_base / stage2_json_rel
    stage2_json_path.parent.mkdir(parents=True, exist_ok=True)
    stage1_img_path.mkdir(parents=True, exist_ok=True)
    img1_rel_path_str = str(stage1_img_rel.name + '/dummy_img1.jpg')
    img2_rel_path_str = str(stage1_img_rel.name + '/dummy_img2.png')
    if not (stage1_img_path / 'dummy_img1.jpg').exists(): PIL.Image.new('RGB', (60, 30), color = 'red').save(stage1_img_path / 'dummy_img1.jpg')
    if not (stage1_img_path / 'dummy_img2.png').exists(): PIL.Image.new('RGB', (60, 30), color = 'green').save(stage1_img_path / 'dummy_img2.png')
    if not stage2_json_path.exists() or stage2_json_path.stat().st_size < 10:
        dummy_jsonl_content = [
            {"id": "s2_001", "image": img1_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nDescribe image."}, {"from": "gpt", "value": "It is a red object."}]}, 
            {"id": "s2_002", "image": img2_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nIs it green?"}, {"from": "gpt", "value": "Yes, it appears green."}]},
        ]
        with open(stage2_json_path, 'w') as f: [f.write(json.dumps(item) + '\n') for item in dummy_jsonl_content]
    
    # Save the modified config for the training function to load
    import yaml
    temp_config_path = '../configs/temp_test_config.yaml'
    with open(temp_config_path, 'w') as f_temp:
         yaml.dump(_test_config, f_temp)
    
    # Run baseline training first (ablation=None by default)
    print("--- Running Baseline Stage 2 Training --- ")
    train_stage2(temp_config_path)
    
    # Run adaptive training with forced baseline ablation
    print("\n--- Running Adaptive Stage 2 Training --- ")
    train_adaptive_stage2(temp_config_path, ablation_mode='baseline')
    
    # Clean up temp config
    # os.remove(temp_config_path) 
    print("Stage 2 Training Test Run Complete.")
except FileNotFoundError as e:
    print(f"Skipping training test: FileNotFoundError - {e}")
except RuntimeError as e:
     print(f"Skipping training test: RuntimeError - {e}. (Likely CUDA OOM or model setup issue)")
except Exception as e:
    print(f"An error occurred during Stage 2 training test run: {e}")
    traceback.print_exc()

--- Running Baseline Stage 2 Training --- 
--- Starting Stage 2 Training (Baseline Model) --- 
Loading configuration from: ../configs/temp_test_config.yaml
Using ablation mode from config: None
--- Setting up Stage 2 Learner (BaselineLLaVAModel) ---
Loading Stage 2 DataLoaders...
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming images relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.
DataLoaders loaded. Train samples: 1, Valid samples: 1
Instantiating BaselineLLaVAModel for Stage 2...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.




LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Model instantiated successfully.
Attempting to load Stage 1 projector weights from: /workspace/llava/output/models/stage1_projector.pth
Successfully loaded Stage 1 projector weights from /workspace/llava/output/models/stage1_projector.pth
Loss function: LLaVALoss
Optimizer: AdamW (lr=2e-05, wd=0.0)
Parameter splitter: llava_stage2_splitter
Metrics: ['perplexity']
W&B enabled in config, but entity not set or default. Skipping W&B init and callback.
SaveModelCallback is configured but commented out. Manual saving of adapters/projector is preferred.
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
Applying Stage 2 splitter (Baseline)...
  - Collecting projector parameters (trainable).




  - Collecting vision tower parameters (frozen).
Splitter created groups: Trainable (33554432 params), Frozen (6891018240 params)
--- Stage 2 Learner Setup Complete (BaselineLLaVAModel) ---
Reset CUDA peak memory stats before training.
Starting training for 1 epochs with max_lr=2e-05...






  if self.num_workers > 0 and warn:


Training finished in 6.75 seconds.
Peak Training VRAM used (Stage 2): 11.80 GB
Saving final projector weights to: /workspace/llava/output/models/stage2_llava_lora_baseline_projector_final.pth
Projector weights saved.
LoRA configured but not applied. Cannot save adapters.
Cleaned up baseline learner and model memory.
Total Stage 2 script execution time (Baseline): 35.79 seconds.
--- Stage 2 Training Complete (Baseline Model) --- 

--- Running Adaptive Stage 2 Training --- 
--- Starting Stage 2 Training (Adaptive Model) --- 
Loading configuration from: ../configs/temp_test_config.yaml
Using ablation mode from config: baseline
--- Setting up Stage 2 Learner (AdaptiveLLaVAModel) ---
Loading Stage 2 DataLoaders...
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming images relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.




LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Adaptive Patcher enabled with strategy: 'variable_resolution' (VariableResolutionPatcher)
Initialized VariableResolutionPatcher: Ablation active - FORCING BASELINE grid (336x336).
Model instantiated successfully.
Attempting to load Stage 1 projector weights from: /workspace/llava/output/models/stage1_projector.pth
Successfully loaded Stage 1 projector weights from /workspace/llava/output/models/stage1_projector.pth
Loss function: LLaVALoss
Optimizer: AdamW (lr=2e-05, wd=0.0)
Parameter splitter: adaptive_llava_stage2_splitter
Metrics: ['perplexity']
W&B enabled in config, but entity not set or default. Skipping W&B init and callback.
SaveModelCallback is configured but commented out. Manual saving of adapters/projector is preferred.
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
Applying Stage 2 splitter (Adaptive)...
  - Col



  - Patcher found, but has no trainable parameters.
  - Collecting vision tower parameters (frozen).
Splitter created groups: Trainable (33554432 params), Frozen (6891018240 params)
--- Stage 2 Learner Setup Complete (AdaptiveLLaVAModel) ---
Reset CUDA peak memory stats before training.
Starting training for 1 epochs with max_lr=2e-05...






  if self.num_workers > 0 and warn:


Training finished in 6.66 seconds.
Peak Training VRAM used (Stage 2): 11.80 GB
Saving final projector weights to: /workspace/llava/output/models/stage2_llava_lora_adaptive_abl-baseline_projector_final.pth
Projector weights saved.
LoRA configured but not applied. Cannot save adapters.
Adaptive patcher exists but has no trainable parameters to save.
Cleaned up adaptive learner and model memory.
Total Stage 2 script execution time (Adaptive): 34.81 seconds.
--- Stage 2 Training Complete (Adaptive Model) --- 
Stage 2 Training Test Run Complete.


In [None]:
#| export
# Command-line execution block for Stage 2
if __name__ == "__main__" and "get_ipython" not in locals():
    parser = argparse.ArgumentParser(description="Run LLaVA Stage 2 Training")
    parser.add_argument("--config", type=str, default="configs/config.yaml", 
                        help="Path to the configuration YAML file (relative to project root or absolute).")
    parser.add_argument("--model_type", type=str, default="baseline", choices=['baseline', 'adaptive'],
                        help="Choose model type to train: 'baseline' or 'adaptive'.")
    parser.add_argument("--ablation", type=str, default=None, choices=[None, 'baseline'], # Add ablation choices
                        help="Optional ablation mode (e.g., 'baseline' to force base grid). Overrides config.")
    args = parser.parse_args()
    
    # Resolve config path relative to project root (defined earlier in the script's import section)
    config_file_path = project_root / args.config
    
    if not config_file_path.is_file():
        print(f"Error: Config file not found at {config_file_path}")
        sys.exit(1)

    try:
        if args.model_type == 'baseline':
            # Ablation doesn't apply to baseline model directly, but pass None
            train_stage2(config_path=config_file_path, ablation_mode=None) 
        elif args.model_type == 'adaptive':
            train_adaptive_stage2(config_path=config_file_path, ablation_mode=args.ablation)
        else:
             # This shouldn't happen due to choices in argparse
             print(f"Error: Invalid model_type '{args.model_type}'. Choose 'baseline' or 'adaptive'.")
             sys.exit(1)
    except NotImplementedError as e:
         print(f"Exiting: {e}") 
         sys.exit(0) 
    except Exception as e:
        print(f"Stage 2 training setup or execution failed for model type '{args.model_type}'{(' with ablation ' + args.ablation) if args.ablation else ''}: {e}")
        traceback.print_exc()
        # W&B run should be finished by the finally block in train_stage2/train_adaptive_stage2
        sys.exit(1)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()