# Adaptive Model Components

> Defines adaptive components for the LLaVA model, starting with the Adaptive Patcher interface.

In [1]:
#| default_exp model.adaptive

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose
    pass

Adding project root to sys.path: /workspace/llava


In [4]:
#| export
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Optional, Tuple, List
from PIL import Image
from torch.nn.utils.rnn import pad_sequence # For padding variable length sequences
from transformers.modeling_outputs import CausalLMOutputWithPast # For type hints
import warnings


# Import base model and utilities
from llava.model.baseline import BaselineLLaVAModel
from llava.utils import load_config
from llava.data.preprocessing import tokenizer, IMAGE_TOKEN_INDEX_PLACEHOLDER, IGNORE_INDEX

Project root already in sys.path: /workspace/llava
Loaded config from configs/config.yaml


## Step 6.1: Define Adaptive Patcher Interface

This base class defines the interface for any adaptive patching strategy. Subclasses will implement specific logic (e.g., variable resolution, attention-based patching).

In [5]:
#| export
class AdaptivePatcher(nn.Module):
    """Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the Adaptive Patcher.

        Args:
            config: Dictionary containing configuration relevant to the patcher strategy.
        """
        super().__init__()
        self.config = config
        # Potentially load sub-modules or parameters based on config['strategy']

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None, # Existing features (e.g., from base processing)
        text_features: Optional[torch.Tensor] = None, # For text-guided strategies
        raw_image: Optional[Image.Image] = None, # Original image for properties like aspect ratio
        **kwargs
    ) -> Tuple[Optional[torch.Tensor], Optional[Dict[str, Any]]]:
        """Processes the input image adaptively or determines processing strategy.

        The exact inputs used and outputs returned depend on the specific strategy.
        For example, a strategy predictor might only need pooled features,
        while a variable resolution strategy primarily needs the raw image aspect ratio.
        A text-guided strategy needs text_features and potentially patch features.

        Args:
            pixel_values: Optional preprocessed image tensor (e.g., B x C x H x W).
                           Could represent the full image or specific patches.
            text_features: Optional tensor containing embeddings of the instruction text.
                           Needed for text-guided patching strategies.
            raw_image: Optional PIL image, potentially needed for calculating aspect ratio
                       or other properties not easily derived from pixel_values alone.
            **kwargs: Additional keyword arguments specific to the patching strategy.

        Returns:
            A tuple containing:
            - Optional[torch.Tensor]: Processed/selected image features, if the patcher
                                    directly outputs features. Shape might vary.
                                    Can be None if the patcher only outputs metadata.
            - Optional[Dict[str, Any]]: Metadata about the patching process or decision.
                                      (e.g., strategy used, number of patches, selected grid).
                                      Can be None if no metadata is generated.

        Raises:
            NotImplementedError: This base method must be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")

In [6]:
show_doc(AdaptivePatcher)

---

### AdaptivePatcher

>      AdaptivePatcher (config:Dict[str,Any])

*Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.*

## Step 6.2: Implement Variable Resolution Patcher Strategy

This patcher implements the variable resolution strategy inspired by LLaVA-NeXT. It analyzes the input image's aspect ratio and selects the optimal grid configuration from a predefined set (`image_grid_pinpoints`). The actual image resizing and patching happen elsewhere, based on the grid configuration returned by this patcher.

In [7]:
#| export
class VariableResolutionPatcher(AdaptivePatcher):
    """Adaptive patcher that selects an optimal grid resolution based on image aspect ratio.

    Inspired by the LLaVA-NeXT 'anyres' logic. It determines the target processing
    dimensions but does not perform the actual image processing or feature extraction.
    The main model's forward pass should use the output metadata to handle the image.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the VariableResolutionPatcher.

        Args:
            config: Dictionary containing configuration. Expected keys:
                    'model.adaptive_patcher.image_grid_pinpoints': List of [H, W] grids.
                    'model.vision_config.image_size': Base image size (e.g., 336).
                    'model.vision_config.patch_size': Patch size (e.g., 14).
        """
        super().__init__(config)
        # Default grid from LLaVA-NeXT / Project Spec
        default_grid = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
        # Retrieve nested config safely
        patcher_config = self.config.get('model', {}).get('adaptive_patcher', {})
        vision_config = self.config.get('model', {}).get('vision_config', {})
        
        self.image_grid_pinpoints = patcher_config.get('image_grid_pinpoints', default_grid)
        self.base_image_size = vision_config.get('image_size', 336)
        self.patch_size = vision_config.get('patch_size', 14)
        
        # Ensure base resolution is included as an option (implicitly or explicitly)
        base_grid = [self.base_image_size, self.base_image_size]
        if base_grid not in self.image_grid_pinpoints:
             self.image_grid_pinpoints.append(base_grid)
        print(f"Initialized VariableResolutionPatcher with grid options: {self.image_grid_pinpoints}")

    def select_best_resolution(self, original_height: int, original_width: int) -> Tuple[int, int]:
        """Selects the best grid resolution based on aspect ratio and minimizing waste.

        Args:
            original_height: Height of the raw input image.
            original_width: Width of the raw input image.

        Returns:
            Tuple[int, int]: The selected best grid dimensions (height, width).
        """
        original_aspect_ratio = original_height / original_width

        best_fit_grid = None
        min_wasted_pixels = float('inf')

        for grid_h, grid_w in self.image_grid_pinpoints:
            grid_aspect_ratio = grid_h / grid_w

            # Calculate the dimensions if the image were scaled to fit this grid
            # Scale based on the limiting dimension
            scale_h = grid_h / original_height
            scale_w = grid_w / original_width

            if original_aspect_ratio > grid_aspect_ratio: # Scale based on height
                scaled_h = grid_h
                scaled_w = int(original_width * scale_h)
            else: # Scale based on width
                scaled_w = grid_w
                scaled_h = int(original_height * scale_w)
            
            # Ensure scaled dimensions do not exceed grid dimensions (due to int conversion)
            scaled_h = min(scaled_h, grid_h)
            scaled_w = min(scaled_w, grid_w)
            
            # Calculate wasted area (pixels in the grid not covered by the scaled image)
            grid_area = grid_h * grid_w
            scaled_image_area = scaled_h * scaled_w
            wasted_pixels = grid_area - scaled_image_area

            # Prefer grids with less waste
            # Among grids with similar waste, the reference code doesn't specify tie-breaking.
            # Let's pick the first one encountered with the minimum waste.
            # A slightly better tie-breaker might be aspect ratio closeness, but min waste is simpler.
            if wasted_pixels < min_wasted_pixels:
                min_wasted_pixels = wasted_pixels
                best_fit_grid = (grid_h, grid_w)
            # Simple tie-breaking: if waste is equal, prefer larger area (less scaling down)
            elif wasted_pixels == min_wasted_pixels:
                 if best_fit_grid is None or (grid_h * grid_w > best_fit_grid[0] * best_fit_grid[1]):
                       best_fit_grid = (grid_h, grid_w)
                       
        # Fallback to base resolution if something went wrong
        if best_fit_grid is None:
            print(f"Warning: Could not determine best fit grid for H={original_height}, W={original_width}. Defaulting to base {self.base_image_size}x{self.base_image_size}.")
            best_fit_grid = (self.base_image_size, self.base_image_size)
            
        return best_fit_grid

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None, # Not used by this patcher
        text_features: Optional[torch.Tensor] = None, # Not used by this patcher
        raw_image: Optional[Image.Image] = None,
        **kwargs
    ) -> Tuple[None, Dict[str, Any]]: # Returns None for features, Dict for metadata
        """Determines the best grid resolution based on the raw image's aspect ratio.

        Args:
            raw_image: The original PIL Image object.
            pixel_values: Ignored by this patcher.
            text_features: Ignored by this patcher.
            **kwargs: Ignored.

        Returns:
            A tuple containing:
            - None: This patcher does not directly return processed features.
            - Dict[str, Any]: Metadata including:
                - 'strategy': 'variable_resolution'
                - 'selected_grid' (Tuple[int, int]): The chosen grid dimensions (H, W).
                - 'num_patches_h' (int): Number of patches vertically in the grid.
                - 'num_patches_w' (int): Number of patches horizontally in the grid.
                - 'total_patches' (int): Total number of patches in the selected grid.

        Raises:
            ValueError: If raw_image is not provided or patch_size is invalid.
        """
        if raw_image is None:
            raise ValueError("VariableResolutionPatcher requires the 'raw_image' (PIL Image) input.")

        original_width, original_height = raw_image.size
        
        # Select the best grid based on aspect ratio
        selected_grid_h, selected_grid_w = self.select_best_resolution(original_height, original_width)
        
        # Calculate number of patches for the selected grid
        if self.patch_size <= 0:
            raise ValueError("Patch size must be positive.")
        num_patches_h = selected_grid_h // self.patch_size
        num_patches_w = selected_grid_w // self.patch_size
        total_patches = num_patches_h * num_patches_w
        
        metadata = {
            'strategy': 'variable_resolution',
            'selected_grid': (selected_grid_h, selected_grid_w),
            'num_patches_h': num_patches_h,
            'num_patches_w': num_patches_w,
            'total_patches': total_patches
        }
        
        # This patcher returns metadata for the main model to use, not processed features.
        return None, metadata

In [8]:
show_doc(VariableResolutionPatcher)

---

### VariableResolutionPatcher

>      VariableResolutionPatcher (config:Dict[str,Any])

*Adaptive patcher that selects an optimal grid resolution based on image aspect ratio.

Inspired by the LLaVA-NeXT 'anyres' logic. It determines the target processing
dimensions but does not perform the actual image processing or feature extraction.
The main model's forward pass should use the output metadata to handle the image.*

In [9]:
#| test
import PIL.Image

# Create dummy config
test_config = {
    'model': {
        'adaptive_patcher': {
             'image_grid_pinpoints': [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
        },
        'vision_config': {
            'image_size': 336,
            'patch_size': 14
        }
    }
}

# Instantiate the patcher
patcher = VariableResolutionPatcher(test_config)

# Test cases with different aspect ratios
test_images = {
    "square": PIL.Image.new('RGB', (600, 600)),
    "tall": PIL.Image.new('RGB', (400, 800)),
    "wide": PIL.Image.new('RGB', (800, 400)),
    "very_wide": PIL.Image.new('RGB', (1200, 400)),
    "very_tall": PIL.Image.new('RGB', (400, 1200)),
    "large_square": PIL.Image.new('RGB', (800, 800)), # Test selection of larger grids
}

expected_grids = {
    "square": (336, 336), # Closest default is base 1:1
    "tall": (336, 672), # 1:2
    "wide": (672, 336), # 2:1
    "very_wide": (1008, 336), # 3:1
    "very_tall": (336, 1008), # 1:3
    "large_square": (672, 672), # Prefers larger grid for less scaling if waste is equal
}

for name, img in test_images.items():
    _, metadata = patcher(raw_image=img)
    selected_grid = metadata['selected_grid']
    total_patches = metadata['total_patches']
    expected_grid = expected_grids[name]
    w, h = img.size
    print(f"Image ({h}, {w}), Ratio {h/w:.2f} -> Selected Grid: {selected_grid}, Patches: {metadata['num_patches_h']}x{metadata['num_patches_w']}={total_patches}")
    assert selected_grid == expected_grid, f"Test '{name}' failed. Expected {expected_grid}, got {selected_grid}"

# Test ValueError if raw_image is None
try:
    patcher(raw_image=None)
    assert False, "Should have raised ValueError when raw_image is None"
except ValueError:
    pass # Expected

print("Variable Resolution Patcher tests passed.")

Initialized VariableResolutionPatcher with grid options: [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008], [336, 336]]
Image (600, 600), Ratio 1.00 -> Selected Grid: (336, 336), Patches: 24x24=576
Image (400, 800), Ratio 0.50 -> Selected Grid: (336, 672), Patches: 24x48=1152
Image (800, 400), Ratio 2.00 -> Selected Grid: (672, 336), Patches: 48x24=1152
Image (1200, 400), Ratio 3.00 -> Selected Grid: (1008, 336), Patches: 72x24=1728
Image (400, 1200), Ratio 0.33 -> Selected Grid: (336, 1008), Patches: 24x72=1728
Image (800, 800), Ratio 1.00 -> Selected Grid: (672, 672), Patches: 48x48=2304
Variable Resolution Patcher tests passed.


## Step 6.4: Define Adaptive LLaVA Model

This step involves creating `AdaptiveLLaVAModel` inheriting from `BaselineLLaVAModel` and integrating the chosen `AdaptivePatcher`.

In [10]:
#| export
# Placeholder for other potential patcher implementations
# class PredictorPatcher(AdaptivePatcher): ...
# class TextGuidedPatcher(AdaptivePatcher): ...

# Mapping from strategy name in config to Patcher class
PATCHER_STRATEGIES = {
    'variable_resolution': VariableResolutionPatcher,
    # 'predictor': PredictorPatcher, # Add when implemented
    # 'text_guided': TextGuidedPatcher, # Add when implemented
}

class AdaptiveLLaVAModel(BaselineLLaVAModel):
    """LLaVA Model extended with an Adaptive Patcher module.

    Inherits from BaselineLLaVAModel and adds an adaptive patcher component
    based on the configuration. The forward pass needs to be overridden
    to incorporate the patcher's logic.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the Adaptive LLaVA model.

        Loads baseline components and instantiates the specified adaptive patcher.

        Args:
            config: The main configuration dictionary.
        """
        # Initialize baseline components (Vision Tower, LLM, Projector)
        super().__init__(config)
        
        self.patcher = None
        patcher_config = self.config.get('model', {}).get('adaptive_patcher', {})
        patcher_enabled = patcher_config.get('enabled', False)
        patcher_strategy = patcher_config.get('strategy')

        if patcher_enabled and patcher_strategy:
            if patcher_strategy in PATCHER_STRATEGIES:
                PatcherClass = PATCHER_STRATEGIES[patcher_strategy]
                try:
                    self.patcher = PatcherClass(config) # Pass the full config
                    print(f"Adaptive Patcher enabled with strategy: '{patcher_strategy}' ({PatcherClass.__name__})")
                except Exception as e:
                    print(f"Error initializing patcher '{patcher_strategy}': {e}. Disabling patcher.")
                    self.patcher = None
            else:
                print(f"Warning: Unknown adaptive patcher strategy '{patcher_strategy}'. Disabling patcher.")
                self.patcher = None
        else:
            print("Adaptive Patcher is disabled in the configuration.")

    # --- Step 6.5: Implement Adaptive Forward Pass --- 
    # Override the forward pass to integrate the patcher logic
    def forward(self, 
                pixel_values: torch.Tensor,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None, 
                labels: Optional[torch.Tensor] = None,
                # Add raw_images potentially needed by patcher
                raw_images: Optional[List[Image.Image]] = None 
               ) -> CausalLMOutputWithPast: # Return type from transformers
        """Defines the forward pass of the Adaptive LLaVA model.

        If an adaptive patcher is enabled, it's called to determine patching strategy
        (metadata stored). The actual image features used currently come from the standard
        `pixel_values` input (base resolution), regardless of patcher output. This allows
        the structural integration without implementing complex variable feature handling yet.

        Args:
            pixel_values: Tensor of shape (batch_size, C, H, W) for the base resolution (e.g., 336x336).
            input_ids: Tensor of shape (batch_size, sequence_length) containing token IDs
                       and IMAGE_TOKEN_INDEX_PLACEHOLDER markers (-200).
            attention_mask: Optional tensor of shape (batch_size, sequence_length).
            labels: Optional tensor of shape (batch_size, sequence_length) corresponding
                    to input_ids (with -100 masking) for loss calculation.
            raw_images: Optional list of PIL Images for the batch, needed by some patchers.

        Returns:
            Output dictionary from the language model (transformers.CausalLMOutputWithPast).
        """
        # print("AdaptiveLLaVAModel forward pass called.") # Debug print
        patcher_metadata_batch = [None] * pixel_values.shape[0] # Initialize metadata list

        # --- Patcher Logic --- 
        if self.patcher is not None:
            if raw_images is None or len(raw_images) != pixel_values.shape[0]:
                 warnings.warn("Patcher is enabled but 'raw_images' not provided or length mismatch. Cannot run patcher logic.")
            else:
                # Iterate through batch to get metadata for each sample
                for i in range(pixel_values.shape[0]):
                    try:
                        _, patcher_metadata_sample = self.patcher(raw_image=raw_images[i])
                        patcher_metadata_batch[i] = patcher_metadata_sample
                        # print(f"Batch {i} Patcher metadata: {patcher_metadata_sample}") # Debug print
                    except Exception as e:
                        warnings.warn(f"Error running patcher for batch index {i}: {e}")
                # Store or log patcher_metadata_batch if needed
                # Currently, we log/print but don't change the feature processing path.
                self.current_patcher_metadata = patcher_metadata_batch # Store for potential inspection/logging
        
        # --- Baseline Feature Processing (Simplified path for now) --- 
        # 1. Encode Image (using standard pixel_values) & Project Features
        image_features = self.encode_image(pixel_values) # (B, P_base, D_clip)
        if image_features is None:
            raise RuntimeError("Image encoding failed.")
        projected_image_features = self.projector(image_features) # (B, P_base, D_llm)
        num_image_patches = projected_image_features.shape[1]

        # --- Prepare LLM inputs (Same as baseline) --- 
        input_ids_clone = input_ids.clone()
        input_ids_clone[input_ids_clone == self.image_token_index_marker] = 0 
        text_embeddings = self.get_input_embeddings()(input_ids_clone) 

        new_input_embeds = []
        new_labels = [] if labels is not None else None
        new_attention_mask = []

        for batch_idx in range(input_ids.shape[0]):
            image_token_indices = torch.where(input_ids[batch_idx] == self.image_token_index_marker)[0]
            if len(image_token_indices) == 0:
                warnings.warn(f"Image token placeholder {self.image_token_index_marker} not found in batch index {batch_idx}. Skipping image features.")
                new_input_embeds.append(text_embeddings[batch_idx])
                current_attention_mask = attention_mask[batch_idx] if attention_mask is not None else (input_ids[batch_idx] != tokenizer.pad_token_id).long()
                new_attention_mask.append(current_attention_mask)
                if new_labels is not None and labels is not None:
                    new_labels.append(labels[batch_idx])
                continue

            image_token_start_index = image_token_indices[0].item()

            text_emb_before = text_embeddings[batch_idx, :image_token_start_index]
            text_emb_after = text_embeddings[batch_idx, image_token_start_index + 1:]

            cur_new_embed = torch.cat([
                text_emb_before,
                projected_image_features[batch_idx].to(text_embeddings.device, dtype=text_embeddings.dtype),
                text_emb_after
            ], dim=0)
            new_input_embeds.append(cur_new_embed)

            current_attention_mask = attention_mask[batch_idx] if attention_mask is not None else (input_ids[batch_idx] != tokenizer.pad_token_id).long()
            mask_before = current_attention_mask[:image_token_start_index]
            mask_image = torch.ones(num_image_patches, dtype=torch.long, device=current_attention_mask.device)
            mask_after = current_attention_mask[image_token_start_index + 1:]
            cur_new_mask = torch.cat([
                mask_before,
                mask_image,
                mask_after
            ], dim=0)
            new_attention_mask.append(cur_new_mask)

            if new_labels is not None and labels is not None:
                label_before = labels[batch_idx, :image_token_start_index]
                label_image = torch.full((num_image_patches,), self.ignore_index, dtype=torch.long, device=labels.device)
                label_after = labels[batch_idx, image_token_start_index + 1:]
                cur_new_label = torch.cat([
                    label_before,
                    label_image,
                    label_after
                ], dim=0)
                new_labels.append(cur_new_label)

        # --- Padding (Same as baseline) --- 
        padded_input_embeds = pad_sequence(new_input_embeds, batch_first=True, padding_value=0.0)
        padded_attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
        padded_labels = None
        if new_labels is not None:
            padded_labels = pad_sequence(new_labels, batch_first=True, padding_value=self.ignore_index)

        # --- Pass to LLM (Same as baseline) --- 
        outputs: CausalLMOutputWithPast = self.language_model(
            inputs_embeds=padded_input_embeds,
            attention_mask=padded_attention_mask,
            labels=padded_labels,
            return_dict=True
        )

        return outputs

In [11]:
show_doc(AdaptiveLLaVAModel)

---

### AdaptiveLLaVAModel

>      AdaptiveLLaVAModel (config:Dict[str,Any])

*LLaVA Model extended with an Adaptive Patcher module.

    Inherits from BaselineLLaVAModel and adds an adaptive patcher component
    based on the configuration. The forward pass needs to be overridden
    to incorporate the patcher's logic.*

In [12]:
#| test
import torch, gc
from transformers import AutoModelForCausalLM, CLIPVisionModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from PIL import Image

try:
    config_path = '../configs/config.yaml'
    test_config = load_config(config_path)
    print(f"Loaded config from {config_path}")

    # --- Test with Adaptive Patcher Enabled --- 
    if 'model' not in test_config: test_config['model'] = {}
    if 'adaptive_patcher' not in test_config['model']: test_config['model']['adaptive_patcher'] = {}
    test_config['model']['adaptive_patcher']['enabled'] = True
    test_config['model']['adaptive_patcher']['strategy'] = 'variable_resolution' # Test this strategy
    # Disable LoRA and Checkpointing for simplicity in this forward pass test
    if 'peft' not in test_config['model']: test_config['model']['peft'] = {}
    test_config['model']['peft']['use_lora'] = False 
    test_config['model']['use_activation_checkpointing'] = False

    # Instantiate the adaptive model
    test_adaptive_model = AdaptiveLLaVAModel(test_config)
    test_adaptive_model.eval() # Set to eval mode

    print("Running Adaptive Forward Pass Test...")
    # Prepare dummy inputs
    batch_size = 2
    seq_len = 15 # Short sequence for testing
    img_size = 336 # From config
    num_patches = 576 # (336/14)^2 - Base patches
    llm_hidden_dim = test_config['model']['projector']['output_dim']
    tokenizer_vocab_size = len(tokenizer)

    dummy_pixel_values = torch.randn(batch_size, 3, img_size, img_size)
    # Create input_ids with placeholder
    dummy_input_ids = torch.randint(1, tokenizer_vocab_size, (batch_size, seq_len), dtype=torch.long)
    placeholder_idx = 5 # Place the image token marker at index 5
    dummy_input_ids[:, placeholder_idx] = IMAGE_TOKEN_INDEX_PLACEHOLDER
    dummy_attention_mask = torch.ones_like(dummy_input_ids)
    # Create labels (copy input_ids, mask placeholder and potentially prompt)
    dummy_labels = dummy_input_ids.clone()
    dummy_labels[:, :placeholder_idx+1] = IGNORE_INDEX # Mask prompt + image token
    
    # Create dummy raw images
    dummy_raw_images = [
        Image.new('RGB', (600, 600)), # Square
        Image.new('RGB', (400, 800))  # Tall
    ]

    # Perform forward pass
    with torch.no_grad(): 
         outputs = test_adaptive_model(
            pixel_values=dummy_pixel_values,
            input_ids=dummy_input_ids,
            attention_mask=dummy_attention_mask,
            labels=dummy_labels,
            raw_images=dummy_raw_images # Pass raw images for patcher
         )

    # Check output type and attributes
    assert isinstance(outputs, CausalLMOutputWithPast)
    assert hasattr(outputs, 'logits')
    assert outputs.logits is not None
    assert hasattr(outputs, 'loss') # Should have loss since labels were provided
    assert outputs.loss is not None

    # Check output shapes (should match baseline for now)
    expected_seq_len = seq_len - 1 + num_patches
    expected_logits_shape = (batch_size, expected_seq_len, tokenizer_vocab_size)

    assert outputs.logits.shape == expected_logits_shape, \
        f"Expected logits shape {expected_logits_shape}, but got {outputs.logits.shape}"
    print(f"Output logits shape: {outputs.logits.shape}")
    print(f"Output loss: {outputs.loss}")
    
    # Check if patcher metadata was stored (optional check)
    assert hasattr(test_adaptive_model, 'current_patcher_metadata')
    assert len(test_adaptive_model.current_patcher_metadata) == batch_size
    assert test_adaptive_model.current_patcher_metadata[0]['selected_grid'] == (336, 336) # Expected for square image
    assert test_adaptive_model.current_patcher_metadata[1]['selected_grid'] == (336, 672) # Expected for tall image
    
    print("Adaptive Forward Pass test successful!")

except FileNotFoundError:
    print(f"Config file {config_path} not found. Skipping AdaptiveLLaVAModel test.")
except ImportError as e:
    print(f"Skipping test due to ImportError: {e}. (Likely `peft` is missing)")
except Exception as e:
    print(f"An error occurred during AdaptiveLLaVAModel test: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Clean up
    if 'test_adaptive_model' in locals():
        if hasattr(test_adaptive_model, 'vision_tower') and test_adaptive_model.vision_tower is not None:
            test_adaptive_model.vision_tower.to('cpu')
        if hasattr(test_adaptive_model, 'language_model') and test_adaptive_model.language_model is not None:
            test_adaptive_model.language_model.to('cpu')
        if hasattr(test_adaptive_model, 'projector') and test_adaptive_model.projector is not None:
            test_adaptive_model.projector.to('cpu')
        del test_adaptive_model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Cleaned up test_adaptive_model")


Loaded config from ../configs/config.yaml
Initializing Projector: Input Dim=1024, Output Dim=4096
Loading Vision Tower: openai/clip-vit-large-patch14-336...


Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.enco

Vision Tower loaded successfully.
Vision Tower weights frozen.
Loading Language Model: lmsys/vicuna-7b-v1.5...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Language Model loaded successfully.
Base Language Model weights frozen.
LoRA is disabled in the configuration.
LLM embedding size already matches tokenizer size. No resizing needed.
Activation checkpointing is disabled in the configuration.
Adaptive Patcher enabled with strategy: 'variable_resolution' (VariableResolutionPatcher)
Initialized VariableResolutionPatcher with grid options: [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008], [336, 336]]
Running Adaptive Forward Pass Test...
Batch 0 Patcher metadata: {'strategy': 'variable_resolution', 'selected_grid': (336, 336), 'num_patches_h': 24, 'num_patches_w': 24, 'total_patches': 576}
Batch 1 Patcher metadata: {'strategy': 'variable_resolution', 'selected_grid': (336, 672), 'num_patches_h': 24, 'num_patches_w': 48, 'total_patches': 1152}
Output logits shape: torch.Size([2, 586, 32001])
Output loss: tensor(10.4655, grad_fn=<NllLossBackward0>)
Adaptive Forward Pass test successful!
Cleaned up test_adaptive_model


In [13]:
#| hide
import nbdev; nbdev.nbdev_export()