Skip to content

Commit

Permalink
fix: anthropic system
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck committed Jun 17, 2024
1 parent a6daafe commit 0f610ca
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
46 changes: 31 additions & 15 deletions examples/pipelines/providers/anthropic_manifold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from pydantic import BaseModel
import requests

from utils.pipelines.main import pop_system_message


class Pipeline:
class Valves(BaseModel):
Expand Down Expand Up @@ -80,6 +82,8 @@ def pipe(
def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
system_message, messages = pop_system_message(messages)

max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
Expand All @@ -91,14 +95,19 @@ def stream_response(
stop_sequences = body.get("stop") if body.get("stop") is not None else []

stream = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
stream=True,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
)

for chunk in stream:
Expand All @@ -108,6 +117,8 @@ def stream_response(
yield chunk.delta.text

def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)

max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
Expand All @@ -119,12 +130,17 @@ def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str
stop_sequences = body.get("stop") if body.get("stop") is not None else []

response = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
)
return response.content[0].text
17 changes: 16 additions & 1 deletion utils/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from schemas import OpenAIChatMessage

import inspect
from typing import get_type_hints, Literal
from typing import get_type_hints, Literal, Tuple


def stream_message_template(model: str, message: str):
Expand Down Expand Up @@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None


def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None


def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]


def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)


def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
Expand Down

0 comments on commit 0f610ca

Please sign in to comment.