Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 266 additions & 45 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import functools
import inspect
import json
import threading
from collections.abc import Callable, Coroutine
from copy import deepcopy
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -181,6 +182,8 @@ def __init__(
# Adapters can be made known to the backend (added) and loaded.
self._added_adapters: dict[str, LocalHFAdapter] = {}
self._loaded_adapters: dict[str, LocalHFAdapter] = {}
self._generate_lock = HFGenerationLock(self)
"""Necessary for generation since adapters alter the underlying model. Use '' with regular generation requests."""

async def generate_from_context(
self,
Expand Down Expand Up @@ -317,27 +320,27 @@ async def _generate_from_intrinsic(
# so we will have to invalidate the cache on our side. This requires
# us having specific caching for each Component/Message.

self.load_adapter(adapter.qualified_name)

# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
# It should allow generate requests with the same adapter to proceed. This logic also
# needs to be added to the other generate functions.
self._model.set_adapter(adapter.qualified_name)

generate_input, other_input = (
granite_common.util.chat_completion_request_to_transformers_inputs(
rewritten, self._tokenizer, self._model
)
)

chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
asyncio.to_thread(
granite_common.util.generate_with_transformers,
self._tokenizer,
self._model,
generate_input,
other_input,
)
def generate_intrinsic_with_lock(
*args, **kwargs
) -> granite_common.ChatCompletionResponse:
with self._generate_lock.get_lock(adapter.qualified_name):
_assert_correct_adapters(adapter.qualified_name, self._model)
output = granite_common.util.generate_with_transformers(*args, **kwargs) # type: ignore
_assert_correct_adapters(adapter.qualified_name, self._model)
return output

chat_response = asyncio.to_thread(
generate_intrinsic_with_lock,
self._tokenizer,
self._model,
generate_input,
other_input,
)

output = ModelOutputThunk(None)
Expand Down Expand Up @@ -369,7 +372,6 @@ async def granite_common_processing(
input_ids=generate_input["input_tokens"],
)

# TODO: Post-processing should release the lock for this generation.
output._post_process = functools.partial(
self.post_processing,
conversation=conversation,
Expand Down Expand Up @@ -489,8 +491,15 @@ async def _generate_from_context_standard(
# Filter out chat template-only options before passing to generate()
generate_options = self._filter_chat_template_only_options(model_options)

def generate_with_lock(*args, **kwargs):
with self._generate_lock.get_lock(""):
_assert_correct_adapters("", self._model)
output = self._model.generate(*args, **kwargs) # type: ignore
_assert_correct_adapters("", self._model)
return output

chat_response = asyncio.to_thread(
self._model.generate, # type: ignore
generate_with_lock,
input_ids,
return_dict_in_generate=True,
output_scores=True,
Expand Down Expand Up @@ -664,42 +673,45 @@ async def generate_from_raw(
self._device
)

if format is None:
outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
**self._make_backend_specific_and_remove(model_opts),
)
else:
format_kwargs = {}
if format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
)

from outlines.models.transformers import TransformerTokenizer
from outlines.processors import RegexLogitsProcessor
from outlines.processors.structured import RegexLogitsProcessor
from transformers import LogitsProcessorList

outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
logits_processor=LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
),
**self._make_backend_specific_and_remove(model_opts),
format_kwargs["logits_processor"] = LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
)

def generate_raw_with_lock(*args, **kwargs):
with self._generate_lock.get_lock(""):
_assert_correct_adapters("", self._model)
output = self._model.generate(*args, **kwargs) # type: ignore
_assert_correct_adapters("", self._model)
return output

outputs = await asyncio.to_thread(
generate_raw_with_lock,
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
**self._make_backend_specific_and_remove(model_opts),
**format_kwargs,
)

sequences_to_decode = [
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
for i, sequence in enumerate(outputs.sequences)
Expand Down Expand Up @@ -853,13 +865,16 @@ def add_adapter(self, adapter: LocalHFAdapter):
self._added_adapters[adapter.qualified_name] = adapter

def load_adapter(self, adapter_qualified_name: str):
"""Loads the given adapter for the backend. Must have previously been added."""
"""Loads the given adapter for the backend. Must have previously been added. Do not manually call while generate requests are happening; will be called automatically."""
adapter = self._added_adapters.get(adapter_qualified_name, None)
if adapter is None:
raise ValueError(
f"could not load adapter {adapter_qualified_name} for backend {self}: adapter was not previously added"
)

if self._loaded_adapters.get(adapter_qualified_name, None) is not None:
return # Exit early since it's already loaded.

try:
adapter_kwargs = {}

Expand All @@ -880,7 +895,6 @@ def load_adapter(self, adapter_qualified_name: str):
# Loading an adapter activates it. We disable adapters immediately after.
# Prefer this over `.disable_adapters()`; the disable function doesn't always
# seem to work.
self._model.set_adapter([])
self._loaded_adapters[adapter.qualified_name] = adapter

def unload_adapter(self, adapter_qualified_name: str):
Expand Down Expand Up @@ -957,3 +971,210 @@ def stepify(self, content: str, step_separator: str) -> list[str]:
step.strip() for step in content.split(step_separator) if step.strip != ""
]
return list_of_steps


# The current implementation of this lock must be defined in this file because the backend requires a reference to
# it and the lock requires an understanding of LocalHFBackends. Because generation also requires loading/activating the
# correct adapter (or no adapter), integrating those load/active checks into state change of the lock made reasoning
# and usage much easier. A similar approach could probably be implemented with multiple locks (which would keep this one
# generic) but would potentially require users to do more work. If we need to eventually refactor this lock to support
# other backends, we can do that at that point in time (since these APIs are all internal).
class HFGenerationLock:
"""A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.

Note: Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple
times from the same thread.
"""

def __init__(self, backend: LocalHFBackend):
"""A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.

Notes:
- Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple times from the same thread.
- This lock prioritizes acquirers with a state equal to the current state.
- Typically easiest to use with `with` syntax: `with lock.get_lock(<state>): ...`
"""
self.backend = backend
"""since adapter management is included in this lock, we set the backend at init"""

self.current_state: str = ""
"""the current state of the lock; usually reflects the model/adapter name; empty string is base model"""

self.num_active: int = 0
"""counts the number of active lock holders"""

# Include a timeout to ensure there are no deadlocks caused by infinitely waiting. No deadlocks should
# occur since events are appended to the list before they attempt to acquire the lock. This means even if
# they fail to acquire the lock, the release caller will set their event to stop waiting.
# Control flow scenarios:
# - Fail to acquire lock -> immediately wait -> release is called elsewhere -> notified and acquire lock
# - Fail to acquire lock -> release is called elsewhere -> wait but see that it's already good to go -> acquire lock
self.timeout: float | None = 5
"""timeout in seconds to wait before trying to acquire the lock again"""

self.lock = threading.Lock()
"""a lock to prevent concurrent modification of internal properties"""

self.events: list[threading.Event] = []
"""a list of waiters; allows notifying single or multiple waiters"""

class GenLock:
"""Necessary for `with` syntax. Enables not using try-finally syntax everywhere."""

def __init__(self, state: str, lock: HFGenerationLock) -> None:
"""Necessary for `with` syntax. Enables not using try-finally syntax everywhere.

Args:
state: the state associated with this locking operation
lock: the parent lock associated with this locking operation
"""
self.state = state
self.lock = lock

def __enter__(self):
"""Acquire the lock with a given state."""
self.lock.acquire(self.state)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Release the lock."""
self.lock.release()

# Re-raise the exception if needed.
if exc_val is not None:
raise exc_val

def get_lock(self, state):
"""Used for with statements.

Examples:
>>> # in a LocalHFBackend
>>> state = adapter.qualified_name # or "" for base model
>>> with self._generate_lock.get_lock(state):
... ... # Generate requests here.
"""
return self.GenLock(state, self)

def acquire(self, state: str):
"""Acquire the 'lock'. Only call once per thread.

Args:
state: the adapter qualified name or "" if the base model
"""
# Notifier for this acquirer.
event = threading.Event()
self.events.append(event)

while True:
if self.lock.acquire(False):
if self.current_state == state or self.num_active == 0:
# Breaking from this loop and the below operations must be atomic; include acquiring the lock
# as part of the condition.

# When `self.current_state == state`, this means that other generation operations with the
# current state are happening. There's no need to block this request.

# When `self.num_active == 0`, this means there's no other generation requests happening. Allow
# a single waiter to break and set the new state.
break
else:
# Have to acquire the lock to check the variables but immediately release it if comparisons fail.
self.lock.release()

# Wait until notified of a release. Add a timeout just in case (see notes in init).
event.wait(self.timeout)

# Reset this waiter so that it patiently waits to be notified again if unable to break from the loop.
event.clear()

self.num_active += 1

# This waiter will never wait again. Remove its event.
self.events.remove(event)

if self.current_state != state:
assert self.num_active == 1, "only this waiter should be active"

# When swapping states, we need to make sure the correct adapter is set.
if state != "":
# Adapter.
try:
# Ensure the adapter is loaded before setting it.
self.backend.load_adapter(state)
self.backend._model.set_adapter(state)
except Exception as e:
# If something goes wrong, the internal state hasn't changed.
# We also have to release the internal lock so that future requests can go through.
self.lock.release()
raise e
else:
# Base Model.
try:
# We can't know if adapters have been loaded / set previously.
# This call will throw an exception if none have been.
self.backend._model.set_adapter([])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except Exception:
pass

# Wait to release the lock until the current_state is set to the new state value.
self.current_state = state
self.lock.release()

# Notify all events. Some might be using the same model/adapter.
for event in self.events:
event.set()
else:
# Or, we immediately release the lock if we don't need to change current_state.
self.lock.release()

def release(self):
"""Release a single hold on the lock. Should only call once per `acquire()` call."""
# Grab the internal lock to prevent concurrent modifications.
with self.lock:
self.num_active -= 1

assert self.num_active > -1, f"release on {self} called too many times"

# Create a local var to track the number active. This lets us release the lock
# before notifying the single waiter.
snapshot_num_active = self.num_active

# If there are no active holds on this lock, notify a single waiter if one exists.
# This also likely means that no waiters with states equal to the current_state exist;
# and a new current_state will be set.
if snapshot_num_active == 0:
if len(self.events) > 0:
self.events[0].set()

def __str__(self) -> str:
"""Stringify the HFGenerationLock."""
return f"{self.current_state}: {self.num_active}"


def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
"""When generating with a huggingface model and a hf generation lock, this can be used to ensure the correct adapters are active.

Args:
expected_state: the current state of the lock
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
"""
try:
active = model.active_adapters()

if expected_state == "":
assert len(active) == 0, (
f'no adapters should be active if expected state is "", got "{active[0]}"'
)
else:
assert len(active) == 1, (
f'one adapter should be active if expected state is "{expected_state}"'
)
assert active[0] == expected_state, (
f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"'
)
except ValueError:
# If no weights have been loaded, the model will raise a ValueError:
# `ValueError("No adapter loaded. Please load an adapter first.")`
assert expected_state == "", (
'expected state must be "" if no adapters have been loaded'
)
Loading