Skip to content

Commit

Permalink
Add Mistral support to Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Mar 22, 2024
1 parent f96af2e commit 51cf7d7
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `GoogleStructureConfig` for providing Structures with Google Prompt and Embedding Driver configuration.
- Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`.
- Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`.
- Support for `mistral.mistral-7b-instruct-v0:2` and `mistral.mixtral-8x7b-instruct-v0:1` through `BedrockMistralPromptModelDriver`.
- `top_k` and `top_p` parameters in `AnthropicPromptDriver`.
- Added `AnthropicImageQueryDriver` for Claude-3 multi-modal models
- Added `AmazonBedrockImageQueryDriver` along with `BedrockClaudeImageQueryDriverModel` for Claude-3 in Bedrock support
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver
from .prompt_model.bedrock_jurassic_prompt_model_driver import BedrockJurassicPromptModelDriver
from .prompt_model.bedrock_llama_prompt_model_driver import BedrockLlamaPromptModelDriver
from .prompt_model.bedrock_mistral_prompt_model_driver import BedrockMistralPromptModelDriver

from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver
from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import (
Expand Down Expand Up @@ -140,6 +141,7 @@
"BedrockClaudePromptModelDriver",
"BedrockJurassicPromptModelDriver",
"BedrockLlamaPromptModelDriver",
"BedrockMistralPromptModelDriver",
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations
import json
import itertools as it
from typing import Optional
from attr import define, field
from griptape.artifacts import TextArtifact
from griptape.utils import PromptStack
from griptape.drivers import BasePromptModelDriver
from griptape.tokenizers import BedrockMistralTokenizer
from griptape.drivers import AmazonBedrockPromptDriver
from griptape.utils import J2


@define
class BedrockMistralPromptModelDriver(BasePromptModelDriver):
BOS_TOKEN = "<s>"
EOS_TOKEN = "</s>"

top_p: Optional[float] = field(default=None, kw_only=True)
top_k: Optional[float] = field(default=None, kw_only=True)
_tokenizer: BedrockMistralTokenizer = field(default=None, kw_only=True)
prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

@property
def tokenizer(self) -> BedrockMistralTokenizer:
"""Returns the tokenizer for this driver.
We need to pass the `session` field from the Prompt Driver to the
Tokenizer. However, the Prompt Driver is not initialized until after
the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
field a @property that is only initialized when it is first accessed.
This ensures that by the time we need to initialize the Tokenizer, the
Prompt Driver has already been initialized.
See this thread more more information: https://github.com/griptape-ai/griptape/issues/244
Returns:
BedrockLlamaTokenizer: The tokenizer for this driver.
"""
if self._tokenizer:
return self._tokenizer
else:
self._tokenizer = BedrockMistralTokenizer(model=self.prompt_driver.model)
return self._tokenizer

def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
"""
Converts a `PromptStack` to a string that can be used as the input to the model.
Prompt structure adapted from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
Args:
prompt_stack: The `PromptStack` to convert.
"""
system_input = next((i for i in prompt_stack.inputs if i.is_system()), None)
non_system_inputs = [i for i in prompt_stack.inputs if not i.is_system()]
if system_input is not None:
non_system_inputs[0].content = f"{system_input.content} {non_system_inputs[0].content}"

prompt_lines = [self.BOS_TOKEN]
for prompt_input in non_system_inputs:
if prompt_input.is_assistant():
prompt_lines.append(f"{prompt_input.content}{self.EOS_TOKEN} ")
else:
prompt_lines.append(f"[INST] {prompt_input.content} [/INST]")

return "".join(prompt_lines)

def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_stack_to_model_input(prompt_stack)

return {
"prompt": prompt,
"stop": self.tokenizer.stop_sequences,
"temperature": self.prompt_driver.temperature,
"max_tokens": self.prompt_driver.max_output_tokens(prompt),
**({"top_p": self.top_p} if self.top_p else {}),
**({"top_k": self.top_k} if self.top_k else {}),
}

def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
# When streaming, the response body comes back as bytes.
if isinstance(output, bytes):
output = output.decode()
elif isinstance(output, list):
raise Exception("Invalid output format.")

body = json.loads(output)
outputs = body["outputs"]

if len(outputs) == 1:
return TextArtifact(outputs[0]["text"])
else:
raise Exception("Completion with more than one choice is not supported yet.")
2 changes: 2 additions & 0 deletions griptape/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from griptape.tokenizers.bedrock_llama_tokenizer import BedrockLlamaTokenizer
from griptape.tokenizers.google_tokenizer import GoogleTokenizer
from griptape.tokenizers.voyageai_tokenizer import VoyageAiTokenizer
from griptape.tokenizers.bedrock_mistral_tokenizer import BedrockMistralTokenizer
from griptape.tokenizers.simple_tokenizer import SimpleTokenizer
from griptape.tokenizers.dummy_tokenizer import DummyTokenizer

Expand All @@ -27,6 +28,7 @@
"BedrockLlamaTokenizer",
"GoogleTokenizer",
"VoyageAiTokenizer",
"BedrockMistralTokenizer",
"SimpleTokenizer",
"DummyTokenizer",
]
13 changes: 13 additions & 0 deletions griptape/tokenizers/bedrock_mistral_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations
from attr import define, field
from .simple_tokenizer import SimpleTokenizer


@define()
class BedrockMistralTokenizer(SimpleTokenizer):
DEFAULT_CHARACTERS_PER_TOKEN = 6
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"mistral": 32000}
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"mistral.mistral-7b-instruct": 8192, "mistral.mixtral-8x7b-instruct": 4096}

model: str = field(kw_only=True)
characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unittest import mock
import json
import boto3
import pytest
from griptape.utils import PromptStack
from griptape.drivers import AmazonBedrockPromptDriver, BedrockMistralPromptModelDriver


class TestBedrockMistralPromptModelDriver:
@pytest.fixture(autouse=True)
def mock_session(self, mocker):
mock_session_class = mocker.patch("boto3.Session")

mock_session_object = mock.Mock()
mock_client = mock.Mock()
mock_response = mock.Mock()

mock_client.invoke_model.return_value = mock_response
mock_session_object.client.return_value = mock_client
mock_session_class.return_value = mock_session_object

return mock_session_object

@pytest.fixture
def driver(self):
return AmazonBedrockPromptDriver(
model="mistral.mistral-7b-instruct-v0:2",
session=boto3.Session(region_name="us-east-1"),
prompt_model_driver=BedrockMistralPromptModelDriver(top_p=0.9, top_k=250),
temperature=0.12345,
).prompt_model_driver

@pytest.fixture
def stack(self):
stack = PromptStack()

stack.add_system_input("system-input")
stack.add_user_input("user-input")
stack.add_assistant_input("assistant-input")
stack.add_user_input("user-input")

return stack

def test_init(self, driver):
assert driver.prompt_driver is not None

def test_prompt_stack_to_model_input(self, driver, stack):
model_input = driver.prompt_stack_to_model_input(stack)

assert isinstance(model_input, str)
assert model_input == "<s>[INST] system-input user-input [/INST]assistant-input</s> [INST] user-input [/INST]"

def test_prompt_stack_to_model_params(self, driver, stack):
assert driver.prompt_stack_to_model_params(stack)["max_tokens"] == 8177
assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345
assert driver.prompt_stack_to_model_params(stack)["top_p"] == 0.9
assert driver.prompt_stack_to_model_params(stack)["top_k"] == 250

def test_process_output(self, driver):
assert driver.process_output(json.dumps({"outputs": [{"text": "foobar", "stop": "reason"}]})).value == "foobar"
24 changes: 24 additions & 0 deletions tests/unit/tokenizers/test_bedrock_mistral_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from griptape.tokenizers import BedrockMistralTokenizer


class TestBedrockMistralTokenizer:
@pytest.fixture
def tokenizer(self, request):
return BedrockMistralTokenizer(model=request.param)

@pytest.mark.parametrize(
"tokenizer,expected",
[("mistral.mistral-7b-instruct", 31997), ("mistral.mixtral-8x7b-instruct", 31997)],
indirect=["tokenizer"],
)
def test_input_tokens_left(self, tokenizer, expected):
assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected

@pytest.mark.parametrize(
"tokenizer,expected",
[("mistral.mistral-7b-instruct", 8189), ("mistral.mixtral-8x7b-instruct", 4093)],
indirect=["tokenizer"],
)
def test_ouput_tokens_left(self, tokenizer, expected):
assert tokenizer.count_output_tokens_left("foo bar huzzah") == expected

0 comments on commit 51cf7d7

Please sign in to comment.