Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions mellea/backends/aloras/huggingface/granite_aloras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,36 @@


class HFConstraintAlora(HFAlora):
"""The [Requirement Checking ALora for Granite 3.2 8B](https://huggingface.co/ibm-granite/granite-3.2-8b-alora-requirement-check) checks if the specified requirement was satisfied by the most recent model generation. Only one requirement is checked at a time."""
"""The Requirement Checking ALora for Granite checks if the specified requirement was satisfied by the most recent model generation. Only one requirement is checked at a time.

Currently supports [Granite 3.2 8B](https://huggingface.co/ibm-granite/granite-3.2-8b-alora-requirement-check) and [Granite 3.3 8B](https://huggingface.co/ibm-granite/granite-3.3-8b-alora-requirement-check) by default.
"""

def __init__(
self,
name: str,
path_or_model_id: str,
generation_prompt: str,
backend: LocalHFBackend,
*,
constraint_prompt: str | None = None,
include_constraint_in_alora_offset: bool = False,
):
"""Initialize after checking that the backend is correct."""
assert backend._hf_model_id == "ibm-granite/granite-3.2-8b-instruct"
"""Initialize after checking that the backend is correct.

Args:
constraint_prompt: a template that the constraint can be interpolated into; can only have a single `{}` slot.
include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset
"""
super().__init__(name, path_or_model_id, generation_prompt, backend)

# Maintain default behavior.
if constraint_prompt is None:
constraint_prompt = "\nRequirement: {}<|end_of_text|>\n"

self._constraint_prompt = constraint_prompt
self._include_constraint_in_alora_offset = include_constraint_in_alora_offset

# We do a lot of logging for ALoras because this is an experimental feature. Maybe we should tag these log messages?
self._logger = FancyLogger.get_logger()

Expand Down Expand Up @@ -51,8 +69,10 @@ def _generate_using_cache(
self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool
) -> str:
assert self._backend.alora_model is not None

# Must tokenize the constraint here since the requirement isn't known at initialization.
constraint_tokens = self._backend._tokenizer(
f"\nRequirement: {constraint}<|end_of_text|>\n", return_tensors="pt"
self._constraint_prompt.format(constraint), return_tensors="pt"
).to(self._backend._device)

input_combined = {
Expand All @@ -74,7 +94,14 @@ def _generate_using_cache(
),
}

alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
if not self._include_constraint_in_alora_offset:
alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
else:
alora_offsets = [
constraint_tokens["input_ids"].shape[1]
+ self._generation_prompt_tokens["input_ids"].shape[1]
- 2
]
self._logger.debug(
f"Prompt for cached aLoRA({self.name}):\n {self._backend._tokenizer.decode(input_combined['input_ids'][0])}"
)
Expand Down Expand Up @@ -136,7 +163,9 @@ def _generate_not_using_cache(

templatized = self._backend._tokenizer.apply_chat_template(chat, tokenize=False)
assert type(templatized) is str
templatized = templatized + f"\nRequirement: {constraint}<|end_of_text|>\n"

# Must tokenize the constraint here since the requirement isn't known at initialization.
templatized = templatized + self._constraint_prompt.format(constraint)

tokenized = self._backend._tokenizer(templatized, return_tensors="pt").to(
self._backend._device
Expand All @@ -156,7 +185,19 @@ def _generate_not_using_cache(
),
}

alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
if not self._include_constraint_in_alora_offset:
alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
else:
# Get the constraint tokens separately so that we can calculate the alora offsets.
constraint_tokens = self._backend._tokenizer(
self._constraint_prompt.format(constraint), return_tensors="pt"
).to(self._backend._device)

alora_offsets = [
constraint_tokens["input_ids"].shape[1]
+ self._generation_prompt_tokens["input_ids"].shape[1]
- 2
]

self._logger.debug(
f"Prompt for non-cached aLoRA({self.name}):\n{self._backend._tokenizer.decode(input_combined['input_ids'][0])}"
Expand Down Expand Up @@ -200,11 +241,29 @@ def _generate_not_using_cache(

def add_granite_aloras(backend: LocalHFBackend):
"""Adds the IBM Granite "starter pack" ALoras to a backend."""
backend.add_alora(
HFConstraintAlora(
name="constraint",
path_or_model_id="ibm-granite/granite-3.2-8b-alora-requirement-check",
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
backend=backend,
if backend._hf_model_id == "ibm-granite/granite-3.2-8b-instruct":
backend.add_alora(
HFConstraintAlora(
name="constraint",
path_or_model_id="ibm-granite/granite-3.2-8b-alora-requirement-check",
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
backend=backend,
constraint_prompt="\nRequirement: {}<|end_of_text|>\n",
include_constraint_in_alora_offset=False,
)
)
elif backend._hf_model_id == "ibm-granite/granite-3.3-8b-instruct":
backend.add_alora(
HFConstraintAlora(
name="constraint",
path_or_model_id="ibm-granite/granite-3.3-8b-alora-requirement-check",
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
backend=backend,
constraint_prompt="\n<|start_of_role|>requirement<|end_of_role|>{}<|end_of_text|>\n",
include_constraint_in_alora_offset=True,
)
)
else:
raise ValueError(
f"cannot add_granite_aloras to unknown huggingface model_id / backend: {backend._hf_model_id}"
)
)
3 changes: 2 additions & 1 deletion test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_system_prompt(self):
def test_constraint_alora(self):
self.m.reset()
answer = self.m.instruct(
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
model_options={ModelOption.MAX_NEW_TOKENS: 300}, # Until aloras get a bit better, try not to abruptly end generation.
)
alora_output = self.backend.get_aloras()[0].generate_using_strings(
input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa",
Expand Down
4 changes: 2 additions & 2 deletions test/backends/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
from typing_extensions import Annotated
from mellea.backends.types import ModelOption
import pytest


class Test_SmokeTestComponents:
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_generate_from_raw(self):

assert len(results) == len(prompts)

@pytest.mark.xfail(reason="ollama sometimes fails generated structured outputs")
def test_generate_from_raw_with_format(self):
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

Expand All @@ -112,6 +114,4 @@ class Answer(pydantic.BaseModel):


if __name__ == "__main__":
import pytest

pytest.main([__file__])
5 changes: 2 additions & 3 deletions test/backends/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_model_option_remove():
), "dict with removed special keys did not match expected"


def test_model_option_replace_to_common_opts(capfd):
def test_model_option_replace_to_common_opts(caplog):
model_opts = {
ModelOption.CONTEXT_WINDOW: 3,
ModelOption.TEMPERATURE: 1,
Expand All @@ -41,8 +41,7 @@ def test_model_option_replace_to_common_opts(capfd):
), "dict with replaced keys did not match expected"

# There should also be a logged message due to context_window key clashes.
out, _ = capfd.readouterr()
assert "old_key (context_size) to new_key (@@@context_window@@@): lost value associated with old_key (4) and kept original value of new_key (3)" in out, "expected log for conflicting keys not found"
assert "old_key (context_size) to new_key (@@@context_window@@@): lost value associated with old_key (4) and kept original value of new_key (3)" in caplog.text, f"expected log for conflicting keys not found in: {caplog.text}"


def test_model_option_replace_to_backend_specific():
Expand Down