This is a Truss for the community edition of Mixtral 8x22B. This is not an optimized model. If you would like to have a more optimized version that has lower latency + higher throughput, please contact our team.
First, clone this repository:
git clone https://github.com/basetenlabs/truss-examples/
cd mistral/mixtral-8x22b
Before deployment:
- Make sure you have a Baseten account and API key.
- Install the latest version of Truss:
pip install --upgrade truss
With mixtral-8x22b
as your working directory, you can deploy the model with:
truss push --publish
Paste your Baseten API key if prompted.
For more information, see Truss documentation.
You need four A100s to run Mixtral at fp16
. If you need access to A100s, please contact us.
This section provides an overview of the Mixtral 8x22B API, its parameters, and how to use it. The API consists of a single route named predict
, which you can invoke to generate text based on the provided prompt.
The predict
route is the primary method for generating text completions based on a given prompt. It takes several parameters:
- prompt: The input text that you want the model to generate a response for.
- stream (optional, default=True): A boolean determining whether the model should stream a response back. When
True
, the API returns generated text as it becomes available. - max_tokens (optional, default=128): Determines the maximum number of tokens to generate
- temperature (optional, default=1.0): Controls the strength of the generation. The higher the temperature, the more diverse and creative the output would be.
- top_p (optional, default=0.95): Parameter used to control the randomness of the output.
- top_k (optional, default=50): Controls the vocab size considered during the generation.
import requests
import os
# Replace the empty string with your model id below
model_id = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
data = {
"prompt": "What is mistral wind?",
"stream": True,
"max_tokens": 256,
"temperature": 0.9
}
# Call model endpoint
res = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=data,
stream=True
)
# Print the generated tokens as they get streamed
for content in res.iter_content():
print(content.decode("utf-8"), end="", flush=True)