diff --git a/.gitignore b/.gitignore index b08f940..89e2123 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ bin node_modules .DS_Store *.log +.vscode .env .env.* diff --git a/mobius/.gitignore b/mobius/.gitignore index 445f2cd..0973e13 100644 --- a/mobius/.gitignore +++ b/mobius/.gitignore @@ -1,7 +1,9 @@ bin .DS_Store +.env .env.* +config.yaml *.log node_modules diff --git a/mobius/internal/api/v1/chatcompletions.go b/mobius/internal/api/v1/chatcompletions.go index e249c79..462b59f 100644 --- a/mobius/internal/api/v1/chatcompletions.go +++ b/mobius/internal/api/v1/chatcompletions.go @@ -14,14 +14,14 @@ func (s *V1Handler) ChatCompletions( ctx context.Context, req *connect.Request[llmv1.CompletionRequest], ) (*connect.Response[llmv1.CompletionResponse], error) { - provider, err := providers.GetProvider(ctx) + provider, err := providers.GetProvider(ctx, req.Header()) if err != nil { - return nil, errors.NewNotFound("provider not found") + return nil, errors.New(err) } completionProvider, ok := provider.(base.ChatCompilationInterface) if !ok { - return nil, errors.NewInternalError("not able to get chat compilation provider") + return nil, errors.NewInternalError("provider don't have chat compilation capabilities") } data, err := completionProvider.ChatCompilation(ctx, req.Msg) diff --git a/mobius/internal/providers/azure/azure.go b/mobius/internal/providers/azure/azure.go new file mode 100644 index 0000000..9f58248 --- /dev/null +++ b/mobius/internal/providers/azure/azure.go @@ -0,0 +1,12 @@ +package azure + +import ( + "context" + "errors" + + llmv1 "github.com/missingstudio/studio/protos/pkg/llm" +) + +func (az *AzureProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) { + return nil, errors.New("Not yet implemented") +} diff --git a/mobius/internal/providers/azure/base.go b/mobius/internal/providers/azure/base.go new file mode 100644 index 0000000..8a18099 --- /dev/null +++ b/mobius/internal/providers/azure/base.go @@ -0,0 +1,56 @@ +package azure + +import ( + "net/http" + "strings" + + "github.com/missingstudio/studio/backend/config" + "github.com/missingstudio/studio/backend/internal/providers/base" + "github.com/missingstudio/studio/common/errors" +) + +type AzureProviderFactory struct{} + +func (f AzureProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) { + authorization := headers.Get(config.Authorization) + if authorization == "" { + return nil, errors.NewBadRequest("authorization header is required") + } + + authorizationKey := strings.Replace(authorization, "Bearer ", "", 1) + azureProvider := NewazureProvider(authorizationKey) + return azureProvider, nil +} + +type AzureHeaders struct { + APIKey string +} + +type AzureProvider struct { + Name string + Config base.ProviderConfig + AzureHeaders +} + +func NewazureProvider(apikey string) *AzureProvider { + config := getAzureConfig() + + return &AzureProvider{ + Name: "Azure AI", + AzureHeaders: AzureHeaders{ + APIKey: apikey, + }, + Config: config, + } +} + +func (az *AzureProvider) GetName() string { + return az.Name +} + +func getAzureConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "", + ChatCompletions: "/chat/completions", + } +} diff --git a/mobius/internal/providers/base/base.go b/mobius/internal/providers/base/base.go index 5a3b059..e7b3e3c 100644 --- a/mobius/internal/providers/base/base.go +++ b/mobius/internal/providers/base/base.go @@ -11,7 +11,9 @@ type ProviderConfig struct { ChatCompletions string } -type ProviderInterface interface{} +type ProviderInterface interface { + GetName() string +} type ChatCompilationInterface interface { ProviderInterface diff --git a/mobius/internal/providers/openai/base.go b/mobius/internal/providers/openai/base.go index 0e51bb9..f088e73 100644 --- a/mobius/internal/providers/openai/base.go +++ b/mobius/internal/providers/openai/base.go @@ -1,28 +1,51 @@ package openai import ( + "net/http" + "strings" + + "github.com/missingstudio/studio/backend/config" "github.com/missingstudio/studio/backend/internal/providers/base" + "github.com/missingstudio/studio/common/errors" ) -type OpenAIProvider struct { +type OpenAIProviderFactory struct{} + +func (f OpenAIProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) { + authorization := headers.Get(config.Authorization) + if authorization == "" { + return nil, errors.NewBadRequest("authorization header is required") + } + + authorizationKey := strings.Replace(authorization, "Bearer ", "", 1) + openAIProvider := NewOpenAIProvider(authorizationKey, "https://api.openai.com") + return openAIProvider, nil +} + +type OpenAIHeaders struct { APIKey string +} + +type OpenAIProvider struct { + Name string Config base.ProviderConfig + OpenAIHeaders } func NewOpenAIProvider(apikey string, baseURL string) *OpenAIProvider { config := getOpenAIConfig(baseURL) return &OpenAIProvider{ - APIKey: apikey, + Name: "Open AI", + OpenAIHeaders: OpenAIHeaders{ + APIKey: apikey, + }, Config: config, } } -type OpenAIProviderFactory struct{} - -func (f OpenAIProviderFactory) Create(apikey string) base.ProviderInterface { - openAIProvider := NewOpenAIProvider(apikey, "https://api.openai.com") - return openAIProvider +func (oai *OpenAIProvider) GetName() string { + return oai.Name } func getOpenAIConfig(baseURL string) base.ProviderConfig { diff --git a/mobius/internal/providers/providers.go b/mobius/internal/providers/providers.go index 0243316..1a2b742 100644 --- a/mobius/internal/providers/providers.go +++ b/mobius/internal/providers/providers.go @@ -2,38 +2,36 @@ package providers import ( "context" - "errors" + "net/http" - "connectrpc.com/connect" "github.com/missingstudio/studio/backend/config" + "github.com/missingstudio/studio/backend/internal/providers/azure" "github.com/missingstudio/studio/backend/internal/providers/base" "github.com/missingstudio/studio/backend/internal/providers/openai" + "github.com/missingstudio/studio/common/errors" ) type ProviderFactory interface { - Create(token string) base.ProviderInterface + Create(headers http.Header) (base.ProviderInterface, error) } var providerFactories = make(map[string]ProviderFactory) func init() { providerFactories["openai"] = openai.OpenAIProviderFactory{} + providerFactories["azure"] = azure.AzureProviderFactory{} } -func GetProvider(ctx context.Context) (base.ProviderInterface, error) { +func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) { providerName, ok := ctx.Value(config.ProviderKey{}).(string) if !ok { - return nil, connect.NewError(connect.CodeNotFound, errors.New("failed to get provider")) - } - - authkey, ok := ctx.Value(config.AuthorizationKey{}).(string) - if !ok { - return nil, connect.NewError(connect.CodeNotFound, errors.New("failed to get access key")) + return nil, errors.NewBadRequest("provider is required from headers") } providerFactory, ok := providerFactories[providerName] if !ok { - return nil, connect.NewError(connect.CodeNotFound, errors.New("provider not found")) + return nil, errors.NewNotFound("provider is not available") } - return providerFactory.Create(authkey), nil + + return providerFactory.Create(headers) } diff --git a/mobius/pkg/utils/interceptor.go b/mobius/pkg/utils/interceptor.go index d00549c..8074aa1 100644 --- a/mobius/pkg/utils/interceptor.go +++ b/mobius/pkg/utils/interceptor.go @@ -3,7 +3,6 @@ package utils import ( "context" "errors" - "strings" "connectrpc.com/connect" "github.com/missingstudio/studio/backend/config" @@ -33,15 +32,6 @@ func ProviderInterceptor() connect.UnaryInterceptorFunc { } ctx = context.WithValue(ctx, config.ProviderKey{}, provider) - - authorization := req.Header().Get(config.Authorization) - if authorization == "" { - return nil, errors.New("Authorization header is required") - } - - authorizationKey := strings.Replace(authorization, "Bearer ", "", 1) - ctx = context.WithValue(ctx, config.AuthorizationKey{}, authorizationKey) - return next(ctx, req) }) }