Skip to content

Latest commit

 

History

History
78 lines (54 loc) · 2.75 KB

README.md

File metadata and controls

78 lines (54 loc) · 2.75 KB

Mixtral 8x22B Truss

This is a Truss for the community edition of Mixtral 8x22B. This is not an optimized model. If you would like to have a more optimized version that has lower latency + higher throughput, please contact our team.

Deployment

First, clone this repository:

git clone https://github.com/basetenlabs/truss-examples/
cd mistral/mixtral-8x22b

Before deployment:

  1. Make sure you have a Baseten account and API key.
  2. Install the latest version of Truss: pip install --upgrade truss

With mixtral-8x22b as your working directory, you can deploy the model with:

truss push --publish

Paste your Baseten API key if prompted.

For more information, see Truss documentation.

Hardware notes

You need four A100s to run Mixtral at fp16. If you need access to A100s, please contact us.

Mixtral 8x22B API documentation

This section provides an overview of the Mixtral 8x22B API, its parameters, and how to use it. The API consists of a single route named predict, which you can invoke to generate text based on the provided prompt.

API route: predict

The predict route is the primary method for generating text completions based on a given prompt. It takes several parameters:

  • prompt: The input text that you want the model to generate a response for.
  • stream (optional, default=True): A boolean determining whether the model should stream a response back. When True, the API returns generated text as it becomes available.
  • max_tokens (optional, default=128): Determines the maximum number of tokens to generate
  • temperature (optional, default=1.0): Controls the strength of the generation. The higher the temperature, the more diverse and creative the output would be.
  • top_p (optional, default=0.95): Parameter used to control the randomness of the output.
  • top_k (optional, default=50): Controls the vocab size considered during the generation.

Example usage

import requests
import os

# Replace the empty string with your model id below
model_id = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]

data = {
    "prompt": "What is mistral wind?",
    "stream": True,
    "max_tokens": 256,
    "temperature": 0.9
}

# Call model endpoint
res = requests.post(
    f"https://model-{model_id}.api.baseten.co/production/predict",
    headers={"Authorization": f"Api-Key {baseten_api_key}"},
    json=data,
    stream=True
)

# Print the generated tokens as they get streamed
for content in res.iter_content():
    print(content.decode("utf-8"), end="", flush=True)