Skip to content

Commit

Permalink
feat(gateway): add additional provider API support
Browse files Browse the repository at this point in the history
- get provider by id
- get provider configs

Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 19, 2024
1 parent 84ef294 commit 1087447
Show file tree
Hide file tree
Showing 17 changed files with 847 additions and 163 deletions.
12 changes: 7 additions & 5 deletions gateway/internal/api/v1/chatcompletions.go
Expand Up @@ -40,10 +40,10 @@ func (s *V1Handler) GetChatCompletions(
}

providerName := req.Header().Get(constants.XMSProvider)
connectionObj := models.Connection{}
connectionObj.Name = providerName
connectionObj.Headers = headerConfig

connectionObj := models.Connection{
Name: providerName,
Headers: headerConfig,
}
provider, err := s.providerService.GetProvider(connectionObj)
if err != nil {
return nil, errors.New(err)
Expand Down Expand Up @@ -74,7 +74,9 @@ func (s *V1Handler) GetChatCompletions(
}

ingesterdata := make(map[string]interface{})
ingesterdata["provider"] = provider.Name()
providerInfo := provider.Info()

ingesterdata["provider"] = providerInfo.Name
ingesterdata["model"] = data.Model
ingesterdata["latency"] = latency
ingesterdata["total_tokens"] = *data.Usage.TotalTokens
Expand Down
3 changes: 2 additions & 1 deletion gateway/internal/api/v1/models.go
Expand Up @@ -17,8 +17,9 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M
continue
}

providerName := provider.Name()
providerInfo := provider.Info()
providerModels := provider.Models()
providerName := providerInfo.Name

var models []*llmv1.Model
for _, val := range providerModels {
Expand Down
50 changes: 48 additions & 2 deletions gateway/internal/api/v1/providers.go
Expand Up @@ -2,23 +2,69 @@ package v1

import (
"context"
"encoding/json"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
)

func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ProvidersResponse], error) {
providers := s.providerService.GetProviders()

data := []*llmv1.Provider{}
for name := range providers {
for _, provider := range providers {
providerInfo := provider.Info()
data = append(data, &llmv1.Provider{
Name: name,
Title: providerInfo.Title,
Name: providerInfo.Name,
Description: providerInfo.Description,
})
}

return connect.NewResponse(&llmv1.ProvidersResponse{
Providers: data,
}), nil
}

func (s *V1Handler) GetProviderById(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) {
provider, err := s.providerService.GetProvider(models.Connection{Name: req.Msg.Id})
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

info := provider.Info()
p := &llmv1.Provider{
Title: info.Title,
Name: info.Name,
Description: info.Description,
}

return connect.NewResponse(&llmv1.GetProviderResponse{
Provider: p,
}), nil
}

func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) {
provider, err := s.providerService.GetProvider(models.Connection{Name: req.Msg.Id})
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

configs := map[string]any{}
if err := json.Unmarshal(provider.Schema(), &configs); err != nil {
return nil, errors.NewInternalError(err.Error())
}

stConfigs, err := structpb.NewStruct(configs)
if err != nil {
return nil, errors.NewInternalError(err.Error())
}

return connect.NewResponse(&llmv1.GetProviderConfigResponse{
Config: stConfigs,
}), nil
}
13 changes: 9 additions & 4 deletions gateway/internal/mock/mock_provider.go
Expand Up @@ -5,17 +5,22 @@ import "github.com/missingstudio/studio/backend/internal/providers/base"
var _ base.IProvider = &providerMock{}

type providerMock struct {
name string
info base.ProviderInfo
config base.ProviderConfig
}

func NewProviderMock(name string) base.IProvider {
return &providerMock{
name: name,
info: base.ProviderInfo{Name: name},
}
}

func (p providerMock) Name() string {
return p.name
func (p providerMock) Info() base.ProviderInfo {
return p.info
}

func (p providerMock) Config() base.ProviderConfig {
return p.config
}

func (p providerMock) Schema() []byte {
Expand Down
21 changes: 17 additions & 4 deletions gateway/internal/providers/anyscale/base.go
Expand Up @@ -13,19 +13,32 @@ var schema []byte
var _ base.IProvider = &anyscaleProvider{}

type anyscaleProvider struct {
name string
info base.ProviderInfo
config base.ProviderConfig
conn models.Connection
}

func (anyscale anyscaleProvider) Name() string {
return anyscale.name
func (anyscale anyscaleProvider) Info() base.ProviderInfo {
return anyscale.info
}

func (anyscale anyscaleProvider) Config() base.ProviderConfig {
return anyscale.config
}

func (anyscale anyscaleProvider) Schema() []byte {
return schema
}

func getAnyscaleInfo() base.ProviderInfo {
return base.ProviderInfo{
Title: "Anyscale",
Name: "anyscale",
Description: `Anyscale Endpoints is a fast and scalable API to integrate OSS LLMs into your app.
Use our growing list of high performance models or deploy your own.`,
}
}

func getAnyscaleConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Expand All @@ -37,7 +50,7 @@ func init() {
models.ProviderRegistry["anyscale"] = func(connection models.Connection) base.IProvider {
config := getAnyscaleConfig("https://api.endpoints.anyscale.com")
return &anyscaleProvider{
name: "Anyscale",
info: getAnyscaleInfo(),
config: config,
conn: connection,
}
Expand Down
29 changes: 16 additions & 13 deletions gateway/internal/providers/azure/base.go
Expand Up @@ -13,33 +13,36 @@ var schema []byte
var _ base.IProvider = &azureProvider{}

type azureProvider struct {
name string
info base.ProviderInfo
config base.ProviderConfig
conn models.Connection
}

func (az azureProvider) Name() string {
return az.name
func (anyscale azureProvider) Info() base.ProviderInfo {
return anyscale.info
}

func (az azureProvider) Config() base.ProviderConfig {
return az.config
}

func (az azureProvider) Schema() []byte {
return schema
}

func getAzureInfo() base.ProviderInfo {
return base.ProviderInfo{
Title: "Azure",
Name: "azure",
Description: "Azure OpenAI Service offers industry-leading coding and language AI models that you can fine-tune to your specific needs for a variety of use cases.",
}
}

func getAzureConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "",
ChatCompletions: "/chat/completions",
}
}

func init() {
models.ProviderRegistry["azure"] = func(connection models.Connection) base.IProvider {
config := getAzureConfig()
return &azureProvider{
name: "Azure",
config: config,
conn: connection,
}
}
}
func init() {}
8 changes: 7 additions & 1 deletion gateway/internal/providers/base/base.go
Expand Up @@ -9,9 +9,15 @@ type ProviderConfig struct {
BaseURL string
ChatCompletions string
}
type ProviderInfo struct {
Title string
Name string
Description string
}

type IProvider interface {
Name() string
Info() ProviderInfo
Config() ProviderConfig
Models() []string
Schema() []byte
}
Expand Down
21 changes: 17 additions & 4 deletions gateway/internal/providers/deepinfra/base.go
Expand Up @@ -13,19 +13,32 @@ var schema []byte
var _ base.IProvider = &deepinfraProvider{}

type deepinfraProvider struct {
name string
info base.ProviderInfo
config base.ProviderConfig
conn models.Connection
}

func (deepinfra deepinfraProvider) Name() string {
return deepinfra.name
func (anyscale deepinfraProvider) Info() base.ProviderInfo {
return anyscale.info
}

func (deepinfra deepinfraProvider) Config() base.ProviderConfig {
return deepinfra.config
}

func (deepinfra deepinfraProvider) Schema() []byte {
return schema
}

func getDeepinfraInfo() base.ProviderInfo {
return base.ProviderInfo{
Title: "Deepinfra",
Name: "deepinfra",
Description: `Deep Infra offers 100+ machine learning models from Text-to-Image, Object-Detection,
Automatic-Speech-Recognition, Text-to-Text Generation, and more!`,
}
}

func getDeepinfraConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Expand All @@ -37,7 +50,7 @@ func init() {
models.ProviderRegistry["deepinfra"] = func(connection models.Connection) base.IProvider {
config := getDeepinfraConfig("https://api.deepinfra.com/v1/openai")
return &deepinfraProvider{
name: "Deepinfra",
info: getDeepinfraInfo(),
config: config,
conn: connection,
}
Expand Down
20 changes: 16 additions & 4 deletions gateway/internal/providers/openai/base.go
Expand Up @@ -13,19 +13,31 @@ var schema []byte
var _ base.IProvider = &openAIProvider{}

type openAIProvider struct {
name string
info base.ProviderInfo
config base.ProviderConfig
conn models.Connection
}

func (oai openAIProvider) Name() string {
return oai.name
func (anyscale openAIProvider) Info() base.ProviderInfo {
return anyscale.info
}

func (oai openAIProvider) Config() base.ProviderConfig {
return oai.config
}

func (oai openAIProvider) Schema() []byte {
return schema
}

func getOpenAIInfo() base.ProviderInfo {
return base.ProviderInfo{
Title: "OpenAI",
Name: "openai",
Description: `OpenAI API platform offers latest models and guides for safety best practices.`,
}
}

func getOpenAIConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Expand All @@ -37,7 +49,7 @@ func init() {
models.ProviderRegistry["openai"] = func(connection models.Connection) base.IProvider {
config := getOpenAIConfig("https://api.openai.com")
return &openAIProvider{
name: "OpenAI",
info: getOpenAIInfo(),
config: config,
conn: connection,
}
Expand Down
20 changes: 16 additions & 4 deletions gateway/internal/providers/togetherai/base.go
Expand Up @@ -13,19 +13,31 @@ var schema []byte
var _ base.IProvider = &togetherAIProvider{}

type togetherAIProvider struct {
name string
info base.ProviderInfo
config base.ProviderConfig
conn models.Connection
}

func (togetherAI togetherAIProvider) Name() string {
return togetherAI.name
func (anyscale togetherAIProvider) Info() base.ProviderInfo {
return anyscale.info
}

func (togetherAI togetherAIProvider) Config() base.ProviderConfig {
return togetherAI.config
}

func (togetherAI togetherAIProvider) Schema() []byte {
return schema
}

func getTogetherAIInfo() base.ProviderInfo {
return base.ProviderInfo{
Title: "Together AI",
Name: "togetherai",
Description: `Build gen AI models with Together AI. Benefit from the fastest and most cost-efficient tools and infra.`,
}
}

func getTogetherAIConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Expand All @@ -37,7 +49,7 @@ func init() {
models.ProviderRegistry["togetherai"] = func(connection models.Connection) base.IProvider {
config := getTogetherAIConfig("https://api.together.xyz")
return &togetherAIProvider{
name: "Together AI",
info: getTogetherAIInfo(),
config: config,
conn: connection,
}
Expand Down

0 comments on commit 1087447

Please sign in to comment.