Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add provider configurations and validation using json schema #6

Merged
merged 3 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gateway/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (

"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/api"
"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
"github.com/missingstudio/studio/backend/internal/server"
"github.com/missingstudio/studio/common/logger"
Expand All @@ -36,7 +38,11 @@ func Serve(cfg *config.Config) error {

rl := ratelimiter.NewRateLimiter(cfg.Ratelimiter, logger, cfg.Ratelimiter.Type, rdb)
ingester := ingester.GetIngesterWithDefault(ctx, cfg.Ingester, logger)
deps := api.NewDeps(logger, ingester, rl)

providerService := providers.NewService()
connectionService := connections.NewService()

deps := api.NewDeps(logger, ingester, rl, providerService, connectionService)

if err := server.Serve(ctx, logger, cfg.App, deps); err != nil {
logger.Error("error starting server", "error", err)
Expand Down
6 changes: 4 additions & 2 deletions gateway/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/MakeNowJust/heredoc v1.0.0
github.com/go-playground/validator/v10 v10.18.0
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.4.0
github.com/hashicorp/consul/api v1.25.1
github.com/mcuadros/go-defaults v1.2.0
github.com/missingstudio/studio/common v0.0.0-00010101000000-000000000000
Expand All @@ -29,15 +30,18 @@ require (
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
github.com/xeipuuv/gojsonschema v1.2.0
github.com/zeebo/assert v1.3.1
golang.org/x/net v0.21.0
golang.org/x/text v0.14.0
google.golang.org/protobuf v1.32.0
gopkg.in/yaml.v2 v2.4.0
)

require (
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20230824200731-b9b8148056b9.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect
github.com/Jeffail/gabs/v2 v2.7.0
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 // indirect
Expand Down Expand Up @@ -107,7 +111,6 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opentelemetry.io/otel v1.19.0 // indirect
go.opentelemetry.io/otel/metric v1.19.0 // indirect
Expand All @@ -124,7 +127,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231212172506-995d672761c0 // indirect
google.golang.org/grpc v1.61.0 // indirect
google.golang.org/protobuf v1.32.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions gateway/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/InfluxCommunity/influxdb3-go v0.5.0 h1:s79+Sw1CFDG5xzaqZGK9gLkkfPffWUwN6cEerAGZfRU=
github.com/InfluxCommunity/influxdb3-go v0.5.0/go.mod h1:jodr5YDf5zQANV+N2bIaYWrW9J5epnnGVLJOJl005lM=
github.com/Jeffail/gabs/v2 v2.7.0 h1:Y2edYaTcE8ZpRsR2AtmPu5xQdFDIthFG0jYhu5PY8kg=
github.com/Jeffail/gabs/v2 v2.7.0/go.mod h1:dp5ocw1FvBBQYssgHsG7I1WYsiLRtkUaB1FEtSwvNUw=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg=
Expand Down
28 changes: 20 additions & 8 deletions gateway/internal/api/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,32 @@ package api
import (
"log/slog"

"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
)

type Deps struct {
Logger *slog.Logger
Ingester ingester.Ingester
RateLimiter *ratelimiter.RateLimiter
Logger *slog.Logger
Ingester ingester.Ingester
RateLimiter *ratelimiter.RateLimiter
ProviderService *providers.Service
ConnectionService *connections.Service
}

func NewDeps(logger *slog.Logger, ingester ingester.Ingester, ratelimiter *ratelimiter.RateLimiter) Deps {
return Deps{
Logger: logger,
Ingester: ingester,
RateLimiter: ratelimiter,
func NewDeps(
logger *slog.Logger,
ingester ingester.Ingester,
ratelimiter *ratelimiter.RateLimiter,
ps *providers.Service,
cs *connections.Service,
) *Deps {
return &Deps{
Logger: logger,
Ingester: ingester,
RateLimiter: ratelimiter,
ProviderService: ps,
ConnectionService: cs,
}
}
35 changes: 26 additions & 9 deletions gateway/internal/api/v1/chatcompletions.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
Expand All @@ -25,26 +26,42 @@ func (s *V1Handler) GetChatCompletions(
) (*connect.Response[llmv1.ChatCompletionResponse], error) {
startTime := time.Now()

providerName := req.Header().Get(constants.XMSProvider)
provider, err := providers.NewProvider(providerName, req.Header())
payload, err := json.Marshal(req.Msg)
if err != nil {
return nil, errors.New(err)
}

if err := provider.Validate(); err != nil {
// Convert headers into map[string]any
headerConfig := make(map[string]any)
for key, values := range req.Header() {
if len(values) > 0 {
headerConfig[key] = values[0]
}
}

providerName := req.Header().Get(constants.XMSProvider)
connectionObj := models.Connection{}
connectionObj.Name = providerName
connectionObj.Headers = headerConfig

provider, err := s.providerService.GetProvider(connectionObj)
if err != nil {
return nil, errors.New(err)
}

// Validate provider configs
err = providers.Validate(provider, map[string]any{
"headers": headerConfig,
})
if err != nil {
return nil, errors.NewBadRequest(err.Error())
}

chatCompletionProvider, ok := provider.(base.ChatCompletionInterface)
if !ok {
return nil, ErrChatCompletionNotSupported
}

payload, err := json.Marshal(req.Msg)
if err != nil {
return nil, errors.New(err)
}

resp, err := chatCompletionProvider.ChatCompletion(ctx, payload)
if err != nil {
return nil, errors.New(err)
Expand All @@ -57,7 +74,7 @@ func (s *V1Handler) GetChatCompletions(
}

ingesterdata := make(map[string]interface{})
ingesterdata["provider"] = provider.GetName()
ingesterdata["provider"] = provider.Name()
ingesterdata["model"] = data.Model
ingesterdata["latency"] = latency
ingesterdata["total_tokens"] = *data.Usage.TotalTokens
Expand Down
13 changes: 6 additions & 7 deletions gateway/internal/api/v1/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@ import (
"context"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/models"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.ModelRequest]) (*connect.Response[llmv1.ModelResponse], error) {
plist := providers.Providers
allProviderModels := map[string]*llmv1.ProviderModels{}

for _, p := range plist {
providerfactory, err := providers.NewProvider(p, req.Header())
for name := range models.ProviderRegistry {
provider, err := s.providerService.GetProvider(models.Connection{Name: name})
if err != nil {
continue
}

providerName := providerfactory.GetName()
providerModels := providerfactory.GetModels()
providerName := provider.Name()
providerModels := provider.Models()

var models []*llmv1.Model
for _, val := range providerModels {
Expand All @@ -29,7 +28,7 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M
})
}

allProviderModels[p] = &llmv1.ProviderModels{
allProviderModels[name] = &llmv1.ProviderModels{
Name: providerName,
Models: models,
}
Expand Down
24 changes: 24 additions & 0 deletions gateway/internal/api/v1/providers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package v1

import (
"context"

"connectrpc.com/connect"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
"google.golang.org/protobuf/types/known/emptypb"
)

func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ProvidersResponse], error) {
providers := s.providerService.GetProviders()

data := []*llmv1.Provider{}
for name := range providers {
data = append(data, &llmv1.Provider{
Name: name,
})
}

return connect.NewResponse(&llmv1.ProvidersResponse{
Providers: data,
}), nil
}
16 changes: 11 additions & 5 deletions gateway/internal/api/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,35 @@ import (
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/internal/api"
"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/interceptor"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
)

type V1Handler struct {
llmv1connect.UnimplementedLLMServiceHandler
ingester ingester.Ingester
ingester ingester.Ingester
providerService *providers.Service
connectionService *connections.Service
}

func NewHandlerV1(ingester ingester.Ingester) *V1Handler {
func NewHandlerV1(d *api.Deps) *V1Handler {
return &V1Handler{
ingester: ingester,
ingester: d.Ingester,
providerService: d.ProviderService,
connectionService: d.ConnectionService,
}
}

func Register(d api.Deps) (http.Handler, error) {
func Register(d *api.Deps) (http.Handler, error) {
validateInterceptor, err := validate.NewInterceptor()
if err != nil {
return nil, fmt.Errorf("failed to create validate interceptor: %w", err)
}

v1Handler := NewHandlerV1(d.Ingester)
v1Handler := NewHandlerV1(d)
otelconnectInterceptor, err := otelconnect.NewInterceptor(otelconnect.WithTrustRemote())
if err != nil {
return nil, fmt.Errorf("failed to create validate otel connect: %w", err)
Expand Down
7 changes: 7 additions & 0 deletions gateway/internal/connections/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package connections

type Service struct{}

func NewService() *Service {
return &Service{}
}
2 changes: 1 addition & 1 deletion gateway/internal/connectrpc/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
)

func NewConnectMux(d api.Deps) (*http.ServeMux, error) {
func NewConnectMux(d *api.Deps) (*http.ServeMux, error) {
mux := http.NewServeMux()

compress1KB := connect.WithCompressMinBytes(1024)
Expand Down
18 changes: 9 additions & 9 deletions gateway/internal/mock/mock_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@ package mock

import "github.com/missingstudio/studio/backend/internal/providers/base"

var _ base.ProviderInterface = &providerMock{}
var _ base.IProvider = &providerMock{}

type providerMock struct {
Name string
name string
}

func NewProviderMock(name string) base.ProviderInterface {
func NewProviderMock(name string) base.IProvider {
return &providerMock{
Name: name,
name: name,
}
}

func (p providerMock) GetName() string {
return p.Name
func (p providerMock) Name() string {
return p.name
}

func (p providerMock) Validate() error {
return nil
func (p providerMock) Schema() []byte {
return []byte{}
}

func (*providerMock) GetModels() []string {
func (p providerMock) Models() []string {
return []string{}
}
16 changes: 13 additions & 3 deletions gateway/internal/providers/anyscale/anyscale.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@ import (
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
client := requester.NewHTTPClient()
requestURL := fmt.Sprintf("%s%s", anyscale.Config.BaseURL, anyscale.Config.ChatCompletions)
requestURL := fmt.Sprintf("%s%s", anyscale.config.BaseURL, anyscale.config.ChatCompletions)
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))

connectionConfigMap := anyscale.conn.GetHeaders([]string{
models.AuthorizationHeader,
})

var authorizationHeader string
if val, ok := connectionConfigMap[models.AuthorizationHeader].(string); ok && val != "" {
authorizationHeader = val
}

req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", anyscale.APIKey))
req.Header.Add("Authorization", authorizationHeader)

return client.SendRequestRaw(req)
}

func (anyscale *anyscaleProvider) GetModels() []string {
func (anyscale *anyscaleProvider) Models() []string {
return []string{
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-chat-hf",
Expand Down
Loading