# Stage 1 Training: Projector Pre-training

> Sets up and runs the first stage of LLaVA training, focusing on pre-training the projector module.

In [None]:
#| default_exp training.stage1

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [None]:
#| export
import torch
from fastai.learner import Learner
from fastai.optimizer import AdamW
from fastai.callback.wandb import WandbCallback
from fastai.callback.schedule import fit_one_cycle
from fastai.callback.save import SaveModelCallback
from fastai.vision.all import params # For splitter
from fastai.data.core import DataLoaders

from llava.utils import load_config, init_wandb
from llava.data.loading import get_stage1_dataloaders
from llava.model.baseline import BaselineLLaVAModel
from llava.training.core import LLaVALoss

## Stage 1 Splitter Function

We need a custom splitter function for the fastai `Learner`. In Stage 1, we only want to train the `projector` module, keeping the `vision_tower` and `language_model` frozen.

In [None]:
#| export
def llava_stage1_splitter(model: BaselineLLaVAModel):
    """Splits the `BaselineLLaVAModel` parameters for Stage 1 training.
    
    Only the parameters of the `projector` module are marked as trainable.
    The `vision_tower` and `language_model` parameters will remain frozen.
    
    Args:
        model: An instance of `BaselineLLaVAModel`.
        
    Returns:
        A list containing a single parameter group for the projector.
    """
    if not hasattr(model, 'projector') or model.projector is None:
        raise AttributeError("Model does not have a 'projector' attribute or it is None.")
        
    print("Applying Stage 1 splitter: Training only the projector.")
    # Fastai's `params` function selects parameters from the given module(s)
    # Only parameters returned by the splitter are trained.
    trainable_params = list(model.projector.parameters())
    
    # Verify that projector parameters require grad
    # They should by default unless explicitly frozen, but good to check
    for p in trainable_params:
         p.requires_grad = True # Ensure they are trainable
            
    # Ensure other parts are frozen (should be done during model init, but double-check)
    if hasattr(model, 'vision_tower') and model.vision_tower is not None:
        for p in model.vision_tower.parameters():
            p.requires_grad = False
    if hasattr(model, 'language_model') and model.language_model is not None:
        # Note: If PEFT is somehow applied here (it shouldn't be for stage 1),
        # this would wrongly freeze LoRA adapters. This assumes base LLM is frozen.
        for p in model.language_model.parameters():
            p.requires_grad = False
            
    return [trainable_params]

## Step 3.2: Setup Learner Configuration (Stage 1)

In [None]:
#| export
def get_stage1_learner(config: dict) -> Learner:
    """Configures and returns a fastai Learner for Stage 1 projector pre-training.

    Args:
        config: The main configuration dictionary.

    Returns:
        A configured fastai Learner instance for Stage 1.
    """
    print("--- Setting up Stage 1 Learner ---")
    # 1. Load DataLoaders
    print("Loading Stage 1 DataLoaders...")
    dls = get_stage1_dataloaders(config)
    if not dls:
        raise RuntimeError("Failed to create Stage 1 DataLoaders.")
    print(f"DataLoaders loaded. Train samples: {len(dls.train_ds)}, Valid samples: {len(dls.valid_ds)}")

    # 2. Instantiate Model
    print("Instantiating BaselineLLaVAModel...")
    model = BaselineLLaVAModel(config)
    print("Model instantiated.")

    # 3. Define Loss Function
    loss_func = LLaVALoss()
    print(f"Loss function: {type(loss_func).__name__}")

    # 4. Define Optimizer
    # AdamW is generally preferred for transformer models
    lr = config.get('training', {}).get('learning_rate_stage1', 1e-4)
    wd = config.get('training', {}).get('weight_decay', 0.0)
    opt_func = partial(AdamW, lr=lr, wd=wd, eps=1e-8)
    print(f"Optimizer: AdamW (lr={lr}, wd={wd})")

    # 5. Define Splitter
    splitter = llava_stage1_splitter
    print(f"Parameter splitter: {splitter.__name__}")

    # 6. Define Callbacks
    cbs = []
    # Weights & Biases Logging
    if config.get('logging', {}).get('wandb', {}).get('enabled', False):
        # Initialize W&B run here before creating WandbCallback
        run_name = f"stage1_{config.get('project_name', 'llava')}_{Path(config['paths']['stage1_projector_weights']).stem}"
        init_wandb(config, job_type="stage1-training", run_name=run_name)
        cbs.append(WandbCallback(log_preds=False, log_model=False)) # Don't log model via W&B callback, use SaveModelCallback
        print("Added WandbCallback.")
        
    # Model Saving
    output_dir = Path(config['paths']['output_dir'])
    projector_weights_fname = Path(config['paths']['stage1_projector_weights']).stem # Get filename without extension
    save_cb = SaveModelCallback(
        monitor='valid_loss', 
        min_delta=0.001, # Avoid saving too often for tiny improvements
        fname=projector_weights_fname, # Saves as f'{fname}.pth'
        every_epoch=True, # Save at the end of every epoch regardless of improvement
        with_opt=False, # Don't save optimizer state for projector
        reset_on_fit=True # Ensures it checks from the start of training
    )
    cbs.append(save_cb)
    print(f"Added SaveModelCallback (saves projector weights to {output_dir / f'{projector_weights_fname}.pth'})")

    # --- Add Optimization Callbacks (Placeholder for Step 3.3) ---
    # Example placeholders, will be implemented in the next step
    # grad_accum_steps = config.get('training', {}).get('gradient_accumulation_steps', 1)
    # if grad_accum_steps > 1:
    #     from fastai.callback.training import GradientAccumulation
    #     cbs.append(GradientAccumulation(grad_accum_steps))
    #     print(f"Added GradientAccumulation callback (steps={grad_accum_steps})")
    
    # use_mixed_precision = config.get('training', {}).get('use_mixed_precision', False)
    # if use_mixed_precision:
    #     from fastai.callback.fp16 import MixedPrecision
    #     cbs.append(MixedPrecision())
    #     print("Added MixedPrecision callback.")
    # --------------------------------------------------------------
    
    # 7. Create Learner
    learner = Learner(
        dls=dls,
        model=model,
        loss_func=loss_func,
        opt_func=opt_func,
        splitter=splitter,
        cbs=cbs,
        path=output_dir # Set Learner path for saving models
    )
    
    print("--- Stage 1 Learner Setup Complete ---")
    return learner

#### Example Usage & Test (Learner Configuration)

In [None]:
#| eval: false
#| test 
import gc
try:
    # Load config
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # Check if W&B is enabled and potentially skip if entity isn't set
    wandb_enabled = config.get('logging', {}).get('wandb', {}).get('enabled', False)
    wandb_entity = config.get('logging', {}).get('wandb', {}).get('entity')
    if wandb_enabled and (wandb_entity is None or 'your_wandb_entity' in wandb_entity):
        print("Warning: W&B is enabled but entity is not set or is default. Disabling W&B for this test.")
        config['logging']['wandb']['enabled'] = False

    # Create learner
    stage1_learner = get_stage1_learner(config)
    print("Stage 1 Learner created successfully.")

    # Perform basic checks on the learner
    assert isinstance(stage1_learner, Learner)
    assert isinstance(stage1_learner.dls, DataLoaders)
    assert isinstance(stage1_learner.model, BaselineLLaVAModel)
    assert isinstance(stage1_learner.loss_func, LLaVALoss)
    assert len(stage1_learner.opt.param_groups) == 1 # Only one group (projector) should be trainable
    
    # Check if only projector parameters are in the trainable group
    projector_params_set = set(stage1_learner.model.projector.parameters())
    opt_params_set = set(stage1_learner.opt.param_groups[0]['params'])
    assert opt_params_set == projector_params_set, "Optimizer parameter group does not match projector parameters."
    
    # Check if other parts are frozen (requires_grad=False)
    for name, param in stage1_learner.model.named_parameters():
        if 'projector' in name:
            assert param.requires_grad == True, f"Projector parameter {name} is frozen but should be trainable."
        elif 'vision_tower' in name or 'language_model' in name:
             assert param.requires_grad == False, f"Parameter {name} should be frozen but is not."
                
    print("\nLearner summary:")
    stage1_learner.summary()
    
    print("\nLearner setup test passed.")

except FileNotFoundError as e:
    print(f"Skipping test: FileNotFoundError - {e}")
    print("Ensure config, data, and model paths are correct.")
except ImportError as e:
    print(f"Skipping test: ImportError - {e}")
    print("Ensure all required libraries (transformers, fastai, etc.) are installed.")
except Exception as e:
    import traceback
    print(f"An error occurred during learner setup test: {e}")
    traceback.print_exc()
finally:
    # Clean up memory
    if 'stage1_learner' in locals() and stage1_learner is not None and stage1_learner.model is not None:
        if hasattr(stage1_learner.model, 'vision_tower') and stage1_learner.model.vision_tower is not None:
            stage1_learner.model.vision_tower.to('cpu')
        if hasattr(stage1_learner.model, 'language_model') and stage1_learner.model.language_model is not None:
            stage1_learner.model.language_model.to('cpu')
        if hasattr(stage1_learner.model, 'projector') and stage1_learner.model.projector is not None:
            stage1_learner.model.projector.to('cpu')
        del stage1_learner.model
        del stage1_learner
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Cleaned up stage1_learner and model")
    # Terminate wandb run if it was initialized
    if wandb.run is not None:
        wandb.finish()
        print("Finished W&B run.")

Loaded config from ../configs/config.yaml
--- Setting up Stage 1 Learner ---
Loading Stage 1 DataLoaders...
Creating Stage 1 DataLoaders with batch size: 8, num_workers: 4
Loading Stage 1 items from: /workspace/llava/data/llava_pretrain/llava_pretrain.jsonl
Assuming images relative to: /workspace/llava/data/llava_pretrain/images
Found 595375 samples for Stage 1.
DataLoaders created successfully.
DataLoaders loaded. Train samples: 589422, Valid samples: 5953
Instantiating BaselineLLaVAModel...
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Language Model weights frozen (will be partially unfrozen by PEFT if enabled).
Resizing LLM token embeddings from 32000 to 32001 (tokenizer size)...
LLM token embeddings resized.
Model instantiated.
LLaVALoss initialized, ignoring index: -100
Loss function: LLaVALoss
Optimizer: AdamW (lr=0.0001, wd=0.0)
Parameter splitter: llava_stage1_splitter


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnumb3r33[0m (team: [33mnumb3r33[0m). Use [1m`wandb login --relogin`[0m to force relogin


W&B run initialized: stage1_llava_stage1_projector (Project: adaptive_patching_vit, Entity: your_wandb_entity)
Track run at: https://wandb.ai/your_wandb_entity/adaptive_patching_vit/runs/19p7s2b4
Added WandbCallback.
Added SaveModelCallback (saves projector weights to /workspace/llava/output/stage1_projector.pth)
Applying Stage 1 splitter: Training only the projector.
--- Stage 1 Learner Setup Complete ---
Stage 1 Learner created successfully.
Learner summary:
BaselineLLaVAModel(
  (projector): LLaVAProjector(
    (model): Sequential(
      (0): Linear(in_features=1024, out_features=4096, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=4096, out_features=4096, bias=True)
    )
  )
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (p

wandb:                                                                                
wandb: Run `wandb offline` to turn off syncing.

## Step 3.4: Implement Stage 1 Training Script (Placeholder)

The actual training loop execution will be added here in a later step.

In [None]:
# Placeholder for train_stage1 function
# This function will load config, get learner, and call learner.fit_one_cycle()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()