# Stage 1 Training: Projector Pre-training

> Sets up and runs the first stage of LLaVA training, focusing on pre-training the projector module.

In [1]:
#| default_exp training.stage1

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os
import gc # For memory cleanup
import argparse # For command-line execution


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [4]:
#| export
import torch
import warnings

from fastai.learner import Learner
from fastai.vision.all import * # For splitter

from fastai.callback.wandb import WandbCallback
from fastai.callback.schedule import fit_one_cycle
from fastai.callback.training import GradientAccumulation # Import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision # Import MixedPrecision
from fastai.data.core import DataLoaders
from functools import partial
import wandb # Import wandb directly for cleanup
import json # For dummy data creation
import PIL.Image # For dummy data creation

from llava.utils import load_config, init_wandb
from llava.data.loading import get_stage1_dataloaders
from llava.model.baseline import BaselineLLaVAModel
from llava.training.core import LLaVALoss, LLaVAMixedPrecision, extract_loss_from_output, SafeGradientAccumulation

Project root already in sys.path: /workspace/llava
Project root already in sys.path: /workspace/llava
Loaded config from configs/config.yaml


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336
CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- {'mean': tensor([[[[0.4815]],

         [[0.4578]],

         [[0.4082]]]], device='cuda:0'), 'std': tensor([[[[0.2686]],

         [[0.2613]],

         [[0.2758]]]], device='cuda:0'), 'axes': (0, 2, 3)}
(enc:2,dec:2)
Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)
LLaVADataBlockStage1 defined.
LLaVADataBlockStage2 defined.
Project root already in sys.path: /workspace/llava
Project root already in sys.path: /workspace/llava


## Stage 1 Splitter Function

We need a custom splitter function for the fastai `Learner`. In Stage 1, we only want to train the `projector` module, keeping the `vision_tower` and `language_model` frozen.

In [5]:
#| export
def llava_stage1_splitter(model: BaselineLLaVAModel):
    """Splits the `BaselineLLaVAModel` parameters for Stage 1 training.
    
    Only the parameters of the `projector` module are marked as trainable for the optimizer.
    Vision tower is frozen. LLM's LoRA adapters (if any) remain trainable but won't be
    passed to the optimizer in Stage 1. Base LLM weights are frozen (by PEFT or explicitly).
    
    Args:
        model: An instance of `BaselineLLaVAModel`.
        
    Returns:
        A list containing a single parameter group for the projector.
    """
    if not hasattr(model, 'projector') or model.projector is None:
        raise AttributeError("Model does not have a 'projector' attribute or it is None.")
        
    print("Applying Stage 1 splitter: Ensuring only projector parameters are passed to optimizer.")
    
    trainable_params = []
    # Ensure projector parameters are trainable
    if hasattr(model, 'projector') and model.projector is not None:
        print("  - Setting projector parameters to require_grad=True.")
        for p in model.projector.parameters():
            p.requires_grad = True
        trainable_params.extend(list(model.projector.parameters()))
    
    # Ensure vision tower is frozen
    if hasattr(model, 'vision_tower') and model.vision_tower is not None:
        print("  - Setting vision_tower parameters to require_grad=False.")
        for p in model.vision_tower.parameters():
            p.requires_grad = False
            
    # For the language model in Stage 1:
    # - If PEFT/LoRA is used, get_peft_model has already frozen base weights and made adapters trainable.
    #   We do NOT want to turn off grads for LoRA adapters here.
    # - If no PEFT/LoRA, the base LLM should be frozen. This is handled in model.__init__
    #   by the `self.language_model.requires_grad_(False)` call if not QLoRA/LoRA.
    # So, the splitter's main job for Stage 1 regarding LLM is *not* to modify its requires_grad state
    # but to ensure only projector params are given to the optimizer.
    # The `model.__init__` handles initial freezing.

    if not trainable_params:
         raise ValueError("Splitter function resulted in no trainable parameters for Stage 1 (projector).")
    
    print(f"  - Stage 1 Splitter will provide {len(trainable_params)} projector parameters to the optimizer.")
    # The optimizer will only receive these parameters.
    return trainable_params # This was already correct, the issue was modifying LLM params.

In [None]:
show_doc(llava_stage1_splitter)

---

### llava_stage1_splitter

>      llava_stage1_splitter (model:llava.model.baseline.BaselineLLaVAModel)

*Splits the `BaselineLLaVAModel` parameters for Stage 1 training.
    
    Only the parameters of the `projector` module are marked as trainable.
    The `vision_tower` and `language_model` parameters will remain frozen.
    
    Args:
        model: An instance of `BaselineLLaVAModel`.
        
    Returns:
        A list containing a single parameter group for the projector.*

## Step 3.2 & 3.3: Setup Learner Configuration (Stage 1) with Optimization Callbacks

This function sets up the `Learner` object for Stage 1, including the model, data, loss function, optimizer, splitter, and standard callbacks (W&B, SaveModel). It now also includes the `GradientAccumulation` and `MixedPrecision` callbacks based on configuration.

In [6]:
#| export
def get_stage1_learner(config: dict) -> Learner:
    """Configures and returns a fastai Learner for Stage 1 projector pre-training.
       Includes optimization callbacks (GradientAccumulation, MixedPrecision) based on config.

    Args:
        config: The main configuration dictionary.

    Returns:
        A configured fastai Learner instance for Stage 1.

    Raises:
        RuntimeError: If DataLoaders or Model instantiation fails.
        FileNotFoundError: If specified data paths in config are incorrect.
        AttributeError: If the model is missing expected components (e.g., projector).
    """
    print("--- Setting up Stage 1 Learner ---")
    
    # 1. Load DataLoaders
    print("Loading Stage 1 DataLoaders...")
    try:
        dls = get_stage1_dataloaders(config)
    except (FileNotFoundError, Exception) as e:
        print(f"Error loading DataLoaders: {e}")
        raise RuntimeError("Failed to create Stage 1 DataLoaders. Check config paths and data availability.") from e
    if not dls:
        raise RuntimeError("Stage 1 DataLoaders object is None.")
    print(f"DataLoaders loaded. Train samples: {len(dls.train_ds)}, Valid samples: {len(dls.valid_ds)}")

    # 2. Instantiate Model
    print("Instantiating BaselineLLaVAModel...")
    try:
        model = BaselineLLaVAModel(config)
        # Ensure model components loaded successfully
        if model.vision_tower is None or model.language_model is None or model.projector is None:
             raise RuntimeError("BaselineLLaVAModel initialization failed: one or more components are None.")
        print("Model instantiated successfully.")
    except Exception as e:
        print(f"Error instantiating BaselineLLaVAModel: {e}")
        raise RuntimeError("Failed to instantiate baseline model.") from e

    # 3. Define Loss Function
    # loss_func = LLaVALoss()
    loss_func = extract_loss_from_output # Use the extractor function
    print(f"Loss function: {loss_func.__name__}")
    print(f"Loss function: {type(loss_func).__name__}")

    # 4. Define Optimizer
    # AdamW is generally preferred for transformer models
    lr = config.get('training', {}).get('learning_rate_stage1', 1e-4)
    wd = config.get('training', {}).get('weight_decay', 0.0)
    
    
    opt_func = partial(Adam, lr=lr, wd=wd, eps=1e-8) # Added eps for numerical stability
    print(f"Optimizer: Adam (lr={lr}, wd={wd})")

    # opt_func = AdamW # Pass the optimizer class directly
    # print(f"Optimizer: AdamW (lr will be set by Learner/fit_one_cycle, default wd={wd})") # 
    
    # 5. Define Splitter
    splitter = llava_stage1_splitter
    print(f"Parameter splitter: {splitter.__name__}")

    # 6. Define Callbacks
    cbs = []
    # Weights & Biases Logging
    if config.get('logging', {}).get('wandb', {}).get('enabled', False):
        # Initialize W&B run here before creating WandbCallback
        project_name = config.get('logging', {}).get('wandb', {}).get('project', 'llava-adaptive-patching')
        entity = config.get('logging', {}).get('wandb', {}).get('entity') # Optional
        run_name_prefix = config.get('logging', {}).get('wandb', {}).get('run_name_prefix', 'stage1')
        
        # Create a unique run name (init_wandb will create one if not passed, but we can pre-define)
        run_name = f"{run_name_prefix}_{Path(config['paths']['stage1_projector_weights']).stem}_{wandb.util.generate_id()}"
        
        # Init W&B Run
        init_wandb(config, job_type="stage1-training", run_name=run_name)
        
        # Add W&B Callback
        cbs.append(WandbCallback(log_preds=False, log_model=False)) # Don't log model via W&B callback, use SaveModelCallback
        print("Added WandbCallback.")
        
    # Model Saving (Only Projector Weights)
    output_dir = Path(config['paths']['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
    projector_weights_fname = Path(config['paths']['stage1_projector_weights']).stem # Get filename without extension
    save_cb = SaveModelCallback(
        monitor='valid_loss', 
        min_delta=0.001, # Avoid saving too often for tiny improvements
        fname=projector_weights_fname, # Saves as f'{fname}.pth' in learner.path/models/
        every_epoch=False, # Save only best based on monitor
        with_opt=False, # Don't save optimizer state for projector
        reset_on_fit=True # Ensures it checks from the start of training
    )
    cbs.append(save_cb)
    print(f"Added SaveModelCallback (saves best projector weights based on valid_loss to {output_dir/'models'/f'{projector_weights_fname}.pth'})")

    # --- Add Optimization Callbacks (Step 3.3 Implementation) --- 
    grad_accum_steps = config.get('training', {}).get('gradient_accumulation_steps', 1)
    if grad_accum_steps > 1:
        cbs.append(SafeGradientAccumulation(grad_accum_steps))
        print(f"Added SafeGradientAccumulation callback with {grad_accum_steps} steps.")
        
        # cbs.append(GradientAccumulation(grad_accum_steps))
        # print(f"Added GradientAccumulation callback with {grad_accum_steps} steps.")
    
    use_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    qlora_enabled = config.get('model', {}).get('quantization', {}).get('load_in_4bit', False)
    if use_mixed_precision and not qlora_enabled: # Only add if explicitly enabled AND QLoRA is OFF
        cbs.append(LLaVAMixedPrecision()) # Or your current custom MixedPrecision
        print("Added MixedPrecision callback (QLoRA is disabled).")
    elif use_mixed_precision and qlora_enabled:
        print("QLoRA is enabled, MixedPrecision callback will be skipped as QLoRA handles its own precision.")
    
    # if use_mixed_precision:
    #     cbs.append(LLaVAMixedPrecision())
    #     print("Added MixedPrecision callback.")
    
    # --------------------------------------------------------------
    
    # 7. Create Learner
    try:
        learner = Learner(
            dls=dls,
            model=model,
            loss_func=loss_func,
            opt_func=opt_func,
            splitter=splitter,
            cbs=cbs,
            path=output_dir, # Set Learner path for saving models
            train_bn=False, # Avoid issues with frozen batch norm layers in LLM/Vision Tower
            # wd=wd
        )
            
    except Exception as e:
        print(f"Error creating Learner: {e}")
        # Clean up wandb run if initialized
        if wandb.run is not None:
            wandb.finish(exit_code=1)
            print("Finished W&B run due to error during Learner creation.")
        raise RuntimeError("Failed to create the Learner object.") from e
    
    print("--- Stage 1 Learner Setup Complete ---")
    return learner

In [11]:
show_doc(get_stage1_learner)

---

### get_stage1_learner

>      get_stage1_learner (config:dict)

*Configures and returns a fastai Learner for Stage 1 projector pre-training.
   Includes optimization callbacks (GradientAccumulation, MixedPrecision) based on config.

Args:
    config: The main configuration dictionary.

Returns:
    A configured fastai Learner instance for Stage 1.

Raises:
    RuntimeError: If DataLoaders or Model instantiation fails.
    FileNotFoundError: If specified data paths in config are incorrect.
    AttributeError: If the model is missing expected components (e.g., projector).*

#### Example Usage & Test (Learner Configuration)

In [7]:
#| test 
import gc

from fastai.callback.wandb import WandbCallback
from fastai.callback.training import GradientAccumulation
from fastai.callback.fp16 import MixedPrecision

# Add PEFT import block for _peft_available and PeftModel, similar to 20_model_baseline.ipynb
try:
    from peft import PeftModel # PeftModel is used for isinstance check
    _peft_available = True
except ImportError:
    print("Warning: peft library not found in test cell. LoRA-related checks might be affected.")
    PeftModel = None 
    _peft_available = False


try:
    # Load config
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # --- Test Setup --- 
    # Create dummy data files and directories if they don't exist
    data_base = Path(config['paths']['data_base'])
    stage1_json_rel = Path(config['paths']['stage1_data'])
    stage1_img_rel = Path(config['paths']['stage1_images'])
    stage1_json_path = data_base / stage1_json_rel
    stage1_img_path = data_base / stage1_img_rel
    
    stage1_json_path.parent.mkdir(parents=True, exist_ok=True)
    stage1_img_path.mkdir(parents=True, exist_ok=True)
    
    if not stage1_json_path.exists() or stage1_json_path.stat().st_size < 10: # Check if exists and not empty
        print(f"Creating dummy Stage 1 JSON: {stage1_json_path}")
        # For parse_llava_jsonl which now expects JSON array or JSONL:
        # We'll create JSONL for this dummy data as it's simpler line-by-line.
        dummy_json_lines_content = [
            json.dumps({"id": "s1_001", "image": "dummy_img1.jpg", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "Dummy caption 1."}]}),
            json.dumps({"id": "s1_002", "image": "dummy_img2.png", "conversations": [{"from": "human", "value": "<image>Describe"}, {"from": "gpt", "value": "Dummy caption 2."}]}),
            json.dumps({"id": "s1_003", "image": "dummy_img1.jpg", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "Dummy cap 3."}]}),
            json.dumps({"id": "s1_004", "image": "dummy_img2.png", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "Dummy cap 4."}]})
        ]
        with open(stage1_json_path, 'w') as f:
            for line in dummy_json_lines_content:
                f.write(line + '\n') # Write as JSON Lines
    
    # Create dummy images (if they don't exist)
    try:
        img1_path = stage1_img_path / 'dummy_img1.jpg'
        img2_path = stage1_img_path / 'dummy_img2.png'
        if not img1_path.exists():
            PIL.Image.new('RGB', (60, 30), color = 'red').save(img1_path)
            print(f"Created dummy image: {img1_path}")
        if not img2_path.exists():
            PIL.Image.new('RGB', (60, 30), color = 'green').save(img2_path)
            print(f"Created dummy image: {img2_path}")
    except Exception as e:
        print(f"Warning: Could not create dummy image files: {e}")

    # Ensure output dir exists
    Path(config['paths']['output_dir']).mkdir(parents=True, exist_ok=True)
    (Path(config['paths']['output_dir']) / 'models').mkdir(parents=True, exist_ok=True) # for SaveModelCallback
    # -----------------
    
    # Check if W&B is enabled and potentially skip if entity isn't set
    wandb_enabled_original = config.get('logging', {}).get('wandb', {}).get('enabled', False) # Store original
    wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
    if wandb_enabled_original and (wandb_entity is None or 'your_wandb_entity' in str(wandb_entity)): # Use str for safety
        print("Warning: W&B is enabled but entity is not set or is default. Disabling W&B for this test.")
        if 'logging' not in config: config['logging'] = {}
        if 'wandb' not in config['logging']: config['logging']['wandb'] = {}
        config['logging']['wandb']['enabled'] = False


    # Create learner
    stage1_learner = get_stage1_learner(config)
    print("Stage 1 Learner created successfully.")

    # Perform basic checks on the learner
    assert isinstance(stage1_learner, Learner)
    assert isinstance(stage1_learner.dls, DataLoaders)
    assert isinstance(stage1_learner.model, BaselineLLaVAModel)
    assert isinstance(stage1_learner.loss_func, LLaVALoss)
    
    # Initialize the optimizer to check param_groups
    stage1_learner.create_opt()
    assert stage1_learner.opt is not None, "Optimizer was not created."
    assert len(stage1_learner.opt.param_groups) == 1 # Only one group (projector) should be trainable
    
    # Check if only projector parameters are in the trainable group
    projector_params_set = set(stage1_learner.model.projector.parameters())
    opt_params_set = set(stage1_learner.opt.param_groups[0]['params'])
    assert opt_params_set == projector_params_set, "Optimizer parameter group does not match projector parameters."
    
    # Check parameter freezing status
    proj_params_count = 0
    frozen_params_count = 0
    for name, param in stage1_learner.model.named_parameters():
        if 'projector' in name:
            assert param.requires_grad == True, f"Projector parameter {name} is frozen but should be trainable."
            proj_params_count += param.numel()
        elif 'vision_tower' in name or 'language_model' in name:
             # If LoRA is applied to language_model, some sub-params might be trainable
             # This check assumes base LLM and vision tower are frozen for Stage 1
             is_lora_param = False
             if _peft_available and PeftModel is not None: # Check if PeftModel is defined
                 if 'language_model' in name and isinstance(stage1_learner.model.language_model, PeftModel):
                     # A more robust way to check if a param belongs to LoRA is by checking its name or if it requires_grad
                     # For Stage 1, the entire LLM (even if PeftModel wrapped) should have its base weights frozen.
                     # LoRA adapters themselves require_grad=True, but we are checking requires_grad == False for non-projector parts.
                     # This means if LoRA is active (it shouldn't be for Stage 1 logic), this assertion might be tricky.
                     # However, `llava_stage1_splitter` explicitly sets requires_grad=False for language_model and vision_tower.
                     pass # Let the main assertion handle it.

             # This assertion should hold because llava_stage1_splitter freezes vision_tower and language_model
             assert param.requires_grad == False, f"Parameter {name} should be frozen but is not. Its requires_grad is {param.requires_grad}."
             frozen_params_count += param.numel()


    print(f"Checked parameter freezing: {proj_params_count} projector params trainable, {frozen_params_count} other non-projector params frozen.")

    # Check callbacks
    has_save_cb = any(isinstance(cb, SaveModelCallback) for cb in stage1_learner.cbs)
    assert has_save_cb, "SaveModelCallback not found in learner callbacks."
    
    expect_grad_accum = config.get('training', {}).get('gradient_accumulation_steps', 1) > 1
    has_grad_accum = any(isinstance(cb, GradientAccumulation) for cb in stage1_learner.cbs)
    assert has_grad_accum == expect_grad_accum, f"GradientAccumulation presence mismatch (Expected: {expect_grad_accum}, Found: {has_grad_accum})"

    expect_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    has_mixed_precision = any(isinstance(cb, MixedPrecision) for cb in stage1_learner.cbs)
    assert has_mixed_precision == expect_mixed_precision, f"MixedPrecision presence mismatch (Expected: {expect_mixed_precision}, Found: {has_mixed_precision})"

    expect_wandb_cb = config.get('logging', {}).get('wandb', {}).get('enabled', False)
    has_wandb_cb = any(isinstance(cb, WandbCallback) for cb in stage1_learner.cbs)
    assert has_wandb_cb == expect_wandb_cb, f"WandbCallback presence mismatch (Expected: {expect_wandb_cb}, Found: {has_wandb_cb})"
    
    print("Callback checks passed.")
    
    print("\nLearner summary:")
    stage1_learner.summary() # This will also create opt if not already created
    
    print("\nLearner setup test passed.")

except FileNotFoundError as e:
    print(f"Skipping test: FileNotFoundError - {e}")
    print("Ensure config, data, and model paths are correct and accessible.")
except ImportError as e:
    print(f"Skipping test: ImportError - {e}")
    print("Ensure all required libraries (transformers, fastai, etc.) are installed.")
except Exception as e:
    import traceback
    print(f"An error occurred during learner setup test: {e}")
    traceback.print_exc()
finally:
    # Restore original W&B config if modified for test
    if 'wandb_enabled_original' in locals() and 'logging' in config and 'wandb' in config['logging']:
        config['logging']['wandb']['enabled'] = wandb_enabled_original

    if 'stage1_learner' in locals() and stage1_learner is not None:
        if hasattr(stage1_learner, 'model') and stage1_learner.model is not None:
            if hasattr(stage1_learner.model, 'vision_tower') and stage1_learner.model.vision_tower is not None:
                stage1_learner.model.vision_tower.to('cpu')
            if hasattr(stage1_learner.model, 'language_model') and stage1_learner.model.language_model is not None:
                stage1_learner.model.language_model.to('cpu')
            if hasattr(stage1_learner.model, 'projector') and stage1_learner.model.projector is not None:
                stage1_learner.model.projector.to('cpu')
            del stage1_learner.model
        
        if hasattr(stage1_learner, 'opt') and stage1_learner.opt is not None:
            stage1_learner.opt.zero_grad() 
            del stage1_learner.opt
            stage1_learner.opt = None

        if hasattr(stage1_learner, 'cbs'):
            stage1_learner.cbs = [] 

        del stage1_learner
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Cleaned up stage1_learner and model")
    
    if wandb.run is not None:
        try:
             if wandb.run.id: 
                  wandb.finish()
                  print("Finished W&B run.")
        except Exception as e:
             print(f"Error finishing W&B run: {e}")

Loaded config from ../configs/config.yaml
--- Setting up Stage 1 Learner ---
Loading Stage 1 DataLoaders...
Creating Stage 1 DataLoaders with batch size: 2, num_workers: 16
Loading Training Stage '1' items from: /workspace/llava/data/llava_pretrain/llava_pretrain.jsonl
Assuming image paths relative to: /workspace/llava/data/llava_pretrain/images
Found 595375 samples for Training Stage '1'.
DataLoaders created successfully.
DataLoaders loaded. Train samples: 589422, Valid samples: 5953
Instantiating BaselineLLaVAModel...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...
Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...
QLoRA enabled: Loading LLM in 4-bit with compute dtype torch.float16.
  Setting device_map to {'': 0} for QLoRA.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Language Model loaded successfully.
Applying LoRA...
LoRA applied. Config: r=8, alpha=16, dropout=0.05, modules=['q_proj', 'v_proj']
trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622
Resizing LLM token embeddings from 32000 to 32001 (tokenizer size)...


The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


LLM token embeddings resized.
Activation checkpointing is disabled in the configuration.
Moving vision_tower to cuda:0
Moving projector to cuda:0
Language model's first parameter is on device: cuda:0
Model instantiated successfully.
LLaVALoss initialized, ignoring index: -100
Loss function: LLaVALoss
Optimizer: AdamW (lr=1e-4, wd=0.0)
Parameter splitter: llava_stage1_splitter
Added SaveModelCallback (saves best projector weights based on valid_loss to /workspace/llava/output/models/stage1_projector.pth)
Added GradientAccumulation callback with 4 steps.
Added MixedPrecision callback.
--- Stage 1 Learner Setup Complete ---
Stage 1 Learner created successfully.
Applying Stage 1 splitter: Training only the projector.
Checked parameter freezing: 20979712 projector params trainable, 3808122880 other non-projector params frozen.
Callback checks passed.

Learner summary:


  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()



Learner setup test passed.
Cleaned up stage1_learner and model


## Step 3.4: Implement Stage 1 Training Script

This section defines the `train_stage1` function that orchestrates the actual training loop using the configured learner.

In [7]:
#| export
def train_stage1(config_path: str | Path):
    """Loads config, sets up Stage 1 learner, and runs training.
    
    Args:
        config_path: Path to the YAML configuration file.
    """
    print(f"--- Starting Stage 1 Training --- ")
    print(f"Loading configuration from: {config_path}")
    config = load_config(config_path)
    
    learner = None # Initialize learner to None for finally block
    try:
        # --- Get Learner (including optimization callbacks) --- 
        learner = get_stage1_learner(config)
        
        # --- Start Training --- 
        epochs = config.get('training', {}).get('num_epochs_stage1', 1)
        lr = config.get('training', {}).get('learning_rate_stage1', 1e-4)
        print(f"Starting training for {epochs} epochs with max_lr={lr}...")
        
        # Use fit_one_cycle (common practice)
        # You could also use learner.fit(epochs, lr=lr) or other fine-tuning methods
        learner.fit(epochs, lr=lr)
        
        print("Training finished.")
        
        # --- Save final projector weights explicitly --- 
        # SaveModelCallback saves the *best* model during training.
        # It might be useful to save the *final* projector state as well.
        output_dir = Path(config['paths']['output_dir'])
        final_projector_filename = Path(config['paths']['stage1_projector_weights']).stem + "_final.pth"
        final_save_path = output_dir / 'models' / final_projector_filename 
        print(f"Saving final projector state to: {final_save_path}")
        # Save only the projector's state_dict
        torch.save(learner.model.projector.state_dict(), final_save_path)
        print("Final projector weights saved.")

    except Exception as e:
        print(f"An error occurred during Stage 1 training: {e}")
        import traceback
        traceback.print_exc()
        # Potentially re-raise or handle cleanup
        raise e
    finally:
        # Clean up memory
        if learner is not None and hasattr(learner, 'model') and learner.model is not None:
            if hasattr(learner.model, 'vision_tower') and learner.model.vision_tower is not None:
                learner.model.vision_tower.to('cpu')
            if hasattr(learner.model, 'language_model') and learner.model.language_model is not None:
                learner.model.language_model.to('cpu')
            if hasattr(learner.model, 'projector') and learner.model.projector is not None:
                learner.model.projector.to('cpu')
            del learner.model
            learner.destroy() # Release learner resources
            del learner
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("Cleaned up learner and model memory.")
            
        # Ensure W&B run is finished if it was started
        if wandb.run is not None:
            try:
                if wandb.run.id: # Check if run is still active
                    wandb.finish()
                    print("Finished W&B run.")
            except Exception as e:
                 print(f"Error finishing W&B run: {e}")
            
    print(f"--- Stage 1 Training Complete --- ")

#### Example Usage (Execution)

In [9]:
#| export
# Add command-line execution block
if __name__ == "__main__" and "get_ipython" not in locals():
    parser = argparse.ArgumentParser(description="Run LLaVA Stage 1 Training")
    parser.add_argument("--config", type=str, default="../configs/config.yaml", 
                        help="Path to the configuration YAML file.")
    args = parser.parse_args()
    
    config_path = Path(args.config)
    if not config_path.is_file():
        print(f"Error: Config file not found at {config_path}")
        sys.exit(1)
        
    # Ensure the config path is absolute or relative to the script's execution dir
    # If running from the project root, '../configs/config.yaml' works.
    # If running from nbs/, 'configs/config.yaml' might be needed depending on cwd.
    # Using absolute paths or paths relative to a known root is safer.
    # Assuming execution from project root or script location within project:
    if not config_path.exists():
         # Try resolving relative to the script file itself if it doesn't exist relative to CWD
         script_dir = Path(__file__).parent.resolve()
         config_path = (script_dir / args.config).resolve()
         if not config_path.exists():
              print(f"Error: Config file not found at specified path or relative to script: {args.config}")
              sys.exit(1)

    try:
        train_stage1(config_path=config_path)
    except Exception as e:
        print(f"Stage 1 training failed: {e}")
        sys.exit(1)

In [8]:
#| hide
# Example of how to run this from within the notebook (for testing purposes)
# Requires dummy data to be set up as in the get_stage1_learner test cell
train_stage1('../configs/config.yaml') # Use '../' if running from nbs/

--- Starting Stage 1 Training --- 
Loading configuration from: ../configs/config.yaml
--- Setting up Stage 1 Learner ---
Loading Stage 1 DataLoaders...
Creating Stage 1 DataLoaders with batch size: 2, num_workers: 16
Loading Training Stage '1' items from: /workspace/llava/data/llava_pretrain/llava_pretrain.jsonl
Assuming image paths relative to: /workspace/llava/data/llava_pretrain/images
Found 595375 samples for Training Stage '1'.
DataLoaders created successfully.
DataLoaders loaded. Train samples: 589422, Valid samples: 5953
Instantiating BaselineLLaVAModel...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...
Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...
QLoRA enabled: Loading LLM in 4-bit with compute dtype torch.float16.
  Setting device_map to {'': 0} for QLoRA.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Language Model loaded successfully.
Applying LoRA...
LoRA applied. Config: r=8, alpha=16, dropout=0.05, modules=['q_proj', 'v_proj']
trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622
Resizing LLM token embeddings from 32000 to 32001 (tokenizer size)...
LLM token embeddings resized.
Activation checkpointing is disabled in the configuration.
Moving vision_tower to cuda:0
Moving projector to cuda:0
Language model's first parameter is on device: cuda:0
Model instantiated successfully.
Loss function: extract_loss_from_output
Loss function: function
Optimizer: Adam (lr=1e-4, wd=0.0)
Parameter splitter: llava_stage1_splitter
Added SaveModelCallback (saves best projector weights based on valid_loss to /workspace/llava/output/models/stage1_projector.pth)
Added SafeGradientAccumulation callback with 16 steps.
--- Stage 1 Learner Setup Complete ---
Starting training for 1 epochs with max_lr=1e-4...
Applying Stage 1 splitter: Ensuring only projector parameters are pass

epoch,train_loss,valid_loss,time


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

An error occurred during Stage 1 training: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacity of 7.67 GiB of which 23.88 MiB is free. Process 950506 has 180.00 MiB memory in use. Process 1266111 has 180.00 MiB memory in use. Process 1434080 has 7.27 GiB memory in use. Of the allocated memory 6.96 GiB is allocated by PyTorch, and 122.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


Traceback (most recent call last):
  File "/tmp/ipykernel_7741/314466185.py", line 24, in train_stage1
    learner.fit(epochs, lr=lr)
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 272, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 207, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 261, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 207, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 255, in _do_epoch
    self._do_epoch_train()
  File "/venv/main/lib/python3.10/site-packages/fastai/learner.py", line 247, in _do_epoch_train
    self._with_events(self.all_batches, 'train', C

AttributeError: destroy

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()