Skip to content

Commit

Permalink
refactor: contants and errors
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Feb 15, 2024
1 parent 29b04b2 commit ef2b619
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 59 deletions.
33 changes: 33 additions & 0 deletions gateway/internal/constants/contants.go
@@ -0,0 +1,33 @@
package constants

const (
MIMEApplicationJSON = "application/json"
MIMEApplicationProtobuf = "application/protobuf"
MIMEOctetStream = "application/octet-stream"
MIMEApplicationForm = "application/x-www-form-urlencoded"
MIMEMultipartForm = "multipart/form-data"
MIMETextEventStream = "text/event-stream"
MIMETextPlain = "text/plain"
MIMETextHTML = "text/html"
MIMEKeepAlive = "keep-alive"
)

const (
HeaderAuthorization = "Authorization"
HeaderCacheControl = "Cache-Control"
)

const (
XMSAPIKey = "X-MS-Api-Key"
XMSProvider = "X-MS-Provider"
XMSConfig = "X-MS-Config"
XMSCache = "X-MS-Cache"
XMSRequestId = "X-MS-Request-Id"
XMSTraceId = "X-MS-Trace-Id"
XMSRetryCount = "X-MS-Retry-count"
)

const (
XMSRetryAttemptCount = "X-MS-Retry-Attempt-count"
XMSCacheStatus = "X-MS-Cache-Status"
)
16 changes: 16 additions & 0 deletions gateway/internal/errors/errors.go
@@ -0,0 +1,16 @@
package errors

import (
"fmt"

"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/common/errors"
)

var (
ErrProviderHeaderNotExit = errors.NewBadRequest(fmt.Sprintf("%s header is required", constants.XMSProvider))
ErrRequiredHeaderNotExit = errors.NewBadRequest(fmt.Sprintf("either %s or %s header is required", constants.XMSProvider, constants.XMSConfig))
ErrRateLimitExceeded = errors.NewForbidden("rate limit exceeded")
ErrUnauthenticated = errors.NewUnauthorized("unauthenticated")
ErrProviderNotFound = errors.NewNotFound("provider is not found")
)
6 changes: 4 additions & 2 deletions gateway/internal/interceptor/auth.go
Expand Up @@ -5,19 +5,21 @@ import (
"log/slog"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/errors"
)

// NewAPIKeyInterceptor returns interceptor which is checking if api key exits
func NewAPIKeyInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc {
return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
apiHeader := req.Header().Get("X-MS-API-KEY")
apiHeader := req.Header().Get(constants.XMSAPIKey)
if apiHeader == "" {
logger.Info("request without api key",
"api_key", apiHeader,
"addr", req.Peer().Addr,
"endpoint", req.Spec().Procedure)
return nil, ErrUnauthenticated
return nil, errors.ErrUnauthenticated
}

return next(ctx, req)
Expand Down
35 changes: 35 additions & 0 deletions gateway/internal/interceptor/headers.go
@@ -0,0 +1,35 @@
package interceptor

import (
"context"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/errors"
"github.com/missingstudio/studio/backend/internal/providers"
)

func ProviderInterceptor() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
// Check if required headers are available
provider := req.Header().Get(constants.XMSProvider)
config := req.Header().Get(constants.XMSConfig)
if provider == "" || config == "" {
return nil, errors.ErrRequiredHeaderNotExit
}

// Check if provider has registered of not
_, ok := providers.ProviderFactories[provider]
if !ok {
return nil, errors.ErrProviderNotFound
}

return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
13 changes: 0 additions & 13 deletions gateway/internal/interceptor/interceptor.go

This file was deleted.

24 changes: 0 additions & 24 deletions gateway/internal/interceptor/provider.go

This file was deleted.

3 changes: 2 additions & 1 deletion gateway/internal/interceptor/ratelimit.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/errors"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
)

Expand All @@ -15,7 +16,7 @@ func RateLimiterInterceptor(rl *ratelimiter.RateLimiter) connect.UnaryIntercepto
) (connect.AnyResponse, error) {
key := "req_count"
if !rl.Limiter.Validate(key) {
return nil, ErrRateLimitExceeded
return nil, errors.ErrRateLimitExceeded
}

return next(ctx, req)
Expand Down
29 changes: 12 additions & 17 deletions gateway/internal/providers/providers.go
Expand Up @@ -2,46 +2,41 @@ package providers

import (
"context"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/errors"
"github.com/missingstudio/studio/backend/internal/providers/anyscale"
"github.com/missingstudio/studio/backend/internal/providers/azure"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/providers/deepinfra"
"github.com/missingstudio/studio/backend/internal/providers/openai"
"github.com/missingstudio/studio/backend/internal/providers/togetherai"
"github.com/missingstudio/studio/common/errors"
)

var (
ErrProviderHeaderNotExit = errors.New(fmt.Errorf("x-ms-provider provider header not available"))
ErrProviderNotFound = errors.NewNotFound("provider is not found")
)

type ProviderFactory interface {
Create(headers http.Header) (base.ProviderInterface, error)
}

var providerFactories = make(map[string]ProviderFactory)
var ProviderFactories = make(map[string]ProviderFactory)

func init() {
providerFactories["openai"] = openai.OpenAIProviderFactory{}
providerFactories["azure"] = azure.AzureProviderFactory{}
providerFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
providerFactories["deepinfra"] = deepinfra.DeepinfraProviderFactory{}
providerFactories["togetherai"] = togetherai.TogetherAIProviderFactory{}
ProviderFactories["openai"] = openai.OpenAIProviderFactory{}
ProviderFactories["azure"] = azure.AzureProviderFactory{}
ProviderFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
ProviderFactories["deepinfra"] = deepinfra.DeepinfraProviderFactory{}
ProviderFactories["togetherai"] = togetherai.TogetherAIProviderFactory{}
}

func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) {
providerName := headers.Get("x-ms-provider")
providerName := headers.Get(constants.XMSProvider)
if providerName == "" {
return nil, ErrProviderHeaderNotExit
return nil, errors.ErrProviderHeaderNotExit
}

providerFactory, ok := providerFactories[providerName]
providerFactory, ok := ProviderFactories[providerName]
if !ok {
return nil, ErrProviderNotFound
return nil, errors.ErrProviderNotFound
}

return providerFactory.Create(headers)
Expand Down
3 changes: 2 additions & 1 deletion gateway/main_test.go
Expand Up @@ -10,6 +10,7 @@ import (
"connectrpc.com/connect"

v1 "github.com/missingstudio/studio/backend/internal/api/v1"
"github.com/missingstudio/studio/backend/internal/errors"
"github.com/missingstudio/studio/backend/internal/interceptor"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
Expand Down Expand Up @@ -54,7 +55,7 @@ func TestGatewayServer(t *testing.T) {
_, err := client.ChatCompletions(context.Background(), req)

require.NotNil(t, err)
assert.True(t, strings.Contains(err.Error(), interceptor.ErrProviderHeaderNotExit.Error()))
assert.True(t, strings.Contains(err.Error(), errors.ErrRequiredHeaderNotExit.Error()))
}
})
}
3 changes: 2 additions & 1 deletion gateway/pkg/utils/headers.go
Expand Up @@ -10,6 +10,7 @@ import (

"connectrpc.com/connect"
"github.com/go-playground/validator/v10"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/common/errors"
)

Expand All @@ -20,7 +21,7 @@ func isJSON(s string, v interface{}) bool {
}

func UnmarshalConfigHeaders(header http.Header, v interface{}) error {
msconfig := header.Get("x-ms-provider")
msconfig := header.Get(constants.XMSProvider)
if msconfig == "" && isJSON(msconfig, v) {
return ErrGatewayConfigHeaderNotValid
}
Expand Down

0 comments on commit ef2b619

Please sign in to comment.