Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ RUN apt-get update && \

FROM build-requirements AS builder-base

ARG GO_TAGS=""
ARG GO_TAGS="auth"
ARG GRPC_BACKENDS
ARG MAKEFLAGS
ARG LD_FLAGS="-s -w"
Expand Down
21 changes: 21 additions & 0 deletions core/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
"gorm.io/gorm"
)

type Application struct {
Expand All @@ -22,6 +23,7 @@ type Application struct {
galleryService *services.GalleryService
agentJobService *services.AgentJobService
agentPoolService atomic.Pointer[services.AgentPoolService]
authDB *gorm.DB
watchdogMutex sync.Mutex
watchdogStop chan bool
p2pMutex sync.Mutex
Expand Down Expand Up @@ -74,6 +76,11 @@ func (a *Application) AgentPoolService() *services.AgentPoolService {
return a.agentPoolService.Load()
}

// AuthDB returns the auth database connection, or nil if auth is not enabled.
func (a *Application) AuthDB() *gorm.DB {
return a.authDB
}

// StartupConfig returns the original startup configuration (from env vars, before file loading)
func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig
Expand Down Expand Up @@ -118,9 +125,23 @@ func (a *Application) StartAgentPool() {
xlog.Error("Failed to create agent pool service", "error", err)
return
}
if a.authDB != nil {
aps.SetAuthDB(a.authDB)
}
if err := aps.Start(a.applicationConfig.Context); err != nil {
xlog.Error("Failed to start agent pool", "error", err)
return
}

// Wire per-user scoped services so collections, skills, and jobs are isolated per user
usm := services.NewUserServicesManager(
aps.UserStorage(),
a.applicationConfig,
a.modelLoader,
a.backendLoader,
a.templatesEvaluator,
)
aps.SetUserServicesManager(usm)

a.agentPoolService.Store(aps)
}
4 changes: 2 additions & 2 deletions core/application/config_file_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
envF16 := appConfig.F16 == startupAppConfig.F16
envDebug := appConfig.Debug == startupAppConfig.Debug
envCORS := appConfig.CORS == startupAppConfig.CORS
envCSRF := appConfig.CSRF == startupAppConfig.CSRF
envCSRF := appConfig.DisableCSRF == startupAppConfig.DisableCSRF
envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins
envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken
envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID
Expand Down Expand Up @@ -313,7 +313,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
appConfig.CORS = *settings.CORS
}
if settings.CSRF != nil && !envCSRF {
appConfig.CSRF = *settings.CSRF
appConfig.DisableCSRF = *settings.CSRF
}
if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins {
appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins
Expand Down
67 changes: 67 additions & 0 deletions core/application/startup.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package application

import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
Expand All @@ -10,6 +12,7 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
Expand Down Expand Up @@ -81,6 +84,45 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}

// Initialize auth database if auth is enabled
if options.Auth.Enabled {
// Auto-generate HMAC secret if not provided
if options.Auth.APIKeyHMACSecret == "" {
secretFile := filepath.Join(options.DataPath, ".hmac_secret")
secret, err := loadOrGenerateHMACSecret(secretFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize HMAC secret: %w", err)
}
options.Auth.APIKeyHMACSecret = secret
}

authDB, err := auth.InitDB(options.Auth.DatabaseURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
}
application.authDB = authDB
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)

// Start session and expired API key cleanup goroutine
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-options.Context.Done():
return
case <-ticker.C:
if err := auth.CleanExpiredSessions(authDB); err != nil {
xlog.Error("failed to clean expired sessions", "error", err)
}
if err := auth.CleanExpiredAPIKeys(authDB); err != nil {
xlog.Error("failed to clean expired API keys", "error", err)
}
}
}
}()
}

if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
xlog.Error("error installing models", "error", err)
}
Expand Down Expand Up @@ -434,6 +476,31 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
}
}

// loadOrGenerateHMACSecret loads an HMAC secret from the given file path,
// or generates a random 32-byte secret and persists it if the file doesn't exist.
func loadOrGenerateHMACSecret(path string) (string, error) {
data, err := os.ReadFile(path)
if err == nil {
secret := string(data)
if len(secret) >= 32 {
return secret, nil
}
}

b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate HMAC secret: %w", err)
}
secret := hex.EncodeToString(b)

if err := os.WriteFile(path, []byte(secret), 0600); err != nil {
return "", fmt.Errorf("failed to persist HMAC secret: %w", err)
}

xlog.Info("Generated new HMAC secret for API key hashing", "path", path)
return secret, nil
}

// migrateDataFiles moves persistent data files from the old config directory
// to the new data directory. Only moves files that exist in src but not in dst.
func migrateDataFiles(srcDir, dstDir string) {
Expand Down
59 changes: 57 additions & 2 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type RunCMD struct {
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
DisableCSRF bool `env:"LOCALAI_DISABLE_CSRF" help:"Disable CSRF middleware (enabled by default)" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
Expand Down Expand Up @@ -121,6 +121,21 @@ type RunCMD struct {
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`

// Authentication
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"`
OIDCIssuer string `env:"LOCALAI_OIDC_ISSUER" help:"OIDC issuer URL for auto-discovery" group:"auth"`
OIDCClientID string `env:"LOCALAI_OIDC_CLIENT_ID" help:"OIDC Client ID (auto-enables auth)" group:"auth"`
OIDCClientSecret string `env:"LOCALAI_OIDC_CLIENT_SECRET" help:"OIDC Client Secret" group:"auth"`
AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"`
AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"`
AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default), 'approval', or 'invite' (invite code required)" group:"auth"`
DisableLocalAuth bool `env:"LOCALAI_DISABLE_LOCAL_AUTH" default:"false" help:"Disable local email/password registration and login (use with OAuth/OIDC-only setups)" group:"auth"`
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`

Version bool
}

Expand Down Expand Up @@ -165,7 +180,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithBackendGalleries(r.BackendGalleries),
config.WithCors(r.CORS),
config.WithCorsAllowOrigins(r.CORSAllowOrigins),
config.WithCsrf(r.CSRF),
config.WithDisableCSRF(r.DisableCSRF),
config.WithThreads(r.Threads),
config.WithUploadLimitMB(r.UploadLimit),
config.WithApiKeys(r.APIKeys),
Expand Down Expand Up @@ -311,6 +326,46 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
}

// Authentication
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
if authEnabled {
opts = append(opts, config.WithAuthEnabled(true))

dbURL := r.AuthDatabaseURL
if dbURL == "" {
dbURL = filepath.Join(r.DataPath, "database.db")
}
opts = append(opts, config.WithAuthDatabaseURL(dbURL))

if r.GitHubClientID != "" {
opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID))
opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret))
}
if r.OIDCClientID != "" {
opts = append(opts, config.WithAuthOIDCIssuer(r.OIDCIssuer))
opts = append(opts, config.WithAuthOIDCClientID(r.OIDCClientID))
opts = append(opts, config.WithAuthOIDCClientSecret(r.OIDCClientSecret))
}
if r.AuthBaseURL != "" {
opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL))
}
if r.AuthAdminEmail != "" {
opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail))
}
if r.AuthRegistrationMode != "" {
opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode))
}
if r.DisableLocalAuth {
opts = append(opts, config.WithAuthDisableLocalAuth(true))
}
if r.AuthAPIKeyHMACSecret != "" {
opts = append(opts, config.WithAuthAPIKeyHMACSecret(r.AuthAPIKeyHMACSecret))
}
if r.DefaultAPIKeyExpiry != "" {
opts = append(opts, config.WithAuthDefaultAPIKeyExpiry(r.DefaultAPIKeyExpiry))
}
}

if idleWatchDog || busyWatchDog {
opts = append(opts, config.EnableWatchDog)
if idleWatchDog {
Expand Down
Loading
Loading