diff --git a/examples/pipelines/providers/groq_manifold_pipeline.py b/examples/pipelines/providers/groq_manifold_pipeline.py new file mode 100755 index 0000000..717f738 --- /dev/null +++ b/examples/pipelines/providers/groq_manifold_pipeline.py @@ -0,0 +1,122 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +from pydantic import BaseModel + +import os +import requests + + +class Pipeline: + class Valves(BaseModel): + GROQ_API_BASE_URL: str = "https://api.groq.com/openai/v1" + GROQ_API_KEY: str = "" + pass + + def __init__(self): + self.type = "manifold" + # Optionally, you can set the id and name of the pipeline. + # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline. + # The identifier must be unique across all pipelines. + # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. + self.id = "groq" + self.name = "Groq: " + + self.valves = self.Valves( + **{ + "GROQ_API_KEY": os.getenv( + "GROQ_API_KEY", "your-groq-api-key-here" + ) + } + ) + + self.pipelines = self.get_models() + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + async def on_valves_updated(self): + # This function is called when the valves are updated. + print(f"on_valves_updated:{__name__}") + self.pipelines = self.get_models() + pass + + def get_models(self): + if self.valves.GROQ_API_KEY: + try: + headers = {} + headers["Authorization"] = f"Bearer {self.valves.GROQ_API_KEY}" + headers["Content-Type"] = "application/json" + + r = requests.get( + f"{self.valves.GROQ_API_BASE_URL}/models", headers=headers + ) + + models = r.json() + return [ + { + "id": model["id"], + "name": model["name"] if "name" in model else model["id"], + } + for model in models["data"] + ] + + except Exception as e: + + print(f"Error: {e}") + return [ + { + "id": "error", + "name": "Could not fetch models from Groq, please update the API Key in the valves.", + }, + ] + else: + return [] + + def pipe( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Union[str, Generator, Iterator]: + # This is where you can add your custom pipelines like RAG. + print(f"pipe:{__name__}") + + print(messages) + print(user_message) + + headers = {} + headers["Authorization"] = f"Bearer {self.valves.GROQ_API_KEY}" + headers["Content-Type"] = "application/json" + + payload = {**body, "model": model_id} + + if "user" in payload: + del payload["user"] + if "chat_id" in payload: + del payload["chat_id"] + if "title" in payload: + del payload["title"] + + print(payload) + + try: + r = requests.post( + url=f"{self.valves.GROQ_API_BASE_URL}/chat/completions", + json=payload, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}" \ No newline at end of file