Skip to content

Sam2 add defaut task config.#426

Merged
chinazhangchao merged 9 commits into
mainfrom
chao/sam2
May 9, 2026
Merged

Sam2 add defaut task config.#426
chinazhangchao merged 9 commits into
mainfrom
chao/sam2

Conversation

@chinazhangchao
Copy link
Copy Markdown
Contributor

No description provided.

@chinazhangchao chinazhangchao changed the title Chao/sam2 Sam2 add defaut task config. Apr 30, 2026
@chinazhangchao chinazhangchao marked this pull request as ready for review April 30, 2026 08:59
@chinazhangchao chinazhangchao requested a review from a team as a code owner April 30, 2026 08:59
@DingmaomaoBJTU
Copy link
Copy Markdown
Collaborator

Design feedback: MODEL_TASK_DEFAULTS introduces a second parallel registry

sam.py now expresses the same fact — "sam2's canonical export task is mask-generation" — in two places:

# 1. MODEL_TASK_DEFAULTS
MODEL_TASK_DEFAULTS = {"sam2": "mask-generation"}

# 2. MODEL_CLASS_MAPPING (already implies this via its primary entry)
MODEL_CLASS_MAPPING = {
    ("sam2", "mask-generation"): SAM2MaskGeneration,  # same information
    ("sam2", "feature-extraction"): Sam2VisionEncoder,
}

The value of MODEL_TASK_DEFAULTS["sam2"] and the key of MODEL_CLASS_MAPPING's primary entry are the same string. Adding a new model now requires keeping two tables in sync, with no enforcement — it's easy to define MODEL_TASK_DEFAULTS["newmodel"] = "task-a" while forgetting to add ("newmodel", "task-a") to MODEL_CLASS_MAPPING, causing a silent wrong-class fallback.

Alternative: encode the default task as a None-sentinel entry in MODEL_CLASS_MAPPING

# sam.py — one table instead of two
MODEL_CLASS_MAPPING = {
    ("sam2", None): "mask-generation",         # None = default task for auto-detection
    ("sam2", "mask-generation"): SAM2MaskGeneration,
    ("sam2", "feature-extraction"): Sam2VisionEncoder,
    ...
}

_get_custom_model_class reads the sentinel before the class lookup:

def _get_custom_model_class(model_type: str, task: str) -> tuple[str, type | None]:
    default_task = MODEL_CLASS_MAPPING.get((model_type_normalized, None))
    resolved_task = default_task or task
    return resolved_task, MODEL_CLASS_MAPPING.get((model_type_normalized, resolved_task))

Benefits:

  • Single table; the sentinel entry structurally enforces that a matching class entry must exist
  • task.py no longer imports MODEL_TASK_DEFAULTS — all model-specific data flows through the existing MODEL_CLASS_MAPPING path
  • Adding a new model still only requires editing hf/newmodel.py

The only cost is that MODEL_CLASS_MAPPING's value type becomes str | type (sentinel entries hold a task string, others hold a class). Worth considering as a follow-up if this pattern is going to extend to other model families.

🤖 Generated with Claude Code

@chinazhangchao chinazhangchao requested a review from tezheng May 6, 2026 05:50
Comment thread src/winml/modelkit/loader/task.py Outdated
@chinazhangchao chinazhangchao merged commit bed41b6 into main May 9, 2026
9 checks passed
@chinazhangchao chinazhangchao deleted the chao/sam2 branch May 9, 2026 07:51
ssss141414 pushed a commit that referenced this pull request May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants