diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index e9407b53..2ccf22b7 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -12,6 +12,7 @@ import functools import inspect import json +import threading from collections.abc import Callable, Coroutine from copy import deepcopy from typing import TYPE_CHECKING, Any, cast @@ -182,6 +183,9 @@ def __init__( self._added_adapters: dict[str, LocalHFAdapter] = {} self._loaded_adapters: dict[str, LocalHFAdapter] = {} + self._generation_lock = threading.Lock() + """Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters.""" + async def generate_from_context( self, action: Component | CBlock, @@ -245,6 +249,32 @@ async def generate_from_context( ) return mot, ctx.add(action).add(mot) + def _generate_with_adapter_lock( + self, adapter_name: str, generate_func: Callable, *args, **kwargs + ): + """Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights.""" + with self._generation_lock: + if adapter_name != "": + self.load_adapter(adapter_name) + self._model.set_adapter(adapter_name) + else: + try: + # `._model.disable_adapters()` doesn't seem to actually disable them or + # remove them from the model's list of `.active_adapters()`. + self._model.set_adapter([]) + except ValueError as e: + # If no weights have been loaded, the model will raise a ValueError: + # `ValueError("No adapter loaded. Please load an adapter first.")` + if "No adapter loaded" in str(e): + pass + else: + raise e + + _assert_correct_adapters(adapter_name, self._model) + out = generate_func(*args, **kwargs) + _assert_correct_adapters(adapter_name, self._model) + return out + async def _generate_from_intrinsic( self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any] ) -> ModelOutputThunk: @@ -317,27 +347,21 @@ async def _generate_from_intrinsic( # so we will have to invalidate the cache on our side. This requires # us having specific caching for each Component/Message. - self.load_adapter(adapter.qualified_name) - - # TODO: This modifies the underlying model. We should set a non-exclusive lock here. - # It should allow generate requests with the same adapter to proceed. This logic also - # needs to be added to the other generate functions. - self._model.set_adapter(adapter.qualified_name) - generate_input, other_input = ( granite_common.util.chat_completion_request_to_transformers_inputs( rewritten, self._tokenizer, self._model ) ) - chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = ( - asyncio.to_thread( - granite_common.util.generate_with_transformers, - self._tokenizer, - self._model, - generate_input, - other_input, - ) + chat_response = asyncio.to_thread( + self._generate_with_adapter_lock, + adapter.qualified_name, + granite_common.util.generate_with_transformers, + # Passed as args/kwargs to generate. + self._tokenizer, + self._model, + generate_input, + other_input, ) output = ModelOutputThunk(None) @@ -490,7 +514,10 @@ async def _generate_from_context_standard( generate_options = self._filter_chat_template_only_options(model_options) chat_response = asyncio.to_thread( + self._generate_with_adapter_lock, + "", # Empty for no adapters. self._model.generate, # type: ignore + # Passed as args/kwargs to generate. input_ids, return_dict_in_generate=True, output_scores=True, @@ -664,42 +691,41 @@ async def generate_from_raw( self._device ) - if format is None: - outputs = await asyncio.to_thread( - self._model.generate, # type: ignore - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - return_dict_in_generate=True, - output_scores=True, - **self._make_backend_specific_and_remove(model_opts), - ) - else: + format_kwargs = {} + if format: + # outlines.generate.json always parses the resulting json into a python dict. + # We however want to keep it as a json string for later storing it in ModelOutputThunk schema: dict[str, Any] = format.model_json_schema() schema_json: str = json.dumps(schema) - regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json ) from outlines.models.transformers import TransformerTokenizer - from outlines.processors import RegexLogitsProcessor + from outlines.processors.structured import RegexLogitsProcessor from transformers import LogitsProcessorList - outputs = await asyncio.to_thread( - self._model.generate, # type: ignore - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - return_dict_in_generate=True, - output_scores=True, - logits_processor=LogitsProcessorList( - [ - RegexLogitsProcessor( - regex_str, tokenizer=TransformerTokenizer(self._tokenizer) - ) - ] - ), - **self._make_backend_specific_and_remove(model_opts), + format_kwargs["logits_processor"] = LogitsProcessorList( + [ + RegexLogitsProcessor( + regex_str, tokenizer=TransformerTokenizer(self._tokenizer) + ) + ] ) + outputs = await asyncio.to_thread( + self._generate_with_adapter_lock, + "", # Empty for no adapter. + self._model.generate, # type: ignore + # Passed as args/kwargs to generate. + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + return_dict_in_generate=True, + output_scores=True, + **self._make_backend_specific_and_remove(model_opts), + **format_kwargs, + ) + sequences_to_decode = [ sequence[inputs["input_ids"][i].size(0) :] # type: ignore for i, sequence in enumerate(outputs.sequences) @@ -853,7 +879,7 @@ def add_adapter(self, adapter: LocalHFAdapter): self._added_adapters[adapter.qualified_name] = adapter def load_adapter(self, adapter_qualified_name: str): - """Loads the given adapter for the backend. Must have previously been added.""" + """Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening.""" adapter = self._added_adapters.get(adapter_qualified_name, None) if adapter is None: raise ValueError( @@ -880,7 +906,7 @@ def load_adapter(self, adapter_qualified_name: str): # Loading an adapter activates it. We disable adapters immediately after. # Prefer this over `.disable_adapters()`; the disable function doesn't always # seem to work. - self._model.set_adapter([]) + self._model.disable_adapters() self._loaded_adapters[adapter.qualified_name] = adapter def unload_adapter(self, adapter_qualified_name: str): @@ -906,6 +932,38 @@ def list_adapters(self) -> list[str]: return list(self._loaded_adapters.keys()) +def _assert_correct_adapters(expected_state: str, model: PreTrainedModel): + """When generating with a huggingface model, this can be used to ensure the correct adapters are active. + + Args: + expected_state: the current state of the lock + model: the model underlying the LocalHFBackend; this is the model the adapters are activated on + """ + try: + active = model.active_adapters() + + if expected_state == "": + assert len(active) == 0, ( + f'no adapters should be active if expected state is "", got "{active[0]}"' + ) + else: + assert len(active) == 1, ( + f'one adapter should be active if expected state is "{expected_state}"' + ) + assert active[0] == expected_state, ( + f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"' + ) + except ValueError as e: + # If no weights have been loaded, the model will raise a ValueError: + # `ValueError("No adapter loaded. Please load an adapter first.")` + if "No adapter loaded" in str(e): + assert expected_state == "", ( + f'got no adapters loaded but expected state is "{expected_state}"' + ) + else: + raise e + + class HFProcessRewardModel(PRM, abc.ABC): """A Process Reward Model that works with a huggingface backend.""" diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index c2a5497f..ba55b9f3 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -1,17 +1,26 @@ import asyncio +from copy import copy +import faulthandler +import random +import time +from typing import Any, Coroutine +from unittest.mock import Mock import pydantic import pytest +import torch from typing_extensions import Annotated from mellea import MelleaSession from mellea.backends.adapters.adapter import GraniteCommonAdapter from mellea.backends.cache import SimpleLRUCache from mellea.backends.formatter import TemplateFormatter -from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.huggingface import LocalHFBackend, _assert_correct_adapters from mellea.backends.types import ModelOption from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk, SimpleContext) +from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics.intrinsic import Intrinsic from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement, Requirement, ValidationResult, default_output_to_bool) @@ -30,6 +39,11 @@ def backend(): "requirement_check", base_model_name=backend.base_model_name ) ) + backend.add_adapter( + GraniteCommonAdapter( + "answerability", base_model_name=backend.base_model_name + ) + ) return backend @@ -291,6 +305,210 @@ async def test_async_avalue(session): assert m1_final_val is not None assert m1_final_val == mot1.value +@pytest.mark.qualitative +async def test_generate_with_lock(backend): + # Enable the faulthandler for this test. + faulthandler.enable(all_threads=True) + + # Create local versions of these objects so that mocking + # doesn't impact other functions. Don't do this in regular code, + # the copying is complex. + b: LocalHFBackend = copy(backend) + model = copy(b._model) + b._model = model + b._added_adapters = {} + b._loaded_adapters = {} + b.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=b.base_model_name + ) + ) + b.add_adapter( + GraniteCommonAdapter( + "answerability", base_model_name=b.base_model_name + ) + ) + + memoized = dict() + gen_func = model.generate + def mock_func(input_ids, *args, **kwargs): + """Mocks the generate function. Must call `populate_mocked_dict` with each input that must be cached before using this.""" + for key, val in memoized.items(): + if torch.equal(key, input_ids): + time.sleep(random.uniform(.1, .5)) # Simulate a bit of work. + return val + assert False, "did not get a cached response" + + # Safely create the dict. + def populate_mocked_dict(input_ids, *args, **kwargs): + """Generates the model output and adds to the memoized dict.""" + output = gen_func(input_ids, *args, **kwargs) # type: ignore + memoized[input_ids] = output + return output + + model.generate = Mock(side_effect=populate_mocked_dict) + assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail" + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + raw_act = CBlock("goodb") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + answerability_intrinsic = Intrinsic("answerability") + + def call_backend_generate(): + """Helper function for generating outputs.""" + return [ + b.generate_from_context(act, ctx), + b.generate_from_context(req_intrinsic, ctx), + b.generate_from_context(answerability_intrinsic, ctx), + b.generate_from_raw([raw_act], ctx, model_options={ModelOption.MAX_NEW_TOKENS: 3}) + ] + + # Call once to populate the memoized mock. + outputs = await asyncio.gather(*call_backend_generate()) + for output in outputs: + mot = output[0] + await mot.avalue() # Ensure all values are computed. + + # Use the memoized mock that errors if not precomputed. + model.generate = Mock(side_effect=mock_func) + count = 5 # Use a high number to try to put pressure on the lock and catch deadlocks. + coros: list[Coroutine[Any, Any, tuple[ModelOutputThunk, Context]]] = [] + for _ in range(count): + coros.extend(call_backend_generate()) + + # Ensure no ordering effects are happening. + random.shuffle(coros) + + outputs = await asyncio.gather(*coros) + for output in outputs: + mot = output[0] + await mot.avalue() # Ensure all values get computed. + + faulthandler.disable() + +@pytest.mark.qualitative +async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): + """This is a tricky test to setup. + + It's purpose is to ensure that a long-running generation doesn't get blocked + when awaiting the `model_output_thunk.avalue()` of a different generation request. + + This means that it is somewhat timing dependent. The generation has to take long enough + to not instantly resolve but not longer than the timeout. Modify the parameters below to + finetune this. + + If generation is taking too long, you could just increase the timeout, but that + causes the test to take longer to run. The best scenario is that the generation doesn't + resolve before awaiting the other `mot.avalue()` but resolves immediately after. + """ + # Params to modify depending on speed. + token_generation_length = 100 + timeout_in_seconds = 30 + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + answerability_intrinsic = Intrinsic("answerability") + + # Create a few model output thunks: + # - a streaming generation that will take a long time to resolve. + # - a regular generation that should be able to happen while the streaming is happening. + # - two intrinsics that shouldn't be able to happen concurrently. + reg_mot_stream, _ = await backend.generate_from_context(act, ctx, model_options={ModelOption.STREAM: True, ModelOption.MAX_NEW_TOKENS: token_generation_length, "min_length": token_generation_length}) + reg_mot, _ = await backend.generate_from_context(act, ctx) + req_mot, _ = await backend.generate_from_context(req_intrinsic, ctx, model_options={ModelOption.STREAM: True}) + answerability_mot, _ = await backend.generate_from_context(answerability_intrinsic, ctx, model_options={ModelOption.STREAM: True}) + + # Ensure the stream is generating but not yet completing. + await reg_mot_stream.astream() + assert not reg_mot_stream.is_computed(), "generation completed too early, see test for more details" + + # Awaiting this shouldn't cause a deadlock. Add the timeout so the test can fail. + # If the test fails, this means that the streaming generation wasn't able to complete, + # most likely due to a deadlock caused by awaiting a generation that cannot complete until + # the streaming is done. + try: + async with asyncio.timeout(timeout_in_seconds): + await req_mot.avalue() + except Exception as e: + # The timeout could also be caused by the generation taking too long... be careful! + # We assume that if the streaming model output thunk is computed after getting its astream here, + # that it was a deadlock and not the generation taking too long (since the generation is now done). + await reg_mot_stream.astream() + if reg_mot_stream.is_computed(): + raise e + else: + raise Exception("timeout ended too early, see test for more details") + + for output in [reg_mot_stream, reg_mot, req_mot, answerability_mot]: + if not output.is_computed(): + await output.avalue() # Ensure everything gets computed. + +@pytest.mark.qualitative +async def test_error_during_generate_with_lock(backend): + # Create local versions of these objects so that mocking + # doesn't impact other functions. Don't do this in regular code, + # the copying is complex. + b: LocalHFBackend = copy(backend) + model = copy(b._model) + b._model = model + b._model.set_adapter([]) + b._added_adapters = {} + b._loaded_adapters = {} + b.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=b.base_model_name + ) + ) + + regular_generate = b._model.generate + def generate_and_raise_exc(*args, **kwargs): + """Will generate like usual for the intrinsic request. Will fail for the regular generation request.""" + if "max_new_tokens" in kwargs: + return regular_generate(*args, **kwargs) # type: ignore + raise Exception("Oops!") + + b._model.generate = Mock(side_effect=generate_and_raise_exc) + assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail" + + # Set up the inputs. + ctx = ChatContext().add(Message("user", "hello")) + act = CBlock("hello") + req_intrinsic = Intrinsic("requirement_check", {"requirement": "did nothing"}) + + reg_mot, _ = await b.generate_from_context(act, ctx) + req_mot, _ = await b.generate_from_context(req_intrinsic, ctx) + + with pytest.raises(Exception, match="Oops!"): + await reg_mot.avalue() + + await req_mot.avalue() + +def test_assert_correct_adapters(): + model = Mock() + + # Test scenarios with no active adapters. + model.active_adapters = Mock(return_value=[]) + _assert_correct_adapters("", model) + with pytest.raises(AssertionError): + _assert_correct_adapters("new", model) + + # Test scenarios with one active adapter. + model.active_adapters = Mock(return_value=["new"]) + with pytest.raises(AssertionError): + _assert_correct_adapters("", model) + with pytest.raises(AssertionError): + _assert_correct_adapters("diff", model) + _assert_correct_adapters("new", model) + + # Test scenarios when no adapters have been loaded. + model.active_adapters = Mock(side_effect=ValueError("No adapter loaded. Please load an adapter first.")) + _assert_correct_adapters("", model) # This will fail if peft ever changes the error message. + with pytest.raises(AssertionError): + _assert_correct_adapters("new", model) if __name__ == "__main__": import pytest