Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/cli/explorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)

signals.RegisterGracefulTerminationHandler(func() {
if err := appHTTP.Shutdown(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := appHTTP.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("error during shutdown")
}
})

return appHTTP.Listen(e.Address)
return appHTTP.Start(e.Address)
}
2 changes: 1 addition & 1 deletion core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}
})

return appHTTP.Listen(r.Address)
return appHTTP.Start(r.Address)
}
223 changes: 112 additions & 111 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@
"embed"
"errors"
"fmt"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"

"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/websocket/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"

"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"

"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"

"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"

// swagger handler
"github.com/rs/zerolog/log"
)

Expand All @@ -49,120 +42,126 @@
// @in header
// @name Authorization

func API(application *application.Application) (*fiber.App, error) {
func API(application *application.Application) (*echo.Echo, error) {
e := echo.New()

fiberCfg := fiber.Config{
Views: renderEngine(),
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
// Set body limit
if application.ApplicationConfig().UploadLimitMB > 0 {
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
}

// Set error handler
if !application.ApplicationConfig().OpaqueErrors {
// Normally, return errors as JSON responses
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError

// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}

// Handle 404 errors with HTML rendering when appropriate
if code == http.StatusNotFound {
notFoundHandler(c)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
return
}

// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
c.JSON(code, schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
})
Comment on lines +69 to +71

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
}
} else {
// If OpaqueErrors are required, replace everything with a blank 500.
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
return ctx.Status(500).SendString("")
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
c.NoContent(code)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
}
}

router := fiber.New(fiberCfg)
// Set renderer
e.Renderer = renderEngine()

router.Use(middleware.StripPathPrefix())
// Hide banner
e.HideBanner = true

if application.ApplicationConfig().MachineTag != "" {
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
e.Pre(httpMiddleware.StripPathPrefix())

return c.Next()
if application.ApplicationConfig().MachineTag != "" {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return next(c)
}
})
}

router.Use("/v1/realtime", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
// Returns true if the client requested upgrade to the WebSocket protocol
return c.Next()
}

return nil
})

router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
// Custom logger middleware using zerolog
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
res := c.Response()
start := log.Logger.Info()
err := next(c)
start.
Str("method", req.Method).
Str("path", req.URL.Path).
Int("status", res.Status).
Msg("HTTP request")
return err
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})

// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))

// Default middleware config

// Recover middleware
if !application.ApplicationConfig().Debug {
router.Use(recover.New())
e.Use(middleware.Recover())
}

// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
return nil, err
}

if metricsService != nil {
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
router.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
e.Server.RegisterOnShutdown(func() {
metricsService.Shutdown()

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
})
}
}

// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router)
routes.HealthRoutes(e)

kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil || kaConfig == nil {
// Get key auth middleware
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}

httpFS := http.FS(embedDirStatic)

router.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})

router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
}
e.StaticFS("/static", staticFS)

// Generated content directories
if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
Expand All @@ -173,51 +172,53 @@
os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750)

router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
e.Static("/generated-audio", audioPath)
e.Static("/generated-images", imagePath)
e.Static("/generated-videos", videoPath)
}

// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
router.Use(v2keyauth.New(*kaConfig))
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
e.Use(keyAuthMiddleware)

// CORS middleware
if application.ApplicationConfig().CORS {
var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig().CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
corsConfig := middleware.CORSConfig{}
if application.ApplicationConfig().CORSAllowOrigins != "" {
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
}

router.Use(c)
e.Use(middleware.CORSWithConfig(corsConfig))
}

// CSRF middleware
if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
router.Use(csrf.New())
e.Use(middleware.CSRF())
}

requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())

routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())

routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())

// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
}
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)

routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())

// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route

// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)
// Log startup message
e.Server.RegisterOnShutdown(func() {
log.Info().Msg("LocalAI API server shutting down")
})

return router, nil
return e, nil
}
Loading
Loading