# Adaptive Model Components

> Defines adaptive components for the LLaVA model, starting with the Adaptive Patcher interface.

In [1]:
#| default_exp model.adaptive

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose
    pass

Adding project root to sys.path: /workspace/llava


In [4]:
#| export
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Optional, Tuple, List
from PIL import Image

## Step 6.1: Define Adaptive Patcher Interface

This base class defines the interface for any adaptive patching strategy. Subclasses will implement specific logic (e.g., variable resolution, attention-based patching).

In [5]:
#| export
class AdaptivePatcher(nn.Module):
    """Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the Adaptive Patcher.

        Args:
            config: Dictionary containing configuration relevant to the patcher strategy.
        """
        super().__init__()
        self.config = config
        # Potentially load sub-modules or parameters based on config['strategy']

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None, # Existing features (e.g., from base processing)
        text_features: Optional[torch.Tensor] = None, # For text-guided strategies
        raw_image: Optional[Image.Image] = None, # Original image for properties like aspect ratio
        **kwargs
    ) -> Tuple[Optional[torch.Tensor], Optional[Dict[str, Any]]]:
        """Processes the input image adaptively or determines processing strategy.

        The exact inputs used and outputs returned depend on the specific strategy.
        For example, a strategy predictor might only need pooled features,
        while a variable resolution strategy primarily needs the raw image aspect ratio.
        A text-guided strategy needs text_features and potentially patch features.

        Args:
            pixel_values: Optional preprocessed image tensor (e.g., B x C x H x W).
                           Could represent the full image or specific patches.
            text_features: Optional tensor containing embeddings of the instruction text.
                           Needed for text-guided patching strategies.
            raw_image: Optional PIL image, potentially needed for calculating aspect ratio
                       or other properties not easily derived from pixel_values alone.
            **kwargs: Additional keyword arguments specific to the patching strategy.

        Returns:
            A tuple containing:
            - Optional[torch.Tensor]: Processed/selected image features, if the patcher
                                    directly outputs features. Shape might vary.
                                    Can be None if the patcher only outputs metadata.
            - Optional[Dict[str, Any]]: Metadata about the patching process or decision.
                                      (e.g., strategy used, number of patches, selected grid).
                                      Can be None if no metadata is generated.

        Raises:
            NotImplementedError: This base method must be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")

In [6]:
show_doc(AdaptivePatcher)

---

### AdaptivePatcher

>      AdaptivePatcher (config:Dict[str,Any])

*Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.*

## Step 6.2: Implement Variable Resolution Patcher Strategy

This patcher implements the variable resolution strategy inspired by LLaVA-NeXT. It analyzes the input image's aspect ratio and selects the optimal grid configuration from a predefined set (`image_grid_pinpoints`). The actual image resizing and patching happen elsewhere, based on the grid configuration returned by this patcher.

In [7]:
#| export
class VariableResolutionPatcher(AdaptivePatcher):
    """Adaptive patcher that selects an optimal grid resolution based on image aspect ratio.

    Inspired by the LLaVA-NeXT 'anyres' logic. It determines the target processing
    dimensions but does not perform the actual image processing or feature extraction.
    The main model's forward pass should use the output metadata to handle the image.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the VariableResolutionPatcher.

        Args:
            config: Dictionary containing configuration. Expected keys:
                    'model.adaptive_patcher.image_grid_pinpoints': List of [H, W] grids.
                    'model.vision_config.image_size': Base image size (e.g., 336).
                    'model.vision_config.patch_size': Patch size (e.g., 14).
        """
        super().__init__(config)
        # Default grid from LLaVA-NeXT / Project Spec
        default_grid = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
        self.image_grid_pinpoints = self.config.get('model', {}).get('adaptive_patcher', {}).get('image_grid_pinpoints', default_grid)
        # Add the base resolution to the grid options
        self.base_image_size = self.config.get('model', {}).get('vision_config', {}).get('image_size', 336)
        self.patch_size = self.config.get('model', {}).get('vision_config', {}).get('patch_size', 14)
        
        # Ensure base resolution is included as an option (implicitly or explicitly)
        base_grid = [self.base_image_size, self.base_image_size]
        if base_grid not in self.image_grid_pinpoints:
             self.image_grid_pinpoints.append(base_grid)
        print(f"Initialized VariableResolutionPatcher with grid options: {self.image_grid_pinpoints}")

    def select_best_resolution(self, original_height: int, original_width: int) -> Tuple[int, int]:
        """Selects the best grid resolution based on aspect ratio and minimizing waste.

        Args:
            original_height: Height of the raw input image.
            original_width: Width of the raw input image.

        Returns:
            Tuple[int, int]: The selected best grid dimensions (height, width).
        """
        original_aspect_ratio = original_height / original_width

        best_fit_grid = None
        min_wasted_pixels = float('inf')

        for grid_h, grid_w in self.image_grid_pinpoints:
            grid_aspect_ratio = grid_h / grid_w

            # Calculate the dimensions if the image were scaled to fit this grid
            # Scale based on the limiting dimension
            scale_h = grid_h / original_height
            scale_w = grid_w / original_width

            if original_aspect_ratio > grid_aspect_ratio: # Scale based on height
                scaled_h = grid_h
                scaled_w = int(original_width * scale_h)
            else: # Scale based on width
                scaled_w = grid_w
                scaled_h = int(original_height * scale_w)
            
            # Ensure scaled dimensions do not exceed grid dimensions (due to int conversion)
            scaled_h = min(scaled_h, grid_h)
            scaled_w = min(scaled_w, grid_w)
            
            # Calculate wasted area (pixels in the grid not covered by the scaled image)
            grid_area = grid_h * grid_w
            scaled_image_area = scaled_h * scaled_w
            wasted_pixels = grid_area - scaled_image_area

            # Prefer grids with less waste
            # Among grids with similar waste, the reference code doesn't specify tie-breaking.
            # Let's pick the first one encountered with the minimum waste.
            # A slightly better tie-breaker might be aspect ratio closeness, but min waste is simpler.
            if wasted_pixels < min_wasted_pixels:
                min_wasted_pixels = wasted_pixels
                best_fit_grid = (grid_h, grid_w)
            # Simple tie-breaking: if waste is equal, prefer larger area (less scaling down)
            elif wasted_pixels == min_wasted_pixels:
                 if best_fit_grid is None or (grid_h * grid_w > best_fit_grid[0] * best_fit_grid[1]):
                       best_fit_grid = (grid_h, grid_w)
                       
        # Fallback to base resolution if something went wrong
        if best_fit_grid is None:
            print(f"Warning: Could not determine best fit grid for H={original_height}, W={original_width}. Defaulting to base {self.base_image_size}x{self.base_image_size}.")
            best_fit_grid = (self.base_image_size, self.base_image_size)
            
        return best_fit_grid

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None, # Not used by this patcher
        text_features: Optional[torch.Tensor] = None, # Not used by this patcher
        raw_image: Optional[Image.Image] = None,
        **kwargs
    ) -> Tuple[None, Dict[str, Any]]: # Returns None for features, Dict for metadata
        """Determines the best grid resolution based on the raw image's aspect ratio.

        Args:
            raw_image: The original PIL Image object.
            pixel_values: Ignored by this patcher.
            text_features: Ignored by this patcher.
            **kwargs: Ignored.

        Returns:
            A tuple containing:
            - None: This patcher does not directly return processed features.
            - Dict[str, Any]: Metadata including:
                - 'selected_grid' (Tuple[int, int]): The chosen grid dimensions (H, W).
                - 'num_patches_h' (int): Number of patches vertically in the grid.
                - 'num_patches_w' (int): Number of patches horizontally in the grid.
                - 'total_patches' (int): Total number of patches in the selected grid.

        Raises:
            ValueError: If raw_image is not provided.
        """
        if raw_image is None:
            raise ValueError("VariableResolutionPatcher requires the 'raw_image' (PIL Image) input.")

        original_width, original_height = raw_image.size
        
        # Select the best grid based on aspect ratio
        selected_grid_h, selected_grid_w = self.select_best_resolution(original_height, original_width)
        
        # Calculate number of patches for the selected grid
        if self.patch_size <= 0:
            raise ValueError("Patch size must be positive.")
        num_patches_h = selected_grid_h // self.patch_size
        num_patches_w = selected_grid_w // self.patch_size
        total_patches = num_patches_h * num_patches_w
        
        metadata = {
            'strategy': 'variable_resolution',
            'selected_grid': (selected_grid_h, selected_grid_w),
            'num_patches_h': num_patches_h,
            'num_patches_w': num_patches_w,
            'total_patches': total_patches
        }
        
        # This patcher returns metadata for the main model to use, not processed features.
        return None, metadata

In [8]:
show_doc(VariableResolutionPatcher)

---

### VariableResolutionPatcher

>      VariableResolutionPatcher (config:Dict[str,Any])

*Adaptive patcher that selects an optimal grid resolution based on image aspect ratio.

Inspired by the LLaVA-NeXT 'anyres' logic. It determines the target processing
dimensions but does not perform the actual image processing or feature extraction.
The main model's forward pass should use the output metadata to handle the image.*

In [9]:
#| test
import PIL.Image

# Create dummy config
test_config = {
    'model': {
        'adaptive_patcher': {
             'image_grid_pinpoints': [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
        },
        'vision_config': {
            'image_size': 336,
            'patch_size': 14
        }
    }
}

# Instantiate the patcher
patcher = VariableResolutionPatcher(test_config)

# Test cases with different aspect ratios
test_images = {
    "square": PIL.Image.new('RGB', (600, 600)),
    "tall": PIL.Image.new('RGB', (400, 800)),
    "wide": PIL.Image.new('RGB', (800, 400)),
    "very_wide": PIL.Image.new('RGB', (1200, 400)),
    "very_tall": PIL.Image.new('RGB', (400, 1200)),
    "large_square": PIL.Image.new('RGB', (800, 800)), # Test selection of larger grids
}

expected_grids = {
    "square": (336, 336), # Closest default is base 1:1
    "tall": (336, 672), # 1:2
    "wide": (672, 336), # 2:1
    "very_wide": (1008, 336), # 3:1
    "very_tall": (336, 1008), # 1:3
    "large_square": (672, 672), # Prefers larger grid for less scaling if waste is equal
}

for name, img in test_images.items():
    _, metadata = patcher(raw_image=img)
    selected_grid = metadata['selected_grid']
    total_patches = metadata['total_patches']
    expected_grid = expected_grids[name]
    w, h = img.size
    print(f"Image ({h}, {w}), Ratio {h/w:.2f} -> Selected Grid: {selected_grid}, Patches: {metadata['num_patches_h']}x{metadata['num_patches_w']}={total_patches}")
    assert selected_grid == expected_grid, f"Test '{name}' failed. Expected {expected_grid}, got {selected_grid}"

# Test ValueError if raw_image is None
try:
    patcher(raw_image=None)
    assert False, "Should have raised ValueError when raw_image is None"
except ValueError:
    pass # Expected

print("Variable Resolution Patcher tests passed.")

Initialized VariableResolutionPatcher with grid options: [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008], [336, 336]]
Image (600, 600), Ratio 1.00 -> Selected Grid: (336, 336), Patches: 24x24=576
Image (400, 800), Ratio 0.50 -> Selected Grid: (336, 672), Patches: 24x48=1152
Image (800, 400), Ratio 2.00 -> Selected Grid: (672, 336), Patches: 48x24=1152
Image (1200, 400), Ratio 3.00 -> Selected Grid: (1008, 336), Patches: 72x24=1728
Image (400, 1200), Ratio 0.33 -> Selected Grid: (336, 1008), Patches: 24x72=1728
Image (800, 800), Ratio 1.00 -> Selected Grid: (672, 672), Patches: 48x48=2304
Variable Resolution Patcher tests passed.


## Step 6.4: Define Adaptive LLaVA Model

This step will involve creating `AdaptiveLLaVAModel` inheriting from `BaselineLLaVAModel` and integrating the chosen `AdaptivePatcher`.

In [10]:
# Placeholder for AdaptiveLLaVAModel definition (Step 6.4)
# from llava.model.baseline import BaselineLLaVAModel
#
# class AdaptiveLLaVAModel(BaselineLLaVAModel):
#     def __init__(self, config: Dict[str, Any]):
#         super().__init__(config)
#         patcher_strategy = config.get('model', {}).get('adaptive_patcher', {}).get('strategy')
#         if patcher_strategy == 'variable_resolution':
#             self.patcher = VariableResolutionPatcher(config)
#         # Add other strategies here
#         # elif patcher_strategy == 'predictor':
#         #     self.patcher = PredictorPatcher(config)
#         else:
#             print(f"Warning: Unknown or no adaptive patcher strategy specified ('{patcher_strategy}'). Using baseline behavior.")
#             self.patcher = None
#
#     # Override forward pass (Step 6.5)
#     def forward(self, ...):
#         # ... implementation ...
#         pass


In [11]:
#| hide
import nbdev; nbdev.nbdev_export()