Skip to content

Commit

Permalink
Merge pull request #1141 from newrelic/add-mistral
Browse files Browse the repository at this point in the history
Add support for mistral bedrock model
  • Loading branch information
hmstepanek committed May 17, 2024
2 parents 3e5be52 + ac6fe6d commit 93ad043
Show file tree
Hide file tree
Showing 5 changed files with 462 additions and 10 deletions.
62 changes: 53 additions & 9 deletions newrelic/hooks/external_botocore.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ def create_chat_completion_message_event(
"request_id": request_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": settings.ai_monitoring.llm_token_count_callback(request_model, content)
if settings.ai_monitoring.llm_token_count_callback
else None,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"role": message.get("role"),
"completion_id": chat_completion_id,
"sequence": index,
Expand Down Expand Up @@ -156,9 +158,11 @@ def create_chat_completion_message_event(
"request_id": request_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": settings.ai_monitoring.llm_token_count_callback(request_model, content)
if settings.ai_monitoring.llm_token_count_callback
else None,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"role": message.get("role"),
"completion_id": chat_completion_id,
"sequence": index,
Expand Down Expand Up @@ -189,6 +193,14 @@ def extract_bedrock_titan_text_model_request(request_body, bedrock_attrs):
return bedrock_attrs


def extract_bedrock_mistral_text_model_request(request_body, bedrock_attrs):
request_body = json.loads(request_body)
bedrock_attrs["input_message_list"] = [{"role": "user", "content": request_body.get("prompt")}]
bedrock_attrs["request.max_tokens"] = request_body.get("max_tokens")
bedrock_attrs["request.temperature"] = request_body.get("temperature")
return bedrock_attrs


def extract_bedrock_titan_text_model_response(response_body, bedrock_attrs):
if response_body:
response_body = json.loads(response_body)
Expand All @@ -203,6 +215,18 @@ def extract_bedrock_titan_text_model_response(response_body, bedrock_attrs):
return bedrock_attrs


def extract_bedrock_mistral_text_model_response(response_body, bedrock_attrs):
if response_body:
response_body = json.loads(response_body)
outputs = response_body.get("outputs")
if outputs:
bedrock_attrs["response.choices.finish_reason"] = outputs[0]["stop_reason"]
bedrock_attrs["output_message_list"] = [
{"role": "assistant", "content": result["text"]} for result in outputs
]
return bedrock_attrs


def extract_bedrock_titan_text_model_streaming_response(response_body, bedrock_attrs):
if response_body:
if "outputText" in response_body:
Expand All @@ -214,6 +238,18 @@ def extract_bedrock_titan_text_model_streaming_response(response_body, bedrock_a
return bedrock_attrs


def extract_bedrock_mistral_text_model_streaming_response(response_body, bedrock_attrs):
if response_body:
outputs = response_body.get("outputs")
if outputs:
bedrock_attrs["output_message_list"] = bedrock_attrs.get(
"output_message_list", [{"role": "assistant", "content": ""}]
)
bedrock_attrs["output_message_list"][0]["content"] += outputs[0].get("text", "")
bedrock_attrs["response.choices.finish_reason"] = outputs[0].get("stop_reason", None)
return bedrock_attrs


def extract_bedrock_titan_embedding_model_request(request_body, bedrock_attrs):
request_body = json.loads(request_body)

Expand Down Expand Up @@ -407,6 +443,12 @@ def extract_bedrock_cohere_model_streaming_response(response_body, bedrock_attrs
extract_bedrock_llama_model_response,
extract_bedrock_llama_model_streaming_response,
),
(
"mistral",
extract_bedrock_mistral_text_model_request,
extract_bedrock_mistral_text_model_response,
extract_bedrock_mistral_text_model_streaming_response,
),
]


Expand Down Expand Up @@ -698,9 +740,11 @@ def handle_embedding_event(transaction, bedrock_attrs):
"id": embedding_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": settings.ai_monitoring.llm_token_count_callback(model, input)
if settings.ai_monitoring.llm_token_count_callback
else None,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(model, input)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"request_id": request_id,
"duration": bedrock_attrs.get("duration", None),
"request.model": model,
Expand Down
Loading

0 comments on commit 93ad043

Please sign in to comment.