Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions core/application/agent_jobs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package application

import (
"time"

"github.com/mudler/LocalAI/core/services"
"github.com/rs/zerolog/log"
)

// RestartAgentJobService restarts the agent job service with current ApplicationConfig settings
func (a *Application) RestartAgentJobService() error {
a.agentJobMutex.Lock()
defer a.agentJobMutex.Unlock()

// Stop existing service if running
if a.agentJobService != nil {
if err := a.agentJobService.Stop(); err != nil {
log.Warn().Err(err).Msg("Error stopping agent job service")
}
// Wait a bit for shutdown to complete
time.Sleep(200 * time.Millisecond)
}

// Create new service instance
agentJobService := services.NewAgentJobService(
a.ApplicationConfig(),
a.ModelLoader(),
a.ModelConfigLoader(),
a.TemplatesEvaluator(),
)

// Start the service
err := agentJobService.Start(a.ApplicationConfig().Context)
if err != nil {
log.Error().Err(err).Msg("Failed to start agent job service")
return err
}

a.agentJobService = agentJobService
log.Info().Msg("Agent job service restarted")
return nil
}

21 changes: 21 additions & 0 deletions core/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ type Application struct {
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
templatesEvaluator *templates.Evaluator
galleryService *services.GalleryService
agentJobService *services.AgentJobService
watchdogMutex sync.Mutex
watchdogStop chan bool
p2pMutex sync.Mutex
p2pCtx context.Context
p2pCancel context.CancelFunc
agentJobMutex sync.Mutex
}

func newApplication(appConfig *config.ApplicationConfig) *Application {
Expand Down Expand Up @@ -53,6 +55,10 @@ func (a *Application) GalleryService() *services.GalleryService {
return a.galleryService
}

func (a *Application) AgentJobService() *services.AgentJobService {
return a.agentJobService
}

// StartupConfig returns the original startup configuration (from env vars, before file loading)
func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig
Expand All @@ -67,5 +73,20 @@ func (a *Application) start() error {

a.galleryService = galleryService

// Initialize agent job service
agentJobService := services.NewAgentJobService(
a.ApplicationConfig(),
a.ModelLoader(),
a.ModelConfigLoader(),
a.TemplatesEvaluator(),
)

err = agentJobService.Start(a.ApplicationConfig().Context)
if err != nil {
return err
}

a.agentJobService = agentJobService

return nil
}
7 changes: 7 additions & 0 deletions core/application/config_file_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler
if err != nil {
log.Error().Err(err).Str("file", "runtime_settings.json").Msg("unable to register config file handler")
}
// Note: agent_tasks.json and agent_jobs.json are handled by AgentJobService directly
// The service watches and reloads these files internally
return c
}

Expand Down Expand Up @@ -206,6 +208,7 @@ type runtimeSettings struct {
AutoloadGalleries *bool `json:"autoload_galleries,omitempty"`
AutoloadBackendGalleries *bool `json:"autoload_backend_galleries,omitempty"`
ApiKeys *[]string `json:"api_keys,omitempty"`
AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"`
}

func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHandler {
Expand Down Expand Up @@ -234,6 +237,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
envFederated := appConfig.Federated == startupAppConfig.Federated
envAutoloadGalleries := appConfig.AutoloadGalleries == startupAppConfig.AutoloadGalleries
envAutoloadBackendGalleries := appConfig.AutoloadBackendGalleries == startupAppConfig.AutoloadBackendGalleries
envAgentJobRetentionDays := appConfig.AgentJobRetentionDays == startupAppConfig.AgentJobRetentionDays

if len(fileContent) > 0 {
var settings runtimeSettings
Expand Down Expand Up @@ -328,6 +332,9 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
// Replace all runtime keys with what's in runtime_settings.json
appConfig.ApiKeys = append(envKeys, runtimeKeys...)
}
if settings.AgentJobRetentionDays != nil && !envAgentJobRetentionDays {
appConfig.AgentJobRetentionDays = *settings.AgentJobRetentionDays
}

// If watchdog is enabled via file but not via env, ensure WatchDog flag is set
if !envWatchdogIdle && !envWatchdogBusy {
Expand Down
7 changes: 7 additions & 0 deletions core/application/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
WatchdogBusyTimeout *string `json:"watchdog_busy_timeout,omitempty"`
SingleBackend *bool `json:"single_backend,omitempty"`
ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"`
}

if err := json.Unmarshal(fileContent, &settings); err != nil {
Expand Down Expand Up @@ -289,6 +290,12 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
options.ParallelBackendRequests = *settings.ParallelBackendRequests
}
}
if settings.AgentJobRetentionDays != nil {
// Only apply if current value is default (0), suggesting it wasn't set from env var
if options.AgentJobRetentionDays == 0 {
options.AgentJobRetentionDays = *settings.AgentJobRetentionDays
}
}
if !options.WatchDogIdle && !options.WatchDogBusy {
if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled {
options.WatchDog = true
Expand Down
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type RunCMD struct {
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`

Version bool
}
Expand Down Expand Up @@ -129,6 +130,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithLoadToMemory(r.LoadToMemory),
config.WithMachineTag(r.MachineTag),
config.WithAPIAddress(r.Address),
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
config.WithTunnelCallback(func(tunnels []string) {
tunnelEnvVar := strings.Join(tunnels, ",")
// TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
Expand Down
15 changes: 12 additions & 3 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,18 @@ type ApplicationConfig struct {
TunnelCallback func(tunnels []string)

DisableRuntimeSettings bool

AgentJobRetentionDays int // Default: 30 days
}

type AppOption func(*ApplicationConfig)

func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
opt := &ApplicationConfig{
Context: context.Background(),
UploadLimitMB: 15,
Debug: true,
Context: context.Background(),
UploadLimitMB: 15,
Debug: true,
AgentJobRetentionDays: 30, // Default: 30 days
}
for _, oo := range o {
oo(opt)
Expand Down Expand Up @@ -333,6 +336,12 @@ func WithApiKeys(apiKeys []string) AppOption {
}
}

func WithAgentJobRetentionDays(days int) AppOption {
return func(o *ApplicationConfig) {
o.AgentJobRetentionDays = days
}
}

func WithEnforcedPredownloadScans(enforced bool) AppOption {
return func(o *ApplicationConfig) {
o.EnforcePredownloadScans = enforced
Expand Down
2 changes: 1 addition & 1 deletion core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func API(application *application.Application) (*echo.Echo, error) {
opcache = services.NewOpCache(application.GalleryService())
}

routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application)
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
Expand Down
167 changes: 167 additions & 0 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,41 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
return json.Unmarshal(body, respJson)
}

func putRequestJSON[B any](url string, bodyJson *B) error {
payload, err := json.Marshal(bodyJson)
if err != nil {
return err
}

GinkgoWriter.Printf("PUT %s: %s\n", url, string(payload))

req, err := http.NewRequest("PUT", url, bytes.NewBuffer(payload))
if err != nil {
return err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}

return nil
}

func postInvalidRequest(url string) (error, int) {

req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
Expand Down Expand Up @@ -1194,6 +1229,138 @@ parameters:
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
}
})

Context("Agent Jobs", Label("agent-jobs"), func() {
It("creates and manages tasks", func() {
// Create a task
taskBody := map[string]interface{}{
"name": "Test Task",
"description": "Test Description",
"model": "testmodel.ggml",
"prompt": "Hello {{.name}}",
"enabled": true,
}

var createResp map[string]interface{}
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
Expect(createResp["id"]).ToNot(BeEmpty())
taskID := createResp["id"].(string)

// Get the task
var task schema.Task
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
json.Unmarshal(body, &task)
Expect(task.Name).To(Equal("Test Task"))

// List tasks
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var tasks []schema.Task
body, _ = io.ReadAll(resp.Body)
json.Unmarshal(body, &tasks)
Expect(len(tasks)).To(BeNumerically(">=", 1))

// Update task
taskBody["name"] = "Updated Task"
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
Expect(err).ToNot(HaveOccurred())

// Verify update
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
body, _ = io.ReadAll(resp.Body)
json.Unmarshal(body, &task)
Expect(task.Name).To(Equal("Updated Task"))

// Delete task
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
})

It("executes and monitors jobs", func() {
// Create a task first
taskBody := map[string]interface{}{
"name": "Job Test Task",
"model": "testmodel.ggml",
"prompt": "Say hello",
"enabled": true,
}

var createResp map[string]interface{}
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
taskID := createResp["id"].(string)

// Execute a job
jobBody := map[string]interface{}{
"task_id": taskID,
"parameters": map[string]string{},
}

var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
jobID := jobResp.JobID

// Get job status
var job schema.Job
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
json.Unmarshal(body, &job)
Expect(job.ID).To(Equal(jobID))
Expect(job.TaskID).To(Equal(taskID))

// List jobs
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var jobs []schema.Job
body, _ = io.ReadAll(resp.Body)
json.Unmarshal(body, &jobs)
Expect(len(jobs)).To(BeNumerically(">=", 1))

// Cancel job (if still pending/running)
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
}
})

It("executes task by name", func() {
// Create a task with a specific name
taskBody := map[string]interface{}{
"name": "Named Task",
"model": "testmodel.ggml",
"prompt": "Hello",
"enabled": true,
}

var createResp map[string]interface{}
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())

// Execute by name
paramsBody := map[string]string{"param1": "value1"}
var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", &paramsBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
})
})
})
})

Expand Down
Loading
Loading