# Stage 2 Training: Instruction Fine-tuning

> Sets up and runs the second stage of LLaVA training: fine-tuning the LLM (using LoRA) and the projector on instruction-following data.

In [None]:
#| default_exp training.stage2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os
import gc # For memory cleanup
import argparse # For command-line execution

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose for script
    pass

Adding project root to sys.path: /workspace/llava


In [None]:
#| export
import torch
import torch.nn as nn
from fastai.learner import Learner
from fastai.optimizer import AdamW
from fastai.callback.wandb import WandbCallback
from fastai.callback.schedule import fit_one_cycle
from fastai.callback.save import SaveModelCallback
from fastai.callback.training import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision
from fastai.vision.all import params # For splitter
from fastai.data.core import DataLoaders
from functools import partial
import wandb # Import wandb directly for cleanup
import json # For dummy data creation
import PIL.Image # For dummy data creation
from typing import List, Optional

# Attempt to import peft, set flag
try:
    from peft import PeftModel
    _peft_available = True
except ImportError:
    print("Warning: peft library not found. LoRA functionality will be disabled.")
    PeftModel = None # Define as None if not available
    _peft_available = False

try:
    from llava.utils import load_config, init_wandb
    from llava.data.loading import get_stage2_dataloaders
    from llava.model.baseline import BaselineLLaVAModel
    from llava.training.core import LLaVALoss
except ImportError as e:
     print(f"Error importing llava modules: {e}")
     print("Ensure that nbdev_export has been run and the llava library is installed/accessible.")
     # In a script context, it's better to exit if core modules are missing
     if __name__ == "__main__" and "get_ipython" not in locals():
          sys.exit(1)

Project root already in sys.path: /workspace/llava


Loaded config from configs/config.yaml
Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336
LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0
LLaVADataBlockStage2 defined.
LLaVALoss initialized, ignoring index: -100


## Stage 2 Splitter Function

In Stage 2, we fine-tune the `projector` and the `language_model`. If LoRA is enabled (`use_lora: true` in config), only the LoRA adapter parameters within the language model are trained. If LoRA is disabled, the entire language model is trained. The `vision_tower` remains frozen by default.

In [None]:
#| export
def llava_stage2_splitter(model: BaselineLLaVAModel):
    """Splits the `BaselineLLaVAModel` parameters for Stage 2 training.

    Trains the `projector` and LoRA adapters (if enabled) or the full `language_model`.
    Keeps the `vision_tower` frozen by default.

    Args:
        model: An instance of `BaselineLLaVAModel`.

    Returns:
        A list containing parameter groups for trainable components.
    """
    projector_params = []
    llm_params = []
    frozen_params = []

    print("Applying Stage 2 splitter...")

    # Projector parameters are always trained in Stage 2
    if hasattr(model, 'projector') and model.projector is not None:
        print("  - Collecting projector parameters (trainable).")
        projector_params.extend(list(model.projector.parameters()))
        for p in model.projector.parameters():
             p.requires_grad = True
    else:
        print("Warning: Model has no projector attribute.")

    # Handle Language Model parameters based on LoRA configuration
    if hasattr(model, 'language_model') and model.language_model is not None:
        # Access config stored within the model instance
        # Handle potential nested structure in config access more safely
        use_lora = model.config.get('model', {}).get('peft', {}).get('use_lora', False)

        if _peft_available and use_lora and isinstance(model.language_model, PeftModel):
            print("  - LoRA enabled: Collecting LLM adapter parameters (trainable).")
            # PEFT model automatically handles requires_grad for adapters
            llm_params.extend([p for p in model.language_model.parameters() if p.requires_grad])
            # Base model parameters should already be frozen by PEFT
            # We can double-check frozen status for verification
            # frozen_params.extend([p for p in model.language_model.parameters() if not p.requires_grad])
        elif use_lora and not _peft_available:
             print("Warning: LoRA configured but PEFT library not found. Cannot train LoRA adapters. Freezing LLM.")
             for p in model.language_model.parameters():
                 p.requires_grad = False
             # frozen_params.extend(list(model.language_model.parameters()))
        else:
            print("  - LoRA disabled: Collecting all LLM parameters (trainable).")
            llm_params.extend(list(model.language_model.parameters()))
            for p in model.language_model.parameters():
                 p.requires_grad = True # Ensure all LLM params are trainable if not using LoRA
    else:
        print("Warning: Model has no language_model attribute.")

    # Vision Tower parameters are frozen by default
    if hasattr(model, 'vision_tower') and model.vision_tower is not None:
        print("  - Collecting vision tower parameters (frozen).")
        # frozen_params.extend(list(model.vision_tower.parameters())) # Collect if needed for verification
        for p in model.vision_tower.parameters():
            p.requires_grad = False
    else:
        print("Warning: Model has no vision_tower attribute.")

    # Combine trainable parameters into one group for the optimizer
    trainable_groups = projector_params + llm_params
    # Count frozen parameters for verification
    frozen_count = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    trainable_count = sum(p.numel() for p in trainable_groups)
    print(f"Splitter created groups: Trainable ({trainable_count} params), Frozen ({frozen_count} params)")
    
    if not trainable_groups:
         raise ValueError("Splitter function resulted in no trainable parameters. Check model structure and config.")
         
    return [trainable_groups]

In [None]:
# show_doc(llava_stage2_splitter) # Omitted for script execution

## Step 4.3: Setup Learner Configuration (Stage 2)

This function sets up the `Learner` object for Stage 2 instruction tuning.

In [None]:
#| export
def get_stage2_learner(config: dict) -> Learner:
    """Configures and returns a fastai Learner for Stage 2 Instruction Fine-tuning.

    Loads Stage 1 projector weights, sets up the model (potentially with LoRA),
    uses the Stage 2 splitter, and includes relevant callbacks.

    Args:
        config: The main configuration dictionary.

    Returns:
        A configured fastai Learner instance for Stage 2.

    Raises:
        RuntimeError: If DataLoaders or Model instantiation fails.
        FileNotFoundError: If Stage 1 projector weights or data paths are invalid.
        AttributeError: If the model is missing expected components.
    """
    print("--- Setting up Stage 2 Learner ---")
    output_dir = Path(config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)

    # 1. Load Stage 2 DataLoaders
    print("Loading Stage 2 DataLoaders...")
    try:
        dls = get_stage2_dataloaders(config)
    except (FileNotFoundError, Exception) as e:
        print(f"Error loading Stage 2 DataLoaders: {e}")
        raise RuntimeError("Failed to create Stage 2 DataLoaders.") from e
    if not dls:
        raise RuntimeError("Stage 2 DataLoaders object is None.")
    print(f"DataLoaders loaded. Train samples: {len(dls.train_ds)}, Valid samples: {len(dls.valid_ds)}")

    # 2. Instantiate Model (handles LoRA based on config)
    print("Instantiating BaselineLLaVAModel for Stage 2...")
    try:
        model = BaselineLLaVAModel(config)
        if model.vision_tower is None or model.language_model is None or model.projector is None:
            raise RuntimeError("BaselineLLaVAModel initialization incomplete.")
        print("Model instantiated successfully.")
    except Exception as e:
        print(f"Error instantiating BaselineLLaVAModel: {e}")
        raise RuntimeError("Failed to instantiate baseline model for Stage 2.") from e

    # 3. Load Stage 1 Projector Weights
    stage1_weights_fname = config['paths'].get('stage1_projector_weights', 'stage1_projector.pth')
    # Look for weights relative to output_dir/models/
    stage1_weights_path = output_dir / 'models' / stage1_weights_fname
    print(f"Attempting to load Stage 1 projector weights from: {stage1_weights_path}")
    if stage1_weights_path.is_file():
        try:
            # Load state dict onto CPU first to avoid device mismatches
            projector_state_dict = torch.load(stage1_weights_path, map_location='cpu')
            # Load into the model's projector
            model.projector.load_state_dict(projector_state_dict)
            print(f"Successfully loaded Stage 1 projector weights from {stage1_weights_path}")
        except Exception as e:
            print(f"Error loading Stage 1 projector weights: {e}")
            # Decide whether to raise error or continue with random init projector
            # For reproducibility, it's usually better to raise if weights are expected.
            raise RuntimeError(f"Failed to load expected Stage 1 projector weights from {stage1_weights_path}") from e
    else:
        print(f"Warning: Stage 1 projector weights not found at {stage1_weights_path}. Projector will use initial weights.")
        # Optionally raise an error if pre-trained weights are strictly required
        # raise FileNotFoundError(f"Stage 1 projector weights not found: {stage1_weights_path}")

    # 4. Define Loss Function
    loss_func = LLaVALoss()
    print(f"Loss function: {type(loss_func).__name__}")

    # 5. Define Optimizer
    lr = config.get('training', {}).get('learning_rate_stage2', 2e-5) # Lower LR for fine-tuning
    wd = config.get('training', {}).get('weight_decay', 0.0)
    opt_func = partial(AdamW, lr=lr, wd=wd, eps=1e-8)
    print(f"Optimizer: AdamW (lr={lr}, wd={wd})")

    # 6. Define Splitter
    splitter = llava_stage2_splitter
    print(f"Parameter splitter: {splitter.__name__}")

    # 7. Define Callbacks
    cbs = []
    if config.get('logging', {}).get('wandb', {}).get('enabled', False):
        # Init W&B Run only if enabled and entity is likely configured
        wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
        if wandb_entity and 'your_wandb_entity' not in str(wandb_entity):
            project_name = config.get('logging', {}).get('wandb', {}).get('project', 'llava-adaptive-patching')
            run_name_prefix = config.get('logging', {}).get('wandb', {}).get('run_name_prefix', 'stage2')
            stage2_model_name = Path(config['paths']['stage2_model_weights']).stem
            run_name = f"{run_name_prefix}_{stage2_model_name}_{wandb.util.generate_id()}"
            init_wandb(config, job_type="stage2-training", run_name=run_name)
            cbs.append(WandbCallback(log_preds=False, log_model=False))
            print("Added WandbCallback.")
        else:
            print("W&B enabled in config, but entity not set or default. Skipping W&B init and callback.")

    # SaveModelCallback for Stage 2 (saves best full model state)
    # Adapter saving might need custom handling in train loop or separate script (Step 4.5)
    stage2_model_fname = Path(config['paths']['stage2_model_weights']).stem
    save_cb = SaveModelCallback(
        monitor='valid_loss',
        min_delta=0.001,
        fname=stage2_model_fname, # Saves best model state
        every_epoch=False,
        with_opt=True, # Save optimizer state to resume training if needed
        reset_on_fit=True
    )
    # Note: SaveModelCallback saves the entire learner state, including the full model.
    # For LoRA, we'll manually save adapters at the end of training in train_stage2.
    # cbs.append(save_cb) # We might disable this if saving adapters manually is sufficient.
    print(f"SaveModelCallback is configured but commented out. Manual saving of adapters/projector in train_stage2 is preferred for LoRA.")

    # Optimization Callbacks
    grad_accum_steps = config.get('training', {}).get('gradient_accumulation_steps', 1)
    if grad_accum_steps > 1:
        cbs.append(GradientAccumulation(grad_accum_steps))
        print(f"Added GradientAccumulation callback with {grad_accum_steps} steps.")
    
    use_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    if use_mixed_precision:
        cbs.append(MixedPrecision())
        print("Added MixedPrecision callback.")

    # 8. Create Learner
    try:
        learner = Learner(
            dls=dls,
            model=model,
            loss_func=loss_func,
            opt_func=opt_func,
            splitter=splitter,
            cbs=cbs,
            path=output_dir,
            train_bn=False # Typically False for LLM fine-tuning
        )
    except Exception as e:
        print(f"Error creating Stage 2 Learner: {e}")
        if wandb.run is not None:
            wandb.finish(exit_code=1)
            print("Finished W&B run due to error during Learner creation.")
        raise RuntimeError("Failed to create the Stage 2 Learner object.") from e

    print("--- Stage 2 Learner Setup Complete ---")
    return learner

In [None]:
# show_doc(get_stage2_learner) # Omitted for script execution

#### Example Usage & Test (Learner Configuration - Stage 2)

In [None]:
#| test 
import gc
from fastai.callback.save import SaveModelCallback
from fastai.callback.wandb import WandbCallback
from fastai.callback.training import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision
# import llava # Import base library for module check
from llava.model.baseline import LLaVAProjector, BaselineLLaVAModel # Import required classes

try:
    # Load config
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # --- Test Setup ---
    output_dir = Path(config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / 'models').mkdir(parents=True, exist_ok=True)
    
    # Create dummy Stage 1 weights if they don't exist
    stage1_weights_fname = config['paths'].get('stage1_projector_weights', 'stage1_projector.pth')
    stage1_weights_path = output_dir / 'models' / stage1_weights_fname
    if not stage1_weights_path.is_file():
        print(f"Creating dummy Stage 1 projector weights: {stage1_weights_path}")
        # Create a dummy projector state dict matching the config
        proj_input_dim = config['model']['projector']['input_dim']
        proj_output_dim = config['model']['projector']['output_dim']
        dummy_projector = LLaVAProjector(proj_input_dim, proj_output_dim)
        torch.save(dummy_projector.state_dict(), stage1_weights_path)
        del dummy_projector # Clean up
        
    # Create dummy Stage 2 data if needed
    data_base = Path(config['paths']['data_base'])
    stage2_json_rel = Path(config['paths']['stage2_data'])
    stage1_img_rel = Path(config['paths']['stage1_images']) # Needed for dummy image paths
    stage1_img_path = data_base / stage1_img_rel
    stage2_json_path = data_base / stage2_json_rel
    stage2_json_path.parent.mkdir(parents=True, exist_ok=True)
    stage1_img_path.mkdir(parents=True, exist_ok=True) # Ensure image dir exists

    img1_rel_path_str = str(stage1_img_rel.name + '/dummy_img1.jpg')
    img2_rel_path_str = str(stage1_img_rel.name + '/dummy_img2.png')
    if not (stage1_img_path / 'dummy_img1.jpg').exists():
        PIL.Image.new('RGB', (60, 30), color = 'red').save(stage1_img_path / 'dummy_img1.jpg')
    if not (stage1_img_path / 'dummy_img2.png').exists():
        PIL.Image.new('RGB', (60, 30), color = 'green').save(stage1_img_path / 'dummy_img2.png')

    if not stage2_json_path.exists() or stage2_json_path.stat().st_size < 10:
        print(f"Creating dummy Stage 2 JSONL: {stage2_json_path}")
        dummy_jsonl_content = [
            {"id": "s2_001", "image": img1_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nDescribe image."}, {"from": "gpt", "value": "It is a red object."}]}, 
            {"id": "s2_002", "image": img2_rel_path_str, "conversations": [{"from": "human", "value": "<image>\nIs it green?"}, {"from": "gpt", "value": "Yes, it appears green."}]},
        ]
        with open(stage2_json_path, 'w') as f:
            for item in dummy_jsonl_content:
                f.write(json.dumps(item) + '\n')
    # -----------------
    
    # Ensure LoRA is enabled in config for testing the splitter
    # Safely modify the dictionary
    if 'model' not in config: config['model'] = {}
    if 'peft' not in config['model']: config['model']['peft'] = {}
    config['model']['peft']['use_lora'] = True
    config['model']['use_activation_checkpointing'] = False # Keep disabled for test

    # Disable W&B for testing unless entity is properly set
    wandb_enabled = config.get('logging', {}).get('wandb', {}).get('enabled', False)
    wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
    if wandb_enabled and (wandb_entity is None or 'your_wandb_entity' in str(wandb_entity)):
        print("Warning: W&B is enabled but entity is not set or default. Disabling W&B for this test.")
        if 'logging' not in config: config['logging'] = {}
        if 'wandb' not in config['logging']: config['logging']['wandb'] = {}
        config['logging']['wandb']['enabled'] = False

    # Create learner
    stage2_learner = get_stage2_learner(config)
    print("Stage 2 Learner created successfully.")

    # Basic checks
    assert isinstance(stage2_learner, Learner)
    assert isinstance(stage2_learner.dls, DataLoaders)
    assert isinstance(stage2_learner.model, BaselineLLaVAModel)
    assert isinstance(stage2_learner.loss_func, LLaVALoss)
    assert len(stage2_learner.opt.param_groups) == 1 # Splitter should return one group

    # Check parameter freezing status based on splitter
    print("Checking parameter groups...")
    trainable_param_count = 0
    frozen_param_count = 0
    trainable_param_names = []
    frozen_param_names = []
    for name, param in stage2_learner.model.named_parameters():
        if param.requires_grad:
            trainable_param_count += param.numel()
            trainable_param_names.append(name)
        else:
            frozen_param_count += param.numel()
            frozen_param_names.append(name)
            
    print(f"  Group 0: {trainable_param_count} parameters (Trainable)")
    print("Checking parameters in group 0 (first few):")
    for i in range(min(10, len(trainable_param_names))):
        name = trainable_param_names[i]
        print(f"  - {name}: {stage2_learner.model.get_parameter(name).requires_grad}")

    is_projector_trainable = any('projector' in name for name in trainable_param_names)
    is_lora_trainable = any('lora_' in name for name in trainable_param_names)
    is_vision_frozen = all('vision_tower' in name for name in frozen_param_names if 'vision_tower' in name)
    is_base_llm_frozen = all('language_model.base_model' in name and 'lora_' not in name for name in frozen_param_names if 'language_model' in name)
    
    assert is_projector_trainable, "Projector parameters not found in trainable group."
    
    if _peft_available and config['model']['peft']['use_lora']:
        assert is_lora_trainable, "LoRA adapter parameters not found in trainable group when LoRA is enabled."
        assert is_base_llm_frozen, "Base LLM parameters are not frozen when LoRA is enabled."
    else:
         # If LoRA disabled, all LLM params should be trainable
         is_llm_trainable = any('language_model' in name for name in trainable_param_names)
         # If PEFT is not available, LoRA cannot be applied, so the LLM should be frozen by the splitter's warning path
         if not _peft_available and config['model']['peft']['use_lora']:
              assert not is_llm_trainable, "LLM parameters are trainable when LoRA was configured but PEFT is unavailable."
         else:
              assert is_llm_trainable, "LLM parameters not found in trainable group when LoRA is disabled."
         
    assert is_vision_frozen, "Vision tower parameters are not frozen."
    print("Checking frozen parameters (examples):")
    if frozen_param_names:
        print(f"  - {frozen_param_names[0]}: {stage2_learner.model.get_parameter(frozen_param_names[0]).requires_grad}")
        if len(frozen_param_names) > 1: print(f"  - {frozen_param_names[1]}: {stage2_learner.model.get_parameter(frozen_param_names[1]).requires_grad}")
    
    print("Splitter check passed: Trainable/Frozen status seems correct based on LoRA config.")

    # Check callbacks
    # has_save_cb = any(isinstance(cb, SaveModelCallback) for cb in stage2_learner.cbs)
    expect_grad_accum = config.get('training', {}).get('gradient_accumulation_steps', 1) > 1
    has_grad_accum = any(isinstance(cb, GradientAccumulation) for cb in stage2_learner.cbs)
    expect_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    has_mixed_precision = any(isinstance(cb, MixedPrecision) for cb in stage2_learner.cbs)
    # Check W&B based on actual decision made in get_stage2_learner
    wandb_init_decision = config.get('logging', {}).get('wandb', {}).get('enabled', False) and \
                          config.get('logging', {}).get('wandb', {}).get('entity') is not None and \
                         'your_wandb_entity' not in str(config.get('logging', {}).get('wandb', {}).get('entity'))
    has_wandb_cb = any(isinstance(cb, WandbCallback) for cb in stage2_learner.cbs)
    
    # assert has_save_cb # Disabled check as SaveModelCallback might be commented out
    assert has_grad_accum == expect_grad_accum
    assert has_mixed_precision == expect_mixed_precision
    assert has_wandb_cb == wandb_init_decision # Check if WandbCallback added matches decision
    print("Callback check passed.")
    
    # print("Learner summary:") # Removed to avoid potential JSON errors in output
    # stage2_learner.summary()

    print("\nStage 2 Learner setup test passed.")

except FileNotFoundError as e:
    print(f"Skipping test: FileNotFoundError - {e}")
except ImportError as e:
    print(f"Skipping test: ImportError - {e}. (Likely `peft` is missing)")
except RuntimeError as e:
    print(f"Skipping test: RuntimeError - {e}. (Likely CUDA OOM or setup issue)")
except Exception as e:
    import traceback
    print(f"An error occurred during Stage 2 learner setup test: {e}")
    traceback.print_exc()
finally:
    # Clean up memory
    if 'stage2_learner' in locals() and stage2_learner is not None:
        if hasattr(stage2_learner, 'model') and stage2_learner.model is not None:
            # Move components to CPU before deleting
            if hasattr(stage2_learner.model, 'vision_tower') and stage2_learner.model.vision_tower is not None:
                stage2_learner.model.vision_tower.to('cpu')
            if hasattr(stage2_learner.model, 'language_model') and stage2_learner.model.language_model is not None:
                stage2_learner.model.language_model.to('cpu')
            if hasattr(stage2_learner.model, 'projector') and stage2_learner.model.projector is not None:
                stage2_learner.model.projector.to('cpu')
            del stage2_learner.model
        stage2_learner.destroy() # Clean up learner properly
        del stage2_learner
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Cleaned up stage2_learner and model")
    # Terminate wandb run if it was initialized and not already finished
    if wandb.run is not None:
        try:
            if wandb.run.id:
                wandb.finish()
                # print("Finished W&B run.") # Less verbose in script
        except Exception as e:
            print(f"Error finishing W&B run: {e}")

Loaded config from ../configs/config.yaml
Creating dummy Stage 1 projector weights: /workspace/llava/output/models/stage1_projector.pth
Initializing Projector: Input Dim=1024, Output Dim=4096
Creating dummy Stage 2 JSONL: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
--- Setting up Stage 2 Learner ---
Loading Stage 2 DataLoaders...
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming images relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.
DataLoaders loaded. Train samples: 1, Valid samples: 1
Instantiating BaselineLLaVAModel for Stage 2...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.
LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Model instantiated successfully.
Attempting to load Stage 1 projector weights from: /workspace/llava/output/models/stage1_projector.pth
Successfully loaded Stage 1 projector weights from /workspace/llava/output/models/stage1_projector.pth
Loss function: LLaVALoss
Optimizer: AdamW (lr=2e-05, wd=0.0)
Parameter splitter: llava_stage2_splitter
W&B enabled in config, but entity not set or default. Skipping W&B init and callback.
SaveModelCallback is configured but commented out. Manual saving of adapters/projector in train_stage2 is preferred for LoRA.
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
Applying Stage 2 splitter...
  - Collecting projector parameters (trainable).
  - Collecting vision tower parameters (frozen).
Splitter created groups: Trainable (

## Step 4.5: Implement Stage 2 Training Script

This function orchestrates the Stage 2 training loop, including saving the final weights (projector + LoRA adapters).

In [None]:
#| export
def train_stage2(config_path: str | Path):
    """Loads config, sets up Stage 2 learner, runs training, and saves weights.
    
    Saves the projector weights and LoRA adapter weights (if used) separately.
    
    Args:
        config_path: Path to the YAML configuration file.
    """
    print(f"--- Starting Stage 2 Training --- ")
    print(f"Loading configuration from: {config_path}")
    config = load_config(config_path)
    output_dir = Path(config['paths']['output_dir'])
    models_dir = output_dir / 'models'
    models_dir.mkdir(parents=True, exist_ok=True)
    
    learner = None # Initialize learner to None for finally block
    try:
        # --- Get Learner (handles loading Stage 1 weights, LoRA setup, etc.) --- 
        learner = get_stage2_learner(config)
        
        # --- Start Training --- 
        epochs = config.get('training', {}).get('num_epochs_stage2', 3)
        lr = config.get('training', {}).get('learning_rate_stage2', 2e-5)
        print(f"Starting training for {epochs} epochs with max_lr={lr}...")
        
        learner.fit_one_cycle(epochs, lr_max=lr)
        
        print("Training finished.")
        
        # --- Save final trained weights --- 
        
        # 1. Save Projector Weights
        projector_save_path = models_dir / (Path(config['paths']['stage2_model_weights']).stem + "_projector_final.pth")
        print(f"Saving final projector weights to: {projector_save_path}")
        if hasattr(learner.model, 'projector') and learner.model.projector is not None:
             torch.save(learner.model.projector.state_dict(), projector_save_path)
             print("Projector weights saved.")
        else:
             print("Warning: Cannot save projector weights, model has no projector.")
             
        # 2. Save LoRA Adapters (if LoRA was used)
        use_lora = config.get('model', {}).get('peft', {}).get('use_lora', False)
        if _peft_available and use_lora and isinstance(learner.model.language_model, PeftModel):
            lora_save_dir = models_dir / (Path(config['paths']['stage2_model_weights']).stem + "_lora_adapters")
            print(f"Saving LoRA adapters to: {lora_save_dir}")
            try:
                learner.model.language_model.save_pretrained(lora_save_dir)
                print("LoRA adapters saved successfully.")
            except Exception as e:
                 print(f"Error saving LoRA adapters: {e}")
        elif not use_lora:
             print("LoRA was not enabled. Only projector weights saved explicitly.")
             # Note: The full model state might have been saved by SaveModelCallback if it was enabled.
        else: # LoRA enabled in config but PEFT not available or not applied
             print("LoRA configured but not applied (PEFT library issue?). Cannot save adapters.")
             

    except Exception as e:
        print(f"An error occurred during Stage 2 training: {e}")
        import traceback
        traceback.print_exc()
        # Potentially re-raise or handle cleanup
        raise e
    finally:
        # Clean up memory
        if learner is not None and hasattr(learner, 'model') and learner.model is not None:
            if hasattr(learner.model, 'vision_tower') and learner.model.vision_tower is not None:
                learner.model.vision_tower.to('cpu')
            if hasattr(learner.model, 'language_model') and learner.model.language_model is not None:
                learner.model.language_model.to('cpu')
            if hasattr(learner.model, 'projector') and learner.model.projector is not None:
                learner.model.projector.to('cpu')
            del learner.model
            learner.destroy() # Release learner resources
            del learner
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("Cleaned up learner and model memory.")
            
        # Ensure W&B run is finished if it was started
        if wandb.run is not None:
            try:
                if wandb.run.id: # Check if run is still active
                    wandb.finish()
                    print("Finished W&B run.")
            except Exception as e:
                 print(f"Error finishing W&B run: {e}")
            
    print(f"--- Stage 2 Training Complete --- ")

In [None]:
# show_doc(train_stage2) # Omitted for script execution

In [None]:
#| export
# Command-line execution block for Stage 2
if __name__ == "__main__" and "get_ipython" not in locals():
    parser = argparse.ArgumentParser(description="Run LLaVA Stage 2 Training")
    # Assume script is run from project root
    parser.add_argument("--config", type=str, default="configs/config.yaml", 
                        help="Path to the configuration YAML file (relative to project root or absolute).")
    args = parser.parse_args()
    
    # Resolve config path relative to project root (defined earlier in the script's import section)
    config_file_path = project_root / args.config
    
    if not config_file_path.is_file():
        print(f"Error: Config file not found at {config_file_path}")
        sys.exit(1)

    try:
        train_stage2(config_path=config_file_path) 
    except NotImplementedError as e:
         print(f"Exiting: {e}") 
         sys.exit(0) 
    except Exception as e:
        print(f"Stage 2 training setup or execution failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()