Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial watchdog implementation #1341

Merged
merged 9 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,18 @@ MODELS_PATH=/models
# LLAMACPP_PARALLEL=1

### Enable to run parallel requests
# PARALLEL_REQUESTS=true
# PARALLEL_REQUESTS=true

### Watchdog settings
###
# Enables watchdog to kill backends that are inactive for too much time
# WATCHDOG_IDLE=true
#
# Enables watchdog to kill backends that are busy for too much time
# WATCHDOG_BUSY=true
#
# Time in duration format (e.g. 1h30m) after which a backend is considered idle
# WATCHDOG_IDLE_TIMEOUT=5m
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# WATCHDOG_BUSY_TIMEOUT=5m
17 changes: 17 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
Expand Down Expand Up @@ -79,6 +80,22 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
options.Loader.StopAllGRPC()
}()

if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}

return options, cl, nil
}

Expand Down
2 changes: 1 addition & 1 deletion api/localai/backend_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}

status, rpcErr := model.GRPC(false).Status(context.TODO())
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
Expand Down
32 changes: 32 additions & 0 deletions api/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"encoding/json"
"time"

"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery"
Expand Down Expand Up @@ -38,6 +39,11 @@ type Option struct {

SingleBackend bool
ParallelBackendRequests bool

WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
}

type AppOption func(*Option)
Expand All @@ -63,6 +69,32 @@ func WithCors(b bool) AppOption {
}
}

var EnableWatchDog = func(o *Option) {
o.WatchDog = true
}

var EnableWatchDogIdleCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogIdle = true
}

var EnableWatchDogBusyCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogBusy = true
}

func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogBusyTimeout = t
}
}

func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogIdleTimeout = t
}
}

var EnableSingleBackend = func(o *Option) {
o.SingleBackend = true
}
Expand Down
47 changes: 47 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"syscall"
"time"

api "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/api/backend"
Expand Down Expand Up @@ -154,6 +155,30 @@ func main() {
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
EnvVars: []string{"API_KEY"},
},
&cli.BoolFlag{
Name: "enable-watchdog-idle",
Usage: "Enable watchdog for stopping idle backends. This will stop the backends if are in idle state for too long.",
EnvVars: []string{"WATCHDOG_IDLE"},
Value: false,
},
&cli.BoolFlag{
Name: "enable-watchdog-busy",
Usage: "Enable watchdog for stopping busy backends that exceed a defined threshold.",
EnvVars: []string{"WATCHDOG_BUSY"},
Value: false,
},
&cli.StringFlag{
Name: "watchdog-busy-timeout",
Usage: "Watchdog timeout. This will restart the backend if it crashes.",
EnvVars: []string{"WATCHDOG_BUSY_TIMEOUT"},
Value: "5m",
},
&cli.StringFlag{
Name: "watchdog-idle-timeout",
Usage: "Watchdog idle timeout. This will restart the backend if it crashes.",
EnvVars: []string{"WATCHDOG_IDLE_TIMEOUT"},
Value: "15m",
},
&cli.BoolFlag{
Name: "preload-backend-only",
Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.",
Expand Down Expand Up @@ -198,6 +223,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithUploadLimitMB(ctx.Int("upload-limit")),
options.WithApiKeys(ctx.StringSlice("api-keys")),
}

idleWatchDog := ctx.Bool("enable-watchdog-idle")
busyWatchDog := ctx.Bool("enable-watchdog-busy")
if idleWatchDog || busyWatchDog {
opts = append(opts, options.EnableWatchDog)
if idleWatchDog {
opts = append(opts, options.EnableWatchDogIdleCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
if err != nil {
return err
}
opts = append(opts, options.SetWatchDogIdleTimeout(dur))
}
if busyWatchDog {
opts = append(opts, options.EnableWatchDogBusyCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
if err != nil {
return err
}
opts = append(opts, options.SetWatchDogBusyTimeout(dur))
}
}
if ctx.Bool("parallel-requests") {
opts = append(opts, options.EnableParallelBackendRequests)
}
Expand Down
44 changes: 43 additions & 1 deletion pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,22 @@ type Client struct {
parallel bool
sync.Mutex
opMutex sync.Mutex
wd WatchDog
}

func NewClient(address string, parallel bool) *Client {
type WatchDog interface {
Mark(address string)
UnMark(address string)
}

func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
if !enableWatchDog {
wd = nil
}
return &Client{
address: address,
parallel: parallel,
wd: wd,
}
}

Expand Down Expand Up @@ -79,6 +89,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -96,6 +110,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -113,6 +131,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -129,6 +151,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
Expand Down Expand Up @@ -164,6 +190,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -180,6 +210,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -196,6 +230,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -232,6 +270,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
// Wait for the service to start up
ready := false
for i := 0; i < o.grpcAttempts; i++ {
if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) {
if client.GRPC(o.parallelRequests, ml.wd).HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Service Ready")
ready = true
break
Expand All @@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

log.Debug().Msgf("GRPC: Loading model with options: %+v", options)

res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options)
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
if err != nil {
return "", fmt.Errorf("could not load model: %w", err)
}
Expand All @@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
if parallel {
return addr.GRPC(parallel), nil
return addr.GRPC(parallel, ml.wd), nil
}

if _, ok := ml.grpcClients[string(addr)]; !ok {
ml.grpcClients[string(addr)] = addr.GRPC(parallel)
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
}
return ml.grpcClients[string(addr)], nil
}
Expand Down
26 changes: 21 additions & 5 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ type ModelLoader struct {
models map[string]ModelAddress
grpcProcesses map[string]*process.Process
templates map[TemplateType]map[string]*template.Template
wd *WatchDog
}

type ModelAddress string

func (m ModelAddress) GRPC(parallel bool) *grpc.Client {
return grpc.NewClient(string(m), parallel)
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
enableWD := false
if wd != nil {
enableWD = true
}
return grpc.NewClient(string(m), parallel, wd, enableWD)
}

func NewModelLoader(modelPath string) *ModelLoader {
Expand All @@ -79,10 +84,15 @@ func NewModelLoader(modelPath string) *ModelLoader {
templates: make(map[TemplateType]map[string]*template.Template),
grpcProcesses: make(map[string]*process.Process),
}

nml.initializeTemplateMap()
return nml
}

func (ml *ModelLoader) SetWatchDog(wd *WatchDog) {
ml.wd = wd
}

func (ml *ModelLoader) ExistsInModelPath(s string) bool {
return existsInPath(ml.ModelPath, s)
}
Expand Down Expand Up @@ -139,11 +149,17 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
func (ml *ModelLoader) ShutdownModel(modelName string) error {
ml.mu.Lock()
defer ml.mu.Unlock()

return ml.StopModel(modelName)
}

func (ml *ModelLoader) StopModel(modelName string) error {
defer ml.deleteProcess(modelName)
if _, ok := ml.models[modelName]; !ok {
return fmt.Errorf("model %s not found", modelName)
}

return ml.deleteProcess(modelName)
return nil
//return ml.deleteProcess(modelName)
}

func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
Expand All @@ -153,7 +169,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
if c, ok := ml.grpcClients[s]; ok {
client = c
} else {
client = m.GRPC(false)
client = m.GRPC(false, ml.wd)
}

if !client.HealthCheck(context.Background()) {
Expand Down