diff --git a/examples/pipelines/providers/anthropic_manifold_pipeline.py b/examples/pipelines/providers/anthropic_manifold_pipeline.py index 7ba9395..57b4745 100644 --- a/examples/pipelines/providers/anthropic_manifold_pipeline.py +++ b/examples/pipelines/providers/anthropic_manifold_pipeline.py @@ -17,6 +17,8 @@ from pydantic import BaseModel import requests +from utils.pipelines.main import pop_system_message + class Pipeline: class Valves(BaseModel): @@ -80,6 +82,8 @@ def pipe( def stream_response( self, model_id: str, messages: List[dict], body: dict ) -> Generator: + system_message, messages = pop_system_message(messages) + max_tokens = ( body.get("max_tokens") if body.get("max_tokens") is not None else 4096 ) @@ -91,14 +95,19 @@ def stream_response( stop_sequences = body.get("stop") if body.get("stop") is not None else [] stream = self.client.messages.create( - model=model_id, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - stop_sequences=stop_sequences, - stream=True, + **{ + "model": model_id, + **( + {"system": system_message} if system_message else {} + ), # Add system message if it exists (optional + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "stop_sequences": stop_sequences, + "stream": True, + } ) for chunk in stream: @@ -108,6 +117,8 @@ def stream_response( yield chunk.delta.text def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: + system_message, messages = pop_system_message(messages) + max_tokens = ( body.get("max_tokens") if body.get("max_tokens") is not None else 4096 ) @@ -119,12 +130,17 @@ def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str stop_sequences = body.get("stop") if body.get("stop") is not None else [] response = self.client.messages.create( - model=model_id, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - stop_sequences=stop_sequences, + **{ + "model": model_id, + **( + {"system": system_message} if system_message else {} + ), # Add system message if it exists (optional + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "stop_sequences": stop_sequences, + } ) return response.content[0].text diff --git a/utils/pipelines/main.py b/utils/pipelines/main.py index f5830b9..2d064d9 100644 --- a/utils/pipelines/main.py +++ b/utils/pipelines/main.py @@ -5,7 +5,7 @@ from schemas import OpenAIChatMessage import inspect -from typing import get_type_hints, Literal +from typing import get_type_hints, Literal, Tuple def stream_message_template(model: str, message: str): @@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str: return None +def get_system_message(messages: List[dict]) -> dict: + for message in messages: + if message["role"] == "system": + return message + return None + + +def remove_system_message(messages: List[dict]) -> List[dict]: + return [message for message in messages if message["role"] != "system"] + + +def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: + return get_system_message(messages), remove_system_message(messages) + + def add_or_update_system_message(content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list