Skip to content

Commit

Permalink
feat(gateway): add headers, provider and gatway from headers into con…
Browse files Browse the repository at this point in the history
…text

Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 22, 2024
1 parent d27c8fc commit a7f6508
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 47 deletions.
10 changes: 2 additions & 8 deletions gateway/internal/api/v1/chatcompletions.go
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/httputil"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
Expand Down Expand Up @@ -39,14 +40,7 @@ func (s *V1Handler) ChatCompletions(
return nil, errors.New(err)
}

// Convert headers into map[string]any + merge config
headerConfig := make(map[string]any)
for key, values := range req.Header() {
if len(values) > 0 {
headerConfig[key] = values[0]
}
}

headerConfig := httputil.GetContextWithHeaderConfig(ctx)
connectionObj := models.Connection{
Name: providerName,
Headers: headerConfig,
Expand Down
3 changes: 2 additions & 1 deletion gateway/internal/api/v1/v1.go
Expand Up @@ -47,9 +47,10 @@ func Register(d *api.Deps) (http.Handler, error) {
stdInterceptors := []connect.Interceptor{
validateInterceptor,
otelconnectInterceptor,
interceptor.NewLoggingInterceptor(d.Logger),
interceptor.HeadersInterceptor(),
interceptor.RateLimiterInterceptor(d.RateLimiter),
interceptor.RetryInterceptor(),
interceptor.NewLoggingInterceptor(d.Logger),
}

services := []*vanguard.Service{
Expand Down
1 change: 1 addition & 0 deletions gateway/internal/errors/errors.go
Expand Up @@ -13,4 +13,5 @@ var (
ErrRateLimitExceeded = errors.NewForbidden("rate limit exceeded")
ErrUnauthenticated = errors.NewUnauthorized("unauthenticated")
ErrProviderNotFound = errors.NewNotFound("provider is not found")
ErrGatewayConfigNotValid = errors.NewNotFound("gateway config is not valid")
)
28 changes: 21 additions & 7 deletions gateway/internal/interceptor/headers.go
Expand Up @@ -2,27 +2,41 @@ package interceptor

import (
"context"
"encoding/json"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/errors"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/pkg/httputil"
)

func ProviderInterceptor(ps *providers.Service) connect.UnaryInterceptorFunc {
func HeadersInterceptor() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
// Check if required headers are available
provider := req.Header().Get(constants.XMSProvider)
headerConfig := make(map[string]any)
for key, values := range req.Header() {
if len(values) > 0 {
headerConfig[key] = values[0]
}
}

var gc httputil.GatewayConfig
config := req.Header().Get(constants.XMSConfig)
if provider == "" && config == "" {
return nil, errors.ErrRequiredHeaderNotExit
provider := req.Header().Get(constants.XMSProvider)

if config != "" {
err := json.Unmarshal([]byte(config), &gc)
if err != nil {
return nil, errors.ErrGatewayConfigNotValid
}
}

// Check if provider has registered of not
ctx = httputil.SetContextWithProviderConfig(ctx, provider)
ctx = httputil.SetContextWithHeaderConfig(ctx, headerConfig)
ctx = httputil.SetContextWithGatewayConfig(ctx, &gc)
return next(ctx, req)
})
}
Expand Down
2 changes: 1 addition & 1 deletion gateway/main_test.go
Expand Up @@ -24,7 +24,7 @@ func TestGatewayServer(t *testing.T) {
mux := http.NewServeMux()
mux.Handle(llmv1connect.NewLLMServiceHandler(
&v1.V1Handler{},
connect.WithInterceptors(interceptor.ProviderInterceptor(nil)),
connect.WithInterceptors(interceptor.HeadersInterceptor()),
))

server := httptest.NewUnstartedServer(mux)
Expand Down
87 changes: 87 additions & 0 deletions gateway/pkg/httputil/context.go
@@ -0,0 +1,87 @@
package httputil

import (
"context"
)

type (
GatewayConfigContextKey struct{}
HeaderConfigContextKey struct{}
)

type headerConfig map[string]any

type RetryConfig struct {
// Number is the number of times to retry the request when a retryable
Number int32 `json:"number"`

// RetryOnStatusCodes is a flat list of http response status codes that are
// eligible for retry. This again should be feasible in any reasonable proxy.
OnStatusCodes []uint32 `json:"on_status_codes"`
}

type CacheConfig struct {
Mode string `json:"mode"`
MaxAge int32 `json:"max_age"`
}

type StrategyConfig struct {
Mode string `json:"mode"`
}

type GatewayConfig struct {
Name string `json:"name"`
ApiKey string `json:"api_key"`
VirtualKey string `json:"virtual_key"`
RetryConfig RetryConfig `json:"retry"`
CacheConfig CacheConfig `json:"cache"`
Targets []GatewayConfig `json:"targets"`
Metadata map[string]any `josn:"metadata"`
}

type Address struct {
City string `json:"city"`
Zip string `json:"zip"`
}

type Person struct {
Name string `json:"name"`
Age int `json:"age"`
Address Address `json:"address"`
}

func SetContextWithHeaderConfig(ctx context.Context, config headerConfig) context.Context {
return context.WithValue(ctx, HeaderConfigContextKey{}, config)
}

func GetContextWithHeaderConfig(ctx context.Context) headerConfig {
c, ok := ctx.Value(HeaderConfigContextKey{}).(headerConfig)
if !ok {
return nil
}
return c
}

func SetContextWithGatewayConfig(ctx context.Context, config *GatewayConfig) context.Context {
return context.WithValue(ctx, GatewayConfigContextKey{}, config)
}

func GetContextWithGatewayConfig(ctx context.Context) *GatewayConfig {
c, ok := ctx.Value(GatewayConfigContextKey{}).(*GatewayConfig)
if !ok {
return nil
}
return c
}

func SetContextWithProviderConfig(ctx context.Context, name string) context.Context {
return context.WithValue(ctx, HeaderConfigContextKey{}, name)
}

func GetContextWithProviderConfig(ctx context.Context) string {
name, ok := ctx.Value(HeaderConfigContextKey{}).(string)
if !ok {
return ""
}
return name
}
30 changes: 0 additions & 30 deletions protos/proto/llm/models.proto

This file was deleted.

0 comments on commit a7f6508

Please sign in to comment.