Skip to content

Commit

Permalink
refactor: add factory pattern to create provider
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Jan 25, 2024
1 parent 22f2ede commit aa0a6e8
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mobius/internal/connectrpc/mux.go
Expand Up @@ -65,7 +65,7 @@ func (s *LLMServer) ChatCompletions(
) (*connect.Response[llmv1.CompletionResponse], error) {
log.Println("Request headers: ", req.Header())

provider, err := providers.NewLLMProvider(req.Header())
provider, err := providers.GetProvider(req.Header())
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
Expand Down
16 changes: 16 additions & 0 deletions mobius/internal/providers/base/base.go
@@ -0,0 +1,16 @@
package base

import (
"context"

llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

type ProviderConfig struct {
BaseURL string
ChatCompletions string
}

type LLMProvider interface {
ChatCompilation(ctx context.Context, ra *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
}
30 changes: 26 additions & 4 deletions mobius/internal/providers/openai/openai.go
Expand Up @@ -7,23 +7,45 @@ import (
"io"
"net/http"

"github.com/missingstudio/studio/backend/internal/providers/base"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

var OpenAIAPIURL = "https://api.openai.com/v1/chat/completions"
type OpenAIProviderFactory struct{}

type OpenAI struct {
func (f OpenAIProviderFactory) Create(token string) base.LLMProvider {
openAIProvider := NewOpenAIProvider(token, "https://api.openai.com")
return openAIProvider
}

type OpenAIProvider struct {
APIKey string
Config base.ProviderConfig
}

func NewOpenAIProvider(token string, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)
return &OpenAIProvider{
APIKey: token,
Config: config,
}
}

func getOpenAIConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
}
}

func (oai OpenAI) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
func (oai OpenAIProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
payload, err := json.Marshal(cr)
if err != nil {
return nil, err
}

client := &http.Client{}
req, _ := http.NewRequestWithContext(ctx, "POST", OpenAIAPIURL, bytes.NewReader(payload))
req, _ := http.NewRequestWithContext(ctx, "POST", oai.Config.BaseURL+oai.Config.ChatCompletions, bytes.NewReader(payload))
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+oai.APIKey)

Expand Down
21 changes: 14 additions & 7 deletions mobius/internal/providers/providers.go
@@ -1,27 +1,34 @@
package providers

import (
"context"
"errors"
"net/http"
"strings"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/providers/openai"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

type LLMProvider interface {
ChatCompilation(ctx context.Context, ra *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
type ProviderFactory interface {
Create(token string) base.LLMProvider
}

func NewLLMProvider(headers http.Header) (LLMProvider, error) {
var providerFactories = make(map[string]ProviderFactory)

func init() {
providerFactories["openai"] = openai.OpenAIProviderFactory{}
}

func GetProvider(headers http.Header) (base.LLMProvider, error) {
provider := headers.Get("x-ms-provider")
if provider == "" {
providerFactory, ok := providerFactories[provider]
if !ok {
return nil, connect.NewError(connect.CodeNotFound, errors.New("provider not found"))
}

authHeader := headers.Get("Authorization")
accessToken := strings.Replace(authHeader, "Bearer ", "", 1)

return &openai.OpenAI{APIKey: accessToken}, nil
return providerFactory.Create(accessToken), nil
}

0 comments on commit aa0a6e8

Please sign in to comment.