Skip to content

Commit

Permalink
refactor: validation process of required headers for providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Jan 31, 2024
1 parent d5226bb commit 6e27d37
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -2,6 +2,7 @@ bin
node_modules
.DS_Store
*.log
.vscode

.env
.env.*
Expand Down
2 changes: 2 additions & 0 deletions mobius/.gitignore
@@ -1,7 +1,9 @@
bin
.DS_Store

.env
.env.*
config.yaml
*.log

node_modules
Expand Down
6 changes: 3 additions & 3 deletions mobius/internal/api/v1/chatcompletions.go
Expand Up @@ -14,14 +14,14 @@ func (s *V1Handler) ChatCompletions(
ctx context.Context,
req *connect.Request[llmv1.CompletionRequest],
) (*connect.Response[llmv1.CompletionResponse], error) {
provider, err := providers.GetProvider(ctx)
provider, err := providers.GetProvider(ctx, req.Header())
if err != nil {
return nil, errors.NewNotFound("provider not found")
return nil, errors.New(err)
}

completionProvider, ok := provider.(base.ChatCompilationInterface)
if !ok {
return nil, errors.NewInternalError("not able to get chat compilation provider")
return nil, errors.NewInternalError("provider don't have chat compilation capabilities")
}

data, err := completionProvider.ChatCompilation(ctx, req.Msg)
Expand Down
12 changes: 12 additions & 0 deletions mobius/internal/providers/azure/azure.go
@@ -0,0 +1,12 @@
package azure

import (
"context"
"errors"

llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (az *AzureProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
return nil, errors.New("Not yet implemented")
}
56 changes: 56 additions & 0 deletions mobius/internal/providers/azure/base.go
@@ -0,0 +1,56 @@
package azure

import (
"net/http"
"strings"

"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/common/errors"
)

type AzureProviderFactory struct{}

func (f AzureProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) {
authorization := headers.Get(config.Authorization)
if authorization == "" {
return nil, errors.NewBadRequest("authorization header is required")
}

authorizationKey := strings.Replace(authorization, "Bearer ", "", 1)
azureProvider := NewazureProvider(authorizationKey)
return azureProvider, nil
}

type AzureHeaders struct {
APIKey string
}

type AzureProvider struct {
Name string
Config base.ProviderConfig
AzureHeaders
}

func NewazureProvider(apikey string) *AzureProvider {
config := getAzureConfig()

return &AzureProvider{
Name: "Azure AI",
AzureHeaders: AzureHeaders{
APIKey: apikey,
},
Config: config,
}
}

func (az *AzureProvider) GetName() string {
return az.Name
}

func getAzureConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "",
ChatCompletions: "/chat/completions",
}
}
4 changes: 3 additions & 1 deletion mobius/internal/providers/base/base.go
Expand Up @@ -11,7 +11,9 @@ type ProviderConfig struct {
ChatCompletions string
}

type ProviderInterface interface{}
type ProviderInterface interface {
GetName() string
}

type ChatCompilationInterface interface {
ProviderInterface
Expand Down
37 changes: 30 additions & 7 deletions mobius/internal/providers/openai/base.go
@@ -1,28 +1,51 @@
package openai

import (
"net/http"
"strings"

"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/common/errors"
)

type OpenAIProvider struct {
type OpenAIProviderFactory struct{}

func (f OpenAIProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) {
authorization := headers.Get(config.Authorization)
if authorization == "" {
return nil, errors.NewBadRequest("authorization header is required")
}

authorizationKey := strings.Replace(authorization, "Bearer ", "", 1)
openAIProvider := NewOpenAIProvider(authorizationKey, "https://api.openai.com")
return openAIProvider, nil
}

type OpenAIHeaders struct {
APIKey string
}

type OpenAIProvider struct {
Name string
Config base.ProviderConfig
OpenAIHeaders
}

func NewOpenAIProvider(apikey string, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)

return &OpenAIProvider{
APIKey: apikey,
Name: "Open AI",
OpenAIHeaders: OpenAIHeaders{
APIKey: apikey,
},
Config: config,
}
}

type OpenAIProviderFactory struct{}

func (f OpenAIProviderFactory) Create(apikey string) base.ProviderInterface {
openAIProvider := NewOpenAIProvider(apikey, "https://api.openai.com")
return openAIProvider
func (oai *OpenAIProvider) GetName() string {
return oai.Name
}

func getOpenAIConfig(baseURL string) base.ProviderConfig {
Expand Down
22 changes: 10 additions & 12 deletions mobius/internal/providers/providers.go
Expand Up @@ -2,38 +2,36 @@ package providers

import (
"context"
"errors"
"net/http"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/providers/azure"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/providers/openai"
"github.com/missingstudio/studio/common/errors"
)

type ProviderFactory interface {
Create(token string) base.ProviderInterface
Create(headers http.Header) (base.ProviderInterface, error)
}

var providerFactories = make(map[string]ProviderFactory)

func init() {
providerFactories["openai"] = openai.OpenAIProviderFactory{}
providerFactories["azure"] = azure.AzureProviderFactory{}
}

func GetProvider(ctx context.Context) (base.ProviderInterface, error) {
func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) {
providerName, ok := ctx.Value(config.ProviderKey{}).(string)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, errors.New("failed to get provider"))
}

authkey, ok := ctx.Value(config.AuthorizationKey{}).(string)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, errors.New("failed to get access key"))
return nil, errors.NewBadRequest("provider is required from headers")
}

providerFactory, ok := providerFactories[providerName]
if !ok {
return nil, connect.NewError(connect.CodeNotFound, errors.New("provider not found"))
return nil, errors.NewNotFound("provider is not available")
}
return providerFactory.Create(authkey), nil

return providerFactory.Create(headers)
}
10 changes: 0 additions & 10 deletions mobius/pkg/utils/interceptor.go
Expand Up @@ -3,7 +3,6 @@ package utils
import (
"context"
"errors"
"strings"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/config"
Expand Down Expand Up @@ -33,15 +32,6 @@ func ProviderInterceptor() connect.UnaryInterceptorFunc {
}

ctx = context.WithValue(ctx, config.ProviderKey{}, provider)

authorization := req.Header().Get(config.Authorization)
if authorization == "" {
return nil, errors.New("Authorization header is required")
}

authorizationKey := strings.Replace(authorization, "Bearer ", "", 1)
ctx = context.WithValue(ctx, config.AuthorizationKey{}, authorizationKey)

return next(ctx, req)
})
}
Expand Down

0 comments on commit 6e27d37

Please sign in to comment.