Skip to content

Fix model infrastructure edge cases#13709

Draft
taivu1998 wants to merge 1 commit intohuggingface:mainfrom
taivu1998:tdv/issue-13655-model-infra
Draft

Fix model infrastructure edge cases#13709
taivu1998 wants to merge 1 commit intohuggingface:mainfrom
taivu1998:tdv/issue-13655-model-infra

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #13655.

This PR addresses several small model-infrastructure edge cases that can surface as runtime failures or state leaks:

  • restores AutoModel.config_name after model-index probing, config fallback, and exceptions
  • supports local PathLike model paths when building _diffusers_load_id
  • clears CacheMixin.cache_context() hook state even when the wrapped block raises
  • keeps FIR up/downsampling kernels on the input device and dtype for low-precision tensors
  • fixes XLA availability checks to call the helper and raise real ImportErrors
  • fixes enable_parallelism() to reject unavailable or uninitialized distributed setups
  • preserves unused kwargs when FlaxModelMixin.from_pretrained() receives an explicit config
  • defaults 1D sin/cos positional embedding calculations to torch.float32

Root Cause

The failures come from a few independent guard/state issues:

  • AutoModel.from_pretrained() temporarily mutated the class-level config_name without guaranteed restoration and joined load-id parts without stringifying path objects.
  • cache_context() reset hook context only on normal exit, so exceptions could leave stateful hooks bound to a stale context.
  • FIR helpers rebuilt kernels as float32 tensors after receiving low-precision inputs, which could mismatch bfloat16 or other low-precision activations/weights.
  • A few availability checks tested function objects instead of calling them, raised strings instead of exceptions, or used and where or was required for guard logic.
  • The Flax explicit-config path skipped initialization of unused_kwargs.

Validation

Ran focused and broader checks locally:

  • pytest tests/models/test_modeling_utils.py -q
  • pytest tests/models/test_modeling_flax_utils.py -q
  • pytest tests/hooks/test_hooks.py -q -k cache_context_clears_stateful_hook_context_after_exception
  • pytest tests/models/test_attention_processor.py -q -k 'xla_flash_attention or is_torch_xla_version'
  • pytest tests/models/test_layers_utils.py -q -k 'sincos_pos_embed or fir_bfloat16 or fir_upsample_with_conv_bfloat16 or fir_downsample_with_conv_bfloat16'
  • pytest tests/models/test_models_auto.py -q -k 'local_model_index_pathlike or local_config_fallback or load_config_exception'
  • pytest tests/hooks/test_hooks.py -q
  • pytest tests/models/test_layers_utils.py -q
  • uvx ruff check ...
  • uvx ruff format --check ...
  • git diff --check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

model_infrastructure model/pipeline review

1 participant