-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f96af2e
commit 51cf7d7
Showing
7 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
griptape/drivers/prompt_model/bedrock_mistral_prompt_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import annotations | ||
import json | ||
import itertools as it | ||
from typing import Optional | ||
from attr import define, field | ||
from griptape.artifacts import TextArtifact | ||
from griptape.utils import PromptStack | ||
from griptape.drivers import BasePromptModelDriver | ||
from griptape.tokenizers import BedrockMistralTokenizer | ||
from griptape.drivers import AmazonBedrockPromptDriver | ||
from griptape.utils import J2 | ||
|
||
|
||
@define | ||
class BedrockMistralPromptModelDriver(BasePromptModelDriver): | ||
BOS_TOKEN = "<s>" | ||
EOS_TOKEN = "</s>" | ||
|
||
top_p: Optional[float] = field(default=None, kw_only=True) | ||
top_k: Optional[float] = field(default=None, kw_only=True) | ||
_tokenizer: BedrockMistralTokenizer = field(default=None, kw_only=True) | ||
prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) | ||
|
||
@property | ||
def tokenizer(self) -> BedrockMistralTokenizer: | ||
"""Returns the tokenizer for this driver. | ||
We need to pass the `session` field from the Prompt Driver to the | ||
Tokenizer. However, the Prompt Driver is not initialized until after | ||
the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer` | ||
field a @property that is only initialized when it is first accessed. | ||
This ensures that by the time we need to initialize the Tokenizer, the | ||
Prompt Driver has already been initialized. | ||
See this thread more more information: https://github.com/griptape-ai/griptape/issues/244 | ||
Returns: | ||
BedrockLlamaTokenizer: The tokenizer for this driver. | ||
""" | ||
if self._tokenizer: | ||
return self._tokenizer | ||
else: | ||
self._tokenizer = BedrockMistralTokenizer(model=self.prompt_driver.model) | ||
return self._tokenizer | ||
|
||
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: | ||
""" | ||
Converts a `PromptStack` to a string that can be used as the input to the model. | ||
Prompt structure adapted from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format | ||
Args: | ||
prompt_stack: The `PromptStack` to convert. | ||
""" | ||
system_input = next((i for i in prompt_stack.inputs if i.is_system()), None) | ||
non_system_inputs = [i for i in prompt_stack.inputs if not i.is_system()] | ||
if system_input is not None: | ||
non_system_inputs[0].content = f"{system_input.content} {non_system_inputs[0].content}" | ||
|
||
prompt_lines = [self.BOS_TOKEN] | ||
for prompt_input in non_system_inputs: | ||
if prompt_input.is_assistant(): | ||
prompt_lines.append(f"{prompt_input.content}{self.EOS_TOKEN} ") | ||
else: | ||
prompt_lines.append(f"[INST] {prompt_input.content} [/INST]") | ||
|
||
return "".join(prompt_lines) | ||
|
||
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: | ||
prompt = self.prompt_stack_to_model_input(prompt_stack) | ||
|
||
return { | ||
"prompt": prompt, | ||
"stop": self.tokenizer.stop_sequences, | ||
"temperature": self.prompt_driver.temperature, | ||
"max_tokens": self.prompt_driver.max_output_tokens(prompt), | ||
**({"top_p": self.top_p} if self.top_p else {}), | ||
**({"top_k": self.top_k} if self.top_k else {}), | ||
} | ||
|
||
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact: | ||
# When streaming, the response body comes back as bytes. | ||
if isinstance(output, bytes): | ||
output = output.decode() | ||
elif isinstance(output, list): | ||
raise Exception("Invalid output format.") | ||
|
||
body = json.loads(output) | ||
outputs = body["outputs"] | ||
|
||
if len(outputs) == 1: | ||
return TextArtifact(outputs[0]["text"]) | ||
else: | ||
raise Exception("Completion with more than one choice is not supported yet.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from __future__ import annotations | ||
from attr import define, field | ||
from .simple_tokenizer import SimpleTokenizer | ||
|
||
|
||
@define() | ||
class BedrockMistralTokenizer(SimpleTokenizer): | ||
DEFAULT_CHARACTERS_PER_TOKEN = 6 | ||
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"mistral": 32000} | ||
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"mistral.mistral-7b-instruct": 8192, "mistral.mixtral-8x7b-instruct": 4096} | ||
|
||
model: str = field(kw_only=True) | ||
characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True) |
60 changes: 60 additions & 0 deletions
60
tests/unit/drivers/prompt_models/test_bedrock_mistral_prompt_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from unittest import mock | ||
import json | ||
import boto3 | ||
import pytest | ||
from griptape.utils import PromptStack | ||
from griptape.drivers import AmazonBedrockPromptDriver, BedrockMistralPromptModelDriver | ||
|
||
|
||
class TestBedrockMistralPromptModelDriver: | ||
@pytest.fixture(autouse=True) | ||
def mock_session(self, mocker): | ||
mock_session_class = mocker.patch("boto3.Session") | ||
|
||
mock_session_object = mock.Mock() | ||
mock_client = mock.Mock() | ||
mock_response = mock.Mock() | ||
|
||
mock_client.invoke_model.return_value = mock_response | ||
mock_session_object.client.return_value = mock_client | ||
mock_session_class.return_value = mock_session_object | ||
|
||
return mock_session_object | ||
|
||
@pytest.fixture | ||
def driver(self): | ||
return AmazonBedrockPromptDriver( | ||
model="mistral.mistral-7b-instruct-v0:2", | ||
session=boto3.Session(region_name="us-east-1"), | ||
prompt_model_driver=BedrockMistralPromptModelDriver(top_p=0.9, top_k=250), | ||
temperature=0.12345, | ||
).prompt_model_driver | ||
|
||
@pytest.fixture | ||
def stack(self): | ||
stack = PromptStack() | ||
|
||
stack.add_system_input("system-input") | ||
stack.add_user_input("user-input") | ||
stack.add_assistant_input("assistant-input") | ||
stack.add_user_input("user-input") | ||
|
||
return stack | ||
|
||
def test_init(self, driver): | ||
assert driver.prompt_driver is not None | ||
|
||
def test_prompt_stack_to_model_input(self, driver, stack): | ||
model_input = driver.prompt_stack_to_model_input(stack) | ||
|
||
assert isinstance(model_input, str) | ||
assert model_input == "<s>[INST] system-input user-input [/INST]assistant-input</s> [INST] user-input [/INST]" | ||
|
||
def test_prompt_stack_to_model_params(self, driver, stack): | ||
assert driver.prompt_stack_to_model_params(stack)["max_tokens"] == 8177 | ||
assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 | ||
assert driver.prompt_stack_to_model_params(stack)["top_p"] == 0.9 | ||
assert driver.prompt_stack_to_model_params(stack)["top_k"] == 250 | ||
|
||
def test_process_output(self, driver): | ||
assert driver.process_output(json.dumps({"outputs": [{"text": "foobar", "stop": "reason"}]})).value == "foobar" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pytest | ||
from griptape.tokenizers import BedrockMistralTokenizer | ||
|
||
|
||
class TestBedrockMistralTokenizer: | ||
@pytest.fixture | ||
def tokenizer(self, request): | ||
return BedrockMistralTokenizer(model=request.param) | ||
|
||
@pytest.mark.parametrize( | ||
"tokenizer,expected", | ||
[("mistral.mistral-7b-instruct", 31997), ("mistral.mixtral-8x7b-instruct", 31997)], | ||
indirect=["tokenizer"], | ||
) | ||
def test_input_tokens_left(self, tokenizer, expected): | ||
assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected | ||
|
||
@pytest.mark.parametrize( | ||
"tokenizer,expected", | ||
[("mistral.mistral-7b-instruct", 8189), ("mistral.mixtral-8x7b-instruct", 4093)], | ||
indirect=["tokenizer"], | ||
) | ||
def test_ouput_tokens_left(self, tokenizer, expected): | ||
assert tokenizer.count_output_tokens_left("foo bar huzzah") == expected |