Skip to content

Commit

Permalink
feat(gateway): add gateway specific config using x-ms-config header
Browse files Browse the repository at this point in the history
Gatewaty specific config can be added using  header by passing a stringify json

For ex.
x-ms-config: {retry: { times: 2 }}
  • Loading branch information
pyadav committed Feb 5, 2024
1 parent ccc550f commit ae14bdc
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 11 deletions.
1 change: 1 addition & 0 deletions gateway/config/constant.go
Expand Up @@ -2,6 +2,7 @@ package config

const (
XMSProvider = "x-ms-provider"
XMSConfig = "x-ms-config"
Authorization = "Authorization"
)

Expand Down
13 changes: 12 additions & 1 deletion gateway/internal/interceptor/interceptor.go
Expand Up @@ -6,6 +6,8 @@ import (

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/schema"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/resilience/retry"
)

Expand Down Expand Up @@ -48,7 +50,16 @@ func RetryInterceptor() connect.UnaryInterceptorFunc {
var err error
var response connect.AnyResponse

runner := retry.New(retry.Config{})
data := &schema.GatewayConfigHeaders{}
err = utils.UnmarshalConfigHeaders(req.Header(), data)
if err != nil {
return nil, err
}

runner := retry.New(retry.Config{
Times: int(data.Retry.Times),
})

err = runner.Run(ctx, func(ctx context.Context) error {
response, err = next(ctx, req)
if err != nil {
Expand Down
35 changes: 35 additions & 0 deletions gateway/internal/schema/gateway.go
@@ -0,0 +1,35 @@
package schema

import (
"time"
)

type CacheConfig struct {
// Mode specifies the type of cache with two possible modes: simple and semantic.
Mode string `json:"mode" default:"simple"`
// TTL (Time To Live) is the duration for which cache entries should be considered valid.
TTL time.Duration `json:"ttl"`
}

type RetryConfig struct {
Times int32 `json:"times" default:"1"`
// Status codes for retry
OnStatusCodes []string `json:"on_status_codes"`
}

type StrategyConfig struct {
Mode string `json:"mode" default:"fallback"`
// Status codes for retry
OnStatusCodes []string `json:"on_status_codes"`
}

type GatewayConfigHeaders struct {
// Virtual key is temporary key with configurations for the gateway
Provider string `json:"provider"`
VirtualKey string `json:"virtual_key"`
// Cache represents the cache configuration for the gateway.
Cache CacheConfig `json:"cache"`
Retry RetryConfig `json:"retry"`
Strategy StrategyConfig `json:"strategy"`
Providers []any `json:"providers"`
}
58 changes: 48 additions & 10 deletions gateway/pkg/utils/headers.go
@@ -1,33 +1,71 @@
package utils

import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"strconv"
"strings"

"github.com/go-playground/validator/v10"
"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/common/errors"
)

func isJSON(s string, v interface{}) bool {
return json.Unmarshal([]byte(s), v) == nil
}

func UnmarshalConfigHeaders(header http.Header, v interface{}) error {
msconfig := header.Get(config.XMSConfig)
if msconfig == "" && isJSON(msconfig, v) {
return errors.New(fmt.Errorf("x-ms-config header is not valid"))
}
return nil
}

// UnmarshalHeader unmarshals an http.Header into a struct
func UnmarshalHeader(header http.Header, v interface{}) error {
// Iterate over the fields in the struct
for i := 0; i < reflect.TypeOf(v).Elem().NumField(); i++ {
field := reflect.TypeOf(v).Elem().Field(i)
tag := field.Tag.Get("json") // Get the tag value

// If the tag is not empty, try to get the corresponding value from the header
if tag != "" {
value := header.Get(tag)
// Set the value in the struct field
reflect.ValueOf(v).Elem().Field(i).SetString(value)
fields := reflect.ValueOf(v).Elem()

for i := 0; i < fields.NumField(); i++ {
field := fields.Type().Field(i)
headerKey := field.Tag.Get("json")
defaultValue := field.Tag.Get("default")

if headerValue := header.Get(headerKey); headerValue != "" {
setFieldValue(fields.Field(i), headerValue)
} else if defaultValue != "" {
setFieldValue(fields.Field(i), defaultValue)
}
}

return nil
}

func setFieldValue(field reflect.Value, value string) {
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
field.SetInt(intValue)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(value, 10, 64)
if err == nil {
field.SetUint(uintValue)
}
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(value, 64)
if err == nil {
field.SetFloat(floatValue)
}
}
}

// ValidateHeaders is a generic function to validate any structure with the `validate` struct tag.
func ValidateHeaders(data interface{}) error {
validate := validator.New()
Expand Down

0 comments on commit ae14bdc

Please sign in to comment.