Skip to content

Lightning-AI/litserve

Repository files navigation

LitGPT

 

LitServe

High-throughput serving engine for AI models

✅ Batching     ✅ Streaming     ✅ Auto-GPU, multi-GPU     ✅ Multi-modal     ✅ PyTorch/JAX/TF     ✅ Full control     ✅ Auth


Open In Studio

PyPI - Python Version cpu-tests license Discord

Lightning AIGet startedExamplesFeatures

 

Deploy AI models Lightning fast ⚡

LitServe is a high-throughput serving engine for deploying AI models at scale. LitServe generates an API endpoint for a model, handles batching, streaming, autoscaling across CPU/GPUs and more.

Why we wrote LitServe:

  1. Work with any model: LLMs, vision, time-series, etc...
  2. We wanted a zero abstraction, minimal, hackable code-base without bloat.
  3. Built for enterprise scale (not demos, etc...).
  4. Easy enough for researchers, scalable and hackable for engineers.
  5. Work on any hardware (GPU/TPU) automatically.
  6. Let you focus on model performance, not the serving boilerplate.

Think of LitServe as PyTorch Lightning for model serving (if you're familiar with Lightning) but supports every framework like PyTorch, JAX, Tensorflow and more.

 

Examples

Explore various examples that show different models deployed with LitServe:

Example description Run
Hello world Hello world model Open In Studio
Llama 3 (8B) (LLM) Deploy Llama 3 Open In Studio
ANY Hugging face model (Text) Deploy any Hugging Face model Open In Studio
Hugging face BERT model (Text) Deploy model for tasks like text generation and more Open In Studio
Open AI CLIP (Multimodal) Deploy Open AI CLIP for tasks like image understanding Open In Studio
Open AI Whisper (Audio) Deploy Open AI Whisper for tasks like speech to text Open In Studio
Stable diffusion 2 (Vision) Deploy Stable diffusion 2 for tasks like image generation Open In Studio

 

Install LitServe

Install LitServe via pip:

pip install litserve
Advanced install options  

Install the main branch:

pip install git+https://github.com/Lightning-AI/litserve.git@main

 

Install from source:

git clone https://github.com/Lightning-AI/litserve
cd litserve
pip install -e '.[all]'

 

Get started

LitServe has a minimal API that allows enterprise-scale, with full control.

  1. Implement the LitAPI class which describes the inference process for the model(s).
  2. Enable the specific optimizations (such as batching or streaming) in the LitServer.

Implement a server

Here's a hello world example:

Open In Studio
# server.py
import litserve as ls

# STEP 1: DEFINE YOUR MODEL API
class SimpleLitAPI(ls.LitAPI):
    def setup(self, device):
        # Setup the model so it can be called in `predict`.
        self.model = lambda x: x**2

    def decode_request(self, request):
        # Convert the request payload to your model input.
        return request["input"]

    def predict(self, x):
        # Run the model on the input and return the output.
        return self.model(x)

    def encode_response(self, output):
        # Convert the model output to a response payload.
        return {"output": output}


# STEP 2: START THE SERVER
if __name__ == "__main__":
    api = SimpleLitAPI()
    server = ls.LitServer(api, accelerator="auto")
    server.run(port=8000)

Now run the server via the command-line

python server.py

Use the server

LitServe automatically generates a client when it starts. Use this client to test the server:

python client.py

Or ping the server yourself directly

import requests
response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})

The server expects the client to send a POST to the /predict URL with a JSON payload. The way the payload is structured is up to the implementation of the LitAPI subclass.

 

Features

LitServe supports multiple advanced state-of-the-art features.

Feature description
Accelerators CPU, GPU, Multi-GPU, mps
Auto-GPU Detects and auto-runs on all GPUs on a machine
Model types LLMs, Vision, Time series, any model type...
ML frameworks PyTorch, Jax, Tensorflow, numpy, etc...
Batching
API authentication
Multiple models in a single API
Full request/response control
Automatic schema validation
Handle timeouts
Handle disconnects
Streaming

Note

Our goal is not to jump on every hype train, but instead support features that scale under the most demanding enterprise deployments.

Feature details

Explore each feature in detail:

Use accelerators automatically (GPUs, CPU, mps)  

LitServe automatically detects GPUs on a machine and uses them when available:

import litserve as ls
from litserve.examples import SimpleLitAPI

# Automatically selects the available accelerator
api = SimpleLitAPI() # defined by you with ls.LitAPI

# when running on GPUs these are equivalent. It's best to let Lightning decide by not specifying it!
server = ls.LitServer(api)
server = ls.LitServer(api, accelerator="cuda")
server = ls.LitServer(api, accelerator="auto")

LitServer accepts an accelerator argument which defaults to "auto". It can also be explicitly set to "cpu", "cuda", or "mps" if you wish to manually control the device placement.

The following example shows how to set the accelerator manually:

import litserve as ls
from litserve.examples import SimpleLitAPI

# Run on CUDA-supported GPUs
server = ls.LitServer(SimpleLitAPI(), accelerator="cuda")

# Run on Apple's Metal-powered GPUs
server = ls.LitServer(SimpleLitAPI(), accelerator="mps")
Serve on multi-GPUs

 

LitServer has the ability to coordinate serving from multiple GPUs.

LitServer accepts a devices argument which defaults to "auto". On multi-GPU machines, LitServe will run a copy of the model on each device detected on the machine.

The devices argument can also be explicitly set to the desired number of devices to use on the machine.

import litserve as ls
from litserve.examples import SimpleLitAPI

# Automatically selects the available accelerators
api = SimpleLitAPI() # defined by you with ls.LitAPI

# when running on a 4-GPUs machine these are equivalent.
# It's best to let Lightning decide by not specifying accelerator and devices!
server = ls.LitServer(api)
server = ls.LitServer(api, accelerator="cuda", devices=4)
server = ls.LitServer(api, accelerator="auto", devices="auto")

For example, running the API server on a 4-GPU machine, with a PyTorch model served on each GPU:

import torch, torch.nn as nn
import litserve as ls

class Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
        self.linear.weight.data.fill_(2.0)
        self.linear.bias.data.fill_(1.0)

    def forward(self, x):
        return self.linear(x)

class SimpleTorchAPI(ls.LitAPI):
    def setup(self, device):
        # move the model to the correct device
        # keep track of the device for moving data accordingly
        self.model = Linear().to(device)
        self.device = device

    def decode_request(self, request):
        # get the input and create a 1D tensor on the correct device
        content = request["input"]
        return torch.tensor([content], device=self.device)

    def predict(self, x):
        # the model expects a batch dimension, so create it
        return self.model(x[None, :])

    def encode_response(self, output):
        # float will take the output value directly onto CPU memory
        return {"output": float(output)}


if __name__ == "__main__":
    # accelerator="auto" (or "cuda"), devices="auto" (or 4) will lead to 4 workers serving
    # the model from "cuda:0", "cuda:1", "cuda:2", "cuda:3" respectively
    server = ls.LitServer(SimpleTorchAPI(), accelerator="auto", devices="auto")
    server.run(port=8000)

The devices argument can also be an array specifying what device id to run the model on:

import litserve as ls
from litserve.examples import SimpleTorchAPI

server = ls.LitServer(SimpleTorchAPI(), accelerator="cuda", devices=[0, 3])

Last, you can run multiple copies of the same model from the same device, if the model is small. The following will load two copies of the model on each of the 4 GPUs:

import litserve as ls
from litserve.examples import SimpleTorchAPI

server = ls.LitServer(SimpleTorchAPI(), accelerator="cuda", devices=4, workers_per_device=2)
Timeouts and disconnections

 

The server will remove a queued request if the client requesting it disconnects.

You can configure a timeout (in seconds) after which clients will receive a 504 HTTP response (Gateway Timeout) indicating that their request has timed out.

For example, this is how you can configure the server with a timeout of 30 seconds per response.

import litserve as ls
from litserve.examples import SimpleLitAPI

server = ls.LitServer(SimpleLitAPI(), timeout=30)

This is useful to avoid requests queuing up beyond the ability of the server to respond.

To disable the timeout for long-running tasks, set timeout=False or timeout=-1:

import litserve as ls
from litserve.examples import SimpleLitAPI

server = ls.LitServer(SimpleLitAPI(), timeout=False)
Use API key authentication

 

In order to secure the API behind an API key, just define the env var when starting the server

LIT_SERVER_API_KEY=supersecretkey python main.py

Clients are expected to auth with the same API key set in the X-API-Key HTTP header.

Dynamic batching  

LitServe can combine individual requests into a batch to improve throughput. To enable batching, you need to set the max_batch_size argument to match the batch size that your model can handle and implement LitAPI.predict to process batched inputs.

import numpy as np
import litserve as ls

class SimpleBatchedAPI(ls.LitAPI):
    def setup(self, device) -> None:
        self.model = lambda x: x ** 2

    def decode_request(self, request):
        return np.asarray(request["input"])

    def predict(self, x):
        result = self.model(x)
        return result

    def encode_response(self, output):
        return {"output": output}

if __name__ == "__main__":
    api = SimpleBatchedAPI()
    server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.05)
    server.run(port=8000)

You can control the wait time to aggregate requests into a batch with the batch_timeout argument. In the above example, the server will wait for 0.05 seconds to combine 4 requests together.

 

LitServe automatically stacks NumPy arrays and PyTorch tensors along the batch dimension before calling the LitAPI.predict method, and splits the output across requests afterward. You can customize this behavior by overriding the LitAPI.batch and LitAPI.unbatch methods to handle different data types.

import litserve as ls
from litserve.examples import SimpleBatchedAPI
import numpy as np

class CustomBatchedAPI(SimpleBatchedAPI):
    def batch(self, inputs):
        return np.stack(inputs)

    def unbatch(self, output):
        return list(output)

if __name__ == "__main__":
    api = CustomBatchedAPI()
    server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.05)
    server.run(port=8000)
Stream long responses

 

LitServe can stream outputs from the model in real-time, such as returning text one word at a time from a language model.

To enable streaming, you need to set LitServer(..., stream=True) and implement LitAPI.predict and LitAPI.encode_response as a generator (a Python function that yields output).

For example, streaming long responses generated over time:

import json
import litserve as ls

class SimpleStreamAPI(ls.LitAPI):
    def setup(self, device) -> None:
        self.model = lambda x, y: x * y

    def decode_request(self, request):
        return request["input"]

    def predict(self, x):
        for i in range(10):
            yield self.model(x, i)

    def encode_response(self, output):
        for out in output:
            yield json.dumps({"output": out})


if __name__ == "__main__":
    api = SimpleStreamAPI()
    server = ls.LitServer(api, stream=True)
    server.run(port=8000)

 

Automatic schema validation

 

Define the request and response as Pydantic models, to automatically validate the request.

from pydantic import BaseModel
import litserve as ls


class PredictRequest(BaseModel):
    input: float


class PredictResponse(BaseModel):
    output: float


class SimpleLitAPI(ls.LitAPI):
    def setup(self, device):
        self.model = lambda x: x**2

    def decode_request(self, request: PredictRequest) -> float:
        return request.input

    def predict(self, x):
        return self.model(x)

    def encode_response(self, output: float) -> PredictResponse:
        return PredictResponse(output=output)


if __name__ == "__main__":
    api = SimpleLitAPI()
    server = ls.LitServer(api, accelerator="auto")
    server.run(port=8000)

Contribute

LitServe is a community project accepting contributions. Let's make the world's most advanced AI inference engine.

Run tests

Use pytest to run tests locally.

First, install test dependencies:

pip install -r _requirements/test.txt

Run the tests

pytest tests

License

litserve is released under the Apache 2.0 license. See LICENSE file for details.

About

Deploy AI models at scale. High-throughput serving engine for AI/ML models that uses the latest state-of-the-art model deployment techniques.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages