Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/winml/modelkit/loader/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,43 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t
"Please specify model_class explicitly."
)

# [3a] Per-model-type default task override.
# Some model families (e.g., SAM/SAM2) have an architecture class whose
# default TasksManager mapping ("feature-extraction") differs from the
# canonical export target ("mask-generation"). The default is encoded as
# a sentinel entry MODEL_CLASS_MAPPING[(model_type, None)] = <class>;
# we reverse-lookup the task name from the matching
# (model_type, default_task) -> same_class entry. This keeps the data in
# one table and structurally enforces that the matching class entry exists.
from ..models.hf import MODEL_CLASS_MAPPING

model_type_normalized = model_type.lower().replace("_", "-")
default_class = MODEL_CLASS_MAPPING.get((model_type_normalized, None))
if default_class is not None:
default_task = next(
(
t
for (mt, t), cls in MODEL_CLASS_MAPPING.items()
if mt == model_type_normalized and t is not None and cls is default_class
),
None,
)
if default_task is None:
raise ValueError(
f"MODEL_CLASS_MAPPING has ({model_type_normalized!r}, None) sentinel "
f"-> {default_class.__name__}, but no matching "
f"({model_type_normalized!r}, <task>) entry maps to that class. "
f"Add the corresponding (model_type, task) entry."
)
if default_task != task:
logger.info(
"Overriding auto-detected task %r with model-type default %r for %s",
task,
default_task,
model_type_normalized,
)
return default_task, default_class

# [4] Check specializations first (CLIP, SAM2, etc.) - highest priority
model_class = _get_custom_model_class(model_type, task)
if model_class:
Expand Down
10 changes: 8 additions & 2 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,14 @@
from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration


# Aggregated model class mappings: (model_type, task) -> HF model class
MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
# Aggregated model class mappings: (model_type, task) -> HF model class.
#
# A sentinel entry with task=None encodes the per-model-type default task
# applied during auto-detection. Its value is the default class; the resolver
# reverse-looks-up the task name from the matching (model_type, default_task)
# entry. See sam.py for the canonical example (mask-generation default for
# SAM/SAM2).
MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = {
**_BART_CLASS_MAPPING,
**_CLIP_CLASS_MAPPING,
**_MARIAN_CLASS_MAPPING,
Expand Down
29 changes: 21 additions & 8 deletions src/winml/modelkit/models/hf/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,31 @@ def forward(

# (model_type, task) -> HuggingFace model class
#
# Why SAM2 needs class mapping:
# TasksManager detects "feature-extraction" by default for Sam2VideoModel.
# We override to Sam2VisionEncoder for encoder-only export (loads parent
# Sam2VideoModel and extracts vision_encoder to get correct weights).
# "image-feature-extraction" routes perf pipeline to ImageDataset.
# Users wanting the full model use --task image-segmentation.

MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
# A sentinel entry with task=None encodes the per-model-type default task
# applied during auto-detection (when the user does not pass --task). Its
# value is the default *class*; the resolver reverse-looks-up the task name
# from the matching (model_type, default_task) -> same_class entry. Encoding
# the default inside MODEL_CLASS_MAPPING — instead of a parallel
# MODEL_TASK_DEFAULTS table — keeps the data in one place and structurally
# enforces that a matching class entry must exist (else reverse lookup fails).
#
# Why SAM/SAM2 need this:
# TasksManager.infer_task_from_model() returns "feature-extraction" for
# SamModel / Sam2Model, but the canonical export target for these
# architectures is the mask-generation decoder wrapper. Encoder-only entries
# remain so users can opt in via --task feature-extraction /
# image-feature-extraction (the latter routes perf pipeline to ImageDataset).
# Users wanting the full encoder+decoder monolith use --task image-segmentation.

MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = {
("sam", None): SAMMaskGeneration,
("sam", "mask-generation"): SAMMaskGeneration,
("sam2", None): SAM2MaskGeneration,
("sam2", "image-segmentation"): Sam2Model,
("sam2", "feature-extraction"): Sam2VisionEncoder,
("sam2", "image-feature-extraction"): Sam2VisionEncoder,
("sam2", "mask-generation"): SAM2MaskGeneration,
("sam2-video", None): SAM2MaskGeneration,
("sam2-video", "image-segmentation"): Sam2Model,
("sam2-video", "feature-extraction"): Sam2VisionEncoder,
("sam2-video", "image-feature-extraction"): Sam2VisionEncoder,
Expand Down Expand Up @@ -1023,6 +1035,7 @@ def outputs(self) -> dict[str, dict[int, str]]:


__all__ = [
"MODEL_CLASS_MAPPING",
"SAM2MaskGeneration",
"SAMMaskGeneration",
"Sam2IOConfig",
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/loader/test_detect_task_and_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,69 @@ def test_fallback_to_arch_class_when_tasksmanager_fails(self):
assert task == "image-text-to-text"
# Should fallback to architecture class
assert resolved_class == BlipForConditionalGeneration


class TestModelTaskDefaultsOverride:
"""Tests for per-model-type default-task auto-detection override.

Some model families (e.g., SAM/SAM2) have an architecture class whose
default TasksManager mapping ("feature-extraction") differs from the
canonical export target ("mask-generation"). The default is encoded as a
MODEL_CLASS_MAPPING[(model_type, None)] sentinel entry that biases
auto-detection toward the right export configuration when --task is
not provided.
"""

def test_sam2_video_defaults_to_mask_generation(self):
"""Sam2Model on sam2_video config auto-detects to mask-generation."""
# Trigger HF model registrations (loads SAM sentinel entries)
import winml.modelkit.models.hf # noqa: F401
from winml.modelkit.models.hf.sam import SAM2MaskGeneration

config = MagicMock()
config.architectures = ["Sam2Model"]
config.model_type = "sam2_video"

task, resolved_class = _detect_task_and_class_from_config(config)

assert task == "mask-generation"
assert resolved_class is SAM2MaskGeneration

def test_sam_defaults_to_mask_generation(self):
"""SamModel on sam config auto-detects to mask-generation."""
import winml.modelkit.models.hf # noqa: F401
from winml.modelkit.models.hf.sam import SAMMaskGeneration

config = MagicMock()
config.architectures = ["SamModel"]
config.model_type = "sam"

task, resolved_class = _detect_task_and_class_from_config(config)

assert task == "mask-generation"
assert resolved_class is SAMMaskGeneration

def test_model_type_underscore_normalized(self):
"""sam2_video (underscore) matches sam2-video (hyphen) in MODEL_CLASS_MAPPING."""
import winml.modelkit.models.hf # noqa: F401

config = MagicMock()
config.architectures = ["Sam2Model"]
config.model_type = "sam2_video"

task, _ = _detect_task_and_class_from_config(config)
assert task == "mask-generation"

def test_no_override_for_unrelated_model(self):
"""Models without a (model_type, None) sentinel keep TasksManager-inferred task."""
from transformers import ResNetForImageClassification

config = MagicMock()
config.architectures = ["ResNetForImageClassification"]
config.model_type = "resnet"

task, resolved_class = _detect_task_and_class_from_config(config)

assert task == "image-classification"
# TasksManager returns AutoModelForImageClassification, not the arch class
assert resolved_class is not ResNetForImageClassification or task == "image-classification"
Loading