From cebfd830a440d2bfb6dc7e53680d5875863ee93e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 18 Mar 2026 16:49:48 +0000 Subject: [PATCH 01/13] feat(ui): add users and authentication support Signed-off-by: Ettore Di Giacinto --- Dockerfile | 2 +- core/application/application.go | 21 + core/application/startup.go | 27 + core/cli/run.go | 35 + core/config/application_config.go | 60 + core/http/app.go | 44 +- core/http/auth/apikeys.go | 100 ++ core/http/auth/apikeys_test.go | 162 +++ core/http/auth/auth_suite_test.go | 15 + core/http/auth/db.go | 49 + core/http/auth/db_nosqlite.go | 13 + core/http/auth/db_sqlite.go | 12 + core/http/auth/db_test.go | 53 + core/http/auth/helpers_test.go | 155 +++ core/http/auth/middleware.go | 335 ++++++ core/http/auth/middleware_test.go | 305 +++++ core/http/auth/models.go | 103 ++ core/http/auth/oauth.go | 331 ++++++ core/http/auth/password.go | 14 + core/http/auth/permissions.go | 93 ++ core/http/auth/roles.go | 90 ++ core/http/auth/roles_test.go | 84 ++ core/http/auth/session.go | 87 ++ core/http/auth/session_test.go | 123 ++ core/http/auth/usage.go | 151 +++ core/http/auth/usage_test.go | 161 +++ .../endpoints/localai/agent_collections.go | 71 +- core/http/endpoints/localai/agent_jobs.go | 231 ++-- core/http/endpoints/localai/agent_skills.go | 88 +- core/http/endpoints/localai/agents.go | 85 +- core/http/middleware/trace.go | 10 +- core/http/middleware/usage.go | 148 +++ core/http/react-ui/src/App.css | 239 ++++ .../src/components/LoadingSpinner.jsx | 16 +- .../react-ui/src/components/RequireAdmin.jsx | 10 + .../react-ui/src/components/RequireAuth.jsx | 9 + .../src/components/RequireFeature.jsx | 10 + core/http/react-ui/src/components/Sidebar.jsx | 86 +- core/http/react-ui/src/components/Toast.jsx | 12 +- .../src/components/UserGroupSection.jsx | 156 +++ .../http/react-ui/src/context/AuthContext.jsx | 75 ++ core/http/react-ui/src/hooks/useUserMap.js | 29 + core/http/react-ui/src/main.jsx | 5 +- core/http/react-ui/src/pages/Account.jsx | 459 ++++++++ core/http/react-ui/src/pages/AgentJobs.jsx | 245 ++-- core/http/react-ui/src/pages/Agents.jsx | 45 +- core/http/react-ui/src/pages/Collections.jsx | 45 +- core/http/react-ui/src/pages/Home.jsx | 42 +- core/http/react-ui/src/pages/Login.jsx | 358 +++++- core/http/react-ui/src/pages/NotFound.jsx | 2 +- core/http/react-ui/src/pages/Skills.jsx | 52 +- core/http/react-ui/src/pages/Usage.jsx | 501 ++++++++ core/http/react-ui/src/pages/Users.jsx | 519 +++++++++ core/http/react-ui/src/router.jsx | 71 +- core/http/react-ui/src/utils/api.js | 70 +- core/http/routes/agents.go | 128 ++- core/http/routes/anthropic.go | 1 + core/http/routes/auth.go | 731 ++++++++++++ core/http/routes/auth_test.go | 807 +++++++++++++ core/http/routes/localai.go | 80 +- core/http/routes/openai.go | 5 + core/http/routes/openresponses.go | 1 + core/http/routes/ui.go | 17 +- core/http/routes/ui_api.go | 54 +- core/http/routes/ui_api_backends_test.go | 4 +- core/services/agent_jobs.go | 21 +- core/services/agent_pool.go | 1015 +++++++++++++++++ core/services/user_services.go | 183 +++ core/services/user_storage.go | 142 +++ docs/content/features/authentication.md | 292 +++++ docs/content/features/runtime-settings.md | 2 + docs/content/getting-started/quickstart.md | 5 +- docs/content/reference/api-errors.md | 11 +- docs/content/reference/cli-reference.md | 14 + go.mod | 20 +- go.sum | 40 +- 76 files changed, 9343 insertions(+), 544 deletions(-) create mode 100644 core/http/auth/apikeys.go create mode 100644 core/http/auth/apikeys_test.go create mode 100644 core/http/auth/auth_suite_test.go create mode 100644 core/http/auth/db.go create mode 100644 core/http/auth/db_nosqlite.go create mode 100644 core/http/auth/db_sqlite.go create mode 100644 core/http/auth/db_test.go create mode 100644 core/http/auth/helpers_test.go create mode 100644 core/http/auth/middleware.go create mode 100644 core/http/auth/middleware_test.go create mode 100644 core/http/auth/models.go create mode 100644 core/http/auth/oauth.go create mode 100644 core/http/auth/password.go create mode 100644 core/http/auth/permissions.go create mode 100644 core/http/auth/roles.go create mode 100644 core/http/auth/roles_test.go create mode 100644 core/http/auth/session.go create mode 100644 core/http/auth/session_test.go create mode 100644 core/http/auth/usage.go create mode 100644 core/http/auth/usage_test.go create mode 100644 core/http/middleware/usage.go create mode 100644 core/http/react-ui/src/components/RequireAdmin.jsx create mode 100644 core/http/react-ui/src/components/RequireAuth.jsx create mode 100644 core/http/react-ui/src/components/RequireFeature.jsx create mode 100644 core/http/react-ui/src/components/UserGroupSection.jsx create mode 100644 core/http/react-ui/src/context/AuthContext.jsx create mode 100644 core/http/react-ui/src/hooks/useUserMap.js create mode 100644 core/http/react-ui/src/pages/Account.jsx create mode 100644 core/http/react-ui/src/pages/Usage.jsx create mode 100644 core/http/react-ui/src/pages/Users.jsx create mode 100644 core/http/routes/auth.go create mode 100644 core/http/routes/auth_test.go create mode 100644 core/services/user_services.go create mode 100644 core/services/user_storage.go create mode 100644 docs/content/features/authentication.md diff --git a/Dockerfile b/Dockerfile index 17c783ec3ae7..4318398193b5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -256,7 +256,7 @@ RUN apt-get update && \ FROM build-requirements AS builder-base -ARG GO_TAGS="" +ARG GO_TAGS="auth" ARG GRPC_BACKENDS ARG MAKEFLAGS ARG LD_FLAGS="-s -w" diff --git a/core/application/application.go b/core/application/application.go index f1adc71449ed..c636be38f137 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" + "gorm.io/gorm" ) type Application struct { @@ -22,6 +23,7 @@ type Application struct { galleryService *services.GalleryService agentJobService *services.AgentJobService agentPoolService atomic.Pointer[services.AgentPoolService] + authDB *gorm.DB watchdogMutex sync.Mutex watchdogStop chan bool p2pMutex sync.Mutex @@ -74,6 +76,11 @@ func (a *Application) AgentPoolService() *services.AgentPoolService { return a.agentPoolService.Load() } +// AuthDB returns the auth database connection, or nil if auth is not enabled. +func (a *Application) AuthDB() *gorm.DB { + return a.authDB +} + // StartupConfig returns the original startup configuration (from env vars, before file loading) func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig @@ -118,9 +125,23 @@ func (a *Application) StartAgentPool() { xlog.Error("Failed to create agent pool service", "error", err) return } + if a.authDB != nil { + aps.SetAuthDB(a.authDB) + } if err := aps.Start(a.applicationConfig.Context); err != nil { xlog.Error("Failed to start agent pool", "error", err) return } + + // Wire per-user scoped services so collections, skills, and jobs are isolated per user + usm := services.NewUserServicesManager( + aps.UserStorage(), + a.applicationConfig, + a.modelLoader, + a.backendLoader, + a.templatesEvaluator, + ) + aps.SetUserServicesManager(usm) + a.agentPoolService.Store(aps) } diff --git a/core/application/startup.go b/core/application/startup.go index 0d69763b486e..76fe7313d8c2 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" @@ -81,6 +82,32 @@ func New(opts ...config.AppOption) (*Application, error) { } } + // Initialize auth database if auth is enabled + if options.Auth.Enabled { + authDB, err := auth.InitDB(options.Auth.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to initialize auth database: %w", err) + } + application.authDB = authDB + xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL) + + // Start session cleanup goroutine + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for { + select { + case <-options.Context.Done(): + return + case <-ticker.C: + if err := auth.CleanExpiredSessions(authDB); err != nil { + xlog.Error("failed to clean expired sessions", "error", err) + } + } + } + }() + } + if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { xlog.Error("error installing models", "error", err) } diff --git a/core/cli/run.go b/core/cli/run.go index 163797ac08aa..ead139a5723a 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -121,6 +121,15 @@ type RunCMD struct { AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` + // Authentication + AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` + AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"` + GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"` + GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"` + AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"` + AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"` + AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default) or 'approval'" group:"auth"` + Version bool } @@ -311,6 +320,32 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithAgentHubURL(r.AgentHubURL)) } + // Authentication + authEnabled := r.AuthEnabled || r.GitHubClientID != "" + if authEnabled { + opts = append(opts, config.WithAuthEnabled(true)) + + dbURL := r.AuthDatabaseURL + if dbURL == "" { + dbURL = filepath.Join(r.DataPath, "database.db") + } + opts = append(opts, config.WithAuthDatabaseURL(dbURL)) + + if r.GitHubClientID != "" { + opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID)) + opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret)) + } + if r.AuthBaseURL != "" { + opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL)) + } + if r.AuthAdminEmail != "" { + opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail)) + } + if r.AuthRegistrationMode != "" { + opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode)) + } + } + if idleWatchDog || busyWatchDog { opts = append(opts, config.EnableWatchDog) if idleWatchDog { diff --git a/core/config/application_config.go b/core/config/application_config.go index 74c3511a6594..357a6e33c93b 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -96,6 +96,20 @@ type ApplicationConfig struct { // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig + + // Authentication & Authorization + Auth AuthConfig +} + +// AuthConfig holds configuration for user authentication and authorization. +type AuthConfig struct { + Enabled bool + DatabaseURL string // "postgres://..." or file path for SQLite + GitHubClientID string + GitHubClientSecret string + BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") + AdminEmail string // auto-promote to admin on login + RegistrationMode string // "open" (default), "approval", "invite" } // AgentPoolConfig holds configuration for the LocalAGI agent pool integration. @@ -150,6 +164,8 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig { "/favicon.svg", "/readyz", "/healthz", + "/api/auth/", + "/assets/", }, } for _, oo := range o { @@ -711,6 +727,50 @@ func WithAgentHubURL(url string) AppOption { } } +// Auth options + +func WithAuthEnabled(enabled bool) AppOption { + return func(o *ApplicationConfig) { + o.Auth.Enabled = enabled + } +} + +func WithAuthDatabaseURL(url string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.DatabaseURL = url + } +} + +func WithAuthGitHubClientID(clientID string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.GitHubClientID = clientID + } +} + +func WithAuthGitHubClientSecret(clientSecret string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.GitHubClientSecret = clientSecret + } +} + +func WithAuthBaseURL(baseURL string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.BaseURL = baseURL + } +} + +func WithAuthAdminEmail(email string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.AdminEmail = email + } +} + +func WithAuthRegistrationMode(mode string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.RegistrationMode = mode + } +} + // ToConfigLoaderOptions returns a slice of ConfigLoader Option. // Some options defined at the application level are going to be passed as defaults for // all the configuration for the models. diff --git a/core/http/app.go b/core/http/app.go index 138515fb7edc..2403cfe6ceb3 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -14,6 +14,7 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/endpoints/localai" httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" @@ -170,11 +171,9 @@ func API(application *application.Application) (*echo.Echo, error) { // Health Checks should always be exempt from auth, so register these first routes.HealthRoutes(e) - // Get key auth middleware - keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig()) - if err != nil { - return nil, fmt.Errorf("failed to create key auth config: %w", err) - } + // Build auth middleware: use the new auth.Middleware when auth is enabled or + // as a unified replacement for the legacy key-auth middleware. + authMiddleware := auth.Middleware(application.AuthDB(), application.ApplicationConfig()) // Favicon handler e.GET("/favicon.svg", func(c echo.Context) error { @@ -209,8 +208,14 @@ func API(application *application.Application) (*echo.Echo, error) { e.Static("/generated-videos", videoPath) } - // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration - e.Use(keyAuthMiddleware) + // Initialize usage recording when auth DB is available + if application.AuthDB() != nil { + httpMiddleware.InitUsageRecorder(application.AuthDB()) + } + + // Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is + // the role of the exempt-path logic inside the middleware. + e.Use(authMiddleware) // CORS middleware if application.ApplicationConfig().CORS { @@ -229,8 +234,25 @@ func API(application *application.Application) (*echo.Echo, error) { e.Use(middleware.CSRF()) } + // Admin middleware: enforces admin role when auth is enabled, no-op otherwise + var adminMiddleware echo.MiddlewareFunc + if application.AuthDB() != nil { + adminMiddleware = auth.RequireAdmin() + } else { + adminMiddleware = auth.NoopMiddleware() + } + + // Feature middlewares: per-feature access control + agentsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureAgents) + skillsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureSkills) + collectionsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureCollections) + mcpJobsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCPJobs) + requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + // Register auth routes (login, callback, API keys, user management) + routes.RegisterAuthRoutes(e, application) + routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Create opcache for tracking UI operations (used by both UI and LocalAI routes) @@ -239,14 +261,14 @@ func API(application *application.Application) (*echo.Echo, error) { opcache = services.NewOpCache(application.GalleryService()) } - routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application) - routes.RegisterAgentPoolRoutes(e, application) + routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw) + routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { - routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application) - routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) + routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware) + routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware) // Serve React SPA from / with SPA fallback via 404 handler reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist") diff --git a/core/http/auth/apikeys.go b/core/http/auth/apikeys.go new file mode 100644 index 000000000000..f21878247818 --- /dev/null +++ b/core/http/auth/apikeys.go @@ -0,0 +1,100 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +const ( + apiKeyPrefix = "lai-" + apiKeyRandBytes = 32 // 32 bytes = 64 hex chars + keyPrefixLen = 8 // display prefix length (from the random part) +) + +// GenerateAPIKey generates a new API key. Returns the plaintext key, +// its SHA-256 hash, and a display prefix. +func GenerateAPIKey() (plaintext, hash, prefix string, err error) { + b := make([]byte, apiKeyRandBytes) + if _, err := rand.Read(b); err != nil { + return "", "", "", fmt.Errorf("failed to generate API key: %w", err) + } + + randHex := hex.EncodeToString(b) + plaintext = apiKeyPrefix + randHex + hash = HashAPIKey(plaintext) + prefix = plaintext[:len(apiKeyPrefix)+keyPrefixLen] + + return plaintext, hash, prefix, nil +} + +// HashAPIKey returns the SHA-256 hex digest of the given plaintext key. +func HashAPIKey(plaintext string) string { + h := sha256.Sum256([]byte(plaintext)) + return hex.EncodeToString(h[:]) +} + +// CreateAPIKey generates and stores a new API key for the given user. +// Returns the plaintext key (shown once) and the database record. +func CreateAPIKey(db *gorm.DB, userID, name, role string) (string, *UserAPIKey, error) { + plaintext, hash, prefix, err := GenerateAPIKey() + if err != nil { + return "", nil, err + } + + record := &UserAPIKey{ + ID: uuid.New().String(), + UserID: userID, + Name: name, + KeyHash: hash, + KeyPrefix: prefix, + Role: role, + } + + if err := db.Create(record).Error; err != nil { + return "", nil, fmt.Errorf("failed to store API key: %w", err) + } + + return plaintext, record, nil +} + +// ValidateAPIKey looks up an API key by hashing the plaintext and searching +// the database. Returns the key record if found, or an error. +// Updates LastUsed on successful validation. +func ValidateAPIKey(db *gorm.DB, plaintext string) (*UserAPIKey, error) { + hash := HashAPIKey(plaintext) + + var key UserAPIKey + if err := db.Preload("User").Where("key_hash = ?", hash).First(&key).Error; err != nil { + return nil, fmt.Errorf("invalid API key") + } + + // Update LastUsed + now := time.Now() + db.Model(&key).Update("last_used", now) + + return &key, nil +} + +// ListAPIKeys returns all API keys for the given user (without plaintext). +func ListAPIKeys(db *gorm.DB, userID string) ([]UserAPIKey, error) { + var keys []UserAPIKey + if err := db.Where("user_id = ?", userID).Order("created_at DESC").Find(&keys).Error; err != nil { + return nil, err + } + return keys, nil +} + +// RevokeAPIKey deletes an API key. Only the owner can revoke their own key. +func RevokeAPIKey(db *gorm.DB, keyID, userID string) error { + result := db.Where("id = ? AND user_id = ?", keyID, userID).Delete(&UserAPIKey{}) + if result.RowsAffected == 0 { + return fmt.Errorf("API key not found or not owned by user") + } + return result.Error +} diff --git a/core/http/auth/apikeys_test.go b/core/http/auth/apikeys_test.go new file mode 100644 index 000000000000..3d1a72dfb216 --- /dev/null +++ b/core/http/auth/apikeys_test.go @@ -0,0 +1,162 @@ +//go:build auth + +package auth_test + +import ( + "strings" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("API Keys", func() { + var ( + db *gorm.DB + user *auth.User + ) + + BeforeEach(func() { + db = testDB() + user = createTestUser(db, "apikey@example.com", auth.RoleUser, "github") + }) + + Describe("GenerateAPIKey", func() { + It("returns key with 'lai-' prefix", func() { + plaintext, _, _, err := auth.GenerateAPIKey() + Expect(err).ToNot(HaveOccurred()) + Expect(plaintext).To(HavePrefix("lai-")) + }) + + It("returns consistent hash for same plaintext", func() { + plaintext, hash, _, err := auth.GenerateAPIKey() + Expect(err).ToNot(HaveOccurred()) + Expect(auth.HashAPIKey(plaintext)).To(Equal(hash)) + }) + + It("returns prefix for display", func() { + _, _, prefix, err := auth.GenerateAPIKey() + Expect(err).ToNot(HaveOccurred()) + Expect(prefix).To(HavePrefix("lai-")) + Expect(len(prefix)).To(Equal(12)) // "lai-" + 8 chars + }) + + It("generates unique keys", func() { + key1, _, _, _ := auth.GenerateAPIKey() + key2, _, _, _ := auth.GenerateAPIKey() + Expect(key1).ToNot(Equal(key2)) + }) + }) + + Describe("CreateAPIKey", func() { + It("stores hashed key in DB", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + Expect(plaintext).To(HavePrefix("lai-")) + Expect(record.KeyHash).To(Equal(auth.HashAPIKey(plaintext))) + }) + + It("does not store plaintext in DB", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + var keys []auth.UserAPIKey + db.Find(&keys) + for _, k := range keys { + Expect(k.KeyHash).ToNot(Equal(plaintext)) + Expect(strings.Contains(k.KeyHash, "lai-")).To(BeFalse()) + } + }) + + It("inherits role from parameter", func() { + _, record, err := auth.CreateAPIKey(db, user.ID, "admin key", auth.RoleAdmin) + Expect(err).ToNot(HaveOccurred()) + Expect(record.Role).To(Equal(auth.RoleAdmin)) + }) + }) + + Describe("ValidateAPIKey", func() { + It("returns UserAPIKey for valid key", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "valid key", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + found, err := auth.ValidateAPIKey(db, plaintext) + Expect(err).ToNot(HaveOccurred()) + Expect(found).ToNot(BeNil()) + Expect(found.UserID).To(Equal(user.ID)) + }) + + It("returns error for invalid key", func() { + _, err := auth.ValidateAPIKey(db, "lai-invalidkey12345678901234567890") + Expect(err).To(HaveOccurred()) + }) + + It("updates LastUsed timestamp", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "used key", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + Expect(record.LastUsed).To(BeNil()) + + _, err = auth.ValidateAPIKey(db, plaintext) + Expect(err).ToNot(HaveOccurred()) + + var updated auth.UserAPIKey + db.First(&updated, "id = ?", record.ID) + Expect(updated.LastUsed).ToNot(BeNil()) + }) + + It("loads associated user", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "with user", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + found, err := auth.ValidateAPIKey(db, plaintext) + Expect(err).ToNot(HaveOccurred()) + Expect(found.User.ID).To(Equal(user.ID)) + Expect(found.User.Email).To(Equal("apikey@example.com")) + }) + }) + + Describe("ListAPIKeys", func() { + It("returns all keys for the user", func() { + auth.CreateAPIKey(db, user.ID, "key1", auth.RoleUser) + auth.CreateAPIKey(db, user.ID, "key2", auth.RoleUser) + + keys, err := auth.ListAPIKeys(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(keys).To(HaveLen(2)) + }) + + It("does not return other users' keys", func() { + other := createTestUser(db, "other@example.com", auth.RoleUser, "github") + auth.CreateAPIKey(db, user.ID, "my key", auth.RoleUser) + auth.CreateAPIKey(db, other.ID, "other key", auth.RoleUser) + + keys, err := auth.ListAPIKeys(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(keys).To(HaveLen(1)) + Expect(keys[0].Name).To(Equal("my key")) + }) + }) + + Describe("RevokeAPIKey", func() { + It("deletes the key record", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RevokeAPIKey(db, record.ID, user.ID) + Expect(err).ToNot(HaveOccurred()) + + _, err = auth.ValidateAPIKey(db, plaintext) + Expect(err).To(HaveOccurred()) + }) + + It("only allows owner to revoke their own key", func() { + _, record, err := auth.CreateAPIKey(db, user.ID, "mine", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + other := createTestUser(db, "attacker@example.com", auth.RoleUser, "github") + err = auth.RevokeAPIKey(db, record.ID, other.ID) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/core/http/auth/auth_suite_test.go b/core/http/auth/auth_suite_test.go new file mode 100644 index 000000000000..c32c18ed19df --- /dev/null +++ b/core/http/auth/auth_suite_test.go @@ -0,0 +1,15 @@ +//go:build auth + +package auth_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAuth(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Auth Suite") +} diff --git a/core/http/auth/db.go b/core/http/auth/db.go new file mode 100644 index 000000000000..f3b2f0d3866f --- /dev/null +++ b/core/http/auth/db.go @@ -0,0 +1,49 @@ +package auth + +import ( + "fmt" + "strings" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// InitDB initializes the auth database. If databaseURL starts with "postgres://" +// or "postgresql://", it connects to PostgreSQL; otherwise it treats the value +// as a SQLite file path (use ":memory:" for in-memory). +// SQLite support requires building with the "auth" build tag (CGO). +func InitDB(databaseURL string) (*gorm.DB, error) { + var dialector gorm.Dialector + + if strings.HasPrefix(databaseURL, "postgres://") || strings.HasPrefix(databaseURL, "postgresql://") { + dialector = postgres.Open(databaseURL) + } else { + d, err := openSQLiteDialector(databaseURL) + if err != nil { + return nil, err + } + dialector = d + } + + db, err := gorm.Open(dialector, &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + return nil, fmt.Errorf("failed to open auth database: %w", err) + } + + if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil { + return nil, fmt.Errorf("failed to migrate auth tables: %w", err) + } + + // Create composite index on users(provider, subject) for fast OAuth lookups + if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil { + // Ignore error on postgres if index already exists + if !strings.Contains(err.Error(), "already exists") { + return nil, fmt.Errorf("failed to create composite index: %w", err) + } + } + + return db, nil +} diff --git a/core/http/auth/db_nosqlite.go b/core/http/auth/db_nosqlite.go new file mode 100644 index 000000000000..73233bf4543c --- /dev/null +++ b/core/http/auth/db_nosqlite.go @@ -0,0 +1,13 @@ +//go:build !auth + +package auth + +import ( + "fmt" + + "gorm.io/gorm" +) + +func openSQLiteDialector(path string) (gorm.Dialector, error) { + return nil, fmt.Errorf("SQLite auth database requires building with -tags auth (CGO); use DATABASE_URL with PostgreSQL instead") +} diff --git a/core/http/auth/db_sqlite.go b/core/http/auth/db_sqlite.go new file mode 100644 index 000000000000..5c13ecf05cc4 --- /dev/null +++ b/core/http/auth/db_sqlite.go @@ -0,0 +1,12 @@ +//go:build auth + +package auth + +import ( + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func openSQLiteDialector(path string) (gorm.Dialector, error) { + return sqlite.Open(path), nil +} diff --git a/core/http/auth/db_test.go b/core/http/auth/db_test.go new file mode 100644 index 000000000000..234921c8adb0 --- /dev/null +++ b/core/http/auth/db_test.go @@ -0,0 +1,53 @@ +//go:build auth + +package auth_test + +import ( + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("InitDB", func() { + Context("SQLite", func() { + It("creates all tables with in-memory SQLite", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + Expect(db).ToNot(BeNil()) + + // Verify tables exist + Expect(db.Migrator().HasTable(&auth.User{})).To(BeTrue()) + Expect(db.Migrator().HasTable(&auth.Session{})).To(BeTrue()) + Expect(db.Migrator().HasTable(&auth.UserAPIKey{})).To(BeTrue()) + }) + + It("is idempotent - running twice does not error", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + + // Re-migrate on same DB should succeed + err = db.AutoMigrate(&auth.User{}, &auth.Session{}, &auth.UserAPIKey{}) + Expect(err).ToNot(HaveOccurred()) + }) + + It("creates composite index on users(provider, subject)", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + + // Insert a user to verify the index doesn't prevent normal operations + user := &auth.User{ + ID: "test-1", + Provider: "github", + Subject: "12345", + Role: "admin", + Status: "active", + } + Expect(db.Create(user).Error).ToNot(HaveOccurred()) + + // Query using the indexed columns should work + var found auth.User + Expect(db.Where("provider = ? AND subject = ?", "github", "12345").First(&found).Error).ToNot(HaveOccurred()) + Expect(found.ID).To(Equal("test-1")) + }) + }) +}) diff --git a/core/http/auth/helpers_test.go b/core/http/auth/helpers_test.go new file mode 100644 index 000000000000..a55342a269c2 --- /dev/null +++ b/core/http/auth/helpers_test.go @@ -0,0 +1,155 @@ +//go:build auth + +package auth_test + +import ( + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +// testDB creates an in-memory SQLite GORM instance with auto-migration. +func testDB() *gorm.DB { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + return db +} + +// createTestUser inserts a user directly into the DB for test setup. +func createTestUser(db *gorm.DB, email, role, provider string) *auth.User { + user := &auth.User{ + ID: generateTestID(), + Email: email, + Name: "Test User", + Provider: provider, + Subject: generateTestID(), + Role: role, + Status: "active", + } + err := db.Create(user).Error + Expect(err).ToNot(HaveOccurred()) + return user +} + +// createTestSession creates a session for a user, returns session ID. +func createTestSession(db *gorm.DB, userID string) string { + sessionID, err := auth.CreateSession(db, userID) + Expect(err).ToNot(HaveOccurred()) + return sessionID +} + +var testIDCounter int + +func generateTestID() string { + testIDCounter++ + return "test-id-" + string(rune('a'+testIDCounter)) +} + +// ok is a simple handler that returns 200 OK. +func ok(c echo.Context) error { + return c.String(http.StatusOK, "ok") +} + +// newAuthTestApp creates a minimal Echo app with the new auth middleware. +func newAuthTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { + e := echo.New() + e.Use(auth.Middleware(db, appConfig)) + + // API routes (require auth) + e.GET("/v1/models", ok) + e.POST("/v1/chat/completions", ok) + e.GET("/api/settings", ok) + e.POST("/api/settings", ok) + + // Auth routes (exempt) + e.GET("/api/auth/status", ok) + e.GET("/api/auth/github/login", ok) + + // Static routes + e.GET("/app", ok) + e.GET("/app/*", ok) + + return e +} + +// newAdminTestApp creates an Echo app with admin-protected routes. +func newAdminTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { + e := echo.New() + e.Use(auth.Middleware(db, appConfig)) + + // Regular routes + e.GET("/v1/models", ok) + e.POST("/v1/chat/completions", ok) + + // Admin-only routes + adminMw := auth.RequireAdmin() + e.POST("/api/settings", ok, adminMw) + e.POST("/models/apply", ok, adminMw) + e.POST("/backends/apply", ok, adminMw) + e.GET("/api/agents", ok, adminMw) + + // Trace/log endpoints (admin only) + e.GET("/api/traces", ok, adminMw) + e.POST("/api/traces/clear", ok, adminMw) + e.GET("/api/backend-logs", ok, adminMw) + e.GET("/api/backend-logs/:modelId", ok, adminMw) + + // Gallery/management reads (admin only) + e.GET("/api/operations", ok, adminMw) + e.GET("/api/models", ok, adminMw) + e.GET("/api/backends", ok, adminMw) + e.GET("/api/resources", ok, adminMw) + e.GET("/api/p2p/workers", ok, adminMw) + + // Agent task/job routes (admin only) + e.POST("/api/agent/tasks", ok, adminMw) + e.GET("/api/agent/tasks", ok, adminMw) + e.GET("/api/agent/jobs", ok, adminMw) + + // System info (admin only) + e.GET("/system", ok, adminMw) + e.GET("/backend/monitor", ok, adminMw) + + return e +} + +// doRequest performs an HTTP request against the given Echo app and returns the recorder. +func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder { + req := httptest.NewRequest(method, path, nil) + req.Header.Set("Content-Type", "application/json") + for _, opt := range opts { + opt(req) + } + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec +} + +func withBearerToken(token string) func(*http.Request) { + return func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+token) + } +} + +func withXApiKey(key string) func(*http.Request) { + return func(req *http.Request) { + req.Header.Set("x-api-key", key) + } +} + +func withSessionCookie(sessionID string) func(*http.Request) { + return func(req *http.Request) { + req.AddCookie(&http.Cookie{Name: "session", Value: sessionID}) + } +} + +func withTokenCookie(token string) func(*http.Request) { + return func(req *http.Request) { + req.AddCookie(&http.Cookie{Name: "token", Value: token}) + } +} diff --git a/core/http/auth/middleware.go b/core/http/auth/middleware.go new file mode 100644 index 000000000000..cb1b414568c8 --- /dev/null +++ b/core/http/auth/middleware.go @@ -0,0 +1,335 @@ +package auth + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "gorm.io/gorm" +) + +const ( + contextKeyUser = "auth_user" + contextKeyRole = "auth_role" +) + +// Middleware returns an Echo middleware that handles authentication. +// +// Resolution order: +// 1. If auth not enabled AND no legacy API keys → pass through +// 2. Skip auth for exempt paths (PathWithoutAuth + /api/auth/) +// 3. If auth enabled (db != nil): +// a. Try "session" cookie → DB lookup +// b. Try Authorization: Bearer → session ID, then user API key +// c. Try x-api-key / xi-api-key → user API key +// d. Try "token" cookie → legacy API key check +// e. Check all extracted keys against legacy ApiKeys → synthetic admin +// 4. If auth not enabled → delegate to legacy API key validation +// 5. If no auth found for /api/ or /v1/ paths → 401 +// 6. Otherwise pass through (static assets, UI pages, etc.) +func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + authEnabled := db != nil + hasLegacyKeys := len(appConfig.ApiKeys) > 0 + + // 1. No auth at all + if !authEnabled && !hasLegacyKeys { + return next(c) + } + + path := c.Request().URL.Path + exempt := isExemptPath(path, appConfig) + authenticated := false + + // 2. Try to authenticate (populates user in context if possible) + if authEnabled { + user := tryAuthenticate(c, db, appConfig) + if user != nil { + c.Set(contextKeyUser, user) + c.Set(contextKeyRole, user.Role) + authenticated = true + } + } + + // 3. Legacy API key validation (works whether auth is enabled or not) + if !authenticated && hasLegacyKeys { + key := extractKey(c) + if key != "" && isValidLegacyKey(key, appConfig) { + syntheticUser := &User{ + ID: "legacy-api-key", + Name: "API Key User", + Role: RoleAdmin, + } + c.Set(contextKeyUser, syntheticUser) + c.Set(contextKeyRole, RoleAdmin) + authenticated = true + } + } + + // 4. If authenticated or exempt path, proceed + if authenticated || exempt { + return next(c) + } + + // 5. Require auth for API paths + if isAPIPath(path) { + return authError(c, appConfig) + } + + // 6. Pass through for non-API paths when auth is DB-based + // (the React UI will redirect to login as needed) + if authEnabled && !hasLegacyKeys { + return next(c) + } + + // 7. Legacy behavior: if API keys are set, all paths require auth + if hasLegacyKeys { + // Check GET exemptions + if appConfig.DisableApiKeyRequirementForHttpGet && c.Request().Method == http.MethodGet { + for _, rx := range appConfig.HttpGetExemptedEndpoints { + if rx.MatchString(c.Path()) { + return next(c) + } + } + } + return authError(c, appConfig) + } + + return next(c) + } + } +} + +// RequireAdmin returns middleware that checks the user has admin role. +func RequireAdmin() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user := GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Authentication required", + Code: http.StatusUnauthorized, + Type: "authentication_error", + }, + }) + } + if user.Role != RoleAdmin { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Admin access required", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + return next(c) + } + } +} + +// NoopMiddleware returns a middleware that does nothing (pass-through). +// Used when auth is disabled to satisfy route registration that expects +// an admin middleware parameter. +func NoopMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return next + } +} + +// RequireFeature returns middleware that checks the user has access to the given feature. +// If no auth DB is provided, it passes through (backward compat). +// Admins always pass. Regular users must have the feature enabled in their permissions. +func RequireFeature(db *gorm.DB, feature string) echo.MiddlewareFunc { + if db == nil { + return NoopMiddleware() + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user := GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Authentication required", + Code: http.StatusUnauthorized, + Type: "authentication_error", + }, + }) + } + if user.Role == RoleAdmin { + return next(c) + } + if !HasFeatureAccess(db, user, feature) { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + return next(c) + } + } +} + +// GetUser returns the authenticated user from the echo context, or nil. +func GetUser(c echo.Context) *User { + u, ok := c.Get(contextKeyUser).(*User) + if !ok { + return nil + } + return u +} + +// GetUserRole returns the role of the authenticated user, or empty string. +func GetUserRole(c echo.Context) string { + role, _ := c.Get(contextKeyRole).(string) + return role +} + +// tryAuthenticate attempts to authenticate the request using the database. +func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User { + // a. Session cookie + if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" { + if user := ValidateSession(db, cookie.Value); user != nil { + return user + } + } + + // b. Authorization: Bearer token + auth := c.Request().Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token := strings.TrimPrefix(auth, "Bearer ") + + // Try as session ID first + if user := ValidateSession(db, token); user != nil { + return user + } + + // Try as user API key + if key, err := ValidateAPIKey(db, token); err == nil { + return &key.User + } + } + + // c. x-api-key / xi-api-key headers + for _, header := range []string{"x-api-key", "xi-api-key"} { + if key := c.Request().Header.Get(header); key != "" { + if apiKey, err := ValidateAPIKey(db, key); err == nil { + return &apiKey.User + } + } + } + + // d. token cookie (legacy) + if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { + // Try as user API key + if key, err := ValidateAPIKey(db, cookie.Value); err == nil { + return &key.User + } + } + + return nil +} + +// extractKey extracts an API key from the request (all sources). +func extractKey(c echo.Context) string { + // Authorization header + auth := c.Request().Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + if auth != "" { + return auth + } + + // x-api-key + if key := c.Request().Header.Get("x-api-key"); key != "" { + return key + } + + // xi-api-key + if key := c.Request().Header.Get("xi-api-key"); key != "" { + return key + } + + // token cookie + if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { + return cookie.Value + } + + return "" +} + +// isValidLegacyKey checks if the key matches any configured API key. +func isValidLegacyKey(key string, appConfig *config.ApplicationConfig) bool { + for _, validKey := range appConfig.ApiKeys { + if key == validKey { + return true + } + } + return false +} + +// isExemptPath returns true if the path should skip authentication. +func isExemptPath(path string, appConfig *config.ApplicationConfig) bool { + // Auth endpoints are always public + if strings.HasPrefix(path, "/api/auth/") { + return true + } + + // Check configured exempt paths + for _, p := range appConfig.PathWithoutAuth { + if strings.HasPrefix(path, p) { + return true + } + } + + return false +} + +// isAPIPath returns true for paths that always require authentication. +func isAPIPath(path string) bool { + return strings.HasPrefix(path, "/api/") || + strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/models/") || + strings.HasPrefix(path, "/backends/") || + strings.HasPrefix(path, "/backend/") || + strings.HasPrefix(path, "/tts") || + strings.HasPrefix(path, "/vad") || + strings.HasPrefix(path, "/video") || + strings.HasPrefix(path, "/stores/") || + strings.HasPrefix(path, "/system") +} + +// authError returns an appropriate error response. +func authError(c echo.Context, appConfig *config.ApplicationConfig) error { + c.Response().Header().Set("WWW-Authenticate", "Bearer") + + if appConfig.OpaqueErrors { + return c.NoContent(http.StatusUnauthorized) + } + + contentType := c.Request().Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "An authentication key is required", + Code: http.StatusUnauthorized, + Type: "invalid_request_error", + }, + }) + } + + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "An authentication key is required", + Code: http.StatusUnauthorized, + Type: "invalid_request_error", + }, + }) +} diff --git a/core/http/auth/middleware_test.go b/core/http/auth/middleware_test.go new file mode 100644 index 000000000000..102e27ca7285 --- /dev/null +++ b/core/http/auth/middleware_test.go @@ -0,0 +1,305 @@ +//go:build auth + +package auth_test + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Auth Middleware", func() { + + Context("auth disabled, no API keys", func() { + var app *echo.Echo + + BeforeEach(func() { + appConfig := config.NewApplicationConfig() + app = newAuthTestApp(nil, appConfig) + }) + + It("passes through all requests", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes through POST requests", func() { + rec := doRequest(app, http.MethodPost, "/v1/chat/completions") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + }) + + Context("auth disabled, API keys configured", func() { + var app *echo.Echo + const validKey = "sk-test-key-123" + + BeforeEach(func() { + appConfig := config.NewApplicationConfig() + appConfig.ApiKeys = []string{validKey} + app = newAuthTestApp(nil, appConfig) + }) + + It("returns 401 for request without key", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("passes with valid Bearer token", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes with valid x-api-key header", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes with valid token cookie", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 401 for invalid key", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key")) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("auth enabled with database", func() { + var ( + db *gorm.DB + app *echo.Echo + appConfig *config.ApplicationConfig + user *auth.User + ) + + BeforeEach(func() { + db = testDB() + appConfig = config.NewApplicationConfig() + app = newAuthTestApp(db, appConfig) + user = createTestUser(db, "user@example.com", auth.RoleUser, "github") + }) + + It("allows requests with valid session cookie", func() { + sessionID := createTestSession(db, user.ID) + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with valid session as Bearer token", func() { + sessionID := createTestSession(db, user.ID) + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with valid user API key as Bearer token", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with legacy API_KEY as admin bypass", func() { + appConfig.ApiKeys = []string{"legacy-key-123"} + app = newAuthTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("legacy-key-123")) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 401 for expired session", func() { + sessionID := createTestSession(db, user.ID) + // Manually expire + db.Model(&auth.Session{}).Where("id = ?", sessionID). + Update("expires_at", "2020-01-01") + + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for invalid session ID", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie("invalid-session-id")) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for revoked API key", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RevokeAPIKey(db, record.ID, user.ID) + Expect(err).ToNot(HaveOccurred()) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("skips auth for /api/auth/* paths", func() { + rec := doRequest(app, http.MethodGet, "/api/auth/status") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("skips auth for PathWithoutAuth paths", func() { + rec := doRequest(app, http.MethodGet, "/healthz") + // healthz is not registered in our test app, so it'll be 404/405 but NOT 401 + Expect(rec.Code).ToNot(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for unauthenticated API requests", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("allows unauthenticated access to non-API paths when no legacy keys", func() { + rec := doRequest(app, http.MethodGet, "/app") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + }) + + Describe("RequireAdmin", func() { + var ( + db *gorm.DB + appConfig *config.ApplicationConfig + ) + + BeforeEach(func() { + db = testDB() + appConfig = config.NewApplicationConfig() + }) + + It("passes for admin user", func() { + admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, "github") + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 403 for user role", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("returns 401 when no user in context", func() { + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("allows admin to access model management", func() { + admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, "github") + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks user from model management", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("allows user to access regular inference endpoints", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/v1/chat/completions", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows legacy API key (admin bypass) on admin routes", func() { + appConfig.ApiKeys = []string{"admin-key"} + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withBearerToken("admin-key")) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows admin to access trace endpoints", func() { + admin := createTestUser(db, "admin2@example.com", auth.RoleAdmin, "github") + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks non-admin from trace endpoints", func() { + user := createTestUser(db, "user2@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + + rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("allows admin to access agent job endpoints", func() { + admin := createTestUser(db, "admin3@example.com", auth.RoleAdmin, "github") + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks non-admin from agent job endpoints", func() { + user := createTestUser(db, "user3@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + + rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("blocks non-admin from system/management endpoints", func() { + user := createTestUser(db, "user4@example.com", auth.RoleUser, "github") + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { + rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden), "expected 403 for path: "+path) + } + }) + + It("allows admin to access system/management endpoints", func() { + admin := createTestUser(db, "admin4@example.com", auth.RoleAdmin, "github") + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { + rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK), "expected 200 for path: "+path) + } + }) + }) +}) diff --git a/core/http/auth/models.go b/core/http/auth/models.go new file mode 100644 index 000000000000..539a560f59f7 --- /dev/null +++ b/core/http/auth/models.go @@ -0,0 +1,103 @@ +package auth + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" +) + +// User represents an authenticated user. +type User struct { + ID string `gorm:"primaryKey;size:36"` + Email string `gorm:"size:255;index"` + Name string `gorm:"size:255"` + AvatarURL string `gorm:"size:512"` + Provider string `gorm:"size:50"` // "github", "oidc" + Subject string `gorm:"size:255"` // provider-specific user ID + PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users + Role string `gorm:"size:20;default:user"` + Status string `gorm:"size:20;default:active"` // "active", "pending" + CreatedAt time.Time + UpdatedAt time.Time +} + +// Session represents a user login session. +type Session struct { + ID string `gorm:"primaryKey;size:64"` // 64-char hex token + UserID string `gorm:"size:36;index"` + ExpiresAt time.Time + CreatedAt time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} + +// UserAPIKey represents a user-generated API key for programmatic access. +type UserAPIKey struct { + ID string `gorm:"primaryKey;size:36"` + UserID string `gorm:"size:36;index"` + Name string `gorm:"size:255"` // user-provided label + KeyHash string `gorm:"size:64;uniqueIndex"` + KeyPrefix string `gorm:"size:12"` // first 8 chars of key for display + Role string `gorm:"size:20"` + CreatedAt time.Time + LastUsed *time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} + +// PermissionMap is a flexible map of feature -> enabled, stored as JSON text. +// Known features: "agents", "skills", "collections", "mcp_jobs". +// New features can be added without schema changes. +type PermissionMap map[string]bool + +// Value implements driver.Valuer for GORM JSON serialization. +func (p PermissionMap) Value() (driver.Value, error) { + if p == nil { + return "{}", nil + } + b, err := json.Marshal(p) + if err != nil { + return nil, fmt.Errorf("failed to marshal PermissionMap: %w", err) + } + return string(b), nil +} + +// Scan implements sql.Scanner for GORM JSON deserialization. +func (p *PermissionMap) Scan(value any) error { + if value == nil { + *p = PermissionMap{} + return nil + } + var bytes []byte + switch v := value.(type) { + case string: + bytes = []byte(v) + case []byte: + bytes = v + default: + return fmt.Errorf("cannot scan %T into PermissionMap", value) + } + return json.Unmarshal(bytes, p) +} + +// InviteCode represents an admin-generated invitation for user registration. +type InviteCode struct { + ID string `gorm:"primaryKey;size:36"` + Code string `gorm:"uniqueIndex;not null;size:64"` + CreatedBy string `gorm:"size:36;not null"` + UsedBy *string `gorm:"size:36"` + UsedAt *time.Time + ExpiresAt time.Time `gorm:"not null;index"` + CreatedAt time.Time + Creator User `gorm:"foreignKey:CreatedBy"` + Consumer *User `gorm:"foreignKey:UsedBy"` +} + +// UserPermission stores per-user feature permissions. +type UserPermission struct { + ID string `gorm:"primaryKey;size:36"` + UserID string `gorm:"size:36;uniqueIndex"` + Permissions PermissionMap `gorm:"type:text"` + CreatedAt time.Time + UpdatedAt time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} diff --git a/core/http/auth/oauth.go b/core/http/auth/oauth.go new file mode 100644 index 000000000000..f429b0e84269 --- /dev/null +++ b/core/http/auth/oauth.go @@ -0,0 +1,331 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "golang.org/x/oauth2" + githubOAuth "golang.org/x/oauth2/github" + "gorm.io/gorm" +) + +// OAuthProvider holds the OAuth2 config and user-info fetch logic for a provider. +type OAuthProvider struct { + Config *oauth2.Config + UserInfoURL string + Name string +} + +// OAuthManager manages multiple OAuth providers. +type OAuthManager struct { + providers map[string]*OAuthProvider +} + +// NewOAuthManager creates an OAuthManager from the given AuthConfig. +func NewOAuthManager(baseURL, githubClientID, githubClientSecret string) (*OAuthManager, error) { + m := &OAuthManager{providers: make(map[string]*OAuthProvider)} + + if githubClientID != "" { + m.providers["github"] = &OAuthProvider{ + Name: "github", + Config: &oauth2.Config{ + ClientID: githubClientID, + ClientSecret: githubClientSecret, + Endpoint: githubOAuth.Endpoint, + RedirectURL: baseURL + "/api/auth/github/callback", + Scopes: []string{"user:email", "read:user"}, + }, + UserInfoURL: "https://api.github.com/user", + } + } + + return m, nil +} + +// Providers returns the list of configured provider names. +func (m *OAuthManager) Providers() []string { + names := make([]string, 0, len(m.providers)) + for name := range m.providers { + names = append(names, name) + } + return names +} + +// LoginHandler redirects the user to the OAuth provider's login page. +func (m *OAuthManager) LoginHandler(providerName string) echo.HandlerFunc { + return func(c echo.Context) error { + provider, ok := m.providers[providerName] + if !ok { + return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) + } + + state, err := generateState() + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate state"}) + } + + c.SetCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 600, // 10 minutes + }) + + // Store invite code in cookie if provided + if inviteCode := c.QueryParam("invite_code"); inviteCode != "" { + c.SetCookie(&http.Cookie{ + Name: "invite_code", + Value: inviteCode, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 600, + }) + } + + url := provider.Config.AuthCodeURL(state) + return c.Redirect(http.StatusTemporaryRedirect, url) + } +} + +// CallbackHandler handles the OAuth callback, creates/updates the user, and +// creates a session. +func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEmail, registrationMode string) echo.HandlerFunc { + return func(c echo.Context) error { + provider, ok := m.providers[providerName] + if !ok { + return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) + } + + // Validate state + stateCookie, err := c.Cookie("oauth_state") + if err != nil || stateCookie.Value == "" || stateCookie.Value != c.QueryParam("state") { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid OAuth state"}) + } + + // Clear state cookie + c.SetCookie(&http.Cookie{ + Name: "oauth_state", + Value: "", + Path: "/", + HttpOnly: true, + MaxAge: -1, + }) + + // Exchange code for token + code := c.QueryParam("code") + if code == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing authorization code"}) + } + + ctx, cancel := context.WithTimeout(c.Request().Context(), 30*time.Second) + defer cancel() + + token, err := provider.Config.Exchange(ctx, code) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to exchange code: " + err.Error()}) + } + + // Fetch user info + userInfo, err := fetchGitHubUserInfo(ctx, token.AccessToken) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to fetch user info"}) + } + + // Retrieve invite code from cookie if present + var inviteCode string + if ic, err := c.Cookie("invite_code"); err == nil && ic.Value != "" { + inviteCode = ic.Value + // Clear the invite code cookie + c.SetCookie(&http.Cookie{ + Name: "invite_code", + Value: "", + Path: "/", + HttpOnly: true, + MaxAge: -1, + }) + } + + // Upsert user (with invite code support) + user, err := upsertUser(db, providerName, userInfo, adminEmail, registrationMode) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create user"}) + } + + // For new users that are pending, check if they have a valid invite + if user.Status != "active" && inviteCode != "" { + if invite, err := ValidateInvite(db, inviteCode); err == nil { + user.Status = "active" + db.Model(user).Update("status", "active") + ConsumeInvite(db, invite, user.ID) + } + } + + if user.Status != "active" { + if registrationMode == "invite" { + return c.JSON(http.StatusForbidden, map[string]string{"error": "a valid invite code is required to register"}) + } + return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"}) + } + + // Maybe promote on login + MaybePromote(db, user, adminEmail) + + // Create session + sessionID, err := CreateSession(db, user.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create session"}) + } + + SetSessionCookie(c, sessionID) + return c.Redirect(http.StatusTemporaryRedirect, "/app") + } +} + +type githubUserInfo struct { + ID int `json:"id"` + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` + AvatarURL string `json:"avatar_url"` +} + +type githubEmail struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) { + client := &http.Client{Timeout: 10 * time.Second} + + req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var info githubUserInfo + if err := json.Unmarshal(body, &info); err != nil { + return nil, err + } + + // If no public email, fetch from /user/emails + if info.Email == "" { + info.Email, _ = fetchGitHubPrimaryEmail(ctx, accessToken) + } + + return &info, nil +} + +func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) { + client := &http.Client{Timeout: 10 * time.Second} + + req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var emails []githubEmail + if err := json.Unmarshal(body, &emails); err != nil { + return "", err + } + + for _, e := range emails { + if e.Primary && e.Verified { + return e.Email, nil + } + } + + // Fall back to first verified email + for _, e := range emails { + if e.Verified { + return e.Email, nil + } + } + + return "", fmt.Errorf("no verified email found") +} + +func upsertUser(db *gorm.DB, provider string, info *githubUserInfo, adminEmail, registrationMode string) (*User, error) { + subject := fmt.Sprintf("%d", info.ID) + + var user User + err := db.Where("provider = ? AND subject = ?", provider, subject).First(&user).Error + if err == nil { + // Existing user — update profile fields + user.Name = info.Name + user.AvatarURL = info.AvatarURL + if info.Email != "" { + user.Email = info.Email + } + db.Save(&user) + return &user, nil + } + + // New user + status := "active" + if registrationMode == "approval" || registrationMode == "invite" { + status = "pending" + } + + role := AssignRole(db, info.Email, adminEmail) + // First user is always active regardless of registration mode + if role == RoleAdmin { + status = "active" + } + + user = User{ + ID: uuid.New().String(), + Email: info.Email, + Name: info.Name, + AvatarURL: info.AvatarURL, + Provider: provider, + Subject: subject, + Role: role, + Status: status, + } + + if err := db.Create(&user).Error; err != nil { + return nil, err + } + + return &user, nil +} + +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/core/http/auth/password.go b/core/http/auth/password.go new file mode 100644 index 000000000000..4c88fedb7267 --- /dev/null +++ b/core/http/auth/password.go @@ -0,0 +1,14 @@ +package auth + +import "golang.org/x/crypto/bcrypt" + +// HashPassword returns a bcrypt hash of the given password. +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +// CheckPassword compares a bcrypt hash with a plaintext password. +func CheckPassword(hash, password string) bool { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil +} diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go new file mode 100644 index 000000000000..1710f364b951 --- /dev/null +++ b/core/http/auth/permissions.go @@ -0,0 +1,93 @@ +package auth + +import ( + "github.com/google/uuid" + "gorm.io/gorm" +) + +// Feature name constants — all code must use these, never bare strings. +const ( + FeatureAgents = "agents" + FeatureSkills = "skills" + FeatureCollections = "collections" + FeatureMCPJobs = "mcp_jobs" +) + +// AllFeatures lists all known features (used by UI and validation). +var AllFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs} + +// GetUserPermissions returns the permission record for a user, creating a default +// (empty map = all disabled) if none exists. +func GetUserPermissions(db *gorm.DB, userID string) (*UserPermission, error) { + var perm UserPermission + err := db.Where("user_id = ?", userID).First(&perm).Error + if err == gorm.ErrRecordNotFound { + perm = UserPermission{ + ID: uuid.New().String(), + UserID: userID, + Permissions: PermissionMap{}, + } + if err := db.Create(&perm).Error; err != nil { + return nil, err + } + return &perm, nil + } + if err != nil { + return nil, err + } + return &perm, nil +} + +// UpdateUserPermissions upserts the permission map for a user. +func UpdateUserPermissions(db *gorm.DB, userID string, perms PermissionMap) error { + var perm UserPermission + err := db.Where("user_id = ?", userID).First(&perm).Error + if err == gorm.ErrRecordNotFound { + perm = UserPermission{ + ID: uuid.New().String(), + UserID: userID, + Permissions: perms, + } + return db.Create(&perm).Error + } + if err != nil { + return err + } + perm.Permissions = perms + return db.Save(&perm).Error +} + +// HasFeatureAccess returns true if the user is an admin or has the given feature enabled. +func HasFeatureAccess(db *gorm.DB, user *User, feature string) bool { + if user == nil { + return false + } + if user.Role == RoleAdmin { + return true + } + perm, err := GetUserPermissions(db, user.ID) + if err != nil { + return false + } + return perm.Permissions[feature] +} + +// GetPermissionMapForUser returns the effective permission map for a user. +// Admins get all features as true (virtual). +func GetPermissionMapForUser(db *gorm.DB, user *User) PermissionMap { + if user == nil { + return PermissionMap{} + } + if user.Role == RoleAdmin { + m := PermissionMap{} + for _, f := range AllFeatures { + m[f] = true + } + return m + } + perm, err := GetUserPermissions(db, user.ID) + if err != nil { + return PermissionMap{} + } + return perm.Permissions +} diff --git a/core/http/auth/roles.go b/core/http/auth/roles.go new file mode 100644 index 000000000000..c75457b06861 --- /dev/null +++ b/core/http/auth/roles.go @@ -0,0 +1,90 @@ +package auth + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +const ( + RoleAdmin = "admin" + RoleUser = "user" +) + +// AssignRole determines the role for a new user. +// First user in the database becomes admin. If adminEmail is set and matches, +// the user becomes admin. Otherwise, the user gets the "user" role. +func AssignRole(db *gorm.DB, email, adminEmail string) string { + var count int64 + db.Model(&User{}).Count(&count) + if count == 0 { + return RoleAdmin + } + + if adminEmail != "" && strings.EqualFold(email, adminEmail) { + return RoleAdmin + } + + return RoleUser +} + +// MaybePromote promotes a user to admin on login if their email matches +// adminEmail. It does not demote existing admins. Returns true if the user +// was promoted. +func MaybePromote(db *gorm.DB, user *User, adminEmail string) bool { + if user.Role == RoleAdmin { + return false + } + + if adminEmail != "" && strings.EqualFold(user.Email, adminEmail) { + user.Role = RoleAdmin + db.Model(user).Update("role", RoleAdmin) + return true + } + + return false +} + +// ValidateInvite checks that an invite code exists, is unused, and has not expired. +func ValidateInvite(db *gorm.DB, code string) (*InviteCode, error) { + var invite InviteCode + if err := db.Where("code = ?", code).First(&invite).Error; err != nil { + return nil, fmt.Errorf("invite code not found") + } + if invite.UsedBy != nil { + return nil, fmt.Errorf("invite code already used") + } + if time.Now().After(invite.ExpiresAt) { + return nil, fmt.Errorf("invite code expired") + } + return &invite, nil +} + +// ConsumeInvite marks an invite code as used by the given user. +func ConsumeInvite(db *gorm.DB, invite *InviteCode, userID string) { + now := time.Now() + invite.UsedBy = &userID + invite.UsedAt = &now + db.Save(invite) +} + +// NeedsInviteOrApproval returns true if registration gating applies for the given mode. +// Admins (first user or matching adminEmail) are never gated. +func NeedsInviteOrApproval(db *gorm.DB, email, adminEmail, registrationMode string) bool { + if registrationMode != "approval" && registrationMode != "invite" { + return false + } + // Admin email is never gated + if adminEmail != "" && strings.EqualFold(email, adminEmail) { + return false + } + // First user is never gated + var count int64 + db.Model(&User{}).Count(&count) + if count == 0 { + return false + } + return true +} diff --git a/core/http/auth/roles_test.go b/core/http/auth/roles_test.go new file mode 100644 index 000000000000..4b73c90a7775 --- /dev/null +++ b/core/http/auth/roles_test.go @@ -0,0 +1,84 @@ +//go:build auth + +package auth_test + +import ( + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Roles", func() { + var db *gorm.DB + + BeforeEach(func() { + db = testDB() + }) + + Describe("AssignRole", func() { + It("returns admin for the first user (empty DB)", func() { + role := auth.AssignRole(db, "first@example.com", "") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("returns user for the second user", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, "github") + + role := auth.AssignRole(db, "second@example.com", "") + Expect(role).To(Equal(auth.RoleUser)) + }) + + It("returns admin when email matches adminEmail", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, "github") + + role := auth.AssignRole(db, "admin@example.com", "admin@example.com") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("is case-insensitive for admin email match", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, "github") + + role := auth.AssignRole(db, "Admin@Example.COM", "admin@example.com") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("returns user when email does not match adminEmail", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, "github") + + role := auth.AssignRole(db, "other@example.com", "admin@example.com") + Expect(role).To(Equal(auth.RoleUser)) + }) + }) + + Describe("MaybePromote", func() { + It("promotes user to admin when email matches", func() { + user := createTestUser(db, "promoted@example.com", auth.RoleUser, "github") + + promoted := auth.MaybePromote(db, user, "promoted@example.com") + Expect(promoted).To(BeTrue()) + Expect(user.Role).To(Equal(auth.RoleAdmin)) + + // Verify in DB + var dbUser auth.User + db.First(&dbUser, "id = ?", user.ID) + Expect(dbUser.Role).To(Equal(auth.RoleAdmin)) + }) + + It("does not promote when email does not match", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, "github") + + promoted := auth.MaybePromote(db, user, "admin@example.com") + Expect(promoted).To(BeFalse()) + Expect(user.Role).To(Equal(auth.RoleUser)) + }) + + It("does not demote an existing admin", func() { + user := createTestUser(db, "admin@example.com", auth.RoleAdmin, "github") + + promoted := auth.MaybePromote(db, user, "other@example.com") + Expect(promoted).To(BeFalse()) + Expect(user.Role).To(Equal(auth.RoleAdmin)) + }) + }) +}) diff --git a/core/http/auth/session.go b/core/http/auth/session.go new file mode 100644 index 000000000000..f1bab72a699c --- /dev/null +++ b/core/http/auth/session.go @@ -0,0 +1,87 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +const ( + sessionDuration = 30 * 24 * time.Hour // 30 days + sessionIDBytes = 32 // 32 bytes = 64 hex chars + sessionCookie = "session" +) + +// CreateSession creates a new session for the given user, returning the +// session ID (64-char hex string). +func CreateSession(db *gorm.DB, userID string) (string, error) { + b := make([]byte, sessionIDBytes) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate session ID: %w", err) + } + + sessionID := hex.EncodeToString(b) + + session := Session{ + ID: sessionID, + UserID: userID, + ExpiresAt: time.Now().Add(sessionDuration), + } + + if err := db.Create(&session).Error; err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + + return sessionID, nil +} + +// ValidateSession looks up a session by ID and returns the associated user. +// Returns nil if the session is not found or expired. +func ValidateSession(db *gorm.DB, sessionID string) *User { + var session Session + if err := db.Preload("User").Where("id = ? AND expires_at > ?", sessionID, time.Now()).First(&session).Error; err != nil { + return nil + } + return &session.User +} + +// DeleteSession removes a session from the database. +func DeleteSession(db *gorm.DB, sessionID string) error { + return db.Where("id = ?", sessionID).Delete(&Session{}).Error +} + +// CleanExpiredSessions removes all sessions that have passed their expiry time. +func CleanExpiredSessions(db *gorm.DB) error { + return db.Where("expires_at < ?", time.Now()).Delete(&Session{}).Error +} + +// SetSessionCookie sets the session cookie on the response. +func SetSessionCookie(c echo.Context, sessionID string) { + cookie := &http.Cookie{ + Name: sessionCookie, + Value: sessionID, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: int(sessionDuration.Seconds()), + } + c.SetCookie(cookie) +} + +// ClearSessionCookie clears the session cookie. +func ClearSessionCookie(c echo.Context) { + cookie := &http.Cookie{ + Name: sessionCookie, + Value: "", + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: -1, + } + c.SetCookie(cookie) +} diff --git a/core/http/auth/session_test.go b/core/http/auth/session_test.go new file mode 100644 index 000000000000..a58ac3c2fbeb --- /dev/null +++ b/core/http/auth/session_test.go @@ -0,0 +1,123 @@ +//go:build auth + +package auth_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Sessions", func() { + var ( + db *gorm.DB + user *auth.User + ) + + BeforeEach(func() { + db = testDB() + user = createTestUser(db, "session@example.com", auth.RoleUser, "github") + }) + + Describe("CreateSession", func() { + It("creates a session with 64-char hex ID", func() { + sessionID, err := auth.CreateSession(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(sessionID).To(HaveLen(64)) + }) + + It("sets expiry to approximately 30 days from now", func() { + sessionID, err := auth.CreateSession(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + + var session auth.Session + db.First(&session, "id = ?", sessionID) + + expectedExpiry := time.Now().Add(30 * 24 * time.Hour) + Expect(session.ExpiresAt).To(BeTemporally("~", expectedExpiry, time.Minute)) + }) + + It("associates session with correct user", func() { + sessionID, err := auth.CreateSession(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + + var session auth.Session + db.First(&session, "id = ?", sessionID) + Expect(session.UserID).To(Equal(user.ID)) + }) + }) + + Describe("ValidateSession", func() { + It("returns user for valid session", func() { + sessionID := createTestSession(db, user.ID) + + found := auth.ValidateSession(db, sessionID) + Expect(found).ToNot(BeNil()) + Expect(found.ID).To(Equal(user.ID)) + }) + + It("returns nil for non-existent session", func() { + found := auth.ValidateSession(db, "nonexistent-session-id") + Expect(found).To(BeNil()) + }) + + It("returns nil for expired session", func() { + sessionID := createTestSession(db, user.ID) + + // Manually expire the session + db.Model(&auth.Session{}).Where("id = ?", sessionID). + Update("expires_at", time.Now().Add(-1*time.Hour)) + + found := auth.ValidateSession(db, sessionID) + Expect(found).To(BeNil()) + }) + }) + + Describe("DeleteSession", func() { + It("removes the session from DB", func() { + sessionID := createTestSession(db, user.ID) + + err := auth.DeleteSession(db, sessionID) + Expect(err).ToNot(HaveOccurred()) + + found := auth.ValidateSession(db, sessionID) + Expect(found).To(BeNil()) + }) + + It("does not error on non-existent session", func() { + err := auth.DeleteSession(db, "nonexistent") + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("CleanExpiredSessions", func() { + It("removes expired sessions", func() { + sessionID := createTestSession(db, user.ID) + + // Manually expire the session + db.Model(&auth.Session{}).Where("id = ?", sessionID). + Update("expires_at", time.Now().Add(-1*time.Hour)) + + err := auth.CleanExpiredSessions(db) + Expect(err).ToNot(HaveOccurred()) + + var count int64 + db.Model(&auth.Session{}).Where("id = ?", sessionID).Count(&count) + Expect(count).To(Equal(int64(0))) + }) + + It("keeps active sessions", func() { + sessionID := createTestSession(db, user.ID) + + err := auth.CleanExpiredSessions(db) + Expect(err).ToNot(HaveOccurred()) + + var count int64 + db.Model(&auth.Session{}).Where("id = ?", sessionID).Count(&count) + Expect(count).To(Equal(int64(1))) + }) + }) +}) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go new file mode 100644 index 000000000000..08841a442dbc --- /dev/null +++ b/core/http/auth/usage.go @@ -0,0 +1,151 @@ +package auth + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +// UsageRecord represents a single API request's token usage. +type UsageRecord struct { + ID uint `gorm:"primaryKey;autoIncrement"` + UserID string `gorm:"size:36;index:idx_usage_user_time"` + UserName string `gorm:"size:255"` + Model string `gorm:"size:255;index"` + Endpoint string `gorm:"size:255"` + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 + Duration int64 // milliseconds + CreatedAt time.Time `gorm:"index:idx_usage_user_time"` +} + +// RecordUsage inserts a usage record. +func RecordUsage(db *gorm.DB, record *UsageRecord) error { + return db.Create(record).Error +} + +// UsageBucket is an aggregated time bucket for the dashboard. +type UsageBucket struct { + Bucket string `json:"bucket"` + Model string `json:"model"` + UserID string `json:"user_id,omitempty"` + UserName string `json:"user_name,omitempty"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +// UsageTotals is a summary of all usage. +type UsageTotals struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +// periodToWindow returns the time window and SQL date format for a period. +func periodToWindow(period string, isSQLite bool) (time.Time, string) { + now := time.Now() + var since time.Time + var dateFmt string + + switch period { + case "day": + since = now.Add(-24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d %H:00', created_at)" + } else { + dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')" + } + case "week": + since = now.Add(-7 * 24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d', created_at)" + } else { + dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" + } + case "all": + since = time.Time{} // zero time = no filter + if isSQLite { + dateFmt = "strftime('%Y-%m', created_at)" + } else { + dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')" + } + default: // "month" + since = now.Add(-30 * 24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d', created_at)" + } else { + dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" + } + } + + return since, dateFmt +} + +func isSQLiteDB(db *gorm.DB) bool { + return strings.Contains(db.Dialector.Name(), "sqlite") +} + +// GetUserUsage returns aggregated usage for a single user. +func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", model, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Where("user_id = ?", userID). + Group("bucket, model"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, err + } + return buckets, nil +} + +// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter. +func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", model, user_id, user_name, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Group("bucket, model, user_id, user_name"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + if userID != "" { + query = query.Where("user_id = ?", userID) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, err + } + return buckets, nil +} diff --git a/core/http/auth/usage_test.go b/core/http/auth/usage_test.go new file mode 100644 index 000000000000..0c3fa5df5846 --- /dev/null +++ b/core/http/auth/usage_test.go @@ -0,0 +1,161 @@ +//go:build auth + +package auth_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Usage", func() { + Describe("RecordUsage", func() { + It("inserts a usage record", func() { + db := testDB() + record := &auth.UsageRecord{ + UserID: "user-1", + UserName: "Test User", + Model: "gpt-4", + Endpoint: "/v1/chat/completions", + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + Duration: 1200, + CreatedAt: time.Now(), + } + err := auth.RecordUsage(db, record) + Expect(err).ToNot(HaveOccurred()) + Expect(record.ID).ToNot(BeZero()) + }) + }) + + Describe("GetUserUsage", func() { + It("returns aggregated usage for a specific user", func() { + db := testDB() + + // Insert records for two users + for i := 0; i < 3; i++ { + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-a", + UserName: "Alice", + Model: "gpt-4", + Endpoint: "/v1/chat/completions", + PromptTokens: 100, + TotalTokens: 150, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + } + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-b", + UserName: "Bob", + Model: "gpt-4", + PromptTokens: 200, + TotalTokens: 300, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + buckets, err := auth.GetUserUsage(db, "user-a", "month") + Expect(err).ToNot(HaveOccurred()) + Expect(buckets).ToNot(BeEmpty()) + + // All returned buckets should be for user-a's model + totalPrompt := int64(0) + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(300))) + }) + + It("filters by period", func() { + db := testDB() + + // Record in the past (beyond day window) + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-c", + UserName: "Carol", + Model: "gpt-4", + PromptTokens: 100, + TotalTokens: 100, + CreatedAt: time.Now().Add(-48 * time.Hour), + }) + Expect(err).ToNot(HaveOccurred()) + + // Record now + err = auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-c", + UserName: "Carol", + Model: "gpt-4", + PromptTokens: 200, + TotalTokens: 200, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + // Day period should only include recent record + buckets, err := auth.GetUserUsage(db, "user-c", "day") + Expect(err).ToNot(HaveOccurred()) + totalPrompt := int64(0) + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(200))) + + // Month period should include both + buckets, err = auth.GetUserUsage(db, "user-c", "month") + Expect(err).ToNot(HaveOccurred()) + totalPrompt = 0 + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(300))) + }) + }) + + Describe("GetAllUsage", func() { + It("returns usage for all users", func() { + db := testDB() + + for _, uid := range []string{"user-x", "user-y"} { + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: uid, + UserName: uid, + Model: "gpt-4", + PromptTokens: 100, + TotalTokens: 150, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + } + + buckets, err := auth.GetAllUsage(db, "month", "") + Expect(err).ToNot(HaveOccurred()) + Expect(len(buckets)).To(BeNumerically(">=", 2)) + }) + + It("filters by user ID when specified", func() { + db := testDB() + + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-p", UserName: "Pat", Model: "gpt-4", + PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-q", UserName: "Quinn", Model: "gpt-4", + PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + buckets, err := auth.GetAllUsage(db, "month", "user-p") + Expect(err).ToNot(HaveOccurred()) + for _, b := range buckets { + Expect(b.UserID).To(Equal("user-p")) + } + }) + }) +}) diff --git a/core/http/endpoints/localai/agent_collections.go b/core/http/endpoints/localai/agent_collections.go index 49b6ea386dc8..439d75c3cef0 100644 --- a/core/http/endpoints/localai/agent_collections.go +++ b/core/http/endpoints/localai/agent_collections.go @@ -12,27 +12,54 @@ import ( func ListCollectionsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - collections, err := svc.ListCollections() + userID := getUserID(c) + cols, err := svc.ListCollectionsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - return c.JSON(http.StatusOK, map[string]any{ - "collections": collections, - "count": len(collections), - }) + + resp := map[string]any{ + "collections": cols, + "count": len(cols), + } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + usm := svc.UserServicesManager() + if usm != nil { + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userCols, err := svc.ListCollectionsForUser(uid) + if err != nil || len(userCols) == 0 { + continue + } + userGroups[uid] = map[string]any{"collections": userCols} + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + } + } + + return c.JSON(http.StatusOK, resp) } } func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Name string `json:"name"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.CreateCollection(payload.Name); err != nil { + if err := svc.CreateCollectionForUser(userID, payload.Name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok", "name": payload.Name}) @@ -42,12 +69,13 @@ func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc { func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) } - if svc.CollectionEntryExists(name, file.Filename) { + if svc.CollectionEntryExistsForUser(userID, name, file.Filename) { return c.JSON(http.StatusConflict, map[string]string{"error": "entry already exists"}) } src, err := file.Open() @@ -55,7 +83,7 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } defer src.Close() - if err := svc.UploadToCollection(name, file.Filename, src); err != nil { + if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -68,7 +96,8 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - entries, err := svc.ListCollectionEntries(c.Param("name")) + userID := getUserID(c) + entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -85,12 +114,13 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun func GetCollectionEntryContentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } - content, chunkCount, err := svc.GetCollectionEntryContent(c.Param("name"), entry) + content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -107,6 +137,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Query string `json:"query"` MaxResults int `json:"max_results"` @@ -114,7 +145,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - results, err := svc.SearchCollection(c.Param("name"), payload.Query, payload.MaxResults) + results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -131,7 +162,8 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.ResetCollection(c.Param("name")); err != nil { + userID := getUserID(c) + if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -144,13 +176,14 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc { func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Entry string `json:"entry"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - remaining, err := svc.DeleteCollectionEntry(c.Param("name"), payload.Entry) + remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -167,6 +200,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` UpdateInterval int `json:"update_interval"` @@ -177,7 +211,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc if payload.UpdateInterval < 1 { payload.UpdateInterval = 60 } - if err := svc.AddCollectionSource(c.Param("name"), payload.URL, payload.UpdateInterval); err != nil { + if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -190,13 +224,14 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.RemoveCollectionSource(c.Param("name"), payload.URL); err != nil { + if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -207,12 +242,13 @@ func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFu func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } - fpath, err := svc.GetCollectionEntryFilePath(c.Param("name"), entry) + fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -226,7 +262,8 @@ func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.Handle func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - sources, err := svc.ListCollectionSources(c.Param("name")) + userID := getUserID(c) + sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) diff --git a/core/http/endpoints/localai/agent_jobs.go b/core/http/endpoints/localai/agent_jobs.go index c46a0208a10f..8ed20d7df446 100644 --- a/core/http/endpoints/localai/agent_jobs.go +++ b/core/http/endpoints/localai/agent_jobs.go @@ -8,19 +8,27 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" ) -// CreateTaskEndpoint creates a new agent task -// @Summary Create a new agent task -// @Description Create a new reusable agent task with prompt template and configuration -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param task body schema.Task true "Task definition" -// @Success 201 {object} map[string]string "Task created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 500 {object} map[string]string "Internal server error" -// @Router /api/agent/tasks [post] +// getJobService returns the job service for the current user. +// Falls back to the global service when no user is authenticated. +func getJobService(app *application.Application, c echo.Context) *services.AgentJobService { + userID := getUserID(c) + if userID == "" { + return app.AgentJobService() + } + svc := app.AgentPoolService() + if svc == nil { + return app.AgentJobService() + } + jobSvc, err := svc.JobServiceForUser(userID) + if err != nil { + return app.AgentJobService() + } + return jobSvc +} + func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var task schema.Task @@ -28,7 +36,7 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } - id, err := app.AgentJobService().CreateTask(task) + id, err := getJobService(app, c).CreateTask(task) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -37,18 +45,6 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// UpdateTaskEndpoint updates an existing task -// @Summary Update an agent task -// @Description Update an existing agent task -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param id path string true "Task ID" -// @Param task body schema.Task true "Updated task definition" -// @Success 200 {object} map[string]string "Task updated" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [put] func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") @@ -57,7 +53,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } - if err := app.AgentJobService().UpdateTask(id, task); err != nil { + if err := getJobService(app, c).UpdateTask(id, task); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -68,19 +64,10 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// DeleteTaskEndpoint deletes a task -// @Summary Delete an agent task -// @Description Delete an agent task by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Task ID" -// @Success 200 {object} map[string]string "Task deleted" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [delete] func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().DeleteTask(id); err != nil { + if err := getJobService(app, c).DeleteTask(id); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -91,33 +78,52 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// ListTasksEndpoint lists all tasks -// @Summary List all agent tasks -// @Description Get a list of all agent tasks -// @Tags agent-jobs -// @Produce json -// @Success 200 {array} schema.Task "List of tasks" -// @Router /api/agent/tasks [get] func ListTasksEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - tasks := app.AgentJobService().ListTasks() + jobSvc := getJobService(app, c) + tasks := jobSvc.ListTasks() + + // Admin cross-user aggregation + if wantsAllUsers(c) { + svc := app.AgentPoolService() + if svc != nil { + usm := svc.UserServicesManager() + if usm != nil { + userID := getUserID(c) + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userJobSvc, err := svc.JobServiceForUser(uid) + if err != nil { + continue + } + userTasks := userJobSvc.ListTasks() + if len(userTasks) == 0 { + continue + } + userGroups[uid] = map[string]any{"tasks": userTasks} + } + if len(userGroups) > 0 { + return c.JSON(http.StatusOK, map[string]any{ + "tasks": tasks, + "user_groups": userGroups, + }) + } + } + } + } + return c.JSON(http.StatusOK, tasks) } } -// GetTaskEndpoint gets a task by ID -// @Summary Get an agent task -// @Description Get an agent task by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Task ID" -// @Success 200 {object} schema.Task "Task details" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [get] func GetTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - task, err := app.AgentJobService().GetTask(id) + task, err := getJobService(app, c).GetTask(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -126,16 +132,6 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// ExecuteJobEndpoint executes a job -// @Summary Execute an agent job -// @Description Create and execute a new agent job -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param request body schema.JobExecutionRequest true "Job execution request" -// @Success 201 {object} schema.JobExecutionResponse "Job created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Router /api/agent/jobs/execute [post] func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var req schema.JobExecutionRequest @@ -147,7 +143,6 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { req.Parameters = make(map[string]string) } - // Build multimedia struct from request var multimedia *schema.MultimediaAttachment if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 { multimedia = &schema.MultimediaAttachment{ @@ -158,7 +153,7 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { } } - jobID, err := app.AgentJobService().ExecuteJob(req.TaskID, req.Parameters, "api", multimedia) + jobID, err := getJobService(app, c).ExecuteJob(req.TaskID, req.Parameters, "api", multimedia) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -172,19 +167,10 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// GetJobEndpoint gets a job by ID -// @Summary Get an agent job -// @Description Get an agent job by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} schema.Job "Job details" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id} [get] func GetJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - job, err := app.AgentJobService().GetJob(id) + job, err := getJobService(app, c).GetJob(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -193,16 +179,6 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// ListJobsEndpoint lists jobs with optional filtering -// @Summary List agent jobs -// @Description Get a list of agent jobs, optionally filtered by task_id and status -// @Tags agent-jobs -// @Produce json -// @Param task_id query string false "Filter by task ID" -// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)" -// @Param limit query int false "Limit number of results" -// @Success 200 {array} schema.Job "List of jobs" -// @Router /api/agent/jobs [get] func ListJobsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var taskID *string @@ -224,25 +200,50 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc { } } - jobs := app.AgentJobService().ListJobs(taskID, status, limit) + jobSvc := getJobService(app, c) + jobs := jobSvc.ListJobs(taskID, status, limit) + + // Admin cross-user aggregation + if wantsAllUsers(c) { + svc := app.AgentPoolService() + if svc != nil { + usm := svc.UserServicesManager() + if usm != nil { + userID := getUserID(c) + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userJobSvc, err := svc.JobServiceForUser(uid) + if err != nil { + continue + } + userJobs := userJobSvc.ListJobs(taskID, status, limit) + if len(userJobs) == 0 { + continue + } + userGroups[uid] = map[string]any{"jobs": userJobs} + } + if len(userGroups) > 0 { + return c.JSON(http.StatusOK, map[string]any{ + "jobs": jobs, + "user_groups": userGroups, + }) + } + } + } + } + return c.JSON(http.StatusOK, jobs) } } -// CancelJobEndpoint cancels a running job -// @Summary Cancel an agent job -// @Description Cancel a running or pending agent job -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} map[string]string "Job cancelled" -// @Failure 400 {object} map[string]string "Job cannot be cancelled" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id}/cancel [post] func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().CancelJob(id); err != nil { + if err := getJobService(app, c).CancelJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -253,19 +254,10 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// DeleteJobEndpoint deletes a job -// @Summary Delete an agent job -// @Description Delete an agent job by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} map[string]string "Job deleted" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id} [delete] func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().DeleteJob(id); err != nil { + if err := getJobService(app, c).DeleteJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -276,52 +268,33 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// ExecuteTaskByNameEndpoint executes a task by name -// @Summary Execute a task by name -// @Description Execute an agent task by its name (convenience endpoint). Parameters can be provided in the request body as a JSON object with string values. -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param name path string true "Task name" -// @Param request body map[string]string false "Template parameters (JSON object with string values)" -// @Success 201 {object} schema.JobExecutionResponse "Job created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{name}/execute [post] func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { name := c.Param("name") var params map[string]string - // Try to bind parameters from request body - // If body is empty or invalid, use empty params if c.Request().ContentLength > 0 { if err := c.Bind(¶ms); err != nil { - // If binding fails, try to read as raw JSON body := make(map[string]interface{}) if err := c.Bind(&body); err == nil { - // Convert interface{} values to strings params = make(map[string]string) for k, v := range body { if str, ok := v.(string); ok { params[k] = str } else { - // Convert non-string values to string params[k] = fmt.Sprintf("%v", v) } } } else { - // If all binding fails, use empty params params = make(map[string]string) } } } else { - // No body provided, use empty params params = make(map[string]string) } - // Find task by name - tasks := app.AgentJobService().ListTasks() + jobSvc := getJobService(app, c) + tasks := jobSvc.ListTasks() var task *schema.Task for _, t := range tasks { if t.Name == name { @@ -334,7 +307,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name}) } - jobID, err := app.AgentJobService().ExecuteJob(task.ID, params, "api", nil) + jobID, err := jobSvc.ExecuteJob(task.ID, params, "api", nil) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } diff --git a/core/http/endpoints/localai/agent_skills.go b/core/http/endpoints/localai/agent_skills.go index 0a9d998c4ac0..8071cad42d50 100644 --- a/core/http/endpoints/localai/agent_skills.go +++ b/core/http/endpoints/localai/agent_skills.go @@ -44,10 +44,38 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse { func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - skills, err := svc.ListSkills() + userID := getUserID(c) + skills, err := svc.ListSkillsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + usm := svc.UserServicesManager() + if usm != nil { + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userSkills, err := svc.ListSkillsForUser(uid) + if err != nil || len(userSkills) == 0 { + continue + } + userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)} + } + resp := map[string]any{ + "skills": skillsToResponses(skills), + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + return c.JSON(http.StatusOK, resp) + } + } + return c.JSON(http.StatusOK, skillsToResponses(skills)) } } @@ -55,7 +83,8 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - cfg := svc.GetSkillsConfig() + userID := getUserID(c) + cfg := svc.GetSkillsConfigForUser(userID) return c.JSON(http.StatusOK, cfg) } } @@ -63,8 +92,9 @@ func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) query := c.QueryParam("q") - skills, err := svc.SearchSkills(query) + skills, err := svc.SearchSkillsForUser(userID, query) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -75,6 +105,7 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Name string `json:"name"` Description string `json:"description"` @@ -87,7 +118,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.CreateSkill(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "already exists") { return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()}) @@ -101,7 +132,8 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - skill, err := svc.GetSkill(c.Param("name")) + userID := getUserID(c) + skill, err := svc.GetSkillForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -112,6 +144,7 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Description string `json:"description"` Content string `json:"content"` @@ -123,7 +156,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.UpdateSkill(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -137,7 +170,8 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteSkill(c.Param("name")); err != nil { + userID := getUserID(c) + if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -147,9 +181,9 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - // The wildcard param captures the path after /export/ + userID := getUserID(c) name := c.Param("*") - data, err := svc.ExportSkill(name) + data, err := svc.ExportSkillForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -162,6 +196,7 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) @@ -175,7 +210,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.ImportSkill(data) + skill, err := svc.ImportSkillForUser(userID, data) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -188,7 +223,8 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - resources, skill, err := svc.ListSkillResources(c.Param("name")) + userID := getUserID(c) + resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -225,7 +261,8 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - content, info, err := svc.GetSkillResource(c.Param("name"), c.Param("*")) + userID := getUserID(c) + content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -245,6 +282,7 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"}) @@ -262,7 +300,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - if err := svc.CreateSkillResource(c.Param("name"), path, data); err != nil { + if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"path": path}) @@ -272,13 +310,14 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Content string `json:"content"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.UpdateSkillResource(c.Param("name"), c.Param("*"), payload.Content); err != nil { + if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -288,7 +327,8 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteSkillResource(c.Param("name"), c.Param("*")); err != nil { + userID := getUserID(c) + if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -300,7 +340,8 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - repos, err := svc.ListGitRepos() + userID := getUserID(c) + repos, err := svc.ListGitReposForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -311,13 +352,14 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.AddGitRepo(payload.URL) + repo, err := svc.AddGitRepoForUser(userID, payload.URL) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -328,6 +370,7 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` Enabled *bool `json:"enabled"` @@ -335,7 +378,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled) + repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -349,7 +392,8 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteGitRepo(c.Param("id")); err != nil { + userID := getUserID(c) + if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -362,7 +406,8 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.SyncGitRepo(c.Param("id")); err != nil { + userID := getUserID(c) + if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"}) @@ -372,7 +417,8 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - repo, err := svc.ToggleGitRepo(c.Param("id")) + userID := getUserID(c) + repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id")) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } diff --git a/core/http/endpoints/localai/agents.go b/core/http/endpoints/localai/agents.go index 5226f7edfc78..01c1618998ca 100644 --- a/core/http/endpoints/localai/agents.go +++ b/core/http/endpoints/localai/agents.go @@ -12,6 +12,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAGI/core/state" @@ -19,10 +20,32 @@ import ( agiServices "github.com/mudler/LocalAGI/services" ) +// getUserID extracts the scoped user ID from the request context. +// Returns empty string when auth is not active (backward compat). +func getUserID(c echo.Context) string { + user := auth.GetUser(c) + if user == nil { + return "" + } + return user.ID +} + +// isAdminUser returns true if the authenticated user has admin role. +func isAdminUser(c echo.Context) bool { + user := auth.GetUser(c) + return user != nil && user.Role == auth.RoleAdmin +} + +// wantsAllUsers returns true if the request has ?all_users=true and the user is admin. +func wantsAllUsers(c echo.Context) bool { + return c.QueryParam("all_users") == "true" && isAdminUser(c) +} + func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - statuses := svc.ListAgents() + userID := getUserID(c) + statuses := svc.ListAgentsForUser(userID) agents := make([]string, 0, len(statuses)) for name := range statuses { agents = append(agents, name) @@ -38,6 +61,22 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { if hubURL := svc.AgentHubURL(); hubURL != "" { resp["agent_hub_url"] = hubURL } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + grouped := svc.ListAllAgentsGrouped() + userGroups := map[string]any{} + for uid, agentList := range grouped { + if uid == userID || uid == "" { + continue + } + userGroups[uid] = map[string]any{"agents": agentList} + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + } + return c.JSON(http.StatusOK, resp) } } @@ -45,11 +84,12 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.CreateAgent(&cfg); err != nil { + if err := svc.CreateAgentForUser(userID, &cfg); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) @@ -59,8 +99,9 @@ func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - ag := svc.GetAgent(name) + ag := svc.GetAgentForUser(userID, name) if ag == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -73,12 +114,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.UpdateAgent(name, &cfg); err != nil { + if err := svc.UpdateAgentForUser(userID, name, &cfg); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -91,8 +133,9 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc { func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - if err := svc.DeleteAgent(name); err != nil { + if err := svc.DeleteAgentForUser(userID, name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -102,8 +145,9 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - cfg := svc.GetAgentConfig(name) + cfg := svc.GetAgentConfigForUser(userID, name) if cfg == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -114,7 +158,8 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc { func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.PauseAgent(c.Param("name")); err != nil { + userID := getUserID(c) + if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -124,7 +169,8 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc { func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.ResumeAgent(c.Param("name")); err != nil { + userID := getUserID(c) + if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -134,8 +180,9 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - history := svc.GetAgentStatus(name) + history := svc.GetAgentStatusForUser(userID, name) if history == nil { history = &state.Status{ActionResults: []coreTypes.ActionState{}} } @@ -162,8 +209,9 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - history, err := svc.GetAgentObservables(name) + history, err := svc.GetAgentObservablesForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -177,8 +225,9 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - if err := svc.ClearAgentObservables(name); err != nil { + if err := svc.ClearAgentObservablesForUser(userID, name); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{"Name": name, "cleared": true}) @@ -188,6 +237,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") var payload struct { Message string `json:"message"` @@ -199,7 +249,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { if message == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Message cannot be empty"}) } - messageID, err := svc.Chat(name, message) + messageID, err := svc.ChatForUser(userID, name, message) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -216,8 +266,9 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - manager := svc.GetSSEManager(name) + manager := svc.GetSSEManagerForUser(userID, name) if manager == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -243,8 +294,9 @@ func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc { func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) name := c.Param("name") - data, err := svc.ExportAgent(name) + data, err := svc.ExportAgentForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -256,6 +308,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc { func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) // Try multipart form file first file, err := c.FormFile("file") @@ -269,7 +322,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read file"}) } - if err := svc.ImportAgent(data); err != nil { + if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) @@ -284,7 +337,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.ImportAgent(data); err != nil { + if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 800b824c8789..22049083d266 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -2,15 +2,16 @@ package middleware import ( "bytes" - "github.com/emirpasic/gods/v2/queues/circularbuffer" "io" "net/http" "sort" "sync" "time" + "github.com/emirpasic/gods/v2/queues/circularbuffer" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/xlog" ) @@ -33,6 +34,8 @@ type APIExchange struct { Request APIExchangeRequest `json:"request"` Response APIExchangeResponse `json:"response"` Error string `json:"error,omitempty"` + UserID string `json:"user_id,omitempty"` + UserName string `json:"user_name,omitempty"` } var traceBuffer *circularbuffer.Queue[APIExchange] @@ -147,6 +150,11 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { exchange.Error = handlerErr.Error() } + if user := auth.GetUser(c); user != nil { + exchange.UserID = user.ID + exchange.UserName = user.Name + } + select { case logChan <- exchange: default: diff --git a/core/http/middleware/usage.go b/core/http/middleware/usage.go new file mode 100644 index 000000000000..0af96a2a73d4 --- /dev/null +++ b/core/http/middleware/usage.go @@ -0,0 +1,148 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +var usageChan chan *auth.UsageRecord + +// InitUsageRecorder starts a background goroutine that writes usage records. +func InitUsageRecorder(db *gorm.DB) { + if db == nil { + return + } + usageChan = make(chan *auth.UsageRecord, 500) + go func() { + for record := range usageChan { + if err := auth.RecordUsage(db, record); err != nil { + xlog.Error("Failed to record usage", "error", err) + } + } + }() +} + +// usageResponseBody is the minimal structure we need from the response JSON. +type usageResponseBody struct { + Model string `json:"model"` + Usage *struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + } `json:"usage"` +} + +// UsageMiddleware extracts token usage from OpenAI-compatible response JSON +// and records it per-user. +func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if db == nil || usageChan == nil { + return next(c) + } + + startTime := time.Now() + + // Wrap response writer to capture body + resBody := new(bytes.Buffer) + origWriter := c.Response().Writer + mw := &bodyWriter{ + ResponseWriter: origWriter, + body: resBody, + } + c.Response().Writer = mw + + handlerErr := next(c) + + // Restore original writer + c.Response().Writer = origWriter + + // Only record on successful responses + if c.Response().Status < 200 || c.Response().Status >= 300 { + return handlerErr + } + + // Get authenticated user + user := auth.GetUser(c) + if user == nil { + return handlerErr + } + + // Try to parse usage from response + responseBytes := resBody.Bytes() + if len(responseBytes) == 0 { + return handlerErr + } + + // Check content type + ct := c.Response().Header().Get("Content-Type") + isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json")) + isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream")) + + if !isJSON && !isSSE { + return handlerErr + } + + var resp usageResponseBody + if isSSE { + last, ok := lastSSEData(responseBytes) + if !ok { + return handlerErr + } + if err := json.Unmarshal(last, &resp); err != nil { + return handlerErr + } + } else { + if err := json.Unmarshal(responseBytes, &resp); err != nil { + return handlerErr + } + } + + if resp.Usage == nil { + return handlerErr + } + + record := &auth.UsageRecord{ + UserID: user.ID, + UserName: user.Name, + Model: resp.Model, + Endpoint: c.Request().URL.Path, + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + Duration: time.Since(startTime).Milliseconds(), + CreatedAt: startTime, + } + + select { + case usageChan <- record: + default: + xlog.Warn("Usage channel full, dropping record") + } + + return handlerErr + } + } +} + +// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]". +func lastSSEData(b []byte) ([]byte, bool) { + prefix := []byte("data: ") + var last []byte + for _, line := range bytes.Split(b, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + if bytes.HasPrefix(line, prefix) { + payload := line[len(prefix):] + if !bytes.Equal(payload, []byte("[DONE]")) { + last = payload + } + } + } + return last, last != nil +} diff --git a/core/http/react-ui/src/App.css b/core/http/react-ui/src/App.css index d0f44789baed..a8696a32bc0d 100644 --- a/core/http/react-ui/src/App.css +++ b/core/http/react-ui/src/App.css @@ -260,6 +260,92 @@ align-items: center; justify-content: space-between; gap: var(--spacing-xs); + flex-wrap: wrap; +} + +.sidebar-user { + display: flex; + align-items: center; + gap: var(--spacing-xs); + width: 100%; + padding: var(--spacing-xs) 0; + font-size: 0.75rem; + color: var(--color-text-secondary); + overflow: hidden; +} + +.sidebar-user-avatar { + width: 20px; + height: 20px; + border-radius: var(--radius-full); + flex-shrink: 0; +} + +.sidebar-user-avatar-icon { + font-size: 1.25rem; + color: var(--color-text-muted); + flex-shrink: 0; +} + +.sidebar-user-link { + display: flex; + align-items: center; + gap: var(--spacing-xs); + flex: 1; + min-width: 0; + background: none; + border: none; + padding: 2px var(--spacing-xs); + margin: -2px calc(-1 * var(--spacing-xs)); + border-radius: var(--radius-sm); + color: inherit; + font: inherit; + cursor: pointer; + transition: background var(--duration-fast), color var(--duration-fast); +} + +.sidebar-user-link:hover { + background: var(--color-bg-hover); + color: var(--color-text-primary); +} + +.sidebar-user-name { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + text-align: left; +} + +.sidebar-logout-btn { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + padding: 2px 4px; + border-radius: var(--radius-sm); + font-size: 0.75rem; + flex-shrink: 0; + transition: color var(--duration-fast); +} + +.sidebar-logout-btn:hover { + color: var(--color-error); +} + +.sidebar.collapsed .sidebar-user { + justify-content: center; +} + +.sidebar.collapsed .sidebar-user-link { + flex: 0; + margin: 0; + padding: 2px; +} + +.sidebar.collapsed .sidebar-user-name, +.sidebar.collapsed .sidebar-logout-btn { + display: none; } .sidebar-collapse-btn { @@ -446,6 +532,12 @@ transform: translateX(20px); } +.toast-exit { + opacity: 0; + transform: translateX(20px); + transition: opacity 150ms ease, transform 150ms ease; +} + .toast-success { background: rgba(20, 184, 166, 0.15); border: 1px solid rgba(20, 184, 166, 0.3); @@ -494,6 +586,14 @@ .spinner-md .spinner-ring { width: 24px; height: 24px; } .spinner-lg .spinner-ring { width: 40px; height: 40px; } +.spinner-logo { + animation: pulse 1.2s ease-in-out infinite; + object-fit: contain; +} +.spinner-sm .spinner-logo { width: 16px; height: 16px; } +.spinner-md .spinner-logo { width: 24px; height: 24px; } +.spinner-lg .spinner-logo { width: 40px; height: 40px; } + /* Model selector */ .model-selector { background: var(--color-bg-tertiary); @@ -623,6 +723,7 @@ max-width: 1200px; margin: 0 auto; width: 100%; + animation: fadeIn var(--duration-normal) var(--ease-default); } .page-header { @@ -651,6 +752,7 @@ .card:hover { border-color: var(--color-border-default); + box-shadow: var(--shadow-sm); } .card-grid { @@ -707,6 +809,10 @@ font-size: 0.8125rem; } +.btn:active:not(:disabled) { + transform: translateY(1px); +} + .btn:disabled { opacity: 0.5; cursor: not-allowed; @@ -727,6 +833,7 @@ } .input:focus { border-color: var(--color-border-strong); + box-shadow: 0 0 0 2px var(--color-primary-light); } .textarea { @@ -745,6 +852,7 @@ } .textarea:focus { border-color: var(--color-border-strong); + box-shadow: 0 0 0 2px var(--color-primary-light); } /* Code editor (syntax-highlighted textarea overlay) */ @@ -877,6 +985,7 @@ padding: var(--spacing-sm) var(--spacing-md); border-bottom: 1px solid var(--color-border-divider); color: var(--color-text-primary); + transition: background var(--duration-fast) var(--ease-default); } .table tr:last-child td { @@ -1039,10 +1148,131 @@ border-color: var(--color-primary-border); } +/* Login page */ +.login-page { + min-height: 100vh; + min-height: 100dvh; + background: var(--color-bg-primary); + display: flex; + align-items: center; + justify-content: center; + padding: var(--spacing-xl); +} + +.login-card { + width: 100%; + max-width: 400px; + padding: var(--spacing-xl); +} + +.login-header { + text-align: center; + margin-bottom: var(--spacing-xl); +} + +.login-logo { + width: 56px; + height: 56px; + margin-bottom: var(--spacing-md); +} + +.login-title { + font-size: 1.5rem; + font-weight: 700; + margin-bottom: var(--spacing-xs); + color: var(--color-text-primary); +} + +.login-subtitle { + color: var(--color-text-secondary); + font-size: 0.875rem; +} + +.login-alert { + padding: var(--spacing-sm) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 0.8125rem; + margin-bottom: var(--spacing-md); +} + +.login-alert-error { + background: var(--color-error-light); + color: var(--color-error); + border: 1px solid var(--color-error-border); +} + +.login-alert-success { + background: var(--color-success-light); + color: var(--color-success); + border: 1px solid var(--color-success-border); +} + +.login-divider { + display: flex; + align-items: center; + gap: var(--spacing-md); + margin: var(--spacing-lg) 0; + color: var(--color-text-muted); + font-size: 0.8125rem; +} + +.login-divider::before, +.login-divider::after { + content: ''; + flex: 1; + height: 1px; + background: var(--color-border-subtle); +} + +.login-footer { + text-align: center; + margin-top: var(--spacing-md); + font-size: 0.8125rem; + color: var(--color-text-secondary); +} + +.login-link { + background: none; + border: none; + color: var(--color-primary); + cursor: pointer; + padding: 0; + font: inherit; +} + +.login-link:hover { + color: var(--color-primary-hover); +} + +.login-token-toggle { + margin-top: var(--spacing-lg); + text-align: center; +} + +.login-token-toggle > button { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + font-size: 0.75rem; + padding: 0; + font: inherit; + font-size: 0.75rem; +} + +.login-token-toggle > button:hover { + color: var(--color-text-secondary); +} + +.login-token-form { + margin-top: var(--spacing-sm); +} + /* Empty state */ .empty-state { text-align: center; padding: var(--spacing-3xl, 4rem) var(--spacing-xl); + animation: fadeIn var(--duration-normal) var(--ease-default); } .empty-state-icon { @@ -2677,3 +2907,12 @@ background: var(--color-bg-secondary); border-top: 1px solid var(--color-border-subtle); } + +/* Reduced motion accessibility */ +@media (prefers-reduced-motion: reduce) { + *, *::before, *::after { + animation-duration: 0.01ms !important; + animation-iteration-count: 1 !important; + transition-duration: 0.01ms !important; + } +} diff --git a/core/http/react-ui/src/components/LoadingSpinner.jsx b/core/http/react-ui/src/components/LoadingSpinner.jsx index b1c1b46a2815..23f858abee7f 100644 --- a/core/http/react-ui/src/components/LoadingSpinner.jsx +++ b/core/http/react-ui/src/components/LoadingSpinner.jsx @@ -1,8 +1,22 @@ +import { useState } from 'react' +import { apiUrl } from '../utils/basePath' + export default function LoadingSpinner({ size = 'md', className = '' }) { const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md' + const [imgFailed, setImgFailed] = useState(false) + return (
-
+ {imgFailed ? ( +
+ ) : ( + setImgFailed(true)} + /> + )}
) } diff --git a/core/http/react-ui/src/components/RequireAdmin.jsx b/core/http/react-ui/src/components/RequireAdmin.jsx new file mode 100644 index 000000000000..48169f57476d --- /dev/null +++ b/core/http/react-ui/src/components/RequireAdmin.jsx @@ -0,0 +1,10 @@ +import { Navigate } from 'react-router-dom' +import { useAuth } from '../context/AuthContext' + +export default function RequireAdmin({ children }) { + const { isAdmin, authEnabled, user, loading } = useAuth() + if (loading) return null + if (authEnabled && !user) return + if (!isAdmin) return + return children +} diff --git a/core/http/react-ui/src/components/RequireAuth.jsx b/core/http/react-ui/src/components/RequireAuth.jsx new file mode 100644 index 000000000000..57268961ab9f --- /dev/null +++ b/core/http/react-ui/src/components/RequireAuth.jsx @@ -0,0 +1,9 @@ +import { Navigate } from 'react-router-dom' +import { useAuth } from '../context/AuthContext' + +export default function RequireAuth({ children }) { + const { authEnabled, user, loading } = useAuth() + if (loading) return null + if (authEnabled && !user) return + return children +} diff --git a/core/http/react-ui/src/components/RequireFeature.jsx b/core/http/react-ui/src/components/RequireFeature.jsx new file mode 100644 index 000000000000..97823cb0f461 --- /dev/null +++ b/core/http/react-ui/src/components/RequireFeature.jsx @@ -0,0 +1,10 @@ +import { Navigate } from 'react-router-dom' +import { useAuth } from '../context/AuthContext' + +export default function RequireFeature({ feature, children }) { + const { isAdmin, hasFeature, authEnabled, user, loading } = useAuth() + if (loading) return null + if (authEnabled && !user) return + if (!isAdmin && !hasFeature(feature)) return + return children +} diff --git a/core/http/react-ui/src/components/Sidebar.jsx b/core/http/react-ui/src/components/Sidebar.jsx index c526481284d4..4d3f4c234bf9 100644 --- a/core/http/react-ui/src/components/Sidebar.jsx +++ b/core/http/react-ui/src/components/Sidebar.jsx @@ -1,19 +1,21 @@ import { useState, useEffect } from 'react' -import { NavLink } from 'react-router-dom' +import { NavLink, useNavigate } from 'react-router-dom' import ThemeToggle from './ThemeToggle' +import { useAuth } from '../context/AuthContext' import { apiUrl } from '../utils/basePath' const COLLAPSED_KEY = 'localai_sidebar_collapsed' const mainItems = [ { path: '/app', icon: 'fas fa-home', label: 'Home' }, - { path: '/app/models', icon: 'fas fa-download', label: 'Install Models' }, + { path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true }, { path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' }, { path: '/app/image', icon: 'fas fa-image', label: 'Images' }, { path: '/app/video', icon: 'fas fa-video', label: 'Video' }, { path: '/app/tts', icon: 'fas fa-music', label: 'TTS' }, { path: '/app/sound', icon: 'fas fa-volume-high', label: 'Sound' }, { path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' }, + { path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true }, ] const agentItems = [ @@ -24,11 +26,12 @@ const agentItems = [ ] const systemItems = [ - { path: '/app/backends', icon: 'fas fa-server', label: 'Backends' }, - { path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces' }, - { path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm' }, - { path: '/app/manage', icon: 'fas fa-desktop', label: 'System' }, - { path: '/app/settings', icon: 'fas fa-cog', label: 'Settings' }, + { path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true }, + { path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true }, + { path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true }, + { path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true }, + { path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true }, + { path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true }, ] function NavItem({ item, onClose, collapsed }) { @@ -53,6 +56,8 @@ export default function Sidebar({ isOpen, onClose }) { const [collapsed, setCollapsed] = useState(() => { try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false } }) + const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth() + const navigate = useNavigate() useEffect(() => { fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {}) @@ -67,6 +72,14 @@ export default function Sidebar({ isOpen, onClose }) { }) } + const visibleMainItems = mainItems.filter(item => { + if (item.adminOnly && !isAdmin) return false + if (item.authOnly && !authEnabled) return false + return true + }) + + const visibleSystemItems = systemItems.filter(item => !item.adminOnly || isAdmin) + return ( <> {isOpen &&
} @@ -89,24 +102,40 @@ export default function Sidebar({ isOpen, onClose }) {
+ ) +} + +function PasswordSection({ addToast }) { + const [currentPw, setCurrentPw] = useState('') + const [newPw, setNewPw] = useState('') + const [confirmPw, setConfirmPw] = useState('') + const [saving, setSaving] = useState(false) + + const handleSubmit = async (e) => { + e.preventDefault() + if (newPw !== confirmPw) { + addToast('Passwords do not match', 'error') + return + } + if (newPw.length < 8) { + addToast('New password must be at least 8 characters', 'error') + return + } + setSaving(true) + try { + await profileApi.changePassword(currentPw, newPw) + addToast('Password changed', 'success') + setCurrentPw('') + setNewPw('') + setConfirmPw('') + } catch (err) { + addToast(err.message, 'error') + } finally { + setSaving(false) + } + } + + return ( +
+
+ + setCurrentPw(e.target.value)} + placeholder="Enter current password" + disabled={saving} + required + /> +
+
+
+ + setNewPw(e.target.value)} + placeholder="At least 8 characters" + minLength={8} + disabled={saving} + required + /> +
+
+ + setConfirmPw(e.target.value)} + placeholder="Repeat new password" + disabled={saving} + required + /> +
+
+
+ +
+
+ ) +} + +function ApiKeysSection({ addToast }) { + const [keys, setKeys] = useState([]) + const [loading, setLoading] = useState(true) + const [creating, setCreating] = useState(false) + const [newKeyName, setNewKeyName] = useState('') + const [newKeyPlaintext, setNewKeyPlaintext] = useState(null) + const [revokingId, setRevokingId] = useState(null) + + const fetchKeys = useCallback(async () => { + setLoading(true) + try { + const data = await apiKeysApi.list() + setKeys(data.keys || []) + } catch (err) { + addToast(`Failed to load API keys: ${err.message}`, 'error') + } finally { + setLoading(false) + } + }, [addToast]) + + useEffect(() => { fetchKeys() }, [fetchKeys]) + + const handleCreate = async (e) => { + e.preventDefault() + if (!newKeyName.trim()) return + setCreating(true) + try { + const data = await apiKeysApi.create(newKeyName.trim()) + setNewKeyPlaintext(data.key) + setNewKeyName('') + await fetchKeys() + addToast('API key created', 'success') + } catch (err) { + addToast(`Failed to create API key: ${err.message}`, 'error') + } finally { + setCreating(false) + } + } + + const handleRevoke = async (id, name) => { + if (!window.confirm(`Revoke API key "${name}"? This cannot be undone.`)) return + setRevokingId(id) + try { + await apiKeysApi.revoke(id) + setKeys(prev => prev.filter(k => k.id !== id)) + addToast('API key revoked', 'success') + } catch (err) { + addToast(`Failed to revoke API key: ${err.message}`, 'error') + } finally { + setRevokingId(null) + } + } + + const copyToClipboard = (text) => { + if (navigator.clipboard?.writeText) { + navigator.clipboard.writeText(text).then( + () => addToast('Copied to clipboard', 'success'), + () => fallbackCopy(text), + ) + } else { + fallbackCopy(text) + } + } + + const fallbackCopy = (text) => { + const ta = document.createElement('textarea') + ta.value = text + ta.style.position = 'fixed' + ta.style.opacity = '0' + document.body.appendChild(ta) + ta.select() + try { + document.execCommand('copy') + addToast('Copied to clipboard', 'success') + } catch (_) { + addToast('Failed to copy', 'error') + } + document.body.removeChild(ta) + } + + return ( +
+ {/* Create key form */} +
+
+ + setNewKeyName(e.target.value)} + disabled={creating} + maxLength={64} + /> +
+ + + + {/* Newly created key banner */} + {newKeyPlaintext && ( +
+
+ + Copy now — this key won't be shown again +
+
+ + {newKeyPlaintext} + + + +
+
+ )} + + {/* Keys list */} + {loading ? ( +
+ +
+ ) : keys.length === 0 ? ( +
+ No API keys yet. Create one above to get programmatic access. +
+ ) : ( +
+ {keys.map(k => ( +
+ +
+
{k.name}
+
+ {k.keyPrefix}... · {formatDate(k.createdAt)} + {k.lastUsed && <> · last used {formatDate(k.lastUsed)}} +
+
+ +
+ ))} +
+ )} +
+ ) +} + +export default function Account() { + const { addToast } = useOutletContext() + const { authEnabled, user } = useAuth() + + if (!authEnabled) { + return ( +
+
+
+

Account unavailable

+

Authentication must be enabled to manage your account.

+
+
+ ) + } + + const isLocal = user?.provider === 'local' + + const sectionHeader = (icon, title) => ( +
+ + {title} +
+ ) + + return ( +
+
+

Account

+

Profile, credentials, and API keys

+
+ +
+
+ {sectionHeader('fas fa-user', 'Profile')} + +
+ + {isLocal && ( +
+ {sectionHeader('fas fa-lock', 'Password')} + +
+ )} + +
+ {sectionHeader('fas fa-key', 'API keys')} + +
+
+
+ ) +} diff --git a/core/http/react-ui/src/pages/AgentJobs.jsx b/core/http/react-ui/src/pages/AgentJobs.jsx index de85aa7c3357..8804f7a0dcfd 100644 --- a/core/http/react-ui/src/pages/AgentJobs.jsx +++ b/core/http/react-ui/src/pages/AgentJobs.jsx @@ -2,20 +2,27 @@ import { useState, useEffect, useCallback, useRef } from 'react' import { useNavigate, useOutletContext } from 'react-router-dom' import { agentJobsApi, modelsApi } from '../utils/api' import { useModels } from '../hooks/useModels' +import { useAuth } from '../context/AuthContext' +import { useUserMap } from '../hooks/useUserMap' import LoadingSpinner from '../components/LoadingSpinner' import { fileToBase64 } from '../utils/api' import Modal from '../components/Modal' +import UserGroupSection from '../components/UserGroupSection' export default function AgentJobs() { const { addToast } = useOutletContext() const navigate = useNavigate() const { models } = useModels() + const { isAdmin, authEnabled, user } = useAuth() + const userMap = useUserMap() const [activeTab, setActiveTab] = useState('tasks') const [tasks, setTasks] = useState([]) const [jobs, setJobs] = useState([]) const [loading, setLoading] = useState(true) const [jobFilter, setJobFilter] = useState('all') const [hasMCPModels, setHasMCPModels] = useState(false) + const [taskUserGroups, setTaskUserGroups] = useState(null) + const [jobUserGroups, setJobUserGroups] = useState(null) // Execute modal state const [executeModal, setExecuteModal] = useState(null) @@ -27,19 +34,45 @@ export default function AgentJobs() { const fileTypeRef = useRef('images') const fetchData = useCallback(async () => { + const allUsers = isAdmin && authEnabled try { const [t, j] = await Promise.allSettled([ - agentJobsApi.listTasks(), - agentJobsApi.listJobs(), + agentJobsApi.listTasks(allUsers), + agentJobsApi.listJobs(allUsers), ]) - if (t.status === 'fulfilled') setTasks(Array.isArray(t.value) ? t.value : []) - if (j.status === 'fulfilled') setJobs(Array.isArray(j.value) ? j.value : []) + if (t.status === 'fulfilled') { + const tv = t.value + // Handle wrapped response (admin) or flat array + if (Array.isArray(tv)) { + setTasks(tv) + setTaskUserGroups(null) + } else if (tv && tv.tasks) { + setTasks(Array.isArray(tv.tasks) ? tv.tasks : []) + setTaskUserGroups(tv.user_groups || null) + } else { + setTasks(Array.isArray(tv) ? tv : []) + setTaskUserGroups(null) + } + } + if (j.status === 'fulfilled') { + const jv = j.value + if (Array.isArray(jv)) { + setJobs(jv) + setJobUserGroups(null) + } else if (jv && jv.jobs) { + setJobs(Array.isArray(jv.jobs) ? jv.jobs : []) + setJobUserGroups(jv.user_groups || null) + } else { + setJobs(Array.isArray(jv) ? jv : []) + setJobUserGroups(null) + } + } } catch (err) { addToast(`Failed to load: ${err.message}`, 'error') } finally { setLoading(false) } - }, [addToast]) + }, [addToast, isAdmin, authEnabled]) useEffect(() => { fetchData() @@ -256,7 +289,7 @@ export default function AgentJobs() { {loading ? (
) : activeTab === 'tasks' ? ( - tasks.length === 0 ? ( + tasks.length === 0 && !taskUserGroups ? (

No tasks defined

@@ -266,73 +299,82 @@ export default function AgentJobs() {
) : ( -
- - - - - - - - - - - - - {tasks.map(task => ( - - - - + + + + + + ))} + +
NameDescriptionModelCronStatusActions
- navigate(`/app/agent-jobs/tasks/${task.id || task.name}`)} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontWeight: 500 }}> - {task.name || task.id} - - - - {task.description || '-'} - - - {task.model ? ( - navigate(`/app/model-editor/${encodeURIComponent(task.model)}`)} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontSize: '0.8125rem' }}> - {task.model} + <> + {taskUserGroups &&

Your Tasks

} + {tasks.length === 0 ? ( +

You have no tasks yet.

+ ) : ( +
+ + + + + + + + + + + + + {tasks.map(task => ( + + - + - - - - ))} - -
NameDescriptionModelCronStatusActions
+ navigate(`/app/agent-jobs/tasks/${task.id || task.name}`)} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontWeight: 500 }}> + {task.name || task.id} - ) : '-'} - - {task.cron ? ( - - {task.cron} + + + {task.description || '-'} - ) : '-'} - - {task.enabled === false ? ( - Disabled - ) : ( - Enabled - )} - -
- - - -
-
-
+
+ {task.model ? ( + navigate(`/app/model-editor/${encodeURIComponent(task.model)}`)} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontSize: '0.8125rem' }}> + {task.model} + + ) : '-'} + + {task.cron ? ( + + {task.cron} + + ) : '-'} + + {task.enabled === false ? ( + Disabled + ) : ( + Enabled + )} + +
+ + + +
+
+
+ )} + + ) ) : ( <> + {jobUserGroups &&

Your Jobs

} {/* Job History Controls */}
@@ -404,9 +446,76 @@ export default function AgentJobs() {
)} + )} + {activeTab === 'tasks' && taskUserGroups && ( + ( +
+ + + + + + + + + + {(items || []).map(task => ( + + + + + + ))} + +
NameDescriptionModel
{task.name || task.id}{task.description || '-'}{task.model || '-'}
+
+ )} + /> + )} + + {activeTab === 'jobs' && jobUserGroups && ( + ( +
+ + + + + + + + + + + {(items || []).map(job => ( + + + + + + + ))} + +
Job IDTaskStatusCreated
{job.id?.slice(0, 12)}...{job.task_id || '-'}{statusBadge(job.status)}{formatDate(job.created_at)}
+
+ )} + /> + )} + {/* Execute Task Modal */} {executeModal && ( setExecuteModal(null)}> diff --git a/core/http/react-ui/src/pages/Agents.jsx b/core/http/react-ui/src/pages/Agents.jsx index 73b25c5a4625..6dc4b3b24c99 100644 --- a/core/http/react-ui/src/pages/Agents.jsx +++ b/core/http/react-ui/src/pages/Agents.jsx @@ -1,21 +1,28 @@ import { useState, useEffect, useCallback, useMemo } from 'react' import { useNavigate, useOutletContext } from 'react-router-dom' import { agentsApi } from '../utils/api' +import { useAuth } from '../context/AuthContext' +import { useUserMap } from '../hooks/useUserMap' +import UserGroupSection from '../components/UserGroupSection' export default function Agents() { const { addToast } = useOutletContext() const navigate = useNavigate() + const { isAdmin, authEnabled, user } = useAuth() + const userMap = useUserMap() const [agents, setAgents] = useState([]) const [loading, setLoading] = useState(true) const [agentHubURL, setAgentHubURL] = useState('') const [search, setSearch] = useState('') + const [userGroups, setUserGroups] = useState(null) const fetchAgents = useCallback(async () => { try { - const data = await agentsApi.list() + const data = await agentsApi.list(isAdmin && authEnabled) const names = Array.isArray(data.agents) ? data.agents : [] const statuses = data.statuses || {} if (data.agent_hub_url) setAgentHubURL(data.agent_hub_url) + setUserGroups(data.user_groups || null) // Fetch observable counts for each agent const agentsWithCounts = await Promise.all( @@ -40,7 +47,7 @@ export default function Agents() { } finally { setLoading(false) } - }, [addToast]) + }, [addToast, isAdmin, authEnabled]) useEffect(() => { fetchAgents() @@ -187,7 +194,7 @@ export default function Agents() {
- ) : agents.length === 0 ? ( + ) : agents.length === 0 && !userGroups ? (

No agents configured

@@ -214,6 +221,7 @@ export default function Agents() {
) : ( <> + {userGroups &&

Your Agents

}
@@ -314,8 +322,39 @@ export default function Agents() {
)} + )} + + {userGroups && ( + ( +
+ + + + + + + + + {(items || []).map(a => ( + + + + + ))} + +
NameStatus
{a.name}{statusBadge(a.active ? 'active' : 'paused')}
+
+ )} + /> + )}
) } diff --git a/core/http/react-ui/src/pages/Collections.jsx b/core/http/react-ui/src/pages/Collections.jsx index d41c1d2c6008..4e65d74abe1a 100644 --- a/core/http/react-ui/src/pages/Collections.jsx +++ b/core/http/react-ui/src/pages/Collections.jsx @@ -1,25 +1,32 @@ import { useState, useEffect, useCallback } from 'react' import { useNavigate, useOutletContext } from 'react-router-dom' import { agentCollectionsApi } from '../utils/api' +import { useAuth } from '../context/AuthContext' +import { useUserMap } from '../hooks/useUserMap' +import UserGroupSection from '../components/UserGroupSection' export default function Collections() { const { addToast } = useOutletContext() const navigate = useNavigate() + const { isAdmin, authEnabled, user } = useAuth() + const userMap = useUserMap() const [collections, setCollections] = useState([]) const [loading, setLoading] = useState(true) const [newName, setNewName] = useState('') const [creating, setCreating] = useState(false) + const [userGroups, setUserGroups] = useState(null) const fetchCollections = useCallback(async () => { try { - const data = await agentCollectionsApi.list() + const data = await agentCollectionsApi.list(isAdmin && authEnabled) setCollections(Array.isArray(data.collections) ? data.collections : []) + setUserGroups(data.user_groups || null) } catch (err) { addToast(`Failed to load collections: ${err.message}`, 'error') } finally { setLoading(false) } - }, [addToast]) + }, [addToast, isAdmin, authEnabled]) useEffect(() => { fetchCollections() @@ -115,13 +122,18 @@ export default function Collections() {
- ) : collections.length === 0 ? ( + ) : collections.length === 0 && !userGroups ? (

No collections yet

Create a collection above to start building your knowledge base.

) : ( + <> + {userGroups &&

Your Collections

} + {collections.length === 0 ? ( +

You have no collections yet.

+ ) : (
{collections.map((collection) => { const name = typeof collection === 'string' ? collection : collection.name @@ -146,6 +158,33 @@ export default function Collections() { ) })}
+ )} + + )} + + {userGroups && ( + ( +
+ {(items || []).map((col) => { + const name = typeof col === 'string' ? col : col.name + return ( +
+
+ + {name} +
+
+ ) + })} +
+ )} + /> )}
) diff --git a/core/http/react-ui/src/pages/Home.jsx b/core/http/react-ui/src/pages/Home.jsx index 7fc479d6fcb3..c031324e0028 100644 --- a/core/http/react-ui/src/pages/Home.jsx +++ b/core/http/react-ui/src/pages/Home.jsx @@ -1,6 +1,7 @@ import { useState, useEffect, useRef, useCallback } from 'react' import { useNavigate, useOutletContext } from 'react-router-dom' import { apiUrl } from '../utils/basePath' +import { useAuth } from '../context/AuthContext' import ModelSelector from '../components/ModelSelector' import UnifiedMCPDropdown from '../components/UnifiedMCPDropdown' import { useResources } from '../hooks/useResources' @@ -25,6 +26,7 @@ const placeholderMessages = [ export default function Home() { const navigate = useNavigate() const { addToast } = useOutletContext() + const { isAdmin } = useAuth() const { resources } = useResources() const [configuredModels, setConfiguredModels] = useState(null) const configuredModelsRef = useRef(configuredModels) @@ -317,15 +319,19 @@ export default function Home() { {/* Quick links */}
- - - + {isAdmin && ( + <> + + + + + )} Documentation @@ -371,8 +377,8 @@ export default function Home() {
)} - ) : ( - /* No models installed wizard */ + ) : isAdmin ? ( + /* No models installed wizard (admin) */

No Models Installed

@@ -443,6 +449,20 @@ export default function Home() {
+ ) : ( + /* No models available (non-admin) */ +
+
+ LocalAI +

No Models Available

+

There are no models installed yet. Ask your administrator to set up models so you can start chatting.

+
+ +
)} + setConfirmDialog(null)} + />
) } diff --git a/core/http/react-ui/src/pages/Manage.jsx b/core/http/react-ui/src/pages/Manage.jsx index ac28a58ba6c3..04fbf1974bc5 100644 --- a/core/http/react-ui/src/pages/Manage.jsx +++ b/core/http/react-ui/src/pages/Manage.jsx @@ -1,6 +1,7 @@ import { useState, useEffect, useCallback } from 'react' import { useNavigate, useOutletContext, useSearchParams } from 'react-router-dom' import ResourceMonitor from '../components/ResourceMonitor' +import ConfirmDialog from '../components/ConfirmDialog' import { useModels } from '../hooks/useModels' import { backendControlApi, modelsApi, backendsApi, systemApi } from '../utils/api' @@ -21,6 +22,7 @@ export default function Manage() { const [backendsLoading, setBackendsLoading] = useState(true) const [reloading, setReloading] = useState(false) const [reinstallingBackends, setReinstallingBackends] = useState(new Set()) + const [confirmDialog, setConfirmDialog] = useState(null) const handleTabChange = (tab) => { setActiveTab(tab) @@ -55,27 +57,43 @@ export default function Manage() { fetchBackends() }, [fetchLoadedModels, fetchBackends]) - const handleStopModel = async (modelName) => { - if (!confirm(`Stop model ${modelName}?`)) return - try { - await backendControlApi.shutdown({ model: modelName }) - addToast(`Stopped ${modelName}`, 'success') - setTimeout(fetchLoadedModels, 500) - } catch (err) { - addToast(`Failed to stop: ${err.message}`, 'error') - } + const handleStopModel = (modelName) => { + setConfirmDialog({ + title: 'Stop Model', + message: `Stop model ${modelName}?`, + confirmLabel: 'Stop', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await backendControlApi.shutdown({ model: modelName }) + addToast(`Stopped ${modelName}`, 'success') + setTimeout(fetchLoadedModels, 500) + } catch (err) { + addToast(`Failed to stop: ${err.message}`, 'error') + } + }, + }) } - const handleDeleteModel = async (modelName) => { - if (!confirm(`Delete model ${modelName}? This cannot be undone.`)) return - try { - await modelsApi.deleteByName(modelName) - addToast(`Deleted ${modelName}`, 'success') - refetchModels() - fetchLoadedModels() - } catch (err) { - addToast(`Failed to delete: ${err.message}`, 'error') - } + const handleDeleteModel = (modelName) => { + setConfirmDialog({ + title: 'Delete Model', + message: `Delete model ${modelName}? This cannot be undone.`, + confirmLabel: 'Delete', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await modelsApi.deleteByName(modelName) + addToast(`Deleted ${modelName}`, 'success') + refetchModels() + fetchLoadedModels() + } catch (err) { + addToast(`Failed to delete: ${err.message}`, 'error') + } + }, + }) } const handleReload = async () => { @@ -106,15 +124,23 @@ export default function Manage() { } } - const handleDeleteBackend = async (name) => { - if (!confirm(`Delete backend ${name}?`)) return - try { - await backendsApi.deleteInstalled(name) - addToast(`Deleted backend ${name}`, 'success') - fetchBackends() - } catch (err) { - addToast(`Failed to delete backend: ${err.message}`, 'error') - } + const handleDeleteBackend = (name) => { + setConfirmDialog({ + title: 'Delete Backend', + message: `Delete backend ${name}?`, + confirmLabel: 'Delete', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await backendsApi.deleteInstalled(name) + addToast(`Deleted backend ${name}`, 'success') + fetchBackends() + } catch (err) { + addToast(`Failed to delete backend: ${err.message}`, 'error') + } + }, + }) } return ( @@ -379,6 +405,16 @@ export default function Manage() { )}
)} + + setConfirmDialog(null)} + /> ) } diff --git a/core/http/react-ui/src/pages/Models.jsx b/core/http/react-ui/src/pages/Models.jsx index 3e17d6bf3622..029cd71ff124 100644 --- a/core/http/react-ui/src/pages/Models.jsx +++ b/core/http/react-ui/src/pages/Models.jsx @@ -4,40 +4,16 @@ import { modelsApi } from '../utils/api' import { useOperations } from '../hooks/useOperations' import { useResources } from '../hooks/useResources' import SearchableSelect from '../components/SearchableSelect' +import ConfirmDialog from '../components/ConfirmDialog' import React from 'react' const LOADING_PHRASES = [ - { text: 'Rounding up the neural networks...', icon: 'fa-brain' }, - { text: 'Asking the models to line up nicely...', icon: 'fa-people-line' }, - { text: 'Convincing transformers to transform...', icon: 'fa-wand-magic-sparkles' }, - { text: 'Herding digital llamas...', icon: 'fa-horse' }, - { text: 'Downloading more RAM... just kidding', icon: 'fa-memory' }, - { text: 'Counting parameters... lost count at a billion', icon: 'fa-calculator' }, - { text: 'Untangling attention heads...', icon: 'fa-diagram-project' }, - { text: 'Warming up the GPUs...', icon: 'fa-fire' }, - { text: 'Teaching AI to sit and stay...', icon: 'fa-graduation-cap' }, - { text: 'Polishing the weights and biases...', icon: 'fa-gem' }, - { text: 'Stacking layers like pancakes...', icon: 'fa-layer-group' }, - { text: 'Negotiating with the token budget...', icon: 'fa-coins' }, - { text: 'Fetching models from the cloud mines...', icon: 'fa-cloud-arrow-down' }, - { text: 'Calibrating the vibe check algorithm...', icon: 'fa-gauge-high' }, - { text: 'Optimizing inference with good intentions...', icon: 'fa-bolt' }, - { text: 'Measuring GPU with a ruler...', icon: 'fa-ruler' }, - { text: 'Will it fit? Asking the VRAM oracle...', icon: 'fa-microchip' }, - { text: 'Playing Tetris with model layers...', icon: 'fa-cubes' }, - { text: 'Checking if we need more RGB...', icon: 'fa-rainbow' }, - { text: 'Squeezing tensors into memory...', icon: 'fa-compress' }, - { text: 'Whispering sweet nothings to CUDA cores...', icon: 'fa-heart' }, - { text: 'Asking the electrons to scoot over...', icon: 'fa-atom' }, - { text: 'Defragmenting the flux capacitor...', icon: 'fa-clock-rotate-left' }, - { text: 'Consulting the tensor gods...', icon: 'fa-hands-praying' }, - { text: 'Checking under the GPU\'s hood...', icon: 'fa-car' }, - { text: 'Seeing if the hamsters can run faster...', icon: 'fa-fan' }, - { text: 'Running very important math... carry the 1...', icon: 'fa-square-root-variable' }, - { text: 'Poking the memory bus gently...', icon: 'fa-bus' }, - { text: 'Bribing the scheduler with clock cycles...', icon: 'fa-stopwatch' }, - { text: 'Asking models to share their VRAM nicely...', icon: 'fa-handshake' }, + { text: 'Loading models...', icon: 'fa-brain' }, + { text: 'Fetching gallery...', icon: 'fa-download' }, + { text: 'Checking availability...', icon: 'fa-circle-check' }, + { text: 'Almost ready...', icon: 'fa-hourglass-half' }, + { text: 'Preparing gallery...', icon: 'fa-store' }, ] function GalleryLoader() { @@ -142,6 +118,7 @@ export default function Models() { const [backendFilter, setBackendFilter] = useState('') const [allBackends, setAllBackends] = useState([]) const debounceRef = useRef(null) + const [confirmDialog, setConfirmDialog] = useState(null) // Total GPU memory for "fits" check const totalGpuMemory = resources?.aggregate?.total_memory || 0 @@ -216,15 +193,24 @@ export default function Models() { } } - const handleDelete = async (modelId) => { - if (!confirm(`Delete model ${modelId}?`)) return - try { - await modelsApi.delete(modelId) - addToast(`Deleting ${modelId}...`, 'info') - fetchModels() - } catch (err) { - addToast(`Failed to delete: ${err.message}`, 'error') - } + const handleDelete = (modelId) => { + setConfirmDialog({ + title: 'Delete Model', + message: `Delete model ${modelId}?`, + confirmLabel: `Delete ${modelId}`, + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await modelsApi.delete(modelId) + addToast(`Deleting ${modelId}...`, 'info') + fetchModels() + } catch (err) { + addToast(`Failed to delete: ${err.message}`, 'error') + } + }, + }) + return } // Clear local installing flags when operations finish (success or error) @@ -332,7 +318,19 @@ export default function Models() {

No models found

-

Try adjusting your search or filters

+

+ {search || filter || backendFilter + ? 'No models match your current search or filters.' + : 'The model gallery is empty.'} +

+ {(search || filter || backendFilter) && ( + + )}
) : (
@@ -535,6 +533,15 @@ export default function Models() {
)} + setConfirmDialog(null)} + /> ) } diff --git a/core/http/react-ui/src/pages/Skills.jsx b/core/http/react-ui/src/pages/Skills.jsx index a3edd033bf81..6d30a9383529 100644 --- a/core/http/react-ui/src/pages/Skills.jsx +++ b/core/http/react-ui/src/pages/Skills.jsx @@ -4,6 +4,7 @@ import { skillsApi } from '../utils/api' import { useAuth } from '../context/AuthContext' import { useUserMap } from '../hooks/useUserMap' import UserGroupSection from '../components/UserGroupSection' +import ConfirmDialog from '../components/ConfirmDialog' export default function Skills() { const { addToast } = useOutletContext() @@ -21,6 +22,7 @@ export default function Skills() { const [gitReposLoading, setGitReposLoading] = useState(false) const [gitReposAction, setGitReposAction] = useState(null) const [userGroups, setUserGroups] = useState(null) + const [confirmDialog, setConfirmDialog] = useState(null) const fetchSkills = useCallback(async () => { setLoading(true) @@ -67,14 +69,22 @@ export default function Skills() { }, [fetchSkills]) const deleteSkill = async (name, userId) => { - if (!window.confirm(`Delete skill "${name}"? This action cannot be undone.`)) return - try { - await skillsApi.delete(name, userId) - addToast(`Skill "${name}" deleted`, 'success') - fetchSkills() - } catch (err) { - addToast(err.message || 'Failed to delete skill', 'error') - } + setConfirmDialog({ + title: 'Delete Skill', + message: `Delete skill "${name}"? This action cannot be undone.`, + confirmLabel: 'Delete', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await skillsApi.delete(name, userId) + addToast(`Skill "${name}" deleted`, 'success') + fetchSkills() + } catch (err) { + addToast(err.message || 'Failed to delete skill', 'error') + } + }, + }) } const exportSkill = async (name, userId) => { @@ -173,15 +183,23 @@ export default function Skills() { } const deleteGitRepo = async (id) => { - if (!window.confirm('Remove this Git repository? Skills from it will no longer be available.')) return - try { - await skillsApi.deleteGitRepo(id) - await loadGitRepos() - fetchSkills() - addToast('Repo removed', 'success') - } catch (err) { - addToast(err.message || 'Remove failed', 'error') - } + setConfirmDialog({ + title: 'Remove Git Repository', + message: 'Remove this Git repository? Skills from it will no longer be available.', + confirmLabel: 'Remove', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await skillsApi.deleteGitRepo(id) + await loadGitRepos() + fetchSkills() + addToast('Repo removed', 'success') + } catch (err) { + addToast(err.message || 'Remove failed', 'error') + } + }, + }) } if (unavailable) { @@ -469,6 +487,16 @@ export default function Skills() { )} + setConfirmDialog(null)} + /> + {userGroups && ( setAllModels(true)} style={allNoneBtnStyle}>All -
- {(availableModels || []).map(m => ( - - ))} +
+ {(availableModels || []).map(m => { + const checked = (allowedModels.models || []).includes(m) + return ( + + ) + })} {(!availableModels || availableModels.length === 0) && ( No models available )} @@ -358,6 +352,7 @@ function InvitesTab({ addToast }) { const [invites, setInvites] = useState([]) const [loading, setLoading] = useState(true) const [creating, setCreating] = useState(false) + const [confirmDialog, setConfirmDialog] = useState(null) const fetchInvites = useCallback(async () => { setLoading(true) @@ -389,14 +384,22 @@ function InvitesTab({ addToast }) { } const handleRevoke = async (invite) => { - if (!window.confirm('Revoke this invite link?')) return - try { - await adminInvitesApi.delete(invite.id) - setInvites(prev => prev.filter(x => x.id !== invite.id)) - addToast('Invite revoked', 'success') - } catch (err) { - addToast(`Failed to revoke invite: ${err.message}`, 'error') - } + setConfirmDialog({ + title: 'Revoke Invite', + message: 'Revoke this invite link?', + confirmLabel: 'Revoke', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await adminInvitesApi.delete(invite.id) + setInvites(prev => prev.filter(x => x.id !== invite.id)) + addToast('Invite revoked', 'success') + } catch (err) { + addToast(`Failed to revoke invite: ${err.message}`, 'error') + } + }, + }) } const handleCopyUrl = (code) => { @@ -515,6 +518,15 @@ function InvitesTab({ addToast }) {
)} + setConfirmDialog(null)} + /> ) } @@ -529,6 +541,7 @@ export default function Users() { const [editingUser, setEditingUser] = useState(null) const [featureMeta, setFeatureMeta] = useState(null) const [availableModels, setAvailableModels] = useState([]) + const [confirmDialog, setConfirmDialog] = useState(null) const fetchUsers = useCallback(async () => { setLoading(true) @@ -600,14 +613,22 @@ export default function Users() { } const handleDelete = async (u) => { - if (!window.confirm(`Delete user "${u.name || u.email}"? This will also remove their sessions and API keys.`)) return - try { - await adminUsersApi.delete(u.id) - setUsers(prev => prev.filter(x => x.id !== u.id)) - addToast(`User deleted`, 'success') - } catch (err) { - addToast(`Failed to delete user: ${err.message}`, 'error') - } + setConfirmDialog({ + title: 'Delete User', + message: `Delete user "${u.name || u.email}"? This will also remove their sessions and API keys.`, + confirmLabel: 'Delete', + danger: true, + onConfirm: async () => { + setConfirmDialog(null) + try { + await adminUsersApi.delete(u.id) + setUsers(prev => prev.filter(x => x.id !== u.id)) + addToast(`User deleted`, 'success') + } catch (err) { + addToast(`Failed to delete user: ${err.message}`, 'error') + } + }, + }) } const filtered = users.filter(u => { @@ -775,6 +796,15 @@ export default function Users() { addToast={addToast} /> )} + setConfirmDialog(null)} + />
) } From b7f28e442a4843ac085e5731f928a226b77f5a65 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 18 Mar 2026 23:08:29 +0000 Subject: [PATCH 08/13] chore(ui): style improvements Signed-off-by: Ettore Di Giacinto --- core/http/react-ui/src/App.css | 89 +++++++++++---- core/http/react-ui/src/App.jsx | 9 +- .../react-ui/src/components/ConfirmDialog.jsx | 57 ++++++++-- core/http/react-ui/src/components/Modal.jsx | 85 ++++++++++++--- .../react-ui/src/components/ResourceCards.jsx | 3 + .../src/components/SearchableModelSelect.jsx | 9 +- .../src/components/SearchableSelect.jsx | 11 +- core/http/react-ui/src/components/Toast.jsx | 4 +- .../src/components/UserGroupSection.jsx | 18 +++- .../react-ui/src/contexts/ThemeContext.jsx | 11 +- core/http/react-ui/src/pages/Chat.jsx | 17 ++- core/http/react-ui/src/theme.css | 102 ++++++++---------- 12 files changed, 300 insertions(+), 115 deletions(-) diff --git a/core/http/react-ui/src/App.css b/core/http/react-ui/src/App.css index d06af9a2bf63..16132da70a2c 100644 --- a/core/http/react-ui/src/App.css +++ b/core/http/react-ui/src/App.css @@ -142,6 +142,7 @@ box-shadow: var(--shadow-sidebar); transition: width var(--duration-normal) var(--ease-default), transform var(--duration-normal) var(--ease-default); + will-change: transform; } .sidebar-overlay { @@ -244,6 +245,7 @@ flex: 1; overflow: hidden; text-overflow: ellipsis; + transition: opacity 150ms ease; } .nav-external { @@ -539,23 +541,23 @@ } .toast-success { - background: rgba(20, 184, 166, 0.15); - border: 1px solid rgba(20, 184, 166, 0.3); + background: var(--color-success-light); + border: 1px solid var(--color-success-border); color: var(--color-success); } .toast-error { - background: rgba(239, 68, 68, 0.15); - border: 1px solid rgba(239, 68, 68, 0.3); + background: var(--color-error-light); + border: 1px solid var(--color-error-border); color: var(--color-error); } .toast-warning { - background: rgba(245, 158, 11, 0.15); - border: 1px solid rgba(245, 158, 11, 0.3); + background: var(--color-warning-light); + border: 1px solid var(--color-warning-border); color: var(--color-warning); } .toast-info { - background: rgba(56, 189, 248, 0.15); - border: 1px solid rgba(56, 189, 248, 0.3); + background: var(--color-info-light); + border: 1px solid var(--color-info-border); color: var(--color-info); } @@ -747,12 +749,13 @@ border: 1px solid var(--color-border-subtle); border-radius: var(--radius-lg); padding: var(--spacing-md); - transition: border-color var(--duration-fast), box-shadow var(--duration-fast); + transition: border-color var(--duration-fast), box-shadow var(--duration-fast), transform var(--duration-fast); } .card:hover { border-color: var(--color-border-default); box-shadow: var(--shadow-sm); + transform: translateY(-1px); } .card-grid { @@ -773,7 +776,7 @@ font-weight: 500; cursor: pointer; border: none; - transition: all var(--duration-fast) var(--ease-default); + transition: background var(--duration-fast) var(--ease-default), color var(--duration-fast) var(--ease-default), border-color var(--duration-fast) var(--ease-default), box-shadow var(--duration-fast) var(--ease-default); text-decoration: none; } @@ -1410,6 +1413,37 @@ 50% { opacity: 0.5; } } +@keyframes messageSlideIn { + from { opacity: 0; transform: translateY(8px); } + to { opacity: 1; transform: translateY(0); } +} + +@keyframes dropdownIn { + from { opacity: 0; transform: translateY(-4px); } + to { opacity: 1; transform: translateY(0); } +} + +@keyframes completionGlow { + 0% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0.2); } + 50% { box-shadow: 0 0 0 4px rgba(59, 130, 246, 0.1); } + 100% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0); } +} + +/* Page route transitions */ +.page-transition { + animation: fadeIn 200ms ease; + display: flex; + flex-direction: column; + flex: 1; + min-height: 0; + min-width: 0; +} + +/* Completion glow on streaming finish */ +.chat-message-new .chat-message-content { + animation: completionGlow 600ms ease-out; +} + /* Chat-specific styles */ .chat-layout { display: flex; @@ -1468,12 +1502,13 @@ cursor: pointer; font-size: 0.8125rem; color: var(--color-text-secondary); - transition: all var(--duration-fast); + transition: background var(--duration-fast), color var(--duration-fast), transform var(--duration-fast); margin-bottom: 2px; } .chat-list-item:hover { background: var(--color-primary-light); + transform: translateX(2px); } .chat-list-item.active { @@ -1570,7 +1605,7 @@ gap: var(--spacing-sm); max-width: 80%; min-width: 0; - animation: fadeIn 200ms ease; + animation: messageSlideIn 250ms ease-out; } .chat-message-user { @@ -1740,7 +1775,7 @@ } .chat-input-wrapper:focus-within { border-color: var(--color-primary-border); - box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.12), 0 0 12px rgba(99, 102, 241, 0.06); + box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.12), 0 0 12px rgba(59, 130, 246, 0.06); } .chat-attach-btn { @@ -1798,6 +1833,9 @@ opacity: 0.3; cursor: not-allowed; } +.chat-send-btn:active:not(:disabled) { + transform: scale(0.92); +} .chat-stop-btn { padding: var(--spacing-xs); @@ -2128,13 +2166,13 @@ overflow-y: auto; } .chat-activity-thinking { - border-left-color: rgba(99, 102, 241, 0.3); + border-left-color: rgba(59, 130, 246, 0.3); } .chat-activity-tool-call { - border-left-color: rgba(139, 92, 246, 0.3); + border-left-color: rgba(245, 158, 11, 0.3); } .chat-activity-tool-result { - border-left-color: rgba(20, 184, 166, 0.3); + border-left-color: rgba(34, 197, 94, 0.3); } /* Context window progress bar */ @@ -2230,6 +2268,7 @@ border: 1px solid var(--color-border-subtle); border-radius: var(--radius-md); box-shadow: var(--shadow-lg); + animation: dropdownIn 120ms ease-out; } .chat-mcp-dropdown-loading, .chat-mcp-dropdown-empty { @@ -2305,15 +2344,15 @@ background: var(--color-text-tertiary); } .chat-client-mcp-status-connected { - background: #22c55e; + background: var(--color-success); box-shadow: 0 0 4px rgba(34, 197, 94, 0.5); } .chat-client-mcp-status-connecting { - background: #f59e0b; + background: var(--color-warning); animation: pulse 1s infinite; } .chat-client-mcp-status-error { - background: #ef4444; + background: var(--color-error); } .chat-client-mcp-status-disconnected { background: var(--color-text-tertiary); @@ -2394,6 +2433,7 @@ transform: translateX(100%); transition: transform 250ms var(--ease-default); box-shadow: var(--shadow-lg); + will-change: transform; } .chat-settings-drawer.open { transform: translateX(0); @@ -2488,7 +2528,7 @@ /* Max tokens/sec badge */ .chat-max-tps-badge { - background: rgba(99, 102, 241, 0.15); + background: rgba(59, 130, 246, 0.15); color: var(--color-primary); padding: 1px 6px; border-radius: var(--radius-full); @@ -2542,7 +2582,7 @@ align-items: center; gap: 4px; padding: 2px 6px; - background: rgba(99, 102, 241, 0.1); + background: rgba(59, 130, 246, 0.1); border-radius: var(--radius-sm); font-size: 0.7rem; color: var(--color-text-secondary); @@ -3097,6 +3137,7 @@ padding: var(--spacing-lg); box-shadow: var(--shadow-lg); animation: slideUp 150ms ease; + will-change: transform, opacity; } @keyframes slideUp { from { opacity: 0; transform: translateY(8px); } @@ -3249,7 +3290,7 @@ } .home-input-container:focus-within { border-color: var(--color-primary-border); - box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.12), 0 0 12px rgba(99, 102, 241, 0.06); + box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.12), 0 0 12px rgba(59, 130, 246, 0.06); } .home-textarea { width: 100%; @@ -3320,6 +3361,9 @@ opacity: 0.3; cursor: not-allowed; } +.home-send-btn:active:not(:disabled) { + transform: scale(0.92); +} /* Home quick links */ .home-quick-links { @@ -3347,6 +3391,7 @@ .home-link-btn:hover { border-color: var(--color-primary-border); color: var(--color-primary); + transform: translateY(-1px); } /* Home loaded models */ diff --git a/core/http/react-ui/src/App.jsx b/core/http/react-ui/src/App.jsx index 421441071c96..f06fe788dc11 100644 --- a/core/http/react-ui/src/App.jsx +++ b/core/http/react-ui/src/App.jsx @@ -29,6 +29,11 @@ export default function App() { return () => window.removeEventListener('sidebar-collapse', handler) }, []) + // Scroll to top on route change + useEffect(() => { + window.scrollTo(0, 0) + }, [location.pathname]) + const layoutClasses = [ 'app-layout', isChatRoute ? 'app-layout-chat' : '', @@ -51,7 +56,9 @@ export default function App() { LocalAI
- +
+ +
{!isChatRoute && (