Skip to content

Commit

Permalink
feat: add --single-active-backend to allow only one backend active at…
Browse files Browse the repository at this point in the history
… the time

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler committed Aug 18, 2023
1 parent 1079b18 commit c7516af
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 157 deletions.
16 changes: 2 additions & 14 deletions api/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,13 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
var inferenceModel interface{}
var err error

opts := []model.Option{
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
}

if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}

if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}

for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
})

if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
Expand Down
16 changes: 2 additions & 14 deletions api/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {

opts := []model.Option{
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
Expand All @@ -25,19 +25,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
}),
}

if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}

if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}

for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
})

inferenceModel, err := loader.BackendLoader(
opts...,
Expand Down
16 changes: 2 additions & 14 deletions api/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,13 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
var inferenceModel *grpc.Client
var err error

opts := []model.Option{
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
}

if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}

if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}

for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
})

if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
Expand Down
22 changes: 22 additions & 0 deletions api/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,32 @@ import (
"path/filepath"

pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"

config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
)

func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
if o.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}

if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}

if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}

for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}

return opts
}

func gRPCModelOpts(c config.Config) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
Expand Down
16 changes: 3 additions & 13 deletions api/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,14 @@ import (
)

func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
opts := []model.Option{

opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
}

if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}

if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
})

whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
Expand Down
10 changes: 3 additions & 7 deletions api/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"

api_config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
Expand Down Expand Up @@ -33,17 +34,12 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
if bb == "" {
bb = model.PiperBackend
}
opts := []model.Option{
opts := modelOpts(api_config.Config{}, o, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
}

for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}

})
piperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
Expand Down
6 changes: 6 additions & 0 deletions api/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type Option struct {
ExternalGRPCBackends map[string]string

AutoloadGalleries bool

SingleBackend bool
}

type AppOption func(*Option)
Expand All @@ -58,6 +60,10 @@ func WithCors(b bool) AppOption {
}
}

var EnableSingleBackend = func(o *Option) {
o.SingleBackend = true
}

var EnableGalleriesAutoload = func(o *Option) {
o.AutoloadGalleries = true
}
Expand Down
5 changes: 5 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func main() {
Name: "debug",
EnvVars: []string{"DEBUG"},
},
&cli.BoolFlag{
Name: "single-active-backend",
EnvVars: []string{"SINGLE_ACTIVE_BACKEND"},
Usage: "Allow only one backend to be running.",
},
&cli.BoolFlag{
Name: "cors",
EnvVars: []string{"CORS"},
Expand Down
35 changes: 35 additions & 0 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"sync"
"time"

pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
Expand All @@ -14,6 +15,8 @@ import (

type Client struct {
address string
busy bool
sync.Mutex
}

func NewClient(address string) *Client {
Expand All @@ -22,7 +25,21 @@ func NewClient(address string) *Client {
}
}

func (c *Client) IsBusy() bool {
c.Lock()
defer c.Unlock()
return c.busy
}

func (c *Client) setBusy(v bool) {
c.Lock()
c.busy = v
c.Unlock()
}

func (c *Client) HealthCheck(ctx context.Context) bool {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
fmt.Println(err)
Expand All @@ -49,6 +66,8 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
}

func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -60,6 +79,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}

func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -71,6 +92,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}

func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -81,6 +104,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}

func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
Expand Down Expand Up @@ -110,6 +135,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}

func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -120,6 +147,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}

func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -130,6 +159,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}

func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -160,6 +191,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}

func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -176,6 +209,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}

func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down

0 comments on commit c7516af

Please sign in to comment.