Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add API_KEY list support #877

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 46 additions & 19 deletions api/api.go
Expand Up @@ -2,6 +2,7 @@ package api

import (
"errors"
"strings"

config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/localai"
Expand Down Expand Up @@ -89,6 +90,32 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
// Default middleware config
app.Use(recover.New())

// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) > 0 {
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}

apiKey := authHeaderParts[1]
validApiKey := false
for _, key := range options.ApiKeys {
if apiKey == key {
validApiKey = true
}
}
if !validApiKey {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
}
return c.Next()
}

if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
return nil, err
Expand Down Expand Up @@ -116,42 +143,42 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cm)

app.Get("/version", func(c *fiber.Ctx) error {
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})

app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService))
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))

// openAI compatible API endpoint

// chat
app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options))
app.Post("/chat/completions", openai.ChatEndpoint(cm, options))
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options))

// edit
app.Post("/v1/edits", openai.EditEndpoint(cm, options))
app.Post("/edits", openai.EditEndpoint(cm, options))
app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options))
app.Post("/edits", auth, openai.EditEndpoint(cm, options))

// completion
app.Post("/v1/completions", openai.CompletionEndpoint(cm, options))
app.Post("/completions", openai.CompletionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options))
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options))

// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options))
app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))

// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options))
app.Post("/tts", localai.TTSEndpoint(cm, options))
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options))
app.Post("/tts", auth, localai.TTSEndpoint(cm, options))

// images
app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options))
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options))

if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
Expand All @@ -170,8 +197,8 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Get("/readyz", ok)

// models
app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm))

// turn off any process that was started by GRPC if the context is canceled
go func() {
Expand Down
8 changes: 8 additions & 0 deletions api/options/options.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"encoding/json"
"strings"

"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
Expand All @@ -23,6 +24,7 @@ type Option struct {
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string

Galleries []gallery.Gallery

Expand Down Expand Up @@ -184,3 +186,9 @@ func WithImageDir(imageDir string) AppOption {
o.ImageDir = imageDir
}
}

func WithApiKey(apiKeys string) AppOption {
return func(o *Option) {
o.ApiKeys = strings.Split(apiKeys, ",")
}
}
6 changes: 6 additions & 0 deletions main.go
Expand Up @@ -130,6 +130,11 @@ func main() {
EnvVars: []string{"UPLOAD_LIMIT"},
Value: 15,
},
&cli.StringFlag{
neboman11 marked this conversation as resolved.
Show resolved Hide resolved
Name: "api-keys",
Usage: "Comma delimited list of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
EnvVars: []string{"API_KEY"},
},
},
Description: `
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
Expand Down Expand Up @@ -167,6 +172,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
options.WithUploadLimitMB(ctx.Int("upload-limit")),
options.WithApiKey(ctx.String("api-keys")),
}

externalgRPC := ctx.StringSlice("external-grpc-backends")
Expand Down