Skip to content

Commit

Permalink
feat(gateway): add router for provider failover
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Feb 3, 2024
1 parent 248cff3 commit 2d6c5a9
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 0 deletions.
21 changes: 21 additions & 0 deletions gateway/internal/mock/mock_provider.go
@@ -0,0 +1,21 @@
package mock

import "github.com/missingstudio/studio/backend/internal/providers/base"

type ProviderMock struct {
Name string
}

func NewProviderMock(name string) base.ProviderInterface {
return &ProviderMock{
Name: name,
}
}

func (p ProviderMock) GetName() string {
return p.Name
}

func (p ProviderMock) Validate() error {
return nil
}
37 changes: 37 additions & 0 deletions gateway/internal/router/priority.go
@@ -0,0 +1,37 @@
package router

import (
"sync/atomic"

"github.com/missingstudio/studio/backend/internal/providers/base"
)

const (
Priority Strategy = "priority"
)

type PriorityRouter struct {
idx *atomic.Uint64
providers []base.ProviderInterface
}

func NewPriorityRouter(providers []base.ProviderInterface) *PriorityRouter {
return &PriorityRouter{
idx: &atomic.Uint64{},
providers: providers,
}
}

func (r *PriorityRouter) Iterator() ProviderIterator {
return r
}

func (r *PriorityRouter) Next() (base.ProviderInterface, error) {
idx := int(r.idx.Load())

// Todo: make a check for healthy provider
model := r.providers[idx]
r.idx.Add(1)

return model, nil
}
44 changes: 44 additions & 0 deletions gateway/internal/router/priority_test.go
@@ -0,0 +1,44 @@
package router_test

import (
"testing"

"github.com/missingstudio/studio/backend/internal/mock"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/router"
"github.com/stretchr/testify/require"
)

func TestPriorityRouter(t *testing.T) {
type Provider struct {
Name string
}

type TestCase struct {
providers []Provider
expectedModelIDs []string
}

tests := map[string]TestCase{
"openai": {[]Provider{{"openai"}, {"anyscale"}, {"azure"}}, []string{"openai", "anyscale", "azure"}},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
providers := make([]base.ProviderInterface, 0, len(tc.providers))

for _, provider := range tc.providers {
providers = append(providers, mock.NewProviderMock(provider.Name))
}

routing := router.NewPriorityRouter(providers)
iterator := routing.Iterator()

for _, modelID := range tc.expectedModelIDs {
model, err := iterator.Next()
require.NoError(t, err)
require.Equal(t, modelID, model.GetName())
}
})
}
}
36 changes: 36 additions & 0 deletions gateway/internal/router/round_robin.go
@@ -0,0 +1,36 @@
package router

import (
"sync/atomic"

"github.com/missingstudio/studio/backend/internal/providers/base"
)

const (
RoundRobin Strategy = "roundrobin"
)

type RoundRobinRouter struct {
idx atomic.Uint64
providers []base.ProviderInterface
}

func NewRoundRobinRouter(providers []base.ProviderInterface) *RoundRobinRouter {
return &RoundRobinRouter{
providers: providers,
}
}

func (r *RoundRobinRouter) Iterator() ProviderIterator {
return r
}

func (r *RoundRobinRouter) Next() (base.ProviderInterface, error) {
providerLen := len(r.providers)

// Todo: make a check for healthy provider
idx := r.idx.Add(1) - 1
model := r.providers[idx%uint64(providerLen)]

return model, nil
}
45 changes: 45 additions & 0 deletions gateway/internal/router/round_robin_test.go
@@ -0,0 +1,45 @@
package router_test

import (
"testing"

"github.com/missingstudio/studio/backend/internal/mock"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/router"
"github.com/stretchr/testify/require"
)

func TestRoundRobinRouter(t *testing.T) {
type Provider struct {
Name string
}

type TestCase struct {
providers []Provider
expectedModelIDs []string
}

tests := map[string]TestCase{
"public llms": {[]Provider{{"openai"}, {"anyscale"}, {"azure"}}, []string{"openai", "anyscale", "azure"}},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
providers := make([]base.ProviderInterface, 0, len(tc.providers))

for _, provider := range tc.providers {
providers = append(providers, mock.NewProviderMock(provider.Name))
}

routing := router.NewRoundRobinRouter(providers)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
for _, providerName := range tc.expectedModelIDs {
provider, err := iterator.Next()
require.NoError(t, err)
require.Equal(t, providerName, provider.GetName())
}
})
}
}
15 changes: 15 additions & 0 deletions gateway/internal/router/router.go
@@ -0,0 +1,15 @@
package router

import (
"errors"

"github.com/missingstudio/studio/backend/internal/providers/base"
)

var ErrNoHealthyProviders = errors.New("no healthy providers found")

type Strategy string

type ProviderIterator interface {
Next() (base.ProviderInterface, error)
}

0 comments on commit 2d6c5a9

Please sign in to comment.