# Adaptive Model Components

> Defines adaptive components for the LLaVA model, starting with the Adaptive Patcher interface.

In [None]:
#| default_exp model.adaptive

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose
    pass

Adding project root to sys.path: /workspace/llava


In [None]:
#| export
import torch
import torch.nn as nn
from typing import Dict, Any, Optional, Tuple
from PIL import Image

## Step 6.1: Define Adaptive Patcher Interface

This base class defines the interface for any adaptive patching strategy. Subclasses will implement specific logic (e.g., variable resolution, attention-based patching).

In [None]:
#| export
class AdaptivePatcher(nn.Module):
    """Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.
    """
    def __init__(self, config: Dict[str, Any]):
        """Initializes the Adaptive Patcher.

        Args:
            config: Dictionary containing configuration relevant to the patcher strategy.
        """
        super().__init__()
        self.config = config
        # Potentially load sub-modules or parameters based on config['strategy']

    def forward(
        self, 
        pixel_values: torch.Tensor, 
        text_features: Optional[torch.Tensor] = None, 
        raw_image: Optional[Image.Image] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
        """Processes the input image adaptively.

        Args:
            pixel_values: Preprocessed image tensor (e.g., B x C x H x W).
                           Could represent the full image or specific patches.
            text_features: Optional tensor containing embeddings of the instruction text (e.g., B x S_txt x D_txt).
                           Needed for text-guided patching strategies.
            raw_image: Optional PIL image, potentially needed for calculating aspect ratio 
                       or other properties not easily derived from pixel_values alone.
            **kwargs: Additional keyword arguments specific to the patching strategy.

        Returns:
            A tuple containing:
            - torch.Tensor: The processed image features ready for projection 
                            (e.g., selected patches, global+local combination).
                            Shape might be (B, NumFeatures, D_vision) or similar.
            - Optional[Dict[str, Any]]: Optional dictionary containing metadata about the patching 
                                      (e.g., strategy used, number of patches).
                                      This can be useful for debugging or conditional logic later.
        
        Raises:
            NotImplementedError: This base method must be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")

In [None]:
show_doc(AdaptivePatcher)

---

### AdaptivePatcher

>      AdaptivePatcher (config:Dict[str,Any])

*Base interface for adaptive image patching modules.

    Subclasses should implement the `forward` method to dynamically process
    an input image (or its features) based on content or context (like text instructions)
    and return a structured representation of image features for the projector.*

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()