Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 101 additions & 43 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import functools
import inspect
import json
import threading
from collections.abc import Callable, Coroutine
from copy import deepcopy
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -182,6 +183,9 @@ def __init__(
self._added_adapters: dict[str, LocalHFAdapter] = {}
self._loaded_adapters: dict[str, LocalHFAdapter] = {}

self._generation_lock = threading.Lock()
"""Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""

async def generate_from_context(
self,
action: Component | CBlock,
Expand Down Expand Up @@ -245,6 +249,32 @@ async def generate_from_context(
)
return mot, ctx.add(action).add(mot)

def _generate_with_adapter_lock(
self, adapter_name: str, generate_func: Callable, *args, **kwargs
):
"""Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
with self._generation_lock:
if adapter_name != "":
self.load_adapter(adapter_name)
self._model.set_adapter(adapter_name)
else:
try:
# `._model.disable_adapters()` doesn't seem to actually disable them or
# remove them from the model's list of `.active_adapters()`.
self._model.set_adapter([])
except ValueError as e:
# If no weights have been loaded, the model will raise a ValueError:
# `ValueError("No adapter loaded. Please load an adapter first.")`
if "No adapter loaded" in str(e):
pass
else:
raise e

_assert_correct_adapters(adapter_name, self._model)
out = generate_func(*args, **kwargs)
_assert_correct_adapters(adapter_name, self._model)
return out

async def _generate_from_intrinsic(
self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any]
) -> ModelOutputThunk:
Expand Down Expand Up @@ -317,27 +347,21 @@ async def _generate_from_intrinsic(
# so we will have to invalidate the cache on our side. This requires
# us having specific caching for each Component/Message.

self.load_adapter(adapter.qualified_name)

# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
# It should allow generate requests with the same adapter to proceed. This logic also
# needs to be added to the other generate functions.
self._model.set_adapter(adapter.qualified_name)

generate_input, other_input = (
granite_common.util.chat_completion_request_to_transformers_inputs(
rewritten, self._tokenizer, self._model
)
)

chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
asyncio.to_thread(
granite_common.util.generate_with_transformers,
self._tokenizer,
self._model,
generate_input,
other_input,
)
chat_response = asyncio.to_thread(
self._generate_with_adapter_lock,
adapter.qualified_name,
granite_common.util.generate_with_transformers,
# Passed as args/kwargs to generate.
self._tokenizer,
self._model,
generate_input,
other_input,
)

output = ModelOutputThunk(None)
Expand Down Expand Up @@ -490,7 +514,10 @@ async def _generate_from_context_standard(
generate_options = self._filter_chat_template_only_options(model_options)

chat_response = asyncio.to_thread(
self._generate_with_adapter_lock,
"", # Empty for no adapters.
self._model.generate, # type: ignore
# Passed as args/kwargs to generate.
input_ids,
return_dict_in_generate=True,
output_scores=True,
Expand Down Expand Up @@ -664,42 +691,41 @@ async def generate_from_raw(
self._device
)

if format is None:
outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
**self._make_backend_specific_and_remove(model_opts),
)
else:
format_kwargs = {}
if format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
)

from outlines.models.transformers import TransformerTokenizer
from outlines.processors import RegexLogitsProcessor
from outlines.processors.structured import RegexLogitsProcessor
from transformers import LogitsProcessorList

outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
logits_processor=LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
),
**self._make_backend_specific_and_remove(model_opts),
format_kwargs["logits_processor"] = LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
)

outputs = await asyncio.to_thread(
self._generate_with_adapter_lock,
"", # Empty for no adapter.
self._model.generate, # type: ignore
# Passed as args/kwargs to generate.
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
output_scores=True,
**self._make_backend_specific_and_remove(model_opts),
**format_kwargs,
)

sequences_to_decode = [
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
for i, sequence in enumerate(outputs.sequences)
Expand Down Expand Up @@ -853,7 +879,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
self._added_adapters[adapter.qualified_name] = adapter

def load_adapter(self, adapter_qualified_name: str):
"""Loads the given adapter for the backend. Must have previously been added."""
"""Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening."""
adapter = self._added_adapters.get(adapter_qualified_name, None)
if adapter is None:
raise ValueError(
Expand All @@ -880,7 +906,7 @@ def load_adapter(self, adapter_qualified_name: str):
# Loading an adapter activates it. We disable adapters immediately after.
# Prefer this over `.disable_adapters()`; the disable function doesn't always
# seem to work.
self._model.set_adapter([])
self._model.disable_adapters()
self._loaded_adapters[adapter.qualified_name] = adapter

def unload_adapter(self, adapter_qualified_name: str):
Expand All @@ -906,6 +932,38 @@ def list_adapters(self) -> list[str]:
return list(self._loaded_adapters.keys())


def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
"""When generating with a huggingface model, this can be used to ensure the correct adapters are active.

Args:
expected_state: the current state of the lock
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
"""
try:
active = model.active_adapters()

if expected_state == "":
assert len(active) == 0, (
f'no adapters should be active if expected state is "", got "{active[0]}"'
)
else:
assert len(active) == 1, (
f'one adapter should be active if expected state is "{expected_state}"'
)
assert active[0] == expected_state, (
f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"'
)
except ValueError as e:
# If no weights have been loaded, the model will raise a ValueError:
# `ValueError("No adapter loaded. Please load an adapter first.")`
if "No adapter loaded" in str(e):
assert expected_state == "", (
f'got no adapters loaded but expected state is "{expected_state}"'
)
else:
raise e


class HFProcessRewardModel(PRM, abc.ABC):
"""A Process Reward Model that works with a huggingface backend."""

Expand Down
Loading