From 9e187b741b5efa142bb61bdec2ec45d606de4498 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 14 Nov 2025 13:09:05 +0100 Subject: [PATCH 1/2] WIP: migrate to echo Signed-off-by: Ettore Di Giacinto --- core/cli/explorer.go | 6 +- core/cli/run.go | 2 +- core/http/app.go | 223 ++-- core/http/app_test.go | 46 +- .../endpoints/elevenlabs/soundgeneration.go | 18 +- core/http/endpoints/elevenlabs/tts.go | 23 +- core/http/endpoints/explorer/dashboard.go | 46 +- core/http/endpoints/jina/rerank.go | 19 +- core/http/endpoints/localai/backend.go | 50 +- .../http/endpoints/localai/backend_monitor.go | 90 +- core/http/endpoints/localai/detection.go | 16 +- core/http/endpoints/localai/edit_model.go | 67 +- .../http/endpoints/localai/edit_model_test.go | 31 +- core/http/endpoints/localai/gallery.go | 320 +++--- .../endpoints/localai/get_token_metrics.go | 12 +- core/http/endpoints/localai/import_model.go | 56 +- core/http/endpoints/localai/metrics.go | 93 +- core/http/endpoints/localai/p2p.go | 12 +- core/http/endpoints/localai/stores.go | 42 +- core/http/endpoints/localai/system.go | 8 +- core/http/endpoints/localai/tokenize.go | 16 +- core/http/endpoints/localai/tts.go | 18 +- core/http/endpoints/localai/vad.go | 18 +- core/http/endpoints/localai/video.go | 19 +- core/http/endpoints/localai/welcome.go | 18 +- core/http/endpoints/openai/chat.go | 265 ++--- core/http/endpoints/openai/completion.go | 168 ++-- core/http/endpoints/openai/edit.go | 18 +- core/http/endpoints/openai/embeddings.go | 16 +- core/http/endpoints/openai/image.go | 19 +- core/http/endpoints/openai/list.go | 13 +- core/http/endpoints/openai/mcp.go | 25 +- core/http/endpoints/openai/realtime.go | 50 +- core/http/endpoints/openai/transcription.go | 18 +- core/http/endpoints/openai/video.go | 20 +- core/http/explorer.go | 66 +- core/http/middleware/auth.go | 162 ++- core/http/middleware/request.go | 952 +++++++++--------- core/http/middleware/strippathprefix.go | 50 +- core/http/middleware/strippathprefix_test.go | 22 +- core/http/openai_videos_test.go | 14 +- core/http/render.go | 67 +- core/http/routes/elevenlabs.go | 18 +- core/http/routes/explorer.go | 10 +- core/http/routes/health.go | 14 +- core/http/routes/jina.go | 11 +- core/http/routes/localai.go | 116 ++- core/http/routes/openai.go | 166 +-- core/http/routes/ui.go | 73 +- core/http/routes/ui_api.go | 175 ++-- core/http/routes/ui_backend_gallery.go | 10 +- core/http/routes/ui_gallery.go | 10 +- core/http/utils/baseurl.go | 30 +- core/http/utils/baseurl_test.go | 14 +- go.mod | 23 +- go.sum | 29 +- 56 files changed, 2089 insertions(+), 1824 deletions(-) diff --git a/core/cli/explorer.go b/core/cli/explorer.go index b12735c735b8..cfebee07d983 100644 --- a/core/cli/explorer.go +++ b/core/cli/explorer.go @@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error { appHTTP := http.Explorer(db) signals.RegisterGracefulTerminationHandler(func() { - if err := appHTTP.Shutdown(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := appHTTP.Shutdown(ctx); err != nil { log.Error().Err(err).Msg("error during shutdown") } }) - return appHTTP.Listen(e.Address) + return appHTTP.Start(e.Address) } diff --git a/core/cli/run.go b/core/cli/run.go index 560b2d8f2d08..9c21eaa4b08f 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { } }) - return appHTTP.Listen(r.Address) + return appHTTP.Start(r.Address) } diff --git a/core/http/app.go b/core/http/app.go index 916ea9600b81..731e69df565c 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -4,30 +4,23 @@ import ( "embed" "errors" "fmt" + "io/fs" "net/http" "os" "path/filepath" + "strings" - "github.com/dave-gray101/v2keyauth" - "github.com/gofiber/websocket/v2" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" "github.com/mudler/LocalAI/core/http/endpoints/localai" - "github.com/mudler/LocalAI/core/http/middleware" + httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" - "github.com/gofiber/contrib/fiberzerolog" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/csrf" - "github.com/gofiber/fiber/v2/middleware/favicon" - "github.com/gofiber/fiber/v2/middleware/filesystem" - "github.com/gofiber/fiber/v2/middleware/recover" - - // swagger handler "github.com/rs/zerolog/log" ) @@ -49,85 +42,85 @@ var embedDirStatic embed.FS // @in header // @name Authorization -func API(application *application.Application) (*fiber.App, error) { +func API(application *application.Application) (*echo.Echo, error) { + e := echo.New() - fiberCfg := fiber.Config{ - Views: renderEngine(), - BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - // We disable the Fiber startup message as it does not conform to structured logging. - // We register a startup log line with connection information in the OnListen hook to keep things user friendly though - DisableStartupMessage: true, - // Override default error handler + // Set body limit + if application.ApplicationConfig().UploadLimitMB > 0 { + e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB))) } + // Set error handler if !application.ApplicationConfig().OpaqueErrors { - // Normally, return errors as JSON responses - fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error { - // Status code defaults to 500 - code := fiber.StatusInternalServerError - - // Retrieve the custom status code if it's a *fiber.Error - var e *fiber.Error - if errors.As(err, &e) { - code = e.Code + e.HTTPErrorHandler = func(err error, c echo.Context) { + code := http.StatusInternalServerError + var he *echo.HTTPError + if errors.As(err, &he) { + code = he.Code + } + + // Handle 404 errors with HTML rendering when appropriate + if code == http.StatusNotFound { + notFoundHandler(c) + return } // Send custom error page - return ctx.Status(code).JSON( - schema.ErrorResponse{ - Error: &schema.APIError{Message: err.Error(), Code: code}, - }, - ) + c.JSON(code, schema.ErrorResponse{ + Error: &schema.APIError{Message: err.Error(), Code: code}, + }) } } else { - // If OpaqueErrors are required, replace everything with a blank 500. - fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error { - return ctx.Status(500).SendString("") + e.HTTPErrorHandler = func(err error, c echo.Context) { + code := http.StatusInternalServerError + var he *echo.HTTPError + if errors.As(err, &he) { + code = he.Code + } + c.NoContent(code) } } - router := fiber.New(fiberCfg) + // Set renderer + e.Renderer = renderEngine() - router.Use(middleware.StripPathPrefix()) + // Hide banner + e.HideBanner = true - if application.ApplicationConfig().MachineTag != "" { - router.Use(func(c *fiber.Ctx) error { - c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag) + // Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing + e.Pre(httpMiddleware.StripPathPrefix()) - return c.Next() + if application.ApplicationConfig().MachineTag != "" { + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag) + return next(c) + } }) } - router.Use("/v1/realtime", func(c *fiber.Ctx) error { - if websocket.IsWebSocketUpgrade(c) { - // Returns true if the client requested upgrade to the WebSocket protocol - return c.Next() - } - - return nil - }) - - router.Hooks().OnListen(func(listenData fiber.ListenData) error { - scheme := "http" - if listenData.TLS { - scheme = "https" + // Custom logger middleware using zerolog + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + res := c.Response() + start := log.Logger.Info() + err := next(c) + start. + Str("method", req.Method). + Str("path", req.URL.Path). + Int("status", res.Status). + Msg("HTTP request") + return err } - log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.") - return nil }) - // Have Fiber use zerolog like the rest of the application rather than it's built-in logger - logger := log.Logger - router.Use(fiberzerolog.New(fiberzerolog.Config{ - Logger: &logger, - })) - - // Default middleware config - + // Recover middleware if !application.ApplicationConfig().Debug { - router.Use(recover.New()) + e.Use(middleware.Recover()) } + // Metrics middleware if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() if err != nil { @@ -135,34 +128,40 @@ func API(application *application.Application) (*fiber.App, error) { } if metricsService != nil { - router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) - router.Hooks().OnShutdown(func() error { - return metricsService.Shutdown() + e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + e.Server.RegisterOnShutdown(func() { + metricsService.Shutdown() }) } } + // Health Checks should always be exempt from auth, so register these first - routes.HealthRoutes(router) + routes.HealthRoutes(e) - kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig()) - if err != nil || kaConfig == nil { + // Get key auth middleware + keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig()) + if err != nil { return nil, fmt.Errorf("failed to create key auth config: %w", err) } - httpFS := http.FS(embedDirStatic) - - router.Use(favicon.New(favicon.Config{ - URL: "/favicon.svg", - FileSystem: httpFS, - File: "static/favicon.svg", - })) + // Favicon handler + e.GET("/favicon.svg", func(c echo.Context) error { + data, err := embedDirStatic.ReadFile("static/favicon.svg") + if err != nil { + return c.NoContent(http.StatusNotFound) + } + c.Response().Header().Set("Content-Type", "image/svg+xml") + return c.Blob(http.StatusOK, "image/svg+xml", data) + }) - router.Use("/static", filesystem.New(filesystem.Config{ - Root: httpFS, - PathPrefix: "static", - Browse: true, - })) + // Static files - use fs.Sub to create a filesystem rooted at "static" + staticFS, err := fs.Sub(embedDirStatic, "static") + if err != nil { + return nil, fmt.Errorf("failed to create static filesystem: %w", err) + } + e.StaticFS("/static", staticFS) + // Generated content directories if application.ApplicationConfig().GeneratedContentDir != "" { os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750) audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio") @@ -173,51 +172,53 @@ func API(application *application.Application) (*fiber.App, error) { os.MkdirAll(imagePath, 0750) os.MkdirAll(videoPath, 0750) - router.Static("/generated-audio", audioPath) - router.Static("/generated-images", imagePath) - router.Static("/generated-videos", videoPath) + e.Static("/generated-audio", audioPath) + e.Static("/generated-images", imagePath) + e.Static("/generated-videos", videoPath) } - // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration - router.Use(v2keyauth.New(*kaConfig)) + // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration + e.Use(keyAuthMiddleware) + // CORS middleware if application.ApplicationConfig().CORS { - var c func(ctx *fiber.Ctx) error - if application.ApplicationConfig().CORSAllowOrigins == "" { - c = cors.New() - } else { - c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins}) + corsConfig := middleware.CORSConfig{} + if application.ApplicationConfig().CORSAllowOrigins != "" { + corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",") } - - router.Use(c) + e.Use(middleware.CORSWithConfig(corsConfig)) } + // CSRF middleware if application.ApplicationConfig().CSRF { log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests") - router.Use(csrf.New()) + e.Use(middleware.CSRF()) } - requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + + routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) - routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) - // Create opcache for tracking UI operations (used by both UI and LocalAI routes) var opcache *services.OpCache if !application.ApplicationConfig().DisableWebUI { opcache = services.NewOpCache(application.GalleryService()) } - - routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) - routes.RegisterOpenAIRoutes(router, requestExtractor, application) + + routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) + routes.RegisterOpenAIRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { - routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) - routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) + routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) + routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) } - routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + + // Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route - // Define a custom 404 handler - // Note: keep this at the bottom! - router.Use(notFoundHandler) + // Log startup message + e.Server.RegisterOnShutdown(func() { + log.Info().Msg("LocalAI API server shutting down") + }) - return router, nil + return e, nil } diff --git a/core/http/app_test.go b/core/http/app_test.go index 2d4ff6d06571..bce9c56903bb 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -10,13 +10,14 @@ import ( "os" "path/filepath" "runtime" + "time" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/schema" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/system" @@ -25,6 +26,7 @@ import ( "gopkg.in/yaml.v3" openaigo "github.com/otiai10/openaigo" + "github.com/rs/zerolog/log" "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8 var _ = Describe("API test", func() { - var app *fiber.App + var app *echo.Echo var client *openai.Client var client2 *openaigo.Client var c context.Context @@ -339,7 +341,11 @@ var _ = Describe("API test", func() { app, err = API(application) Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("server error") + } + }() defaultConfig := openai.DefaultConfig(apiKey) defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" @@ -358,7 +364,9 @@ var _ = Describe("API test", func() { AfterEach(func(sc SpecContext) { cancel() if app != nil { - err := app.Shutdown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } err := os.RemoveAll(tmpdir) @@ -547,7 +555,11 @@ var _ = Describe("API test", func() { app, err = API(application) Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("server error") + } + }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" @@ -566,7 +578,9 @@ var _ = Describe("API test", func() { AfterEach(func() { cancel() if app != nil { - err := app.Shutdown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } err := os.RemoveAll(tmpdir) @@ -755,7 +769,11 @@ var _ = Describe("API test", func() { Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("server error") + } + }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" @@ -773,7 +791,9 @@ var _ = Describe("API test", func() { AfterEach(func() { cancel() if app != nil { - err := app.Shutdown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } }) @@ -1006,7 +1026,11 @@ var _ = Describe("API test", func() { app, err = API(application) Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("server error") + } + }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" @@ -1022,7 +1046,9 @@ var _ = Describe("API test", func() { AfterEach(func() { cancel() if app != nil { - err := app.Shutdown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } }) diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index 53da894e1c9b..4f6cfab8b65c 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -1,7 +1,9 @@ package elevenlabs import ( - "github.com/gofiber/fiber/v2" + "path/filepath" + + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -15,17 +17,17 @@ import ( // @Param request body schema.ElevenLabsSoundGenerationRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/sound-generation [post] -func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) if !ok || input.ModelID == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") @@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader if err != nil { return err } - return c.Download(filePath) + return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index dac5de70bf05..dc1356c767c6 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -1,13 +1,14 @@ package elevenlabs import ( + "path/filepath" + + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/pkg/model" - - "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) @@ -17,19 +18,19 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] -func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { - voiceID := c.Params("voice-id") + voiceID := c.Param("voice-id") - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest) if !ok || input.ModelID == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received") @@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig if err != nil { return err } - return c.Download(filePath) + return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 3c8966819c9c..34b57fc6657c 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -2,28 +2,32 @@ package explorer import ( "encoding/base64" + "net/http" "sort" + "strings" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/internal" ) -func Dashboard() func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { - summary := fiber.Map{ +func Dashboard() echo.HandlerFunc { + return func(c echo.Context) error { + summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": utils.BaseURL(c), } - if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { + contentType := c.Request().Header.Get("Content-Type") + accept := c.Request().Header.Get("Accept") + if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) { // The client expects a JSON response - return c.Status(fiber.StatusOK).JSON(summary) + return c.JSON(http.StatusOK, summary) } else { // Render index - return c.Render("views/explorer", summary) + return c.Render(http.StatusOK, "views/explorer", summary) } } } @@ -39,8 +43,8 @@ type Network struct { Token string `json:"token"` } -func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func ShowNetworks(db *explorer.Database) echo.HandlerFunc { + return func(c echo.Context) error { results := []Network{} for _, token := range db.TokenList() { networkData, exists := db.Get(token) // get the token data @@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error { return len(results[i].Clusters) > len(results[j].Clusters) }) - return c.JSON(results) + return c.JSON(http.StatusOK, results) } } -func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func AddNetwork(db *explorer.Database) echo.HandlerFunc { + return func(c echo.Context) error { request := new(AddNetworkRequest) - if err := c.BodyParser(request); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + if err := c.Bind(request); err != nil { + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"}) } if request.Token == "" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"}) } if request.Name == "" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"}) } if request.Description == "" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"}) } // TODO: check if token is valid, otherwise reject // try to decode the token from base64 _, err := base64.StdEncoding.DecodeString(request.Token) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"}) } if _, exists := db.Get(request.Token); exists { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"}) } err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description}) if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"}) } - return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"}) + return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"}) } } diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 7d9270247f4a..75d3de31410e 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -1,11 +1,12 @@ package jina import ( + "net/http" + + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - - "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" @@ -17,17 +18,17 @@ import ( // @Param request body schema.JINARerankRequest true "query params" // @Success 200 {object} schema.JINARerankResponse "Response" // @Router /v1/rerank [post] -func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received") @@ -58,6 +59,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app response.Usage.TotalTokens = int(results.Usage.TotalTokens) response.Usage.PromptTokens = int(results.Usage.PromptTokens) - return c.Status(fiber.StatusOK).JSON(response) + return c.JSON(http.StatusOK, response) } } diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go index e2f5f5635a21..80f47f658431 100644 --- a/core/http/endpoints/localai/backend.go +++ b/core/http/endpoints/localai/backend.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" @@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste // @Summary Returns the job status // @Success 200 {object} services.GalleryOpStatus "Response" // @Router /backends/jobs/{uuid} [get] -func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - status := mgs.backendApplier.GetStatus(c.Params("uuid")) +func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + status := mgs.backendApplier.GetStatus(c.Param("uuid")) if status == nil { return fmt.Errorf("could not find any status for ID") } - return c.JSON(status) + return c.JSON(200, status) } } @@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro // @Summary Returns all the jobs status progress // @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Router /backends/jobs [get] -func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - return c.JSON(mgs.backendApplier.GetAllStatus()) +func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + return c.JSON(200, mgs.backendApplier.GetAllStatus()) } } @@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err // @Param request body GalleryBackend true "query params" // @Success 200 {object} schema.BackendResponse "Response" // @Router /backends/apply [post] -func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { input := new(GalleryBackend) // Get input data from the request body - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err Galleries: mgs.galleries, } - return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) } } @@ -91,9 +91,9 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err // @Param name path string true "Backend name" // @Success 200 {object} schema.BackendResponse "Response" // @Router /backends/delete/{name} [post] -func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - backendName := c.Params("name") +func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + backendName := c.Param("name") mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ Delete: true, @@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er return err } - return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) } } @@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er // @Summary List all Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends [get] -func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { + return func(c echo.Context) error { backends, err := gallery.ListSystemBackends(systemState) if err != nil { return err } - return c.JSON(backends.GetAll()) + return c.JSON(200, backends.GetAll()) } } @@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst // @Success 200 {object} []config.Gallery "Response" // @Router /backends/galleries [get] // NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! -func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries) dat, err := json.Marshal(mgs.galleries) if err != nil { return err } - return c.Send(dat) + return c.Blob(200, "application/json", dat) } } @@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber. // @Summary List all available Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends/available [get] -func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { + return func(c echo.Context) error { backends, err := gallery.AvailableBackends(mgs.galleries, systemState) if err != nil { return err } - return c.JSON(backends) + return c.JSON(200, backends) } } diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index a1b93ac38bb9..18016c579220 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -1,45 +1,45 @@ -package localai - -import ( - "github.com/gofiber/fiber/v2" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" -) - -// BackendMonitorEndpoint returns the status of the specified backend -// @Summary Backend monitor endpoint -// @Param request body schema.BackendMonitorRequest true "Backend statistics request" -// @Success 200 {object} proto.StatusResponse "Response" -// @Router /backend/monitor [get] -func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - input := new(schema.BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - resp, err := bm.CheckAndSample(input.Model) - if err != nil { - return err - } - return c.JSON(resp) - } -} - -// BackendShutdownEndpoint shuts down the specified backend -// @Summary Backend monitor endpoint -// @Param request body schema.BackendMonitorRequest true "Backend statistics request" -// @Router /backend/shutdown [post] -func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(schema.BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - return bm.ShutdownModel(input.Model) - } -} +package localai + +import ( + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" +) + +// BackendMonitorEndpoint returns the status of the specified backend +// @Summary Backend monitor endpoint +// @Param request body schema.BackendMonitorRequest true "Backend statistics request" +// @Success 200 {object} proto.StatusResponse "Response" +// @Router /backend/monitor [get] +func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { + return func(c echo.Context) error { + + input := new(schema.BackendMonitorRequest) + // Get input data from the request body + if err := c.Bind(input); err != nil { + return err + } + + resp, err := bm.CheckAndSample(input.Model) + if err != nil { + return err + } + return c.JSON(200, resp) + } +} + +// BackendShutdownEndpoint shuts down the specified backend +// @Summary Backend monitor endpoint +// @Param request body schema.BackendMonitorRequest true "Backend statistics request" +// @Router /backend/shutdown [post] +func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { + return func(c echo.Context) error { + input := new(schema.BackendMonitorRequest) + // Get input data from the request body + if err := c.Bind(input); err != nil { + return err + } + + return bm.ShutdownModel(input.Model) + } +} diff --git a/core/http/endpoints/localai/detection.go b/core/http/endpoints/localai/detection.go index c4ab249110fe..796e8f852d21 100644 --- a/core/http/endpoints/localai/detection.go +++ b/core/http/endpoints/localai/detection.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -16,17 +16,17 @@ import ( // @Param request body schema.DetectionRequest true "query params" // @Success 200 {object} schema.DetectionResponse "Response" // @Router /v1/detection [post] -func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection") @@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC } } - return c.JSON(response) + return c.JSON(200, response) } } diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go index 82500749f652..697238cf0e72 100644 --- a/core/http/endpoints/localai/edit_model.go +++ b/core/http/endpoints/localai/edit_model.go @@ -2,9 +2,11 @@ package localai import ( "fmt" + "io" + "net/http" "os" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" httpUtils "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/internal" @@ -14,15 +16,15 @@ import ( ) // GetEditModelPage renders the edit model page with current configuration -func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { - return func(c *fiber.Ctx) error { - modelName := c.Params("name") +func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + modelName := c.Param("name") if modelName == "" { response := ModelResponse{ Success: false, Error: "Model name is required", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } modelConfig, exists := cl.GetModelConfig(modelName) @@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio Success: false, Error: "Model configuration not found", } - return c.Status(404).JSON(response) + return c.JSON(http.StatusNotFound, response) } modelConfigFile := modelConfig.GetModelConfigFile() @@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio Success: false, Error: "Model configuration file not found", } - return c.Status(404).JSON(response) + return c.JSON(http.StatusNotFound, response) } configData, err := os.ReadFile(modelConfigFile) if err != nil { @@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio Success: false, Error: "Failed to read configuration file: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Render the edit page with the current configuration @@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio Version: internal.PrintableVersion(), } - return c.Render("views/model-editor", templateData) + return c.Render(http.StatusOK, "views/model-editor", templateData) } } // EditModelEndpoint handles updating existing model configurations -func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { - return func(c *fiber.Ctx) error { - modelName := c.Params("name") +func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + modelName := c.Param("name") if modelName == "" { response := ModelResponse{ Success: false, Error: "Model name is required", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } modelConfig, exists := cl.GetModelConfig(modelName) @@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Existing model configuration not found", } - return c.Status(404).JSON(response) + return c.JSON(http.StatusNotFound, response) } // Get the raw body - body := c.Body() + body, err := io.ReadAll(c.Request().Body) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to read request body: " + err.Error(), + } + return c.JSON(http.StatusBadRequest, response) + } if len(body) == 0 { response := ModelResponse{ Success: false, Error: "Request body is empty", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Check content to see if it's a valid model config @@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Failed to parse YAML: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Validate required fields @@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Name is required", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Validate the configuration @@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Error: "Validation failed", Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."}, } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Load the existing configuration @@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Model configuration not trusted: " + err.Error(), } - return c.Status(404).JSON(response) + return c.JSON(http.StatusNotFound, response) } // Write new content to file @@ -151,7 +160,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Failed to write configuration file: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Reload configurations @@ -160,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Failed to reload configurations: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Preload the model @@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Success: false, Error: "Failed to preload model: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Return success response @@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati Filename: configPath, Config: req, } - return c.JSON(response) + return c.JSON(200, response) } } // ReloadModelsEndpoint handles reloading model configurations from disk -func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { - return func(c *fiber.Ctx) error { +func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { // Reload configurations if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Preload the models @@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic Success: false, Error: "Failed to preload models: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Return success response @@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic Success: true, Message: "Model configurations reloaded successfully", } - return c.Status(fiber.StatusOK).JSON(response) + return c.JSON(http.StatusOK, response) } } diff --git a/core/http/endpoints/localai/edit_model_test.go b/core/http/endpoints/localai/edit_model_test.go index 813e91301d89..6e4c7bf936f3 100644 --- a/core/http/endpoints/localai/edit_model_test.go +++ b/core/http/endpoints/localai/edit_model_test.go @@ -3,11 +3,12 @@ package localai_test import ( "bytes" "io" + "net/http" "net/http/httptest" "os" "path/filepath" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/pkg/system" @@ -40,33 +41,33 @@ var _ = Describe("Edit Model test", func() { //modelLoader := model.NewModelLoader(systemState, true) modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath) - // Define Fiber app. - app := fiber.New() - app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig)) + // Define Echo app. + app := echo.New() + app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig)) requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`) - req := httptest.NewRequest("PUT", "/import-model", requestBody) - resp, err := app.Test(req, 5000) - Expect(err).ToNot(HaveOccurred()) + req := httptest.NewRequest("POST", "/import-model", requestBody) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) - body, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + body, err := io.ReadAll(rec.Body) Expect(err).ToNot(HaveOccurred()) Expect(string(body)).To(ContainSubstring("Model configuration created successfully")) - Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) + Expect(rec.Code).To(Equal(http.StatusOK)) - app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig)) + app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig)) requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`) req = httptest.NewRequest("GET", "/edit-model/foo", requestBody) - resp, _ = app.Test(req, 1) + rec = httptest.NewRecorder() + app.ServeHTTP(rec, req) - body, err = io.ReadAll(resp.Body) - defer resp.Body.Close() + body, err = io.ReadAll(rec.Body) Expect(err).ToNot(HaveOccurred()) Expect(string(body)).To(ContainSubstring(`"model":"foo"`)) - Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) + Expect(rec.Code).To(Equal(http.StatusOK)) }) }) }) diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index d8079bcc5128..1938a9eb5bed 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -1,160 +1,160 @@ -package localai - -import ( - "encoding/json" - "fmt" - - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/http/utils" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/system" - "github.com/rs/zerolog/log" -) - -type ModelGalleryEndpointService struct { - galleries []config.Gallery - backendGalleries []config.Gallery - modelPath string - galleryApplier *services.GalleryService -} - -type GalleryModel struct { - ID string `json:"id"` - gallery.GalleryModel -} - -func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService { - return ModelGalleryEndpointService{ - galleries: galleries, - backendGalleries: backendGalleries, - modelPath: systemState.Model.ModelsPath, - galleryApplier: galleryApplier, - } -} - -// GetOpStatusEndpoint returns the job status -// @Summary Returns the job status -// @Success 200 {object} services.GalleryOpStatus "Response" -// @Router /models/jobs/{uuid} [get] -func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - status := mgs.galleryApplier.GetStatus(c.Params("uuid")) - if status == nil { - return fmt.Errorf("could not find any status for ID") - } - return c.JSON(status) - } -} - -// GetAllStatusEndpoint returns all the jobs status progress -// @Summary Returns all the jobs status progress -// @Success 200 {object} map[string]services.GalleryOpStatus "Response" -// @Router /models/jobs [get] -func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - return c.JSON(mgs.galleryApplier.GetAllStatus()) - } -} - -// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery -// @Summary Install models to LocalAI. -// @Param request body GalleryModel true "query params" -// @Success 200 {object} schema.GalleryResponse "Response" -// @Router /models/apply [post] -func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(GalleryModel) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - uuid, err := uuid.NewUUID() - if err != nil { - return err - } - mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ - Req: input.GalleryModel, - ID: uuid.String(), - GalleryElementName: input.ID, - Galleries: mgs.galleries, - BackendGalleries: mgs.backendGalleries, - } - - return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) - } -} - -// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance -// @Summary delete models to LocalAI. -// @Param name path string true "Model name" -// @Success 200 {object} schema.GalleryResponse "Response" -// @Router /models/delete/{name} [post] -func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelName := c.Params("name") - - mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ - Delete: true, - GalleryElementName: modelName, - } - - uuid, err := uuid.NewUUID() - if err != nil { - return err - } - - return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) - } -} - -// ListModelFromGalleryEndpoint list the available models for installation from the active galleries -// @Summary List installable models. -// @Success 200 {object} []gallery.GalleryModel "Response" -// @Router /models/available [get] -func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState) - if err != nil { - log.Error().Err(err).Msg("could not list models from galleries") - return err - } - - log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries)) - - m := []gallery.Metadata{} - - for _, mm := range models { - m = append(m, mm.Metadata) - } - - log.Debug().Msgf("Models %#v", m) - - dat, err := json.Marshal(m) - if err != nil { - return fmt.Errorf("could not marshal models: %w", err) - } - return c.Send(dat) - } -} - -// ListModelGalleriesEndpoint list the available galleries configured in LocalAI -// @Summary List all Galleries -// @Success 200 {object} []config.Gallery "Response" -// @Router /models/galleries [get] -// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! -func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - return c.Send(dat) - } -} +package localai + +import ( + "encoding/json" + "fmt" + + "github.com/labstack/echo/v4" + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/pkg/system" + "github.com/rs/zerolog/log" +) + +type ModelGalleryEndpointService struct { + galleries []config.Gallery + backendGalleries []config.Gallery + modelPath string + galleryApplier *services.GalleryService +} + +type GalleryModel struct { + ID string `json:"id"` + gallery.GalleryModel +} + +func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService { + return ModelGalleryEndpointService{ + galleries: galleries, + backendGalleries: backendGalleries, + modelPath: systemState.Model.ModelsPath, + galleryApplier: galleryApplier, + } +} + +// GetOpStatusEndpoint returns the job status +// @Summary Returns the job status +// @Success 200 {object} services.GalleryOpStatus "Response" +// @Router /models/jobs/{uuid} [get] +func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + status := mgs.galleryApplier.GetStatus(c.Param("uuid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + return c.JSON(200, status) + } +} + +// GetAllStatusEndpoint returns all the jobs status progress +// @Summary Returns all the jobs status progress +// @Success 200 {object} map[string]services.GalleryOpStatus "Response" +// @Router /models/jobs [get] +func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + return c.JSON(200, mgs.galleryApplier.GetAllStatus()) + } +} + +// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery +// @Summary Install models to LocalAI. +// @Param request body GalleryModel true "query params" +// @Success 200 {object} schema.GalleryResponse "Response" +// @Router /models/apply [post] +func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + input := new(GalleryModel) + // Get input data from the request body + if err := c.Bind(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + Req: input.GalleryModel, + ID: uuid.String(), + GalleryElementName: input.ID, + Galleries: mgs.galleries, + BackendGalleries: mgs.backendGalleries, + } + + return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) + } +} + +// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance +// @Summary delete models to LocalAI. +// @Param name path string true "Model name" +// @Success 200 {object} schema.GalleryResponse "Response" +// @Router /models/delete/{name} [post] +func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + modelName := c.Param("name") + + mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + Delete: true, + GalleryElementName: modelName, + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + + return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) + } +} + +// ListModelFromGalleryEndpoint list the available models for installation from the active galleries +// @Summary List installable models. +// @Success 200 {object} []gallery.GalleryModel "Response" +// @Router /models/available [get] +func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc { + return func(c echo.Context) error { + + models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState) + if err != nil { + log.Error().Err(err).Msg("could not list models from galleries") + return err + } + + log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries)) + + m := []gallery.Metadata{} + + for _, mm := range models { + m = append(m, mm.Metadata) + } + + log.Debug().Msgf("Models %#v", m) + + dat, err := json.Marshal(m) + if err != nil { + return fmt.Errorf("could not marshal models: %w", err) + } + return c.Blob(200, "application/json", dat) + } +} + +// ListModelGalleriesEndpoint list the available galleries configured in LocalAI +// @Summary List all Galleries +// @Success 200 {object} []config.Gallery "Response" +// @Router /models/galleries [get] +// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! +func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc { + return func(c echo.Context) error { + log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + return c.Blob(200, "application/json", dat) + } +} diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index db00193a8499..1eb12693df8e 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -21,17 +21,17 @@ import ( // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/tokenMetrics [get] // @Router /tokenMetrics [get] -func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.TokenMetricsRequest) // Get input data from the request body - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } - modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if !ok || modelFile != "" { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) @@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a if err != nil { return err } - return c.JSON(response) + return c.JSON(200, response) } } diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index b6d34055390f..8e11bbc01161 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -3,12 +3,14 @@ package localai import ( "encoding/json" "fmt" + "io" + "net/http" "os" "path/filepath" "strings" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery/importers" @@ -21,12 +23,12 @@ import ( ) // ImportModelURIEndpoint handles creating new model configurations from a URI -func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) fiber.Handler { - return func(c *fiber.Ctx) error { +func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.ImportModelRequest) - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -61,7 +63,7 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl BackendGalleries: appConfig.BackendGalleries, } - return c.JSON(schema.GalleryResponse{ + return c.JSON(200, schema.GalleryResponse{ ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()), }) @@ -69,22 +71,28 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl } // ImportModelEndpoint handles creating new model configurations -func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { - return func(c *fiber.Ctx) error { +func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { // Get the raw body - body := c.Body() + body, err := io.ReadAll(c.Request().Body) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to read request body: " + err.Error(), + } + return c.JSON(http.StatusBadRequest, response) + } if len(body) == 0 { response := ModelResponse{ Success: false, Error: "Request body is empty", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Check content type to determine how to parse - contentType := string(c.Context().Request.Header.ContentType()) + contentType := c.Request().Header.Get("Content-Type") var modelConfig config.ModelConfig - var err error if strings.Contains(contentType, "application/json") { // Parse JSON @@ -93,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to parse JSON: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") { // Parse YAML @@ -102,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to parse YAML: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } } else { // Try to auto-detect format - if strings.TrimSpace(string(body))[0] == '{' { + if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' { // Looks like JSON if err := json.Unmarshal(body, &modelConfig); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse JSON: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } } else { // Assume YAML @@ -122,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to parse YAML: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } } } @@ -133,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Name is required", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Set defaults @@ -145,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Invalid configuration", } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Create the configuration file @@ -155,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Model path not trusted: " + err.Error(), } - return c.Status(400).JSON(response) + return c.JSON(http.StatusBadRequest, response) } // Marshal to YAML for storage @@ -165,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to marshal configuration: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Write the file @@ -174,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to write configuration file: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Reload configurations if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { @@ -182,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to reload configurations: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Preload the model @@ -191,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Success: false, Error: "Failed to preload model: " + err.Error(), } - return c.Status(500).JSON(response) + return c.JSON(http.StatusInternalServerError, response) } // Return success response response := ModelResponse{ @@ -199,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica Message: "Model configuration created successfully", Filename: filepath.Base(configPath), } - return c.JSON(response) + return c.JSON(200, response) } } diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go index 8fcc0a7a6ee4..a5f08a7f6444 100644 --- a/core/http/endpoints/localai/metrics.go +++ b/core/http/endpoints/localai/metrics.go @@ -1,46 +1,47 @@ -package localai - -import ( - "time" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/adaptor" - "github.com/mudler/LocalAI/core/services" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI -// @Summary Prometheus metrics endpoint -// @Param request body config.Gallery true "Gallery details" -// @Router /metrics [get] -func LocalAIMetricsEndpoint() fiber.Handler { - return adaptor.HTTPHandler(promhttp.Handler()) -} - -type apiMiddlewareConfig struct { - Filter func(c *fiber.Ctx) bool - metricsService *services.LocalAIMetricsService -} - -func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler { - cfg := apiMiddlewareConfig{ - metricsService: metrics, - Filter: func(c *fiber.Ctx) bool { - return c.Path() == "/metrics" - }, - } - - return func(c *fiber.Ctx) error { - if cfg.Filter != nil && cfg.Filter(c) { - return c.Next() - } - path := c.Path() - method := c.Method() - - start := time.Now() - err := c.Next() - elapsed := float64(time.Since(start)) / float64(time.Second) - cfg.metricsService.ObserveAPICall(method, path, elapsed) - return err - } -} +package localai + +import ( + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/services" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI +// @Summary Prometheus metrics endpoint +// @Param request body config.Gallery true "Gallery details" +// @Router /metrics [get] +func LocalAIMetricsEndpoint() echo.HandlerFunc { + return echo.WrapHandler(promhttp.Handler()) +} + +type apiMiddlewareConfig struct { + Filter func(c echo.Context) bool + metricsService *services.LocalAIMetricsService +} + +func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc { + cfg := apiMiddlewareConfig{ + metricsService: metrics, + Filter: func(c echo.Context) bool { + return c.Path() == "/metrics" + }, + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if cfg.Filter != nil && cfg.Filter(c) { + return next(c) + } + path := c.Path() + method := c.Request().Method + + start := time.Now() + err := next(c) + elapsed := float64(time.Since(start)) / float64(time.Second) + cfg.metricsService.ObserveAPICall(method, path, elapsed) + return err + } + } +} diff --git a/core/http/endpoints/localai/p2p.go b/core/http/endpoints/localai/p2p.go index bbcee8c801e1..afd7d048dc83 100644 --- a/core/http/endpoints/localai/p2p.go +++ b/core/http/endpoints/localai/p2p.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" @@ -11,10 +11,10 @@ import ( // @Summary Returns available P2P nodes // @Success 200 {object} []schema.P2PNodesResponse "Response" // @Router /api/p2p [get] -func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error { +func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc { // Render index - return func(c *fiber.Ctx) error { - return c.JSON(schema.P2PNodesResponse{ + return func(c echo.Context) error { + return c.JSON(200, schema.P2PNodesResponse{ Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)), FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)), }) @@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error { // @Summary Show the P2P token // @Success 200 {string} string "Response" // @Router /api/p2p/token [get] -func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) } +func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) } } diff --git a/core/http/endpoints/localai/stores.go b/core/http/endpoints/localai/stores.go index 303d943f6fc3..033334375228 100644 --- a/core/http/endpoints/localai/stores.go +++ b/core/http/endpoints/localai/stores.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" @@ -9,11 +9,11 @@ import ( "github.com/mudler/LocalAI/pkg/store" ) -func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.StoresSet) - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi vals[i] = []byte(v) } - err = store.SetCols(c.Context(), sb, input.Keys, vals) + err = store.SetCols(c.Request().Context(), sb, input.Keys, vals) if err != nil { return err } - return c.Send(nil) + return c.NoContent(200) } } -func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.StoresDelete) - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo } defer sl.Close() - if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil { + if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil { return err } - return c.Send(nil) + return c.NoContent(200) } } -func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.StoresGet) - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi } defer sl.Close() - keys, vals, err := store.GetCols(c.Context(), sb, input.Keys) + keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys) if err != nil { return err } @@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi res.Values[i] = string(v) } - return c.JSON(res) + return c.JSON(200, res) } } -func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { input := new(schema.StoresFind) - if err := c.BodyParser(input); err != nil { + if err := c.Bind(input); err != nil { return err } @@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf } defer sl.Close() - keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk) + keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk) if err != nil { return err } @@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf res.Values[i] = string(v) } - return c.JSON(res) + return c.JSON(200, res) } } diff --git a/core/http/endpoints/localai/system.go b/core/http/endpoints/localai/system.go index 349f97cf861c..a3831e18483a 100644 --- a/core/http/endpoints/localai/system.go +++ b/core/http/endpoints/localai/system.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" @@ -11,8 +11,8 @@ import ( // @Summary Show the LocalAI instance information // @Success 200 {object} schema.SystemInformationResponse "Response" // @Router /system [get] -func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { availableBackends := []string{} loadedModels := ml.ListLoadedModels() for b := range appConfig.ExternalGRPCBackends { @@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf for _, m := range loadedModels { sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID}) } - return c.JSON( + return c.JSON(200, schema.SystemInformationResponse{ Backends: availableBackends, Models: sysmodels, diff --git a/core/http/endpoints/localai/tokenize.go b/core/http/endpoints/localai/tokenize.go index cd12e50dfb7a..23eec48c7545 100644 --- a/core/http/endpoints/localai/tokenize.go +++ b/core/http/endpoints/localai/tokenize.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -14,22 +14,22 @@ import ( // @Param request body schema.TokenizeRequest true "Request" // @Success 200 {object} schema.TokenizeResponse "Response" // @Router /v1/tokenize [post] -func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(ctx *fiber.Ctx) error { - input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) +func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) if err != nil { return err } - return ctx.JSON(tokenResponse) + return c.JSON(200, tokenResponse) } } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 5c2f01dad4b9..61577d4eb7ef 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -1,12 +1,14 @@ package localai import ( + "path/filepath" + + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/schema" "github.com/rs/zerolog/log" @@ -22,16 +24,16 @@ import ( // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/audio/speech [post] // @Router /tts [post] -func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) +func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received") @@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig return err } - return c.Download(filePath) + return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/core/http/endpoints/localai/vad.go b/core/http/endpoints/localai/vad.go index c3e310fe867b..6bd6e735df0d 100644 --- a/core/http/endpoints/localai/vad.go +++ b/core/http/endpoints/localai/vad.go @@ -1,7 +1,7 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -16,26 +16,26 @@ import ( // @Param request body schema.VADRequest true "query params" // @Success 200 {object} proto.VADResponse "Response" // @Router /vad [post] -func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest) +func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received") - resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg) + resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg) if err != nil { return err } - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index 68b8ec011b3a..5d2482b61022 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -12,14 +12,15 @@ import ( "strings" "time" + "github.com/labstack/echo/v4" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" - "github.com/gofiber/fiber/v2" model "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) @@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) { // @Param request body schema.VideoRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /video [post] -func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) +func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) if !ok || input.Model == "" { log.Error().Msg("Video Endpoint - Invalid Input") - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { log.Error().Msg("Video Endpoint - Invalid Config") - return fiber.ErrBadRequest + return echo.ErrBadRequest } src := "" @@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return err } - baseURL := c.BaseURL() + baseURL := utils.BaseURL(c) fn, err := backend.VideoGeneration( height, @@ -216,6 +217,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index 04f72743e34e..f571ba7d2fb6 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -1,7 +1,9 @@ package localai import ( - "github.com/gofiber/fiber/v2" + "strings" + + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/utils" @@ -11,8 +13,8 @@ import ( ) func WelcomeEndpoint(appConfig *config.ApplicationConfig, - cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error { - return func(c *fiber.Ctx) error { + cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc { + return func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() galleryConfigs := map[string]*gallery.ModelConfig{} @@ -40,7 +42,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, // Get model statuses to display in the UI the operation in progress processingModels, taskTypes := opcache.GetStatus() - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": utils.BaseURL(c), @@ -54,12 +56,14 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, "InstalledBackends": installedBackends, } - if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { + contentType := c.Request().Header.Get("Content-Type") + accept := c.Request().Header.Get("Accept") + if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") { // The client expects a JSON response - return c.Status(fiber.StatusOK).JSON(summary) + return c.JSON(200, summary) } else { // Render index - return c.Render("views/index", summary) + return c.Render(200, "views/index", summary) } } } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 6385997bf4fb..d121a081eb47 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -1,15 +1,12 @@ package openai import ( - "bufio" - "context" "encoding/json" "fmt" - "net" "time" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -20,68 +17,14 @@ import ( "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" ) -// NOTE: this is a bad WORKAROUND! We should find a better way to handle this. -// Fasthttp doesn't support context cancellation from the caller -// for non-streaming requests, so we need to monitor the connection directly. -// Monitor connection for client disconnection during non-streaming requests -// We access the connection directly via c.Context().Conn() to monitor it -// during ComputeChoices execution, not after the response is sent -// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906 -func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) { - var conn net.Conn = c.Context().Conn() - if conn == nil { - return - } - - go func() { - defer func() { - // Clear read deadline when goroutine exits - conn.SetReadDeadline(time.Time{}) - }() - - buf := make([]byte, 1) - // Use a short read deadline to periodically check if connection is closed - // Without a deadline, Read() would block indefinitely waiting for data - // that will never come (client is waiting for response, not sending more data) - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-requestCtx.Done(): - // Request completed or was cancelled - exit goroutine - return - case <-ticker.C: - // Set a short deadline - if connection is closed, read will fail immediately - // If connection is open but no data, it will timeout and we check again - conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) - _, err := conn.Read(buf) - if err != nil { - // Check if it's a timeout (connection still open, just no data) - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - // Timeout is expected - connection is still open, just no data to read - // Continue the loop to check again - continue - } - // Connection closed or other error - cancel the context to stop gRPC call - log.Debug().Msgf("Calling cancellation function") - cancelFunc() - return - } - } - } - }() -} - // ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create // @Summary Generate a chat completions for a given prompt and model. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc { var id, textContentToReturn string var created int @@ -235,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator return err } - return func(c *fiber.Ctx) error { + return func(c echo.Context) error { textContentToReturn = "" id = uuid.New().String() created = int(time.Now().Unix()) - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - extraUsage := c.Get("Extra-Usage", "") != "" + extraUsage := c.Request().Header.Get("Extra-Usage") != "" - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Msgf("Chat endpoint configuration read: %+v", config) @@ -392,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator case toStream: log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - c.Set("X-Correlation-ID", id) + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().Header().Set("X-Correlation-ID", id) responses := make(chan schema.OpenAIResponse) ended := make(chan error, 1) @@ -411,103 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } }() - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - usage := &schema.OpenAIUsage{} - toolsCalled := false - - LOOP: - for { - select { - case <-input.Context.Done(): - // Context was cancelled (client disconnected or request cancelled) - log.Debug().Msgf("Request context cancelled, stopping stream") + usage := &schema.OpenAIUsage{} + toolsCalled := false + + LOOP: + for { + select { + case <-input.Context.Done(): + // Context was cancelled (client disconnected or request cancelled) + log.Debug().Msgf("Request context cancelled, stopping stream") + input.Cancel() + break LOOP + case ev := <-responses: + if len(ev.Choices) == 0 { + log.Debug().Msgf("No choices in the response, skipping") + continue + } + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { + toolsCalled = true + } + respData, err := json.Marshal(ev) + if err != nil { + log.Debug().Msgf("Failed to marshal response: %v", err) + input.Cancel() + continue + } + log.Debug().Msgf("Sending chunk: %s", string(respData)) + _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) + if err != nil { + log.Debug().Msgf("Sending chunk failed: %v", err) input.Cancel() + return err + } + c.Response().Flush() + case err := <-ended: + if err == nil { break LOOP - case ev := <-responses: - if len(ev.Choices) == 0 { - log.Debug().Msgf("No choices in the response, skipping") - continue - } - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - if len(ev.Choices[0].Delta.ToolCalls) > 0 { - toolsCalled = true - } - respData, err := json.Marshal(ev) - if err != nil { - log.Debug().Msgf("Failed to marshal response: %v", err) - input.Cancel() - continue - } - log.Debug().Msgf("Sending chunk: %s", string(respData)) - _, err = fmt.Fprintf(w, "data: %s\n\n", string(respData)) - if err != nil { - log.Debug().Msgf("Sending chunk failed: %v", err) - input.Cancel() - } - w.Flush() - case err := <-ended: - if err == nil { - break LOOP - } - log.Error().Msgf("Stream ended with error: %v", err) - - stopReason := FinishReasonStop - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - FinishReason: &stopReason, - Index: 0, - Delta: &schema.Message{Content: "Internal error: " + err.Error()}, - }}, - Object: "chat.completion.chunk", - Usage: *usage, - } - respData, marshalErr := json.Marshal(resp) - if marshalErr != nil { - log.Error().Msgf("Failed to marshal error response: %v", marshalErr) - // Send a simple error message as fallback - w.WriteString("data: {\"error\":\"Internal error\"}\n\n") - } else { - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - } - w.WriteString("data: [DONE]\n\n") - w.Flush() + } + log.Error().Msgf("Stream ended with error: %v", err) - return + stopReason := FinishReasonStop + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + FinishReason: &stopReason, + Index: 0, + Delta: &schema.Message{Content: "Internal error: " + err.Error()}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, } - } + respData, marshalErr := json.Marshal(resp) + if marshalErr != nil { + log.Error().Msgf("Failed to marshal error response: %v", marshalErr) + // Send a simple error message as fallback + fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") + } else { + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + } + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() - finishReason := FinishReasonStop - if toolsCalled && len(input.Tools) > 0 { - finishReason = FinishReasonToolCalls - } else if toolsCalled { - finishReason = FinishReasonFunctionCall + return nil } + } - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - FinishReason: &finishReason, - Index: 0, - Delta: &schema.Message{}, - }}, - Object: "chat.completion.chunk", - Usage: *usage, - } - respData, _ := json.Marshal(resp) + finishReason := FinishReasonStop + if toolsCalled && len(input.Tools) > 0 { + finishReason = FinishReasonToolCalls + } else if toolsCalled { + finishReason = FinishReasonFunctionCall + } - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - log.Debug().Msgf("Stream ended") - })) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + FinishReason: &finishReason, + Index: 0, + Delta: &schema.Message{}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, + } + respData, _ := json.Marshal(resp) + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + log.Debug().Msgf("Stream ended") return nil // no streaming mode @@ -589,9 +527,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } - // NOTE: this is a workaround as fasthttp - // context cancellation does not fire in non-streaming requests - // handleConnectionCancellation(c, input.Cancel, input.Context) + // Echo properly supports context cancellation via c.Request().Context() + // No workaround needed! result, tokenUsage, err := ComputeChoices( input, @@ -628,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator log.Debug().Msgf("Response: %s", respData) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } } diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index fc9acb0c33d6..f4c5e3fe9580 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -1,24 +1,22 @@ package openai import ( - "bufio" "encoding/json" "errors" "fmt" "time" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" ) // CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions @@ -26,7 +24,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { created := int(time.Now().Unix()) @@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva return err } - return func(c *fiber.Ctx) error { + return func(c echo.Context) error { created := int(time.Now().Unix()) // Handle Correlation - id := c.Get("X-Correlation-ID", uuid.New().String()) - extraUsage := c.Get("Extra-Usage", "") != "" + id := c.Request().Header.Get("X-Correlation-ID") + if id == "" { + id = uuid.New().String() + } + extraUsage := c.Request().Header.Get("Extra-Usage") != "" - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } if config.ResponseFormatMap != nil { @@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva if input.Stream { log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") - if input.Stream { if len(config.PromptStrings) > 1 { return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") } @@ -130,78 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva ended <- process(id, predInput, input, config, ml, responses, extraUsage) }() - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - LOOP: - for { - select { - case ev := <-responses: - if len(ev.Choices) == 0 { - log.Debug().Msgf("No choices in the response, skipping") - continue - } - respData, err := json.Marshal(ev) - if err != nil { - log.Debug().Msgf("Failed to marshal response: %v", err) - continue - } - - log.Debug().Msgf("Sending chunk: %s", string(respData)) - fmt.Fprintf(w, "data: %s\n\n", string(respData)) - w.Flush() - case err := <-ended: - if err == nil { - break LOOP - } - log.Error().Msgf("Stream ended with error: %v", err) - - stopReason := FinishReasonStop - errorResp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, - Choices: []schema.Choice{ - { - Index: 0, - FinishReason: &stopReason, - Text: "Internal error: " + err.Error(), - }, - }, - Object: "text_completion", - } - errorData, marshalErr := json.Marshal(errorResp) - if marshalErr != nil { - log.Error().Msgf("Failed to marshal error response: %v", marshalErr) - // Send a simple error message as fallback - fmt.Fprintf(w, "data: {\"error\":\"Internal error\"}\n\n") - } else { - fmt.Fprintf(w, "data: %s\n\n", string(errorData)) - } - w.Flush() + LOOP: + for { + select { + case ev := <-responses: + if len(ev.Choices) == 0 { + log.Debug().Msgf("No choices in the response, skipping") + continue + } + respData, err := json.Marshal(ev) + if err != nil { + log.Debug().Msgf("Failed to marshal response: %v", err) + continue + } + + log.Debug().Msgf("Sending chunk: %s", string(respData)) + _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) + if err != nil { + return err + } + c.Response().Flush() + case err := <-ended: + if err == nil { break LOOP } - } + log.Error().Msgf("Stream ended with error: %v", err) - stopReason := FinishReasonStop - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - FinishReason: &stopReason, + stopReason := FinishReasonStop + errorResp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, + Choices: []schema.Choice{ + { + Index: 0, + FinishReason: &stopReason, + Text: "Internal error: " + err.Error(), + }, }, - }, - Object: "text_completion", + Object: "text_completion", + } + errorData, marshalErr := json.Marshal(errorResp) + if marshalErr != nil { + log.Error().Msgf("Failed to marshal error response: %v", marshalErr) + // Send a simple error message as fallback + fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") + } else { + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData)) + } + c.Response().Flush() + return nil } - respData, _ := json.Marshal(resp) + } + + stopReason := FinishReasonStop + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Index: 0, + FinishReason: &stopReason, + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return <-ended + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil } var result []schema.Choice @@ -257,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 0cdeba09f4a0..8520bf1e233f 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -4,11 +4,11 @@ import ( "encoding/json" "time" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" @@ -23,20 +23,20 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/edits [post] -func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { - return func(c *fiber.Ctx) error { + return func(c echo.Context) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } // Opt-in extra usage flag - extraUsage := c.Get("Extra-Usage", "") != "" + extraUsage := c.Request().Header.Get("Extra-Usage") != "" - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Msgf("Edit Endpoint Input : %+v", input) @@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index 4154c435a5e9..7b75d1fd5d61 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -4,6 +4,7 @@ import ( "encoding/json" "time" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" @@ -12,7 +13,6 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" - "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) @@ -21,16 +21,16 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] -func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) +func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } log.Debug().Msgf("Parameter Config: %+v", config) @@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 91ecdd23ac97..14d571e36f7f 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -13,14 +13,15 @@ import ( "strings" "time" + "github.com/labstack/echo/v4" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" - "github.com/gofiber/fiber/v2" model "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) @@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) { // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] -func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) +func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { log.Error().Msg("Image Endpoint - Invalid Input") - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { log.Error().Msg("Image Endpoint - Invalid Config") - return fiber.ErrBadRequest + return echo.ErrBadRequest } // Process input images (for img2img/inpainting) @@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return err } - baseURL := c.BaseURL() + baseURL := utils.BaseURL(c) // Use the first input image as src if available, otherwise use the original src inputSrc := src @@ -234,7 +235,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 6c0ffca04aa0..47501dd934f8 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -1,7 +1,7 @@ package openai import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" @@ -12,14 +12,15 @@ import ( // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] -func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error { - return func(c *fiber.Ctx) error { +func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { // If blank, no filter is applied. - filter := c.Query("filter") + filter := c.QueryParam("filter") // By default, exclude any loose files that are already referenced by a configuration file. var policy services.LooseFilePolicy - if c.QueryBool("excludeConfigured", true) { + excludeConfigured := c.QueryParam("excludeConfigured") + if excludeConfigured == "" || excludeConfigured == "true" { policy = services.SKIP_IF_CONFIGURED } else { policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user? @@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) } - return c.JSON(schema.ModelsDataResponse{ + return c.JSON(200, schema.ModelsDataResponse{ Object: "list", Data: dataModels, }) diff --git a/core/http/endpoints/openai/mcp.go b/core/http/endpoints/openai/mcp.go index efb3c6d29096..a91706f51d10 100644 --- a/core/http/endpoints/openai/mcp.go +++ b/core/http/endpoints/openai/mcp.go @@ -8,11 +8,11 @@ import ( "strings" "time" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" @@ -26,24 +26,27 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /mcp/v1/completions [post] -func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { // We do not support streaming mode (Yet?) - return func(c *fiber.Ctx) error { + return func(c echo.Context) error { created := int(time.Now().Unix()) - ctx := c.Context() + ctx := c.Request().Context() // Handle Correlation - id := c.Get("X-Correlation-ID", uuid.New().String()) + id := c.Request().Header.Get("X-Correlation-ID") + if id == "" { + id = uuid.New().String() + } - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } if config.MCP.Servers == "" && config.MCP.Stdio == "" { @@ -80,7 +83,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ctxWithCancellation, cancel := context.WithCancel(ctx) defer cancel() - //handleConnectionCancellation(c, cancel, ctxWithCancellation) + // TODO: instead of connecting to the API, we should just wire this internally // and act like completion.go. // We can do this as cogito expects an interface and we can create one that @@ -147,6 +150,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(200, resp) } } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index c56e1037577f..f715bb2b4281 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -10,9 +10,11 @@ import ( "sync" "time" + "net/http" + "github.com/go-audio/audio" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/websocket/v2" + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" @@ -167,32 +169,50 @@ type Model interface { PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error } +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins + }, +} + // TODO: Implement ephemeral keys to allow these endpoints to be used -func RealtimeSessions(application *application.Application) fiber.Handler { - return func(ctx *fiber.Ctx) error { - return ctx.SendStatus(501) +func RealtimeSessions(application *application.Application) echo.HandlerFunc { + return func(c echo.Context) error { + return c.NoContent(501) } } -func RealtimeTranscriptionSession(application *application.Application) fiber.Handler { - return func(ctx *fiber.Ctx) error { - return ctx.SendStatus(501) +func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { + return func(c echo.Context) error { + return c.NoContent(501) } } -func Realtime(application *application.Application) fiber.Handler { - return websocket.New(registerRealtime(application)) +func Realtime(application *application.Application) echo.HandlerFunc { + return func(c echo.Context) error { + ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return err + } + defer ws.Close() + + // Extract query parameters from Echo context before passing to websocket handler + model := c.QueryParam("model") + if model == "" { + model = "gpt-4o" + } + intent := c.QueryParam("intent") + + registerRealtime(application, model, intent)(ws) + return nil + } } -func registerRealtime(application *application.Application) func(c *websocket.Conn) { +func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) { return func(c *websocket.Conn) { evaluator := application.TemplatesEvaluator() log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) - - model := c.Query("model", "gpt-4o") - - intent := c.Query("intent") if intent != "transcription" { sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter") } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 1811e131f3d8..c5fc9f35261c 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -7,13 +7,13 @@ import ( "path" "path/filepath" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" model "github.com/mudler/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) @@ -24,19 +24,19 @@ import ( // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] -func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) +func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } - diarize := c.FormValue("diarize", "false") != "false" + diarize := c.FormValue("diarize") != "false" // retrieve the file data from the request file, err := c.FormFile("file") @@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app log.Debug().Msgf("Trascribed: %+v", tr) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) + return c.JSON(http.StatusOK, tr) } } diff --git a/core/http/endpoints/openai/video.go b/core/http/endpoints/openai/video.go index 598e9e32b3c7..12c06ffe61ac 100644 --- a/core/http/endpoints/openai/video.go +++ b/core/http/endpoints/openai/video.go @@ -6,7 +6,7 @@ import ( "strconv" "strings" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/middleware" @@ -14,20 +14,24 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) +func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input == nil { - return fiber.ErrBadRequest + return echo.ErrBadRequest } var raw map[string]interface{} - if body := c.Body(); len(body) > 0 { + body := make([]byte, 0) + if c.Request().Body != nil { + c.Request().Body.Read(body) + } + if len(body) > 0 { _ = json.Unmarshal(body, &raw) } // Build VideoRequest using shared mapper vr := MapOpenAIToVideo(input, raw) - // Place VideoRequest into locals so localai.VideoEndpoint can consume it - c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr) + // Place VideoRequest into context so localai.VideoEndpoint can consume it + c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr) // Delegate to existing localai handler return localai.VideoEndpoint(cl, ml, appConfig)(c) } diff --git a/core/http/explorer.go b/core/http/explorer.go index e3001f3a2c20..f405934aa889 100644 --- a/core/http/explorer.go +++ b/core/http/explorer.go @@ -1,48 +1,50 @@ package http import ( + "io/fs" "net/http" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/favicon" - "github.com/gofiber/fiber/v2/middleware/filesystem" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" + "github.com/rs/zerolog/log" ) -func Explorer(db *explorer.Database) *fiber.App { - - fiberCfg := fiber.Config{ - Views: renderEngine(), - // We disable the Fiber startup message as it does not conform to structured logging. - // We register a startup log line with connection information in the OnListen hook to keep things user friendly though - DisableStartupMessage: false, - // Override default error handler +func Explorer(db *explorer.Database) *echo.Echo { + e := echo.New() + + // Set renderer + e.Renderer = renderEngine() + + // Hide banner + e.HideBanner = true + + e.Pre(middleware.StripPathPrefix()) + routes.RegisterExplorerRoutes(e, db) + + // Favicon handler + e.GET("/favicon.svg", func(c echo.Context) error { + data, err := embedDirStatic.ReadFile("static/favicon.svg") + if err != nil { + return c.NoContent(http.StatusNotFound) + } + c.Response().Header().Set("Content-Type", "image/svg+xml") + return c.Blob(http.StatusOK, "image/svg+xml", data) + }) + + // Static files - use fs.Sub to create a filesystem rooted at "static" + staticFS, err := fs.Sub(embedDirStatic, "static") + if err != nil { + // Log error but continue - static files might not work + log.Error().Err(err).Msg("failed to create static filesystem") + } else { + e.StaticFS("/static", staticFS) } - app := fiber.New(fiberCfg) - - app.Use(middleware.StripPathPrefix()) - routes.RegisterExplorerRoutes(app, db) - - httpFS := http.FS(embedDirStatic) - - app.Use(favicon.New(favicon.Config{ - URL: "/favicon.svg", - FileSystem: httpFS, - File: "static/favicon.svg", - })) - - app.Use("/static", filesystem.New(filesystem.Config{ - Root: httpFS, - PathPrefix: "static", - Browse: true, - })) - // Define a custom 404 handler // Note: keep this at the bottom! - app.Use(notFoundHandler) + e.GET("/*", notFoundHandler) - return app + return e } diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go index ceb8b68d358e..0a392d24829a 100644 --- a/core/http/middleware/auth.go +++ b/core/http/middleware/auth.go @@ -3,50 +3,109 @@ package middleware import ( "crypto/subtle" "errors" + "net/http" "strings" - "github.com/dave-gray101/v2keyauth" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/keyauth" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" ) -// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware -// Currently this requires an upstream patch - and feature patches are no longer accepted to v2 -// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate. +var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") -func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) { - customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme) - if err != nil { - return nil, err - } +// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration +func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) { + // Create validator function + validator := getApiKeyValidationFunction(applicationConfig) + + // Create error handler + errorHandler := getApiKeyErrorHandler(applicationConfig) + + // Create Next function (skip middleware for certain requests) + skipper := getApiKeyRequiredFilterFunction(applicationConfig) - return &v2keyauth.Config{ - CustomKeyLookup: customLookup, - Next: getApiKeyRequiredFilterFunction(applicationConfig), - Validator: getApiKeyValidationFunction(applicationConfig), - ErrorHandler: getApiKeyErrorHandler(applicationConfig), - AuthScheme: "Bearer", + // Wrap it with our custom key lookup that checks multiple sources + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if len(applicationConfig.ApiKeys) == 0 { + return next(c) + } + + // Skip if skipper says so + if skipper != nil && skipper(c) { + return next(c) + } + + // Try to extract key from multiple sources + key, err := extractKeyFromMultipleSources(c) + if err != nil { + return errorHandler(err, c) + } + + // Validate the key + valid, err := validator(key, c) + if err != nil || !valid { + return errorHandler(ErrMissingOrMalformedAPIKey, c) + } + + // Store key in context for later use + c.Set("api_key", key) + + return next(c) + } }, nil } -func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler { - return func(ctx *fiber.Ctx, err error) error { - if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) { +// extractKeyFromMultipleSources checks multiple sources for the API key +// in order: Authorization header, x-api-key header, xi-api-key header, token cookie +func extractKeyFromMultipleSources(c echo.Context) (string, error) { + // Check Authorization header first + auth := c.Request().Header.Get("Authorization") + if auth != "" { + // Check for Bearer scheme + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer "), nil + } + // If no Bearer prefix, return as-is (for backward compatibility) + return auth, nil + } + + // Check x-api-key header + if key := c.Request().Header.Get("x-api-key"); key != "" { + return key, nil + } + + // Check xi-api-key header + if key := c.Request().Header.Get("xi-api-key"); key != "" { + return key, nil + } + + // Check token cookie + cookie, err := c.Cookie("token") + if err == nil && cookie != nil && cookie.Value != "" { + return cookie.Value, nil + } + + return "", ErrMissingOrMalformedAPIKey +} + +func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error { + return func(err error, c echo.Context) error { + if errors.Is(err, ErrMissingOrMalformedAPIKey) { if len(applicationConfig.ApiKeys) == 0 { - return ctx.Next() // if no keys are set up, any error we get here is not an error. + return nil // if no keys are set up, any error we get here is not an error. } - ctx.Set("WWW-Authenticate", "Bearer") + c.Response().Header().Set("WWW-Authenticate", "Bearer") if applicationConfig.OpaqueErrors { - return ctx.SendStatus(401) + return c.NoContent(http.StatusUnauthorized) } // Check if the request content type is JSON - contentType := string(ctx.Context().Request.Header.ContentType()) + contentType := c.Request().Header.Get("Content-Type") if strings.Contains(contentType, "application/json") { - return ctx.Status(401).JSON(schema.ErrorResponse{ + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "An authentication key is required", Code: 401, @@ -55,50 +114,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er }) } - return ctx.Status(401).Render("views/login", fiber.Map{ - "BaseURL": utils.BaseURL(ctx), + return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ + "BaseURL": utils.BaseURL(c), }) } if applicationConfig.OpaqueErrors { - return ctx.SendStatus(500) + return c.NoContent(http.StatusInternalServerError) } return err } } -func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) { - +func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) { if applicationConfig.UseSubtleKeyComparison { - return func(ctx *fiber.Ctx, apiKey string) (bool, error) { + return func(key string, c echo.Context) (bool, error) { if len(applicationConfig.ApiKeys) == 0 { return true, nil // If no keys are setup, accept everything } for _, validKey := range applicationConfig.ApiKeys { - if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 { + if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { return true, nil } } - return false, v2keyauth.ErrMissingOrMalformedAPIKey + return false, ErrMissingOrMalformedAPIKey } } - return func(ctx *fiber.Ctx, apiKey string) (bool, error) { + return func(key string, c echo.Context) (bool, error) { if len(applicationConfig.ApiKeys) == 0 { return true, nil // If no keys are setup, accept everything } for _, validKey := range applicationConfig.ApiKeys { - if apiKey == validKey { + if key == validKey { return true, nil } } - return false, v2keyauth.ErrMissingOrMalformedAPIKey + return false, ErrMissingOrMalformedAPIKey } } -func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool { - if applicationConfig.DisableApiKeyRequirementForHttpGet { - return func(c *fiber.Ctx) bool { - if c.Method() != "GET" { +func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper { + return func(c echo.Context) bool { + path := c.Request().URL.Path + + // Always skip authentication for static files + if strings.HasPrefix(path, "/static/") { + return true + } + + // Always skip authentication for generated content + if strings.HasPrefix(path, "/generated-audio/") || + strings.HasPrefix(path, "/generated-images/") || + strings.HasPrefix(path, "/generated-videos/") { + return true + } + + // Skip authentication for favicon + if path == "/favicon.svg" { + return true + } + + // Handle GET request exemptions if enabled + if applicationConfig.DisableApiKeyRequirementForHttpGet { + if c.Request().Method != http.MethodGet { return false } for _, rx := range applicationConfig.HttpGetExemptedEndpoints { @@ -106,8 +184,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig return true } } - return false } + + return false } - return func(c *fiber.Ctx) bool { return false } } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 06647ea57891..362feadc1677 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -1,470 +1,482 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "strings" - - "github.com/google/uuid" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/core/templates" - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/utils" - "github.com/valyala/fasthttp" - - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -type correlationIDKeyType string - -// CorrelationIDKey to track request across process boundary -const CorrelationIDKey correlationIDKeyType = "correlationID" - -type RequestExtractor struct { - modelConfigLoader *config.ModelConfigLoader - modelLoader *model.ModelLoader - applicationConfig *config.ApplicationConfig -} - -func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { - return &RequestExtractor{ - modelConfigLoader: modelConfigLoader, - modelLoader: modelLoader, - applicationConfig: applicationConfig, - } -} - -const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" -const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" -const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" - -// TODO: Refactor to not return error if unchanged -func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) { - model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && model != "" { - return - } - model = ctx.Params("model") - - if (model == "") && ctx.Query("model") != "" { - model = ctx.Query("model") - } - - if model == "" { - // Set model from bearer token, if available - bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request. - if bearer != "" { - exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) - if err == nil && exists { - model = bearer - } - } - } - - ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model) -} - -func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler { - return func(ctx *fiber.Ctx) error { - re.setModelNameFromRequest(ctx) - localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if !ok || localModelName == "" { - ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) - log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default") - } - return ctx.Next() - } -} - -func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler { - return func(ctx *fiber.Ctx) error { - re.setModelNameFromRequest(ctx) - localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if localModelName != "" { // Don't overwrite existing values - return ctx.Next() - } - - modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) - if err != nil { - log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()") - return ctx.Next() - } - - if len(modelNames) == 0 { - log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed") - // This is non-fatal - making it so was breaking the case of direct installation of raw models - // return errors.New("this endpoint requires at least one model to be installed") - return ctx.Next() - } - - ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) - log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model") - return ctx.Next() - } -} - -// TODO: If context and cancel above belong on all methods, move that part of above into here! -// Otherwise, it's in its own method below for now -func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler { - return func(ctx *fiber.Ctx) error { - input := initializer() - if input == nil { - return fmt.Errorf("unable to initialize body") - } - if err := ctx.BodyParser(input); err != nil { - return fmt.Errorf("failed parsing request body: %w", err) - } - - // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain - if input.ModelName(nil) == "" { - localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && localModelName != "" { - log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain") - input.ModelName(&localModelName) - } - } - - cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) - - if err != nil { - log.Err(err) - log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil)) - } else if cfg.Model == "" && input.ModelName(nil) != "" { - log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input") - cfg.Model = input.ModelName(nil) - } - - ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return ctx.Next() - } -} - -func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { - input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) - if !ok || input.Model == "" { - return fiber.ErrBadRequest - } - - cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return fiber.ErrBadRequest - } - - // Extract or generate the correlation ID - correlationID := ctx.Get("X-Correlation-ID", uuid.New().String()) - ctx.Set("X-Correlation-ID", correlationID) - - //c1, cancel := context.WithCancel(re.applicationConfig.Context) - // Use the application context as parent to ensure cancellation on app shutdown - // We'll monitor the Fiber context separately and cancel our context when the request is canceled - c1, cancel := context.WithCancel(re.applicationConfig.Context) - // Monitor the Fiber context and cancel our context when it's canceled - // This ensures we respect request cancellation without causing panics - go func(fiberCtx *fasthttp.RequestCtx) { - if fiberCtx != nil { - <-fiberCtx.Done() - cancel() - } - }(ctx.Context()) - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := mergeOpenAIRequestAndModelConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value") - cfg.Model = input.Model - } - - ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return ctx.Next() -} - -func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - if input.ResponseFormat != nil { - switch responseFormat := input.ResponseFormat.(type) { - case string: - config.ResponseFormat = responseFormat - case map[string]interface{}: - config.ResponseFormatMap = responseFormat - } - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - var toolChoice functions.Tool - - switch content := input.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - input.FunctionCall = map[string]interface{}{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - imgIndex, vidIndex, audioIndex := 0, 0, 0 - for i, m := range input.Messages { - nrOfImgsInMessage := 0 - nrOfVideosInMessage := 0 - nrOfAudiosInMessage := 0 - - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - - textContent := "" - // we will template this at the end - - CONTENT: - for _, pp := range c { - switch pp.Type { - case "text": - textContent += pp.Text - //input.Messages[i].StringContent = pp.Text - case "video", "video_url": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding video: %s", err) - continue CONTENT - } - input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - vidIndex++ - nrOfVideosInMessage++ - case "audio_url", "audio": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding audio: %s", err) - continue CONTENT - } - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - audioIndex++ - nrOfAudiosInMessage++ - case "input_audio": - // TODO: make sure that we only return base64 stuff - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) - audioIndex++ - nrOfAudiosInMessage++ - case "image_url", "image": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding image: %s", err) - continue CONTENT - } - - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - - imgIndex++ - nrOfImgsInMessage++ - } - } - - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ - TotalImages: imgIndex, - TotalVideos: vidIndex, - TotalAudios: audioIndex, - ImagesInMessage: nrOfImgsInMessage, - VideosInMessage: nrOfVideosInMessage, - AudiosInMessage: nrOfAudiosInMessage, - }, textContent) - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input)) - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []any: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []any: - tokens := []int{} - inputStrings := []string{} - for _, ii := range i { - switch ii := ii.(type) { - case int: - tokens = append(tokens, ii) - case float64: - tokens = append(tokens, int(ii)) - case string: - inputStrings = append(inputStrings, ii) - default: - log.Error().Msgf("Unknown input type: %T", ii) - } - } - config.InputToken = append(config.InputToken, tokens) - config.InputStrings = append(config.InputStrings, inputStrings...) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } - - // If a quality was defined as number, convert it to step - if input.Quality != "" { - q, err := strconv.Atoi(input.Quality) - if err == nil { - config.Step = q - } - } - - if config.Validate() { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + +type RequestExtractor struct { + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig +} + +func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { + return &RequestExtractor{ + modelConfigLoader: modelConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, + } +} + +const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" +const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" +const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" + +// TODO: Refactor to not return error if unchanged +func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { + model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && model != "" { + return + } + model = c.Param("model") + + if model == "" { + model = c.QueryParam("model") + } + + if model == "" { + // Set model from bearer token, if available + auth := c.Request().Header.Get("Authorization") + bearer := strings.TrimPrefix(auth, "Bearer ") + if bearer != "" && bearer != auth { + exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) + if err == nil && exists { + model = bearer + } + } + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) +} + +func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || localModelName == "" { + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) + log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default") + } + return next(c) + } + } +} + +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if localModelName != "" { // Don't overwrite existing values + return next(c) + } + + modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) + if err != nil { + log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()") + return next(c) + } + + if len(modelNames) == 0 { + log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed") + // This is non-fatal - making it so was breaking the case of direct installation of raw models + // return errors.New("this endpoint requires at least one model to be installed") + return next(c) + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) + log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model") + return next(c) + } + } +} + +// TODO: If context and cancel above belong on all methods, move that part of above into here! +// Otherwise, it's in its own method below for now +func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + input := initializer() + if input == nil { + return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") + } + if err := c.Bind(input); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) + } + + // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain + if input.ModelName(nil) == "" { + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && localModelName != "" { + log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain") + input.ModelName(&localModelName) + } + } + + cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) + + if err != nil { + log.Err(err) + log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil)) + } else if cfg.Model == "" && input.ModelName(nil) != "" { + log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input") + cfg.Model = input.ModelName(nil) + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return next(c) + } + } +} + +func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID + correlationID := c.Request().Header.Get("X-Correlation-ID") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("X-Correlation-ID", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + // No need for workarounds like handleConnectionCancellation + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := mergeOpenAIRequestAndModelConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value") + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]interface{}: + config.ResponseFormatMap = responseFormat + } + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice functions.Tool + + switch content := input.ToolsChoice.(type) { + case string: + _ = json.Unmarshal([]byte(content), &toolChoice) + case map[string]interface{}: + dat, _ := json.Marshal(content) + _ = json.Unmarshal(dat, &toolChoice) + } + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + + // Decode each request's message content + imgIndex, vidIndex, audioIndex := 0, 0, 0 + for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + + CONTENT: + for _, pp := range c { + switch pp.Type { + case "text": + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text + case "video", "video_url": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding video: %s", err) + continue CONTENT + } + input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + vidIndex++ + nrOfVideosInMessage++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding audio: %s", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + audioIndex++ + nrOfAudiosInMessage++ + case "input_audio": + // TODO: make sure that we only return base64 stuff + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) + audioIndex++ + nrOfAudiosInMessage++ + case "image_url", "image": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding image: %s", err) + continue CONTENT + } + + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + + imgIndex++ + nrOfImgsInMessage++ + } + } + + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input)) + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []any: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []any: + tokens := []int{} + inputStrings := []string{} + for _, ii := range i { + switch ii := ii.(type) { + case int: + tokens = append(tokens, ii) + case float64: + tokens = append(tokens, int(ii)) + case string: + inputStrings = append(inputStrings, ii) + default: + log.Error().Msgf("Unknown input type: %T", ii) + } + } + config.InputToken = append(config.InputToken, tokens) + config.InputStrings = append(config.InputStrings, inputStrings...) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } + + // If a quality was defined as number, convert it to step + if input.Quality != "" { + q, err := strconv.Atoi(input.Quality) + if err == nil { + config.Step = q + } + } + + if config.Validate() { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} diff --git a/core/http/middleware/strippathprefix.go b/core/http/middleware/strippathprefix.go index 5c45d55d3645..9de3c05d2f76 100644 --- a/core/http/middleware/strippathprefix.go +++ b/core/http/middleware/strippathprefix.go @@ -3,34 +3,42 @@ package middleware import ( "strings" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" ) -// StripPathPrefix returns a middleware that strips a path prefix from the request path. +// StripPathPrefix returns middleware that strips a path prefix from the request path. // The path prefix is obtained from the X-Forwarded-Prefix HTTP request header. -func StripPathPrefix() fiber.Handler { - return func(c *fiber.Ctx) error { - for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] { - if prefix != "" { - path := c.Path() - pos := len(prefix) +// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing. +func StripPathPrefix() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + prefixes := c.Request().Header.Values("X-Forwarded-Prefix") + originalPath := c.Request().URL.Path - if prefix[pos-1] == '/' { - pos-- - } else { - prefix += "/" - } + for _, prefix := range prefixes { + if prefix != "" { + normalizedPrefix := prefix + if !strings.HasSuffix(prefix, "/") { + normalizedPrefix = prefix + "/" + } - if strings.HasPrefix(path, prefix) { - c.Path(path[pos:]) - break - } else if prefix[:pos] == path { - c.Redirect(prefix) - return nil + if strings.HasPrefix(originalPath, normalizedPrefix) { + // Update the request path by stripping the normalized prefix + c.Request().URL.Path = originalPath[len(normalizedPrefix):] + if c.Request().URL.Path == "" { + c.Request().URL.Path = "/" + } + // Store original path for BaseURL utility + c.Set("_original_path", originalPath) + break + } else if originalPath == prefix || originalPath == prefix+"/" { + // Redirect to prefix with trailing slash (use 302 to match test expectations) + return c.Redirect(302, normalizedPrefix) + } } } - } - return c.Next() + return next(c) + } } } diff --git a/core/http/middleware/strippathprefix_test.go b/core/http/middleware/strippathprefix_test.go index 529f815f71c0..a6b9fd431cae 100644 --- a/core/http/middleware/strippathprefix_test.go +++ b/core/http/middleware/strippathprefix_test.go @@ -4,24 +4,24 @@ import ( "net/http/httptest" "testing" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/require" ) func TestStripPathPrefix(t *testing.T) { var actualPath string - app := fiber.New() + app := echo.New() - app.Use(StripPathPrefix()) + app.Pre(StripPathPrefix()) - app.Get("/hello/world", func(c *fiber.Ctx) error { - actualPath = c.Path() + app.GET("/hello/world", func(c echo.Context) error { + actualPath = c.Request().URL.Path return nil }) - app.Get("/", func(c *fiber.Ctx) error { - actualPath = c.Path() + app.GET("/", func(c echo.Context) error { + actualPath = c.Request().URL.Path return nil }) @@ -106,15 +106,15 @@ func TestStripPathPrefix(t *testing.T) { req.Header["X-Forwarded-Prefix"] = tc.prefixHeader } - resp, err := app.Test(req, -1) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) - require.NoError(t, err) - require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code") + require.Equal(t, tc.expectStatus, rec.Code, "response status code") if tc.expectStatus == 200 { require.Equal(t, tc.expectPath, actualPath, "rewritten path") } else if tc.expectStatus == 302 { - require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location") + require.Equal(t, tc.expectPath, rec.Header().Get("Location"), "redirect location") } }) } diff --git a/core/http/openai_videos_test.go b/core/http/openai_videos_test.go index ef8168c2d2cd..60faada8f5ca 100644 --- a/core/http/openai_videos_test.go +++ b/core/http/openai_videos_test.go @@ -17,7 +17,7 @@ import ( pb "github.com/mudler/LocalAI/pkg/grpc/proto" "fmt" . "github.com/mudler/LocalAI/core/http" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp var _ = Describe("OpenAI /v1/videos (embedded backend)", func() { var tmpdir string var appServer *application.Application - var app *fiber.App + var app *echo.Echo var ctx context.Context var cancel context.CancelFunc @@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() { AfterEach(func() { cancel() if app != nil { - _ = app.Shutdown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = app.Shutdown(ctx) } _ = os.RemoveAll(tmpdir) }) @@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() { var err error app, err = API(appServer) Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9091") + go func() { + if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed { + // Log error if needed + } + }() // wait for server client := &http.Client{Timeout: 5 * time.Second} diff --git a/core/http/render.go b/core/http/render.go index 2f889f57e177..14f9884af4ae 100644 --- a/core/http/render.go +++ b/core/http/render.go @@ -4,11 +4,13 @@ import ( "embed" "fmt" "html/template" + "io" + "io/fs" "net/http" + "strings" "github.com/Masterminds/sprig/v3" - "github.com/gofiber/fiber/v2" - fiberhtml "github.com/gofiber/template/html/v2" + "github.com/labstack/echo/v4" "github.com/microcosm-cc/bluemonday" "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" @@ -18,26 +20,67 @@ import ( //go:embed views/* var viewsfs embed.FS -func notFoundHandler(c *fiber.Ctx) error { +// TemplateRenderer is a custom template renderer for Echo +type TemplateRenderer struct { + templates *template.Template +} + +// Render renders a template document +func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error { + return t.templates.ExecuteTemplate(w, name, data) +} + +func notFoundHandler(c echo.Context) error { // Check if the request accepts JSON - if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { + contentType := c.Request().Header.Get("Content-Type") + accept := c.Request().Header.Get("Accept") + if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") { // The client expects a JSON response - return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{ - Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound}, + return c.JSON(http.StatusNotFound, schema.ErrorResponse{ + Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound}, }) } else { // The client expects an HTML response - return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{ + return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{ "BaseURL": utils.BaseURL(c), }) } } -func renderEngine() *fiberhtml.Engine { - engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html") - engine.AddFuncMap(sprig.FuncMap()) - engine.AddFunc("MDToHTML", markDowner) - return engine +func renderEngine() *TemplateRenderer { + // Parse all templates from embedded filesystem + tmpl := template.New("").Funcs(sprig.FuncMap()) + tmpl = tmpl.Funcs(template.FuncMap{ + "MDToHTML": markDowner, + }) + + // Recursively walk through embedded filesystem and parse all HTML templates + err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() && strings.HasSuffix(path, ".html") { + data, err := viewsfs.ReadFile(path) + if err == nil { + // Remove .html extension to get template name (e.g., "views/index.html" -> "views/index") + templateName := strings.TrimSuffix(path, ".html") + _, err := tmpl.New(templateName).Parse(string(data)) + if err != nil { + // If parsing fails, try parsing without explicit name (for templates with {{define}}) + tmpl.Parse(string(data)) + } + } + } + return nil + }) + if err != nil { + // Log error but continue - templates might still work + fmt.Printf("Error walking views directory: %v\n", err) + } + + return &TemplateRenderer{ + templates: tmpl, + } } func markDowner(args ...interface{}) template.HTML { diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go index 96e132e94e86..90f73eec6417 100644 --- a/core/http/routes/elevenlabs.go +++ b/core/http/routes/elevenlabs.go @@ -1,7 +1,7 @@ package routes import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/elevenlabs" "github.com/mudler/LocalAI/core/http/middleware" @@ -9,21 +9,23 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func RegisterElevenLabsRoutes(app *fiber.App, +func RegisterElevenLabsRoutes(app *echo.Echo, re *middleware.RequestExtractor, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", + ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig) + app.POST("/v1/text-to-speech/:voice-id", + ttsHandler, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), - re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }), - elevenlabs.TTSEndpoint(cl, ml, appConfig)) + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) })) - app.Post("/v1/sound-generation", + soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig) + app.POST("/v1/sound-generation", + soundGenHandler, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)), - re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }), - elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) })) } diff --git a/core/http/routes/explorer.go b/core/http/routes/explorer.go index 960b476b8ffc..670bf67c42fd 100644 --- a/core/http/routes/explorer.go +++ b/core/http/routes/explorer.go @@ -1,13 +1,13 @@ package routes import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" coreExplorer "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http/endpoints/explorer" ) -func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) { - app.Get("/", explorer.Dashboard()) - app.Post("/network/add", explorer.AddNetwork(db)) - app.Get("/networks", explorer.ShowNetworks(db)) +func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) { + app.GET("/", explorer.Dashboard()) + app.POST("/network/add", explorer.AddNetwork(db)) + app.GET("/networks", explorer.ShowNetworks(db)) } diff --git a/core/http/routes/health.go b/core/http/routes/health.go index f5a08e9baf37..5b03953733d8 100644 --- a/core/http/routes/health.go +++ b/core/http/routes/health.go @@ -1,13 +1,15 @@ package routes -import "github.com/gofiber/fiber/v2" +import ( + "github.com/labstack/echo/v4" +) -func HealthRoutes(app *fiber.App) { +func HealthRoutes(app *echo.Echo) { // Service health checks - ok := func(c *fiber.Ctx) error { - return c.SendStatus(200) + ok := func(c echo.Context) error { + return c.NoContent(200) } - app.Get("/healthz", ok) - app.Get("/readyz", ok) + app.GET("/healthz", ok) + app.GET("/readyz", ok) } diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go index a55ca79f5597..b4fafbc57f50 100644 --- a/core/http/routes/jina.go +++ b/core/http/routes/jina.go @@ -1,24 +1,25 @@ package routes import ( + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/jina" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" - "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/pkg/model" ) -func RegisterJINARoutes(app *fiber.App, +func RegisterJINARoutes(app *echo.Echo, re *middleware.RequestExtractor, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { // POST endpoint to mimic the reranking - app.Post("/v1/rerank", + rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig) + app.POST("/v1/rerank", + rerankHandler, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)), - re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }), - jina.JINARerankEndpoint(cl, ml, appConfig)) + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) })) } diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 5d9ae821007d..9f44e2e6b27c 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -1,8 +1,7 @@ package routes import ( - "github.com/gofiber/fiber/v2" - "github.com/gofiber/swagger" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/middleware" @@ -11,9 +10,10 @@ import ( "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" + echoswagger "github.com/swaggo/echo-swagger" ) -func RegisterLocalAIRoutes(router *fiber.App, +func RegisterLocalAIRoutes(router *echo.Echo, requestExtractor *middleware.RequestExtractor, cl *config.ModelConfigLoader, ml *model.ModelLoader, @@ -21,13 +21,13 @@ func RegisterLocalAIRoutes(router *fiber.App, galleryService *services.GalleryService, opcache *services.OpCache) { - router.Get("/swagger/*", swagger.HandlerDefault) // default + router.GET("/swagger/*", echoswagger.WrapHandler) // default // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { // Import model page - router.Get("/import-model", func(c *fiber.Ctx) error { - return c.Render("views/model-editor", fiber.Map{ + router.GET("/import-model", func(c echo.Context) error { + return c.Render(200, "views/model-editor", map[string]interface{}{ "Title": "LocalAI - Import Model", "BaseURL": httpUtils.BaseURL(c), "Version": internal.PrintableVersion(), @@ -35,97 +35,103 @@ func RegisterLocalAIRoutes(router *fiber.App, }) // Edit model page - router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig)) + router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig)) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService) - router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) - router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) + router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) - router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState)) - router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) - router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) - router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) + router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState)) + router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) + router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) + router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) backendGalleryEndpointService := localai.CreateBackendEndpointService( appConfig.BackendGalleries, appConfig.SystemState, galleryService) - router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint()) - router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint()) - router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState)) - router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState)) - router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint()) - router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint()) + router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint()) + router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint()) + router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState)) + router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState)) + router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint()) + router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint()) // Custom model import endpoint - router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig)) + router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig)) // URI model import endpoint - router.Post("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache)) + router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache)) // Custom model edit endpoint - router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig)) + router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig)) // Reload models endpoint - router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig)) + router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig)) } - router.Post("/v1/detection", + detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig) + router.POST("/v1/detection", + detectionHandler, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)), - requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }), - localai.DetectionEndpoint(cl, ml, appConfig)) + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) })) - router.Post("/tts", + ttsHandler := localai.TTSEndpoint(cl, ml, appConfig) + router.POST("/tts", + ttsHandler, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), - requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }), - localai.TTSEndpoint(cl, ml, appConfig)) + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) })) - vadChain := []fiber.Handler{ + vadHandler := localai.VADEndpoint(cl, ml, appConfig) + router.POST("/vad", + vadHandler, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)), - requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }), - localai.VADEndpoint(cl, ml, appConfig), - } - router.Post("/vad", vadChain...) - router.Post("/v1/vad", vadChain...) + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) })) + router.POST("/v1/vad", + vadHandler, + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) })) // Stores - router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig)) - router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig)) - router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig)) - router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig)) + router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig)) + router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig)) + router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig)) + router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig)) if !appConfig.DisableMetrics { - router.Get("/metrics", localai.LocalAIMetricsEndpoint()) + router.GET("/metrics", localai.LocalAIMetricsEndpoint()) } - router.Post("/video", + videoHandler := localai.VideoEndpoint(cl, ml, appConfig) + router.POST("/video", + videoHandler, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)), - requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }), - localai.VideoEndpoint(cl, ml, appConfig)) + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) })) // Backend Statistics Module // TODO: Should these use standard middlewares? Refactor later, they are extremely simple. backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now - router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) - router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) + router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) + router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) // The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered. - router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) - router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) + router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) + router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) // p2p - router.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) - router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) + router.GET("/api/p2p", localai.ShowP2PNodes(appConfig)) + router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig)) - router.Get("/version", func(c *fiber.Ctx) error { - return c.JSON(struct { + router.GET("/version", func(c echo.Context) error { + return c.JSON(200, struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) - router.Get("/system", localai.SystemInformations(ml, appConfig)) + router.GET("/system", localai.SystemInformations(ml, appConfig)) // misc - router.Post("/v1/tokenize", + tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig) + router.POST("/v1/tokenize", + tokenizeHandler, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)), - requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }), - localai.TokenizeEndpoint(cl, ml, appConfig)) + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) })) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index e57d9e6a490d..d76baad50b03 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -1,7 +1,7 @@ package routes import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" @@ -10,118 +10,172 @@ import ( "github.com/mudler/LocalAI/core/schema" ) -func RegisterOpenAIRoutes(app *fiber.App, +func RegisterOpenAIRoutes(app *echo.Echo, re *middleware.RequestExtractor, application *application.Application) { // openAI compatible API endpoint // realtime // TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions - app.Get("/v1/realtime", openai.Realtime(application)) - app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application)) - app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application)) + app.GET("/v1/realtime", openai.Realtime(application)) + app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application)) + app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application)) // chat - chatChain := []fiber.Handler{ + chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) + chatMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/v1/chat/completions", chatChain...) - app.Post("/chat/completions", chatChain...) + app.POST("/v1/chat/completions", chatHandler, chatMiddleware...) + app.POST("/chat/completions", chatHandler, chatMiddleware...) // edit - editChain := []fiber.Handler{ + editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) + editMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/v1/edits", editChain...) - app.Post("/edits", editChain...) + app.POST("/v1/edits", editHandler, editMiddleware...) + app.POST("/edits", editHandler, editMiddleware...) // completion - completionChain := []fiber.Handler{ + completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) + completionMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/v1/completions", completionChain...) - app.Post("/completions", completionChain...) - app.Post("/v1/engines/:model/completions", completionChain...) + app.POST("/v1/completions", completionHandler, completionMiddleware...) + app.POST("/completions", completionHandler, completionMiddleware...) + app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...) // MCPcompletion - mcpCompletionChain := []fiber.Handler{ + mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) + mcpCompletionMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/mcp/v1/chat/completions", mcpCompletionChain...) - app.Post("/mcp/chat/completions", mcpCompletionChain...) + app.POST("/mcp/v1/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...) + app.POST("/mcp/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...) // embeddings - embeddingChain := []fiber.Handler{ + embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + embeddingMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/v1/embeddings", embeddingChain...) - app.Post("/embeddings", embeddingChain...) - app.Post("/v1/engines/:model/embeddings", embeddingChain...) + app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...) + app.POST("/embeddings", embeddingHandler, embeddingMiddleware...) + app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...) - audioChain := []fiber.Handler{ + audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + audioMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } // audio - app.Post("/v1/audio/transcriptions", audioChain...) - app.Post("/audio/transcriptions", audioChain...) + app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...) + app.POST("/audio/transcriptions", audioHandler, audioMiddleware...) - audioSpeechChain := []fiber.Handler{ + audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + audioSpeechMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }), - localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()), } - app.Post("/v1/audio/speech", - audioSpeechChain...) - app.Post("/audio/speech", audioSpeechChain...) + app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...) + app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...) // images - imageChain := []fiber.Handler{ + imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + imageMiddleware := []echo.MiddlewareFunc{ re.BuildConstantDefaultModelNameMiddleware("stablediffusion"), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } - app.Post("/v1/images/generations", - imageChain...) - app.Post("/images/generations", imageChain...) + app.POST("/v1/images/generations", imageHandler, imageMiddleware...) + app.POST("/images/generations", imageHandler, imageMiddleware...) // videos (OpenAI-compatible endpoints mapped to LocalAI video handler) - videoChain := []fiber.Handler{ + videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + videoMiddleware := []echo.MiddlewareFunc{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - re.SetOpenAIRequest, - openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, } // OpenAI-style create video endpoint - app.Post("/v1/videos", videoChain...) - app.Post("/v1/videos/generations", videoChain...) - app.Post("/videos", videoChain...) + app.POST("/v1/videos", videoHandler, videoMiddleware...) + app.POST("/v1/videos/generations", videoHandler, videoMiddleware...) + app.POST("/videos", videoHandler, videoMiddleware...) // List models - app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) - app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index 8d8b2cebaf7f..d866f71800c6 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -1,6 +1,7 @@ package routes import ( + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/endpoints/localai" @@ -8,11 +9,9 @@ import ( "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" - - "github.com/gofiber/fiber/v2" ) -func RegisterUIRoutes(app *fiber.App, +func RegisterUIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -21,11 +20,11 @@ func RegisterUIRoutes(app *fiber.App, // keeps the state of ops that are started from the UI var processingOps = services.NewOpCache(galleryService) - app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps)) + app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps)) // P2P - app.Get("/p2p", func(c *fiber.Ctx) error { - summary := fiber.Map{ + app.GET("/p2p/", func(c echo.Context) error { + summary := map[string]interface{}{ "Title": "LocalAI - P2P dashboard", "BaseURL": utils.BaseURL(c), "Version": internal.PrintableVersion(), @@ -37,7 +36,7 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/p2p", summary) + return c.Render(200, "views/p2p", summary) }) // Note: P2P UI fragment routes (/p2p/ui/*) were removed @@ -50,15 +49,15 @@ func RegisterUIRoutes(app *fiber.App, registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps) } - app.Get("/talk/", func(c *fiber.Ctx) error { + app.GET("/talk/", func(c echo.Context) error { modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED) if len(modelConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(utils.BaseURL(c)) + return c.Redirect(302, utils.BaseURL(c)) } - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": "LocalAI - Talk", "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, @@ -68,16 +67,16 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/talk", summary) + return c.Render(200, "views/talk", summary) }) - app.Get("/chat/", func(c *fiber.Ctx) error { + app.GET("/chat/", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(utils.BaseURL(c)) + return c.Redirect(302, utils.BaseURL(c)) } modelThatCanBeUsed := "" galleryConfigs := map[string]*gallery.ModelConfig{} @@ -104,7 +103,7 @@ func RegisterUIRoutes(app *fiber.App, } } - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": title, "BaseURL": utils.BaseURL(c), "ModelsWithoutConfig": modelsWithoutConfig, @@ -116,16 +115,16 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/chat", summary) + return c.Render(200, "views/chat", summary) }) // Show the Chat page - app.Get("/chat/:model", func(c *fiber.Ctx) error { + app.GET("/chat/:model", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) galleryConfigs := map[string]*gallery.ModelConfig{} - modelName := c.Params("model") + modelName := c.Param("model") var modelContextSize *int for _, m := range modelConfigs { @@ -139,7 +138,7 @@ func RegisterUIRoutes(app *fiber.App, } } - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": "LocalAI - Chat with " + modelName, "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, @@ -151,33 +150,33 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/chat", summary) + return c.Render(200, "views/chat", summary) }) - app.Get("/text2image/:model", func(c *fiber.Ctx) error { + app.GET("/text2image/:model", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) - summary := fiber.Map{ - "Title": "LocalAI - Generate images with " + c.Params("model"), + summary := map[string]interface{}{ + "Title": "LocalAI - Generate images with " + c.Param("model"), "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, - "Model": c.Params("model"), + "Model": c.Param("model"), "Version": internal.PrintableVersion(), } // Render index - return c.Render("views/text2image", summary) + return c.Render(200, "views/text2image", summary) }) - app.Get("/text2image/", func(c *fiber.Ctx) error { + app.GET("/text2image/", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(utils.BaseURL(c)) + return c.Redirect(302, utils.BaseURL(c)) } modelThatCanBeUsed := "" @@ -191,7 +190,7 @@ func RegisterUIRoutes(app *fiber.App, } } - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": title, "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, @@ -201,33 +200,33 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/text2image", summary) + return c.Render(200, "views/text2image", summary) }) - app.Get("/tts/:model", func(c *fiber.Ctx) error { + app.GET("/tts/:model", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) - summary := fiber.Map{ - "Title": "LocalAI - Generate images with " + c.Params("model"), + summary := map[string]interface{}{ + "Title": "LocalAI - Generate images with " + c.Param("model"), "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, - "Model": c.Params("model"), + "Model": c.Param("model"), "Version": internal.PrintableVersion(), } // Render index - return c.Render("views/tts", summary) + return c.Render(200, "views/tts", summary) }) - app.Get("/tts/", func(c *fiber.Ctx) error { + app.GET("/tts/", func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(utils.BaseURL(c)) + return c.Redirect(302, utils.BaseURL(c)) } modelThatCanBeUsed := "" @@ -240,7 +239,7 @@ func RegisterUIRoutes(app *fiber.App, break } } - summary := fiber.Map{ + summary := map[string]interface{}{ "Title": title, "BaseURL": utils.BaseURL(c), "ModelsConfig": modelConfigs, @@ -250,6 +249,6 @@ func RegisterUIRoutes(app *fiber.App, } // Render index - return c.Render("views/tts", summary) + return c.Render(200, "views/tts", summary) }) } diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 3ea4852e08dd..7eac35705501 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "math" + "net/http" "net/url" "sort" "strconv" "strings" - "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/p2p" @@ -19,13 +20,13 @@ import ( ) // RegisterUIAPIRoutes registers JSON API routes for the web UI -func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { +func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { // Operations API - Get all current operations (models + backends) - app.Get("/api/operations", func(c *fiber.Ctx) error { + app.GET("/api/operations", func(c echo.Context) error { processingData, taskTypes := opcache.GetStatus() - operations := []fiber.Map{} + operations := []map[string]interface{}{} for galleryID, jobID := range processingData { taskType := "installation" if tt, ok := taskTypes[galleryID]; ok { @@ -88,7 +89,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig } } - operations = append(operations, fiber.Map{ + operations = append(operations, map[string]interface{}{ "id": galleryID, "name": displayName, "fullName": galleryID, @@ -118,20 +119,20 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig return operations[i]["id"].(string) < operations[j]["id"].(string) }) - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "operations": operations, }) }) // Cancel operation endpoint - app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error { - jobID := strings.Clone(c.Params("jobID")) + app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error { + jobID := c.Param("jobID") log.Debug().Msgf("API request to cancel operation: %s", jobID) err := galleryService.CancelOperation(jobID) if err != nil { log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID) - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": err.Error(), }) } @@ -139,22 +140,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig // Clean up opcache for cancelled operation opcache.DeleteUUID(jobID) - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "success": true, "message": "Operation cancelled", }) }) // Model Gallery APIs - app.Get("/api/models", func(c *fiber.Ctx) error { - term := c.Query("term") - page := c.Query("page", "1") - items := c.Query("items", "21") + app.GET("/api/models", func(c echo.Context) error { + term := c.QueryParam("term") + page := c.QueryParam("page") + if page == "" { + page = "1" + } + items := c.QueryParam("items") + if items == "" { + items = "21" + } models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) if err != nil { log.Error().Err(err).Msg("could not list models from galleries") - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -197,7 +204,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig } // Convert models to JSON-friendly format and deduplicate by ID - modelsJSON := make([]fiber.Map, 0, len(models)) + modelsJSON := make([]map[string]interface{}, 0, len(models)) seenIDs := make(map[string]bool) for _, m := range models { @@ -223,7 +230,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig _, trustRemoteCodeExists := m.Overrides["trust_remote_code"] - modelsJSON = append(modelsJSON, fiber.Map{ + modelsJSON = append(modelsJSON, map[string]interface{}{ "id": modelID, "name": m.Name, "description": m.Description, @@ -250,7 +257,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig nextPage = totalPages } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "models": modelsJSON, "repositories": appConfig.Galleries, "allTags": tags, @@ -264,12 +271,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) }) - app.Post("/api/models/install/:id", func(c *fiber.Ctx) error { - galleryID := strings.Clone(c.Params("id")) + app.POST("/api/models/install/:id", func(c echo.Context) error { + galleryID := c.Param("id") // URL decode the gallery ID (e.g., "localai%40model" -> "localai@model") galleryID, err := url.QueryUnescape(galleryID) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid model ID", }) } @@ -277,7 +284,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig id, err := uuid.NewUUID() if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -300,18 +307,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig galleryService.ModelGalleryChannel <- op }() - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "jobID": uid, "message": "Installation started", }) }) - app.Post("/api/models/delete/:id", func(c *fiber.Ctx) error { - galleryID := strings.Clone(c.Params("id")) + app.POST("/api/models/delete/:id", func(c echo.Context) error { + galleryID := c.Param("id") // URL decode the gallery ID galleryID, err := url.QueryUnescape(galleryID) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid model ID", }) } @@ -324,7 +331,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig id, err := uuid.NewUUID() if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -350,18 +357,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig cl.RemoveModelConfig(galleryName) }() - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "jobID": uid, "message": "Deletion started", }) }) - app.Post("/api/models/config/:id", func(c *fiber.Ctx) error { - galleryID := strings.Clone(c.Params("id")) + app.POST("/api/models/config/:id", func(c echo.Context) error { + galleryID := c.Param("id") // URL decode the gallery ID galleryID, err := url.QueryUnescape(galleryID) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid model ID", }) } @@ -369,44 +376,44 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } model := gallery.FindGalleryElement(models, galleryID) if model == nil { - return c.Status(fiber.StatusNotFound).JSON(fiber.Map{ + return c.JSON(http.StatusNotFound, map[string]interface{}{ "error": "model not found", }) } config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath) if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } _, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false) if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "message": "Configuration file saved", }) }) - app.Get("/api/models/job/:uid", func(c *fiber.Ctx) error { - jobUID := strings.Clone(c.Params("uid")) + app.GET("/api/models/job/:uid", func(c echo.Context) error { + jobUID := c.Param("uid") status := galleryService.GetStatus(jobUID) if status == nil { // Job is queued but hasn't started processing yet - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "progress": 0, "message": "Operation queued", "galleryElementName": "", @@ -416,7 +423,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) } - response := fiber.Map{ + response := map[string]interface{}{ "progress": status.Progress, "message": status.Message, "galleryElementName": status.GalleryElementName, @@ -434,19 +441,25 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig response["completed"] = true } - return c.JSON(response) + return c.JSON(200, response) }) // Backend Gallery APIs - app.Get("/api/backends", func(c *fiber.Ctx) error { - term := c.Query("term") - page := c.Query("page", "1") - items := c.Query("items", "21") + app.GET("/api/backends", func(c echo.Context) error { + term := c.QueryParam("term") + page := c.QueryParam("page") + if page == "" { + page = "1" + } + items := c.QueryParam("items") + if items == "" { + items = "21" + } backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) if err != nil { log.Error().Err(err).Msg("could not list backends from galleries") - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -489,7 +502,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig } // Convert backends to JSON-friendly format and deduplicate by ID - backendsJSON := make([]fiber.Map, 0, len(backends)) + backendsJSON := make([]map[string]interface{}, 0, len(backends)) seenBackendIDs := make(map[string]bool) for _, b := range backends { @@ -513,7 +526,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig } } - backendsJSON = append(backendsJSON, fiber.Map{ + backendsJSON = append(backendsJSON, map[string]interface{}{ "id": backendID, "name": b.Name, "description": b.Description, @@ -538,7 +551,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig nextPage = totalPages } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "backends": backendsJSON, "repositories": appConfig.BackendGalleries, "allTags": tags, @@ -552,12 +565,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) }) - app.Post("/api/backends/install/:id", func(c *fiber.Ctx) error { - backendID := strings.Clone(c.Params("id")) + app.POST("/api/backends/install/:id", func(c echo.Context) error { + backendID := c.Param("id") // URL decode the backend ID backendID, err := url.QueryUnescape(backendID) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid backend ID", }) } @@ -565,7 +578,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig id, err := uuid.NewUUID() if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -587,18 +600,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig galleryService.BackendGalleryChannel <- op }() - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "jobID": uid, "message": "Backend installation started", }) }) - app.Post("/api/backends/delete/:id", func(c *fiber.Ctx) error { - backendID := strings.Clone(c.Params("id")) + app.POST("/api/backends/delete/:id", func(c echo.Context) error { + backendID := c.Param("id") // URL decode the backend ID backendID, err := url.QueryUnescape(backendID) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid backend ID", }) } @@ -611,7 +624,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig id, err := uuid.NewUUID() if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } @@ -635,19 +648,19 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig galleryService.BackendGalleryChannel <- op }() - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "jobID": uid, "message": "Backend deletion started", }) }) - app.Get("/api/backends/job/:uid", func(c *fiber.Ctx) error { - jobUID := strings.Clone(c.Params("uid")) + app.GET("/api/backends/job/:uid", func(c echo.Context) error { + jobUID := c.Param("uid") status := galleryService.GetStatus(jobUID) if status == nil { // Job is queued but hasn't started processing yet - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "progress": 0, "message": "Operation queued", "galleryElementName": "", @@ -657,7 +670,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) } - response := fiber.Map{ + response := map[string]interface{}{ "progress": status.Progress, "message": status.Message, "galleryElementName": status.GalleryElementName, @@ -675,16 +688,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig response["completed"] = true } - return c.JSON(response) + return c.JSON(200, response) }) // System Backend Deletion API (for installed backends on index page) - app.Post("/api/backends/system/delete/:name", func(c *fiber.Ctx) error { - backendName := strings.Clone(c.Params("name")) + app.POST("/api/backends/system/delete/:name", func(c echo.Context) error { + backendName := c.Param("name") // URL decode the backend name backendName, err := url.QueryUnescape(backendName) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + return c.JSON(http.StatusBadRequest, map[string]interface{}{ "error": "invalid backend name", }) } @@ -693,24 +706,24 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig // Use the gallery package to delete the backend if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil { log.Error().Err(err).Msgf("Failed to delete backend: %s", backendName) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ "error": err.Error(), }) } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "success": true, "message": "Backend deleted successfully", }) }) // P2P APIs - app.Get("/api/p2p/workers", func(c *fiber.Ctx) error { + app.GET("/api/p2p/workers", func(c echo.Context) error { nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) - nodesJSON := make([]fiber.Map, 0, len(nodes)) + nodesJSON := make([]map[string]interface{}, 0, len(nodes)) for _, n := range nodes { - nodesJSON = append(nodesJSON, fiber.Map{ + nodesJSON = append(nodesJSON, map[string]interface{}{ "name": n.Name, "id": n.ID, "tunnelAddress": n.TunnelAddress, @@ -720,17 +733,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "nodes": nodesJSON, }) }) - app.Get("/api/p2p/federation", func(c *fiber.Ctx) error { + app.GET("/api/p2p/federation", func(c echo.Context) error { nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)) - nodesJSON := make([]fiber.Map, 0, len(nodes)) + nodesJSON := make([]map[string]interface{}, 0, len(nodes)) for _, n := range nodes { - nodesJSON = append(nodesJSON, fiber.Map{ + nodesJSON = append(nodesJSON, map[string]interface{}{ "name": n.Name, "id": n.ID, "tunnelAddress": n.TunnelAddress, @@ -740,12 +753,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig }) } - return c.JSON(fiber.Map{ + return c.JSON(200, map[string]interface{}{ "nodes": nodesJSON, }) }) - app.Get("/api/p2p/stats", func(c *fiber.Ctx) error { + app.GET("/api/p2p/stats", func(c echo.Context) error { workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)) @@ -763,12 +776,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig } } - return c.JSON(fiber.Map{ - "workers": fiber.Map{ + return c.JSON(200, map[string]interface{}{ + "workers": map[string]interface{}{ "online": workersOnline, "total": len(workerNodes), }, - "federated": fiber.Map{ + "federated": map[string]interface{}{ "online": federatedOnline, "total": len(federatedNodes), }, diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go index ca9c8c7651e0..52502cb613fa 100644 --- a/core/http/routes/ui_backend_gallery.go +++ b/core/http/routes/ui_backend_gallery.go @@ -1,17 +1,17 @@ package routes import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" ) -func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { +func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { // Show the Backends page (all backends are loaded client-side via Alpine.js) - app.Get("/browse/backends", func(c *fiber.Ctx) error { - summary := fiber.Map{ + app.GET("/browse/backends", func(c echo.Context) error { + summary := map[string]interface{}{ "Title": "LocalAI - Backends", "BaseURL": utils.BaseURL(c), "Version": internal.PrintableVersion(), @@ -19,6 +19,6 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC } // Render index - backends are now loaded via Alpine.js from /api/backends - return c.Render("views/backends", summary) + return c.Render(200, "views/backends", summary) }) } diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go index f84aa7c3702d..d96b6faa53da 100644 --- a/core/http/routes/ui_gallery.go +++ b/core/http/routes/ui_gallery.go @@ -1,17 +1,17 @@ package routes import ( - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" ) -func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { +func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) { - app.Get("/browse", func(c *fiber.Ctx) error { - summary := fiber.Map{ + app.GET("/browse/", func(c echo.Context) error { + summary := map[string]interface{}{ "Title": "LocalAI - Models", "BaseURL": utils.BaseURL(c), "Version": internal.PrintableVersion(), @@ -19,6 +19,6 @@ func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConf } // Render index - models are now loaded via Alpine.js from /api/models - return c.Render("views/models", summary) + return c.Render(200, "views/models", summary) }) } diff --git a/core/http/utils/baseurl.go b/core/http/utils/baseurl.go index 9fe20f44140c..0046a6e6e357 100644 --- a/core/http/utils/baseurl.go +++ b/core/http/utils/baseurl.go @@ -3,22 +3,38 @@ package utils import ( "strings" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" ) // BaseURL returns the base URL for the given HTTP request context. // It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path. // The returned URL is guaranteed to end with `/`. // The method should be used in conjunction with the StripPathPrefix middleware. -func BaseURL(c *fiber.Ctx) string { +func BaseURL(c echo.Context) string { path := c.Path() - origPath := c.OriginalURL() + origPath := c.Request().URL.Path - if path != origPath && strings.HasSuffix(origPath, path) { - pathPrefix := origPath[:len(origPath)-len(path)+1] + if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 { + prefixLen := len(origPath) - len(path) + if prefixLen > 0 && prefixLen <= len(origPath) { + pathPrefix := origPath[:prefixLen] + if !strings.HasSuffix(pathPrefix, "/") { + pathPrefix += "/" + } - return c.BaseURL() + pathPrefix + scheme := "http" + if c.Request().TLS != nil { + scheme = "https" + } + host := c.Request().Host + return scheme + "://" + host + pathPrefix + } } - return c.BaseURL() + "/" + scheme := "http" + if c.Request().TLS != nil { + scheme = "https" + } + host := c.Request().Host + return scheme + "://" + host + "/" } diff --git a/core/http/utils/baseurl_test.go b/core/http/utils/baseurl_test.go index 1750285cf32e..d1d2070b586f 100644 --- a/core/http/utils/baseurl_test.go +++ b/core/http/utils/baseurl_test.go @@ -4,7 +4,7 @@ import ( "net/http/httptest" "testing" - "github.com/gofiber/fiber/v2" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/require" ) @@ -26,22 +26,22 @@ func TestBaseURL(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - app := fiber.New() + app := echo.New() actualURL := "" - app.Get(tc.prefix+"hello/world", func(c *fiber.Ctx) error { + app.GET(tc.prefix+"hello/world", func(c echo.Context) error { if tc.prefix != "/" { - c.Path("/hello/world") + c.Request().URL.Path = "/hello/world" } actualURL = BaseURL(c) return nil }) req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil) - resp, err := app.Test(req, -1) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode, "response status code") + require.Equal(t, 200, rec.Code, "response status code") require.Equal(t, tc.expectURL, actualURL, "base URL") }) } diff --git a/go.mod b/go.mod index 464ec1553086..64af33f0df09 100644 --- a/go.mod +++ b/go.mod @@ -11,14 +11,12 @@ require ( github.com/alecthomas/kong v1.12.1 github.com/charmbracelet/glamour v0.10.0 github.com/containerd/containerd v1.7.29 - github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 github.com/ebitengine/purego v0.9.1 github.com/fsnotify/fsnotify v1.9.0 github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-llama.cpp v0.0.0-20240314183750-6a8041ef6b46 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/swagger v1.1.1 - github.com/gofiber/template/html/v2 v2.1.3 github.com/gofiber/websocket/v2 v2.2.1 github.com/gofrs/flock v0.13.0 github.com/google/go-containerregistry v0.19.2 @@ -29,6 +27,7 @@ require ( github.com/jaypipes/ghw v0.20.0 github.com/joho/godotenv v1.5.1 github.com/klauspost/cpuid/v2 v2.3.0 + github.com/labstack/echo/v4 v4.13.4 github.com/libp2p/go-libp2p v0.43.0 github.com/lithammer/fuzzysearch v1.1.8 github.com/mholt/archiver/v3 v3.5.1 @@ -50,6 +49,7 @@ require ( github.com/shirou/gopsutil/v3 v3.24.5 github.com/streamer45/silero-vad-go v0.2.1 github.com/stretchr/testify v1.11.1 + github.com/swaggo/echo-swagger v1.4.1 github.com/swaggo/swag v1.16.6 github.com/testcontainers/testcontainers-go v0.40.0 github.com/tmc/langchaingo v0.1.14 @@ -65,6 +65,14 @@ require ( oras.land/oras-go/v2 v2.6.0 ) +require ( + github.com/fasthttp/websocket v1.5.3 // indirect + github.com/ghodss/yaml v1.0.0 // indirect + github.com/labstack/gommon v0.4.2 // indirect + github.com/swaggo/files/v2 v2.0.2 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect +) + require ( fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 // indirect github.com/BurntSushi/toml v1.5.0 // indirect @@ -78,7 +86,6 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/fasthttp/websocket v1.5.8 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect @@ -143,9 +150,10 @@ require ( go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.uber.org/mock v0.5.2 // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect + go.yaml.in/yaml/v2 v2.4.2 go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/image v0.25.0 // indirect + golang.org/x/net v0.46.0 // indirect; indirect (for websocket) golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/telemetry v0.0.0-20250908211612-aef8a434d053 // indirect golang.org/x/time v0.12.0 // indirect @@ -195,9 +203,6 @@ require ( github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/spec v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect - github.com/gofiber/contrib/fiberzerolog v1.0.3 - github.com/gofiber/template v1.8.3 // indirect - github.com/gofiber/utils v1.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.4 // indirect @@ -275,7 +280,6 @@ require ( github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect - github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.6 // indirect @@ -296,8 +300,6 @@ require ( github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/cast v1.7.0 // indirect - github.com/swaggo/files/v2 v2.0.2 // indirect - github.com/tinylib/msgp v1.2.5 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/ulikunitz/xz v0.5.14 // indirect @@ -320,7 +322,6 @@ require ( golang.org/x/crypto v0.43.0 // indirect golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect golang.org/x/mod v0.28.0 // indirect - golang.org/x/net v0.46.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/term v0.36.0 // indirect diff --git a/go.sum b/go.sum index 69ffc9adc165..39141ac1c2f6 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,6 @@ github.com/creachadair/otp v0.5.0 h1:q3Th7CXm2zlmCdBjw5tEPFOj4oWJMnVL5HXlq0sNKS0 github.com/creachadair/otp v0.5.0/go.mod h1:0kceI87EnYFNYSTL121goJVAnk3eJhaed9H0nMuJUkA= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0= -github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -151,8 +149,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8= -github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0= +github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek= +github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -177,6 +175,7 @@ github.com/fyne-io/image v0.1.1 h1:WH0z4H7qfvNUw5l4p3bC1q70sa5+YWVt6HCj7y4VNyA= github.com/fyne-io/image v0.1.1/go.mod h1:xrfYBh6yspc+KjkgdZU/ifUC9sPA5Iv7WYUBzQKK7JM= github.com/fyne-io/oksvg v0.2.0 h1:mxcGU2dx6nwjJsSA9PCYZDuoAcsZ/OuJlvg/Q9Njfo8= github.com/fyne-io/oksvg v0.2.0/go.mod h1:dJ9oEkPiWhnTFNCmRgEze+YNprJF7YRbpjgpWS4kzoI= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gkampitakis/ciinfo v0.3.2 h1:JcuOPk8ZU7nZQjdUhctuhQofk7BGHuIy0c9Ez8BNhXs= github.com/gkampitakis/ciinfo v0.3.2/go.mod h1:1NIwaOcFChN4fa/B0hEBdAb6npDlFL8Bwx4dfRLRqAo= @@ -228,18 +227,10 @@ github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofiber/contrib/fiberzerolog v1.0.3 h1:Z97hA5bNfThtZjEYG12g9YcT8I/cmCikNgmE4uzFk0U= -github.com/gofiber/contrib/fiberzerolog v1.0.3/go.mod h1:0MD+NNFy0nZwiSo4dSVW7WwWVzOyuATNXwhJwgOP8uM= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/swagger v1.1.1 h1:FZVhVQQ9s1ZKLHL/O0loLh49bYB5l1HEAgxDlcTtkRA= github.com/gofiber/swagger v1.1.1/go.mod h1:vtvY/sQAMc/lGTUCg0lqmBL7Ht9O7uzChpbvJeJQINw= -github.com/gofiber/template v1.8.3 h1:hzHdvMwMo/T2kouz2pPCA0zGiLCeMnoGsQZBTSYgZxc= -github.com/gofiber/template v1.8.3/go.mod h1:bs/2n0pSNPOkRa5VJ8zTIvedcI/lEYxzV3+YPXdBvq8= -github.com/gofiber/template/html/v2 v2.1.3 h1:n1LYBtmr9C0V/k/3qBblXyMxV5B0o/gpb6dFLp8ea+o= -github.com/gofiber/template/html/v2 v2.1.3/go.mod h1:U5Fxgc5KpyujU9OqKzy6Kn6Qup6Tm7zdsISR+VpnHRE= -github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM= -github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0= github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= @@ -405,6 +396,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= +github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= +github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcXEA= +github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= +github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8= github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg= github.com/libp2p/go-cidranger v1.1.0 h1:ewPN8EZ0dd1LSnrtuwd4709PXVcITVeuwbag38yPW7c= @@ -596,8 +593,6 @@ github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+v github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= -github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= -github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM= github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= @@ -774,6 +769,8 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/swaggo/echo-swagger v1.4.1 h1:Yf0uPaJWp1uRtDloZALyLnvdBeoEL5Kc7DtnjzO/TUk= +github.com/swaggo/echo-swagger v1.4.1/go.mod h1:C8bSi+9yH2FLZsnhqMZLIZddpUxZdBYuNHbtaS1Hljc= github.com/swaggo/files/v2 v2.0.2 h1:Bq4tgS/yxLB/3nwOMcul5oLEUKa877Ykgz3CJMVbQKU= github.com/swaggo/files/v2 v2.0.2/go.mod h1:TVqetIzZsO9OhHX1Am9sRf9LdrFZqoK49N37KON/jr0= github.com/swaggo/swag v1.16.6 h1:qBNcx53ZaX+M5dxVyTrgQ0PJ/ACK+NzhwcbieTt+9yI= @@ -789,8 +786,6 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po= -github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= @@ -807,6 +802,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/vbatts/tar-split v0.11.3 h1:hLFqsOLQ1SsppQNTMpkpPXClLDfC2A3Zgy9OUU+RVck= github.com/vbatts/tar-split v0.11.3/go.mod h1:9QlHN18E+fEH7RdG+QAJJcuya3rqT7eXSTY7wGrAokY= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= From 4c7c4b00e087f4654a554e06b9f1fe911954539a Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 14 Nov 2025 15:45:25 +0100 Subject: [PATCH 2/2] tests Signed-off-by: Ettore Di Giacinto --- core/http/endpoints/explorer/dashboard.go | 4 +- core/http/endpoints/localai/backend.go | 6 +- core/http/endpoints/localai/edit_model.go | 2 +- .../http/endpoints/localai/edit_model_test.go | 23 +- core/http/endpoints/localai/gallery.go | 6 +- core/http/endpoints/localai/import_model.go | 2 +- core/http/endpoints/localai/video.go | 11 +- core/http/endpoints/localai/welcome.go | 8 +- core/http/endpoints/openai/image.go | 11 +- core/http/middleware/auth.go | 3 +- core/http/{utils => middleware}/baseurl.go | 32 ++- core/http/middleware/baseurl_test.go | 58 +++++ core/http/middleware/middleware_suite_test.go | 13 ++ core/http/middleware/strippathprefix.go | 19 +- core/http/middleware/strippathprefix_test.go | 211 ++++++++++-------- core/http/render.go | 4 +- core/http/routes/localai.go | 3 +- core/http/routes/ui.go | 26 +-- core/http/routes/ui_backend_gallery.go | 4 +- core/http/routes/ui_gallery.go | 4 +- core/http/utils/baseurl_test.go | 48 ---- 21 files changed, 286 insertions(+), 212 deletions(-) rename core/http/{utils => middleware}/baseurl.go (64%) create mode 100644 core/http/middleware/baseurl_test.go create mode 100644 core/http/middleware/middleware_suite_test.go delete mode 100644 core/http/utils/baseurl_test.go diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 34b57fc6657c..3c1e0ae91337 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -8,7 +8,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/explorer" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/internal" ) @@ -17,7 +17,7 @@ func Dashboard() echo.HandlerFunc { summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), } contentType := c.Request().Header.Get("Content-Type") diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go index 80f47f658431..4c692538cc61 100644 --- a/core/http/endpoints/localai/backend.go +++ b/core/http/endpoints/localai/backend.go @@ -8,7 +8,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/system" @@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc { Galleries: mgs.galleries, } - return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } @@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { return err } - return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go index 697238cf0e72..4c59add22c31 100644 --- a/core/http/endpoints/localai/edit_model.go +++ b/core/http/endpoints/localai/edit_model.go @@ -8,7 +8,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" - httpUtils "github.com/mudler/LocalAI/core/http/utils" + httpUtils "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/utils" diff --git a/core/http/endpoints/localai/edit_model_test.go b/core/http/endpoints/localai/edit_model_test.go index 6e4c7bf936f3..b354dbc2b249 100644 --- a/core/http/endpoints/localai/edit_model_test.go +++ b/core/http/endpoints/localai/edit_model_test.go @@ -2,6 +2,7 @@ package localai_test import ( "bytes" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -16,6 +17,14 @@ import ( . "github.com/onsi/gomega" ) +// testRenderer is a simple renderer for tests that returns JSON +type testRenderer struct{} + +func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error { + // For tests, just return the data as JSON + return json.NewEncoder(w).Encode(data) +} + var _ = Describe("Edit Model test", func() { var tempDir string @@ -41,9 +50,12 @@ var _ = Describe("Edit Model test", func() { //modelLoader := model.NewModelLoader(systemState, true) modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath) - // Define Echo app. + // Define Echo app and register all routes upfront app := echo.New() + // Set up a simple renderer for the test + app.Renderer = &testRenderer{} app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig)) + app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig)) requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`) @@ -57,16 +69,15 @@ var _ = Describe("Edit Model test", func() { Expect(string(body)).To(ContainSubstring("Model configuration created successfully")) Expect(rec.Code).To(Equal(http.StatusOK)) - app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig)) - requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`) - - req = httptest.NewRequest("GET", "/edit-model/foo", requestBody) + req = httptest.NewRequest("GET", "/edit-model/foo", nil) rec = httptest.NewRecorder() app.ServeHTTP(rec, req) body, err = io.ReadAll(rec.Body) Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(ContainSubstring(`"model":"foo"`)) + // The response contains the model configuration with backend field + Expect(string(body)).To(ContainSubstring(`"backend":"foo"`)) + Expect(string(body)).To(ContainSubstring(`"name":"foo"`)) Expect(rec.Code).To(Equal(http.StatusOK)) }) }) diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 1938a9eb5bed..9a96fd1c2efc 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -8,7 +8,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/system" @@ -85,7 +85,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler BackendGalleries: mgs.backendGalleries, } - return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } @@ -108,7 +108,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle return err } - return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) + return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index 8e11bbc01161..d44d11ff8deb 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -14,7 +14,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery/importers" - httpUtils "github.com/mudler/LocalAI/core/http/utils" + httpUtils "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/utils" diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index 5d2482b61022..7ffb6a3dcf09 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -7,16 +7,16 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "strings" "time" - "github.com/labstack/echo/v4" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" @@ -165,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return err } - baseURL := utils.BaseURL(c) + baseURL := middleware.BaseURL(c) fn, err := backend.VideoGeneration( height, @@ -202,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi item.B64JSON = base64.StdEncoding.EncodeToString(data) } else { base := filepath.Base(output) - item.URL = baseURL + "/generated-videos/" + base + item.URL, err = url.JoinPath(baseURL, "generated-videos", base) + if err != nil { + return err + } } id := uuid.New().String() diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index f571ba7d2fb6..76f7f1a4a969 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -6,7 +6,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" @@ -45,7 +45,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "Models": modelsWithoutConfig, "ModelsConfig": modelConfigs, "GalleryConfig": galleryConfigs, @@ -58,7 +58,9 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, contentType := c.Request().Header.Get("Content-Type") accept := c.Request().Header.Get("Accept") - if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") { + // Default to HTML if Accept header is empty (browser behavior) + // Only return JSON if explicitly requested or Content-Type is application/json + if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) { // The client expects a JSON response return c.JSON(200, summary) } else { diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 14d571e36f7f..9e7aef1a07d6 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -7,17 +7,17 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "strconv" "strings" "time" - "github.com/labstack/echo/v4" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" @@ -189,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi return err } - baseURL := utils.BaseURL(c) + baseURL := middleware.BaseURL(c) // Use the first input image as src if available, otherwise use the original src inputSrc := src @@ -216,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi item.B64JSON = base64.StdEncoding.EncodeToString(data) } else { base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base + item.URL, err = url.JoinPath(baseURL, "generated-images", base) + if err != nil { + return err + } } result = append(result, *item) diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go index 0a392d24829a..2538b795e992 100644 --- a/core/http/middleware/auth.go +++ b/core/http/middleware/auth.go @@ -9,7 +9,6 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" ) @@ -115,7 +114,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(err } return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ - "BaseURL": utils.BaseURL(c), + "BaseURL": BaseURL(c), }) } if applicationConfig.OpaqueErrors { diff --git a/core/http/utils/baseurl.go b/core/http/middleware/baseurl.go similarity index 64% rename from core/http/utils/baseurl.go rename to core/http/middleware/baseurl.go index 0046a6e6e357..78a59289a81f 100644 --- a/core/http/utils/baseurl.go +++ b/core/http/middleware/baseurl.go @@ -1,4 +1,4 @@ -package utils +package middleware import ( "strings" @@ -14,6 +14,25 @@ func BaseURL(c echo.Context) string { path := c.Path() origPath := c.Request().URL.Path + // Check if StripPathPrefix middleware stored the original path + if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" { + origPath = storedPath + } + + // Check X-Forwarded-Proto for scheme + scheme := "http" + if c.Request().Header.Get("X-Forwarded-Proto") == "https" { + scheme = "https" + } else if c.Request().TLS != nil { + scheme = "https" + } + + // Check X-Forwarded-Host for host + host := c.Request().Host + if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 { prefixLen := len(origPath) - len(path) if prefixLen > 0 && prefixLen <= len(origPath) { @@ -21,20 +40,9 @@ func BaseURL(c echo.Context) string { if !strings.HasSuffix(pathPrefix, "/") { pathPrefix += "/" } - - scheme := "http" - if c.Request().TLS != nil { - scheme = "https" - } - host := c.Request().Host return scheme + "://" + host + pathPrefix } } - scheme := "http" - if c.Request().TLS != nil { - scheme = "https" - } - host := c.Request().Host return scheme + "://" + host + "/" } diff --git a/core/http/middleware/baseurl_test.go b/core/http/middleware/baseurl_test.go new file mode 100644 index 000000000000..b0770b8eae41 --- /dev/null +++ b/core/http/middleware/baseurl_test.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http/httptest" + + "github.com/labstack/echo/v4" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("BaseURL", func() { + Context("without prefix", func() { + It("should return base URL without prefix", func() { + app := echo.New() + actualURL := "" + + // Register route - use the actual request path so routing works + routePath := "/hello/world" + app.GET(routePath, func(c echo.Context) error { + actualURL = BaseURL(c) + return nil + }) + + req := httptest.NewRequest("GET", "/hello/world", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualURL).To(Equal("http://example.com/"), "base URL") + }) + }) + + Context("with prefix", func() { + It("should return base URL with prefix", func() { + app := echo.New() + actualURL := "" + + // Register route with the stripped path (after middleware removes prefix) + routePath := "/hello/world" + app.GET(routePath, func(c echo.Context) error { + // Simulate what StripPathPrefix middleware does - store original path + c.Set("_original_path", "/myprefix/hello/world") + // Modify the request path to simulate prefix stripping + c.Request().URL.Path = "/hello/world" + actualURL = BaseURL(c) + return nil + }) + + // Make request with stripped path (middleware would have already processed it) + req := httptest.NewRequest("GET", "/hello/world", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL") + }) + }) +}) diff --git a/core/http/middleware/middleware_suite_test.go b/core/http/middleware/middleware_suite_test.go new file mode 100644 index 000000000000..0f40add2539d --- /dev/null +++ b/core/http/middleware/middleware_suite_test.go @@ -0,0 +1,13 @@ +package middleware_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMiddleware(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Middleware test suite") +} diff --git a/core/http/middleware/strippathprefix.go b/core/http/middleware/strippathprefix.go index 9de3c05d2f76..451ccfe667ca 100644 --- a/core/http/middleware/strippathprefix.go +++ b/core/http/middleware/strippathprefix.go @@ -24,9 +24,22 @@ func StripPathPrefix() echo.MiddlewareFunc { if strings.HasPrefix(originalPath, normalizedPrefix) { // Update the request path by stripping the normalized prefix - c.Request().URL.Path = originalPath[len(normalizedPrefix):] - if c.Request().URL.Path == "" { - c.Request().URL.Path = "/" + newPath := originalPath[len(normalizedPrefix):] + if newPath == "" { + newPath = "/" + } + // Ensure path starts with / for proper routing + if !strings.HasPrefix(newPath, "/") { + newPath = "/" + newPath + } + // Update the URL path - Echo's router uses URL.Path for routing + c.Request().URL.Path = newPath + c.Request().URL.RawPath = "" + // Update RequestURI to match the new path (needed for proper routing) + if c.Request().URL.RawQuery != "" { + c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery + } else { + c.Request().RequestURI = newPath } // Store original path for BaseURL utility c.Set("_original_path", originalPath) diff --git a/core/http/middleware/strippathprefix_test.go b/core/http/middleware/strippathprefix_test.go index a6b9fd431cae..32c1c5d4af6a 100644 --- a/core/http/middleware/strippathprefix_test.go +++ b/core/http/middleware/strippathprefix_test.go @@ -2,120 +2,133 @@ package middleware import ( "net/http/httptest" - "testing" "github.com/labstack/echo/v4" - "github.com/stretchr/testify/require" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" ) -func TestStripPathPrefix(t *testing.T) { +var _ = Describe("StripPathPrefix", func() { + var app *echo.Echo var actualPath string + var appInitialized bool - app := echo.New() + BeforeEach(func() { + actualPath = "" + if !appInitialized { + app = echo.New() + app.Pre(StripPathPrefix()) - app.Pre(StripPathPrefix()) + app.GET("/hello/world", func(c echo.Context) error { + actualPath = c.Request().URL.Path + return nil + }) - app.GET("/hello/world", func(c echo.Context) error { - actualPath = c.Request().URL.Path - return nil + app.GET("/", func(c echo.Context) error { + actualPath = c.Request().URL.Path + return nil + }) + appInitialized = true + } }) - app.GET("/", func(c echo.Context) error { - actualPath = c.Request().URL.Path - return nil + Context("without prefix", func() { + It("should not modify path when no header is present", func() { + req := httptest.NewRequest("GET", "/hello/world", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") + }) + + It("should not modify root path when no header is present", func() { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/"), "rewritten path") + }) + + It("should not modify path when header does not match", func() { + req := httptest.NewRequest("GET", "/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") + }) }) - for _, tc := range []struct { - name string - path string - prefixHeader []string - expectStatus int - expectPath string - }{ - { - name: "without prefix and header", - path: "/hello/world", - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "without prefix and headers on root path", - path: "/", - expectStatus: 200, - expectPath: "/", - }, - { - name: "without prefix but header", - path: "/hello/world", - prefixHeader: []string{"/otherprefix/"}, - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "with prefix but non-matching header", - path: "/prefix/hello/world", - prefixHeader: []string{"/otherprefix/"}, - expectStatus: 404, - }, - { - name: "with prefix and matching header", - path: "/myprefix/hello/world", - prefixHeader: []string{"/myprefix/"}, - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "with prefix and 1st header matching", - path: "/myprefix/hello/world", - prefixHeader: []string{"/myprefix/", "/otherprefix/"}, - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "with prefix and 2nd header matching", - path: "/myprefix/hello/world", - prefixHeader: []string{"/otherprefix/", "/myprefix/"}, - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "with prefix and header not ending with slash", - path: "/myprefix/hello/world", - prefixHeader: []string{"/myprefix"}, - expectStatus: 200, - expectPath: "/hello/world", - }, - { - name: "with prefix and non-matching header not ending with slash", - path: "/myprefix-suffix/hello/world", - prefixHeader: []string{"/myprefix"}, - expectStatus: 404, - }, - { - name: "redirect when prefix does not end with a slash", - path: "/myprefix", - prefixHeader: []string{"/myprefix"}, - expectStatus: 302, - expectPath: "/myprefix/", - }, - } { - t.Run(tc.name, func(t *testing.T) { - actualPath = "" - req := httptest.NewRequest("GET", tc.path, nil) - if tc.prefixHeader != nil { - req.Header["X-Forwarded-Prefix"] = tc.prefixHeader - } + Context("with prefix", func() { + It("should return 404 when prefix does not match header", func() { + req := httptest.NewRequest("GET", "/prefix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(404), "response status code") + }) + + It("should strip matching prefix from path", func() { + req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) - require.Equal(t, tc.expectStatus, rec.Code, "response status code") + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") + }) + + It("should strip prefix when it matches the first header value", func() { + req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) - if tc.expectStatus == 200 { - require.Equal(t, tc.expectPath, actualPath, "rewritten path") - } else if tc.expectStatus == 302 { - require.Equal(t, tc.expectPath, rec.Header().Get("Location"), "redirect location") - } + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) - } -} + + It("should strip prefix when it matches the second header value", func() { + req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") + }) + + It("should strip prefix when header does not end with slash", func() { + req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(200), "response status code") + Expect(actualPath).To(Equal("/hello/world"), "rewritten path") + }) + + It("should return 404 when prefix does not match header without trailing slash", func() { + req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(404), "response status code") + }) + + It("should redirect when prefix does not end with a slash", func() { + req := httptest.NewRequest("GET", "/myprefix", nil) + req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + Expect(rec.Code).To(Equal(302), "response status code") + Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location") + }) + }) +}) diff --git a/core/http/render.go b/core/http/render.go index 14f9884af4ae..569c77987720 100644 --- a/core/http/render.go +++ b/core/http/render.go @@ -12,7 +12,7 @@ import ( "github.com/Masterminds/sprig/v3" "github.com/labstack/echo/v4" "github.com/microcosm-cc/bluemonday" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/russross/blackfriday" ) @@ -42,7 +42,7 @@ func notFoundHandler(c echo.Context) error { } else { // The client expects an HTML response return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{ - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), }) } } diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 9f44e2e6b27c..7b1c003ca021 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -5,7 +5,6 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/middleware" - httpUtils "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" @@ -29,7 +28,7 @@ func RegisterLocalAIRoutes(router *echo.Echo, router.GET("/import-model", func(c echo.Context) error { return c.Render(200, "views/model-editor", map[string]interface{}{ "Title": "LocalAI - Import Model", - "BaseURL": httpUtils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "Version": internal.PrintableVersion(), }) }) diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index d866f71800c6..03cb3c9b7c90 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -5,7 +5,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/endpoints/localai" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" @@ -26,7 +26,7 @@ func RegisterUIRoutes(app *echo.Echo, app.GET("/p2p/", func(c echo.Context) error { summary := map[string]interface{}{ "Title": "LocalAI - P2P dashboard", - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "Version": internal.PrintableVersion(), //"Nodes": p2p.GetAvailableNodes(""), //"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID), @@ -54,12 +54,12 @@ func RegisterUIRoutes(app *echo.Echo, if len(modelConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(302, utils.BaseURL(c)) + return c.Redirect(302, middleware.BaseURL(c)) } summary := map[string]interface{}{ "Title": "LocalAI - Talk", - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "Model": modelConfigs[0], @@ -76,7 +76,7 @@ func RegisterUIRoutes(app *echo.Echo, if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(302, utils.BaseURL(c)) + return c.Redirect(302, middleware.BaseURL(c)) } modelThatCanBeUsed := "" galleryConfigs := map[string]*gallery.ModelConfig{} @@ -105,7 +105,7 @@ func RegisterUIRoutes(app *echo.Echo, summary := map[string]interface{}{ "Title": title, - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsWithoutConfig": modelsWithoutConfig, "GalleryConfig": galleryConfigs, "ModelsConfig": modelConfigs, @@ -140,7 +140,7 @@ func RegisterUIRoutes(app *echo.Echo, summary := map[string]interface{}{ "Title": "LocalAI - Chat with " + modelName, - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "GalleryConfig": galleryConfigs, "ModelsWithoutConfig": modelsWithoutConfig, @@ -159,7 +159,7 @@ func RegisterUIRoutes(app *echo.Echo, summary := map[string]interface{}{ "Title": "LocalAI - Generate images with " + c.Param("model"), - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": c.Param("model"), @@ -176,7 +176,7 @@ func RegisterUIRoutes(app *echo.Echo, if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(302, utils.BaseURL(c)) + return c.Redirect(302, middleware.BaseURL(c)) } modelThatCanBeUsed := "" @@ -192,7 +192,7 @@ func RegisterUIRoutes(app *echo.Echo, summary := map[string]interface{}{ "Title": title, - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": modelThatCanBeUsed, @@ -209,7 +209,7 @@ func RegisterUIRoutes(app *echo.Echo, summary := map[string]interface{}{ "Title": "LocalAI - Generate images with " + c.Param("model"), - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": c.Param("model"), @@ -226,7 +226,7 @@ func RegisterUIRoutes(app *echo.Echo, if len(modelConfigs)+len(modelsWithoutConfig) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect(302, utils.BaseURL(c)) + return c.Redirect(302, middleware.BaseURL(c)) } modelThatCanBeUsed := "" @@ -241,7 +241,7 @@ func RegisterUIRoutes(app *echo.Echo, } summary := map[string]interface{}{ "Title": title, - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "ModelsConfig": modelConfigs, "ModelsWithoutConfig": modelsWithoutConfig, "Model": modelThatCanBeUsed, diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go index 52502cb613fa..8f0a31351236 100644 --- a/core/http/routes/ui_backend_gallery.go +++ b/core/http/routes/ui_backend_gallery.go @@ -3,7 +3,7 @@ package routes import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" ) @@ -13,7 +13,7 @@ func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationC app.GET("/browse/backends", func(c echo.Context) error { summary := map[string]interface{}{ "Title": "LocalAI - Backends", - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "Version": internal.PrintableVersion(), "Repositories": appConfig.BackendGalleries, } diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go index d96b6faa53da..cdcba24520c1 100644 --- a/core/http/routes/ui_gallery.go +++ b/core/http/routes/ui_gallery.go @@ -3,7 +3,7 @@ package routes import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" ) @@ -13,7 +13,7 @@ func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConf app.GET("/browse/", func(c echo.Context) error { summary := map[string]interface{}{ "Title": "LocalAI - Models", - "BaseURL": utils.BaseURL(c), + "BaseURL": middleware.BaseURL(c), "Version": internal.PrintableVersion(), "Repositories": appConfig.Galleries, } diff --git a/core/http/utils/baseurl_test.go b/core/http/utils/baseurl_test.go deleted file mode 100644 index d1d2070b586f..000000000000 --- a/core/http/utils/baseurl_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package utils - -import ( - "net/http/httptest" - "testing" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/require" -) - -func TestBaseURL(t *testing.T) { - for _, tc := range []struct { - name string - prefix string - expectURL string - }{ - { - name: "without prefix", - prefix: "/", - expectURL: "http://example.com/", - }, - { - name: "with prefix", - prefix: "/myprefix/", - expectURL: "http://example.com/myprefix/", - }, - } { - t.Run(tc.name, func(t *testing.T) { - app := echo.New() - actualURL := "" - - app.GET(tc.prefix+"hello/world", func(c echo.Context) error { - if tc.prefix != "/" { - c.Request().URL.Path = "/hello/world" - } - actualURL = BaseURL(c) - return nil - }) - - req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil) - rec := httptest.NewRecorder() - app.ServeHTTP(rec, req) - - require.Equal(t, 200, rec.Code, "response status code") - require.Equal(t, tc.expectURL, actualURL, "base URL") - }) - } -}