Skip to content

Commit

Permalink
Merge pull request #60 from MindFreeze/feat/groq_manifold
Browse files Browse the repository at this point in the history
feat: groq manifold pipeline
  • Loading branch information
tjbck committed Jun 5, 2024
2 parents c9766d6 + 47a7c82 commit b4ed391
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions examples/pipelines/providers/groq_manifold_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
from pydantic import BaseModel

import os
import requests


class Pipeline:
class Valves(BaseModel):
GROQ_API_BASE_URL: str = "https://api.groq.com/openai/v1"
GROQ_API_KEY: str = ""
pass

def __init__(self):
self.type = "manifold"
# Optionally, you can set the id and name of the pipeline.
# Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
# The identifier must be unique across all pipelines.
# The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
self.id = "groq"
self.name = "Groq: "

self.valves = self.Valves(
**{
"GROQ_API_KEY": os.getenv(
"GROQ_API_KEY", "your-groq-api-key-here"
)
}
)

self.pipelines = self.get_models()
pass

async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass

async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass

async def on_valves_updated(self):
# This function is called when the valves are updated.
print(f"on_valves_updated:{__name__}")
self.pipelines = self.get_models()
pass

def get_models(self):
if self.valves.GROQ_API_KEY:
try:
headers = {}
headers["Authorization"] = f"Bearer {self.valves.GROQ_API_KEY}"
headers["Content-Type"] = "application/json"

r = requests.get(
f"{self.valves.GROQ_API_BASE_URL}/models", headers=headers
)

models = r.json()
return [
{
"id": model["id"],
"name": model["name"] if "name" in model else model["id"],
}
for model in models["data"]
]

except Exception as e:

print(f"Error: {e}")
return [
{
"id": "error",
"name": "Could not fetch models from Groq, please update the API Key in the valves.",
},
]
else:
return []

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.
print(f"pipe:{__name__}")

print(messages)
print(user_message)

headers = {}
headers["Authorization"] = f"Bearer {self.valves.GROQ_API_KEY}"
headers["Content-Type"] = "application/json"

payload = {**body, "model": model_id}

if "user" in payload:
del payload["user"]
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]

print(payload)

try:
r = requests.post(
url=f"{self.valves.GROQ_API_BASE_URL}/chat/completions",
json=payload,
headers=headers,
stream=True,
)

r.raise_for_status()

if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"

0 comments on commit b4ed391

Please sign in to comment.