fix tie_weights skipping logic is not tied to model thread scope#44940
fix tie_weights skipping logic is not tied to model thread scope#44940Qubitium wants to merge 11 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Humm, in general I'm not against, it's indeed something I've thought about. However, more changes would be required to make from_pretrained thread-safe in general. For example, in initialization, we sometimes monkey patch torch with guard_torch_init_functions similarly to no_tie_weights.
Also, note that in general multi-threading for model loading is an extremely weird pattern. There's not really any reason to do it. Multi-processing is much more indicated if you really need to load models concurrently. But in the end, the bottleneck to loading is almost always I/O or device transfer, so concurrency will not really make things faster anyway, most likely the opposite. Also, by default loading will launch several threads to load weights from disk, one more reason why multiprocessing would be much more recommended vs multithreading.
| from .modeling_utils import PreTrainedModel | ||
|
|
||
| def empty_func(*args, **kwargs): | ||
| pass | ||
|
|
||
| # Use an opaque scope token so nested or concurrent loads can identify only | ||
| # the models instantiated under this context manager. | ||
| state_token = _SKIP_TIE_WEIGHTS_SCOPE.set(object()) |
There was a problem hiding this comment.
We should keep the functionality under this decarator. We can simply check the value set to the context var or something, instead of delaying to modeling_utils
There was a problem hiding this comment.
Ok. Let me check if this is dosable to reduce the impact to modeling_utils and limit changes to the decorator as much as possbile. But I would still want remove the very ugly set empty_func method currently in place.
There was a problem hiding this comment.
@Cyrilvallez I simulated some threaded scenarios and keeping the monkeypatch is still not threadsafe, even we keep the current empty_func() version or switch it to a wrapper that reads the new contextvar state.
Issue is the contextvar is thread local but monkeypatch itself is still a class method override. So two overlapping no_tie_weights() calls in different threads can still race on read/write of the patched method.
I think the correct design is to keep no_tie_weights() as a scope marker only and make the real tie_weights() decide whether to skip. If we want this part of Transformers to behave cleanly under threads, I think we need to move away from the monkeypatch pattern and keep the state-based one instead.
Having a single state var just for this purpose may be overkill and you have mentioned there are other parts that also need this type of state keeping so there is opportunity to have a more general contextvar scop/state to store all the different loading state as bits but that should be another PR/bigger refractor that strips away all the monkeypatching during loading and use pure states/scope read/writes.
This isn’t some theoretical threading pattern on my side. I ran into it while implementing a real continuous-batching feature with 2 models in one process, 1 thread per model, on 2 separate GPUs. I am going to debut it soon (pypi pkg). Transfromers will be one of it's In that setup, threads are intentional but not for loading speed. Shared address space and simpler coordination are a lot more useful to me than splitting everything across processes and dealing with IPC. |
What does this PR do?
Model loading of same model path but 2 different threads (2 different instances) have meta device tensor issues: unloaded meta/empty embedding/lm-head when it should not be empty post model load.
Cause:
tie_weight()code failing to execute because totie_weightskipping logic/state is 1) not thread safe 2) not tied to a model/thread ctx.Fixed using
contextvar, python's execution scope. Triggered with 2 threads loading at the same time with Llama 3.2 1B instruct.Unit Test:
New unit test uses a tiny dummy model with tied embed_tokens/lm_head weights and a checkpoint that deliberately omits lm_head.weight, which forces the final tie_weights() call during from_pretrained to succeed.
Code Agent Policy
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @Cyrilvallez
Documentation
@stevhliu Please add an section to Transformers
threadingand Python 3.13+nogilakafree-threadingsection to transformer model loading/execution noting that the Transformer is primary designed and tested as a single-process runtime and many parts are not tested or lacking proper threading support. Contributions welcome to move pkg to a threadful world. I think users need to know this in 2026 with Python 3.13-3.15 all support free-threading with GIL disabled. End-users need to lower their expectations when it comes to Transformers as whole and threading.