In [None]:
import os
import random
import string
import threading
import json
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from typing import List, Optional
from groq import Groq

app = FastAPI()

# Initialize the master API client with your API key
master_client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

# Thread lock for synchronizing access to api_keys.json
api_keys_lock = threading.Lock()

# Define the request and response models
class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    messages: List[Message]
    model: str

class ChatCompletionResponse(BaseModel):
    content: str

def generate_api_key(length=30):
    """Generates a random API key of the given length."""
    characters = string.ascii_letters + string.digits
    api_key = ''.join(random.choice(characters) for _ in range(length))
    return api_key

def load_api_keys():
    """Loads API keys from the JSON file."""
    with api_keys_lock:
        if not os.path.exists('api_keys.json'):
            return {}
        with open('api_keys.json', 'r') as f:
            return json.load(f)

def save_api_keys(api_keys):
    """Saves API keys to the JSON file."""
    with api_keys_lock:
        with open('api_keys.json', 'w') as f:
            json.dump(api_keys, f, indent=4)

def add_api_key(client_name):
    """Generates a new API key, adds it to api_keys.json, and returns the key."""
    api_keys = load_api_keys()
    new_key = generate_api_key()
    api_keys[new_key] = client_name
    save_api_keys(api_keys)
    return new_key

@app.post("/generate_api_key")
async def generate_api_key_endpoint(client_name: str):
    """API endpoint to generate a new API key."""
    new_key = add_api_key(client_name)
    return {"api_key": new_key}

@app.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(
    request: ChatCompletionRequest,
    api_key: Optional[str] = Header(None),
):
    # Load the API keys
    api_keys = load_api_keys()

    # Authenticate the client using the wrapper's API key
    if api_key not in api_keys:
        raise HTTPException(status_code=401, detail="Invalid API Key")

    try:
        # Forward the request to the master API
        chat_completion = master_client.chat.completions.create(
            messages=[message.dict() for message in request.messages],
            model=request.model,
        )

        # Extract the response content
        content = chat_completion.choices[0].message.content

        # Return the response to the client
        return ChatCompletionResponse(content=content)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


In [None]:
import os
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from typing import List, Optional
from cerebras.cloud.sdk import Cerebras
from groq import Groq

app = FastAPI()

# Define the request and response models
class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    messages: List[Message]
    model: str
    service_name: str  # Service to use: 'cerebras' or 'groq'

class ChatCompletionResponse(BaseModel):
    content: str

# API keys for the services
SERVICE_API_KEYS = {
    "cerebras": os.environ.get("CEREBRAS_API_KEY"),
    "groq": os.environ.get("GROQ_API_KEY"),
}

# Ensure all necessary API keys are provided
for service, api_key in SERVICE_API_KEYS.items():
    if not api_key:
        raise Exception(f"API key for {service} not set in environment variables.")

def get_client(service_name: str):
    """Get the appropriate client based on the service name."""
    service_name = service_name.lower()
    if service_name == "cerebras":
        return Cerebras(api_key=SERVICE_API_KEYS["cerebras"])
    elif service_name == "groq":
        return Groq(api_key=SERVICE_API_KEYS["groq"])
    else:
        raise ValueError(f"Unsupported service_name: {service_name}")

@app.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(
    request: ChatCompletionRequest,
    api_key: Optional[str] = Header(None),
):
    service_name = request.service_name.lower()

    try:
        client = get_client(service_name)
        messages = [message.dict() for message in request.messages]

        if service_name == "cerebras":
            chat_completion = client.chat.completions.create(
                messages=messages,
                model=request.model,
            )
            content = chat_completion  # Adjust as per actual response

        elif service_name == "groq":
            chat_completion = client.chat.completions.create(
                messages=messages,
                model=request.model,
            )
            content = chat_completion.choices[0].message.content

        else:
            raise ValueError(f"Unsupported service_name: {service_name}")

        return ChatCompletionResponse(content=content)

    except ValueError as ve:
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))