From 8d8d683fe66bf387706f36791c11d6c256644440 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Tue, 18 Nov 2025 09:25:24 -0500 Subject: [PATCH] feat: add lock to hf backend to prevent concurrent generation with conflicting weights --- mellea/backends/huggingface.py | 311 +++++++++++++++--- test/backends/test_huggingface.py | 2 +- .../test_huggingface_generation_lock.py | 292 ++++++++++++++++ 3 files changed, 559 insertions(+), 46 deletions(-) create mode 100644 test/backends/test_huggingface_generation_lock.py diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index e9407b53..da172535 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -12,6 +12,7 @@ import functools import inspect import json +import threading from collections.abc import Callable, Coroutine from copy import deepcopy from typing import TYPE_CHECKING, Any, cast @@ -181,6 +182,8 @@ def __init__( # Adapters can be made known to the backend (added) and loaded. self._added_adapters: dict[str, LocalHFAdapter] = {} self._loaded_adapters: dict[str, LocalHFAdapter] = {} + self._generate_lock = HFGenerationLock(self) + """Necessary for generation since adapters alter the underlying model. Use '' with regular generation requests.""" async def generate_from_context( self, @@ -317,27 +320,27 @@ async def _generate_from_intrinsic( # so we will have to invalidate the cache on our side. This requires # us having specific caching for each Component/Message. - self.load_adapter(adapter.qualified_name) - - # TODO: This modifies the underlying model. We should set a non-exclusive lock here. - # It should allow generate requests with the same adapter to proceed. This logic also - # needs to be added to the other generate functions. - self._model.set_adapter(adapter.qualified_name) - generate_input, other_input = ( granite_common.util.chat_completion_request_to_transformers_inputs( rewritten, self._tokenizer, self._model ) ) - chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = ( - asyncio.to_thread( - granite_common.util.generate_with_transformers, - self._tokenizer, - self._model, - generate_input, - other_input, - ) + def generate_intrinsic_with_lock( + *args, **kwargs + ) -> granite_common.ChatCompletionResponse: + with self._generate_lock.get_lock(adapter.qualified_name): + _assert_correct_adapters(adapter.qualified_name, self._model) + output = granite_common.util.generate_with_transformers(*args, **kwargs) # type: ignore + _assert_correct_adapters(adapter.qualified_name, self._model) + return output + + chat_response = asyncio.to_thread( + generate_intrinsic_with_lock, + self._tokenizer, + self._model, + generate_input, + other_input, ) output = ModelOutputThunk(None) @@ -369,7 +372,6 @@ async def granite_common_processing( input_ids=generate_input["input_tokens"], ) - # TODO: Post-processing should release the lock for this generation. output._post_process = functools.partial( self.post_processing, conversation=conversation, @@ -489,8 +491,15 @@ async def _generate_from_context_standard( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) + def generate_with_lock(*args, **kwargs): + with self._generate_lock.get_lock(""): + _assert_correct_adapters("", self._model) + output = self._model.generate(*args, **kwargs) # type: ignore + _assert_correct_adapters("", self._model) + return output + chat_response = asyncio.to_thread( - self._model.generate, # type: ignore + generate_with_lock, input_ids, return_dict_in_generate=True, output_scores=True, @@ -664,42 +673,45 @@ async def generate_from_raw( self._device ) - if format is None: - outputs = await asyncio.to_thread( - self._model.generate, # type: ignore - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - return_dict_in_generate=True, - output_scores=True, - **self._make_backend_specific_and_remove(model_opts), - ) - else: + format_kwargs = {} + if format: + # outlines.generate.json always parses the resulting json into a python dict. + # We however want to keep it as a json string for later storing it in ModelOutputThunk schema: dict[str, Any] = format.model_json_schema() schema_json: str = json.dumps(schema) - regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json ) from outlines.models.transformers import TransformerTokenizer - from outlines.processors import RegexLogitsProcessor + from outlines.processors.structured import RegexLogitsProcessor from transformers import LogitsProcessorList - outputs = await asyncio.to_thread( - self._model.generate, # type: ignore - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - return_dict_in_generate=True, - output_scores=True, - logits_processor=LogitsProcessorList( - [ - RegexLogitsProcessor( - regex_str, tokenizer=TransformerTokenizer(self._tokenizer) - ) - ] - ), - **self._make_backend_specific_and_remove(model_opts), + format_kwargs["logits_processor"] = LogitsProcessorList( + [ + RegexLogitsProcessor( + regex_str, tokenizer=TransformerTokenizer(self._tokenizer) + ) + ] ) + def generate_raw_with_lock(*args, **kwargs): + with self._generate_lock.get_lock(""): + _assert_correct_adapters("", self._model) + output = self._model.generate(*args, **kwargs) # type: ignore + _assert_correct_adapters("", self._model) + return output + + outputs = await asyncio.to_thread( + generate_raw_with_lock, + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + return_dict_in_generate=True, + output_scores=True, + **self._make_backend_specific_and_remove(model_opts), + **format_kwargs, + ) + sequences_to_decode = [ sequence[inputs["input_ids"][i].size(0) :] # type: ignore for i, sequence in enumerate(outputs.sequences) @@ -853,13 +865,16 @@ def add_adapter(self, adapter: LocalHFAdapter): self._added_adapters[adapter.qualified_name] = adapter def load_adapter(self, adapter_qualified_name: str): - """Loads the given adapter for the backend. Must have previously been added.""" + """Loads the given adapter for the backend. Must have previously been added. Do not manually call while generate requests are happening; will be called automatically.""" adapter = self._added_adapters.get(adapter_qualified_name, None) if adapter is None: raise ValueError( f"could not load adapter {adapter_qualified_name} for backend {self}: adapter was not previously added" ) + if self._loaded_adapters.get(adapter_qualified_name, None) is not None: + return # Exit early since it's already loaded. + try: adapter_kwargs = {} @@ -880,7 +895,6 @@ def load_adapter(self, adapter_qualified_name: str): # Loading an adapter activates it. We disable adapters immediately after. # Prefer this over `.disable_adapters()`; the disable function doesn't always # seem to work. - self._model.set_adapter([]) self._loaded_adapters[adapter.qualified_name] = adapter def unload_adapter(self, adapter_qualified_name: str): @@ -957,3 +971,210 @@ def stepify(self, content: str, step_separator: str) -> list[str]: step.strip() for step in content.split(step_separator) if step.strip != "" ] return list_of_steps + + +# The current implementation of this lock must be defined in this file because the backend requires a reference to +# it and the lock requires an understanding of LocalHFBackends. Because generation also requires loading/activating the +# correct adapter (or no adapter), integrating those load/active checks into state change of the lock made reasoning +# and usage much easier. A similar approach could probably be implemented with multiple locks (which would keep this one +# generic) but would potentially require users to do more work. If we need to eventually refactor this lock to support +# other backends, we can do that at that point in time (since these APIs are all internal). +class HFGenerationLock: + """A lock-like object. Used to prevent concurrent generation from different adapters on the same backend. + + Note: Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple + times from the same thread. + """ + + def __init__(self, backend: LocalHFBackend): + """A lock-like object. Used to prevent concurrent generation from different adapters on the same backend. + + Notes: + - Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple times from the same thread. + - This lock prioritizes acquirers with a state equal to the current state. + - Typically easiest to use with `with` syntax: `with lock.get_lock(): ...` + """ + self.backend = backend + """since adapter management is included in this lock, we set the backend at init""" + + self.current_state: str = "" + """the current state of the lock; usually reflects the model/adapter name; empty string is base model""" + + self.num_active: int = 0 + """counts the number of active lock holders""" + + # Include a timeout to ensure there are no deadlocks caused by infinitely waiting. No deadlocks should + # occur since events are appended to the list before they attempt to acquire the lock. This means even if + # they fail to acquire the lock, the release caller will set their event to stop waiting. + # Control flow scenarios: + # - Fail to acquire lock -> immediately wait -> release is called elsewhere -> notified and acquire lock + # - Fail to acquire lock -> release is called elsewhere -> wait but see that it's already good to go -> acquire lock + self.timeout: float | None = 5 + """timeout in seconds to wait before trying to acquire the lock again""" + + self.lock = threading.Lock() + """a lock to prevent concurrent modification of internal properties""" + + self.events: list[threading.Event] = [] + """a list of waiters; allows notifying single or multiple waiters""" + + class GenLock: + """Necessary for `with` syntax. Enables not using try-finally syntax everywhere.""" + + def __init__(self, state: str, lock: HFGenerationLock) -> None: + """Necessary for `with` syntax. Enables not using try-finally syntax everywhere. + + Args: + state: the state associated with this locking operation + lock: the parent lock associated with this locking operation + """ + self.state = state + self.lock = lock + + def __enter__(self): + """Acquire the lock with a given state.""" + self.lock.acquire(self.state) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the lock.""" + self.lock.release() + + # Re-raise the exception if needed. + if exc_val is not None: + raise exc_val + + def get_lock(self, state): + """Used for with statements. + + Examples: + >>> # in a LocalHFBackend + >>> state = adapter.qualified_name # or "" for base model + >>> with self._generate_lock.get_lock(state): + ... ... # Generate requests here. + """ + return self.GenLock(state, self) + + def acquire(self, state: str): + """Acquire the 'lock'. Only call once per thread. + + Args: + state: the adapter qualified name or "" if the base model + """ + # Notifier for this acquirer. + event = threading.Event() + self.events.append(event) + + while True: + if self.lock.acquire(False): + if self.current_state == state or self.num_active == 0: + # Breaking from this loop and the below operations must be atomic; include acquiring the lock + # as part of the condition. + + # When `self.current_state == state`, this means that other generation operations with the + # current state are happening. There's no need to block this request. + + # When `self.num_active == 0`, this means there's no other generation requests happening. Allow + # a single waiter to break and set the new state. + break + else: + # Have to acquire the lock to check the variables but immediately release it if comparisons fail. + self.lock.release() + + # Wait until notified of a release. Add a timeout just in case (see notes in init). + event.wait(self.timeout) + + # Reset this waiter so that it patiently waits to be notified again if unable to break from the loop. + event.clear() + + self.num_active += 1 + + # This waiter will never wait again. Remove its event. + self.events.remove(event) + + if self.current_state != state: + assert self.num_active == 1, "only this waiter should be active" + + # When swapping states, we need to make sure the correct adapter is set. + if state != "": + # Adapter. + try: + # Ensure the adapter is loaded before setting it. + self.backend.load_adapter(state) + self.backend._model.set_adapter(state) + except Exception as e: + # If something goes wrong, the internal state hasn't changed. + # We also have to release the internal lock so that future requests can go through. + self.lock.release() + raise e + else: + # Base Model. + try: + # We can't know if adapters have been loaded / set previously. + # This call will throw an exception if none have been. + self.backend._model.set_adapter([]) + except Exception: + pass + + # Wait to release the lock until the current_state is set to the new state value. + self.current_state = state + self.lock.release() + + # Notify all events. Some might be using the same model/adapter. + for event in self.events: + event.set() + else: + # Or, we immediately release the lock if we don't need to change current_state. + self.lock.release() + + def release(self): + """Release a single hold on the lock. Should only call once per `acquire()` call.""" + # Grab the internal lock to prevent concurrent modifications. + with self.lock: + self.num_active -= 1 + + assert self.num_active > -1, f"release on {self} called too many times" + + # Create a local var to track the number active. This lets us release the lock + # before notifying the single waiter. + snapshot_num_active = self.num_active + + # If there are no active holds on this lock, notify a single waiter if one exists. + # This also likely means that no waiters with states equal to the current_state exist; + # and a new current_state will be set. + if snapshot_num_active == 0: + if len(self.events) > 0: + self.events[0].set() + + def __str__(self) -> str: + """Stringify the HFGenerationLock.""" + return f"{self.current_state}: {self.num_active}" + + +def _assert_correct_adapters(expected_state: str, model: PreTrainedModel): + """When generating with a huggingface model and a hf generation lock, this can be used to ensure the correct adapters are active. + + Args: + expected_state: the current state of the lock + model: the model underlying the LocalHFBackend; this is the model the adapters are activated on + """ + try: + active = model.active_adapters() + + if expected_state == "": + assert len(active) == 0, ( + f'no adapters should be active if expected state is "", got "{active[0]}"' + ) + else: + assert len(active) == 1, ( + f'one adapter should be active if expected state is "{expected_state}"' + ) + assert active[0] == expected_state, ( + f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"' + ) + except ValueError: + # If no weights have been loaded, the model will raise a ValueError: + # `ValueError("No adapter loaded. Please load an adapter first.")` + assert expected_state == "", ( + 'expected state must be "" if no adapters have been loaded' + ) diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index c2a5497f..2cb00b9f 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -120,7 +120,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): # the correct actions / results in it. assert isinstance(val_result.context, Context) assert isinstance(val_result.thunk, ModelOutputThunk) - assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement) + assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement) # type: ignore assert val_result.context.node_data is val_result.thunk backend.default_to_constraint_checking_alora = True diff --git a/test/backends/test_huggingface_generation_lock.py b/test/backends/test_huggingface_generation_lock.py new file mode 100644 index 00000000..f49c4c54 --- /dev/null +++ b/test/backends/test_huggingface_generation_lock.py @@ -0,0 +1,292 @@ +import asyncio +from copy import copy +import faulthandler +import random +import time +from typing import Any, Coroutine +from unittest.mock import Mock + +import pytest +import torch + +from mellea import MelleaSession +from mellea.backends.adapters.adapter import GraniteCommonAdapter +from mellea.backends.cache import SimpleLRUCache +from mellea.backends.formatter import TemplateFormatter +from mellea.backends.huggingface import HFGenerationLock, LocalHFBackend, _assert_correct_adapters +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, ChatContext, Context, GenerateType, ModelOutputThunk +from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics.intrinsic import Intrinsic + + +@pytest.fixture(scope="module") +def backend(): + """Shared HuggingFace backend for all tests in this module.""" + backend = LocalHFBackend( + model_id="ibm-granite/granite-3.3-8b-instruct", + formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"), + cache=SimpleLRUCache(5), + ) + backend.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=backend.base_model_name + ) + ) + backend.add_adapter( + GraniteCommonAdapter( + "answerability", base_model_name=backend.base_model_name + ) + ) + return backend + + +@pytest.fixture(scope="function") +def session(backend): + """Fresh HuggingFace session for each test.""" + session = MelleaSession(backend, ctx=ChatContext()) + yield session + session.reset() + +@pytest.mark.qualitative +async def test_generate_with_lock(backend): + # Enable the faulthandler for this test. + faulthandler.enable(all_threads=True) + + # Create local versions of these objects so that mocking + # doesn't impact other functions. Don't do this in regular code, + # the copying is complex. + b: LocalHFBackend = copy(backend) + model = copy(b._model) + b._model = model + b._added_adapters = {} + b._loaded_adapters = {} + b._generate_lock = HFGenerationLock(b) + b.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=b.base_model_name + ) + ) + b.add_adapter( + GraniteCommonAdapter( + "answerability", base_model_name=b.base_model_name + ) + ) + + memoized = dict() + gen_func = model.generate + def mock_func(input_ids, *args, **kwargs): + """Mocks the generate function. Must call `populate_mocked_dict` with each input that must be cached before using this.""" + for key, val in memoized.items(): + if torch.equal(key, input_ids): + time.sleep(random.uniform(.1, .5)) # Simulate a bit of work. + return val + assert False, "did not get a cached response" + + # Safely create the dict. + def populate_mocked_dict(input_ids, *args, **kwargs): + """Generates the model output and adds to the memoized dict.""" + output = gen_func(input_ids, *args, **kwargs) # type: ignore + memoized[input_ids] = output + return output + + model.generate = Mock(side_effect=populate_mocked_dict) + assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail" + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + raw_act = CBlock("goodb") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + answerability_intrinsic = Intrinsic("answerability") + + def call_backend_generate(): + """Helper function for generating outputs.""" + return [ + b.generate_from_context(act, ctx), + b.generate_from_context(req_intrinsic, ctx), + b.generate_from_context(answerability_intrinsic, ctx), + b.generate_from_raw([raw_act], ctx, model_options={ModelOption.MAX_NEW_TOKENS: 3}) + ] + + # Call once to populate the memoized mock. + outputs = await asyncio.gather(*call_backend_generate()) + for output in outputs: + mot = output[0] + await mot.avalue() # Ensure all values are computed. + + # Use the memoized mock that errors if not precomputed. + model.generate = Mock(side_effect=mock_func) + count = 100 # Use a high number to try to put pressure on the lock and catch deadlocks. + coros: list[Coroutine[Any, Any, tuple[ModelOutputThunk, Context]]] = [] + for _ in range(count): + coros.extend(call_backend_generate()) + + # Ensure no ordering effects are happening. + random.shuffle(coros) + + outputs = await asyncio.gather(*coros) + for output in outputs: + mot = output[0] + await mot.avalue() # Ensure all values get computed. + + faulthandler.disable() + + +@pytest.mark.qualitative +async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): + """This is a tricky test to setup. + + It's purpose is to ensure that a long-running generation doesn't get blocked + when awaiting the `model_output_thunk.avalue()` of a different generation request. + + This means that it is somewhat timing dependent. The generation has to take long enough + to not instantly resolve but not longer than the timeout. Modify the parameters below to + finetune this. + + If generation is taking too long, you could just increase the timeout, but that + causes the test to take longer to run. The best scenario is that the generation doesn't + resolve before awaiting the other `mot.avalue()` but resolves immediately after. + """ + # Params to modify depending on speed. + token_generation_length = 100 + timeout_in_seconds = 30 + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + answerability_intrinsic = Intrinsic("answerability") + + # Create a few model output thunks: + # - a streaming generation that will take a long time to resolve. + # - a regular generation that should be able to happen while the streaming is happening. + # - two intrinsics that shouldn't be able to happen concurrently. + reg_mot_stream, _ = await backend.generate_from_context(act, ctx, model_options={ModelOption.STREAM: True, ModelOption.MAX_NEW_TOKENS: token_generation_length, "min_length": token_generation_length}) + reg_mot, _ = await backend.generate_from_context(act, ctx) + req_mot, _ = await backend.generate_from_context(req_intrinsic, ctx, model_options={ModelOption.STREAM: True}) + answerability_mot, _ = await backend.generate_from_context(answerability_intrinsic, ctx, model_options={ModelOption.STREAM: True}) + + # Ensure the stream is generating but not yet completing. + await reg_mot_stream.astream() + assert not reg_mot_stream.is_computed(), "generation completed too early, see test for more details" + + # Awaiting this shouldn't cause a deadlock. Add the timeout so the test can fail. + # If the test fails, this means that the streaming generation wasn't able to complete, + # most likely due to a deadlock caused by awaiting a generation that cannot complete until + # the streaming is done. + try: + async with asyncio.timeout(timeout_in_seconds): + await req_mot.avalue() + except Exception as e: + # The timeout could also be caused by the generation taking too long... be careful! + # We assume that if the streaming model output thunk is computed after getting its astream here, + # that it was a deadlock and not the generation taking too long (since the generation is now done). + await reg_mot_stream.astream() + if reg_mot_stream.is_computed(): + raise e + else: + raise Exception("timeout ended too early, see test for more details") + + for output in [reg_mot_stream, reg_mot, req_mot, answerability_mot]: + if not output.is_computed(): + await output.avalue() # Ensure everything gets computed. + +@pytest.mark.qualitative +async def test_error_during_generate_with_lock(backend): + # Create local versions of these objects so that mocking + # doesn't impact other functions. Don't do this in regular code, + # the copying is complex. + b: LocalHFBackend = copy(backend) + model = copy(b._model) + b._model = model + b._model.set_adapter([]) + b._added_adapters = {} + b._loaded_adapters = {} + b._generate_lock = HFGenerationLock(b) + b.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=b.base_model_name + ) + ) + + regular_generate = b._model.generate + def generate_and_raise_exc(*args, **kwargs): + """Will generate like usual for the intrinsic request. Will fail for the regular generation request.""" + if "max_new_tokens" in kwargs: + return regular_generate(*args, **kwargs) # type: ignore + raise Exception("Oops!") + + b._model.generate = Mock(side_effect=generate_and_raise_exc) + assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail" + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + + reg_mot, _ = await b.generate_from_context(act, ctx) + req_mot, _ = await b.generate_from_context(req_intrinsic, ctx) + + with pytest.raises(Exception, match="Oops!"): + await reg_mot.avalue() + + await req_mot.avalue() + + +async def test_generation_lock(): + b = Mock(spec=LocalHFBackend) + b.load_adapter = Mock() + b._model = Mock() + b._model.set_adapter = Mock() + t = HFGenerationLock(b) + + assert t.backend is b + + # Typically don't use `as` syntax, but useful for asserting things here. + state = "" + with t.get_lock(state) as l: + assert l.state == state + assert l.lock is t + + assert t.current_state == state + assert t.num_active == 1 + + new_state = "new" + t.acquire(new_state) + assert t.current_state == new_state + assert t.num_active == 1 + t.release() + assert t.current_state == new_state, "state only changes when re-acquiring the lock" + assert t.num_active == 0 + + assert str(t) == f"{new_state}: 0" + +def test_assert_correct_adapters(): + model = Mock() + + # Test scenarios with no active adapters. + model.active_adapters = Mock(return_value=[]) + _assert_correct_adapters("", model) + with pytest.raises(AssertionError): + _assert_correct_adapters("new", model) + + # Test scenarios with one active adapter. + model.active_adapters = Mock(return_value=["new"]) + with pytest.raises(AssertionError): + _assert_correct_adapters("", model) + with pytest.raises(AssertionError): + _assert_correct_adapters("diff", model) + _assert_correct_adapters("new", model) + + # Test scenarios when no adapters have been loaded. + model.active_adapters = Mock(side_effect=ValueError) + _assert_correct_adapters("", model) + with pytest.raises(AssertionError): + _assert_correct_adapters("new", model) + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__])