Skip to content

Commit

Permalink
feat: support anyscale provider
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Feb 1, 2024
1 parent 7ceedcf commit 9ef2bce
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 3 deletions.
40 changes: 40 additions & 0 deletions mobius/internal/providers/anyscale/anyscale.go
@@ -0,0 +1,40 @@
package anyscale

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/pkg/requester"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (anyscale *AnyscaleProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
payload, err := json.Marshal(cr)
if err != nil {
return nil, err
}

client := requester.NewHTTPClient()
requestURL := fmt.Sprintf("%s%s", anyscale.Config.BaseURL, anyscale.Config.ChatCompletions)
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))

req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", anyscale.APIKey))

resp, err := client.Do(req)
fmt.Println(resp)
if err != nil {
return nil, err
}

var data llmv1.CompletionResponse
err = json.Unmarshal(resp, &data)
if err != nil {
return nil, err
}

return &data, nil
}
67 changes: 67 additions & 0 deletions mobius/internal/providers/anyscale/base.go
@@ -0,0 +1,67 @@
package anyscale

import (
"net/http"
"strings"

"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/errors"
)

type AnyscaleProviderFactory struct{}

type AnyscaleHeaders struct {
APIKey string `validate:"required" json:"Authorization" error:"API key is required"`
}

func (anyscale AnyscaleProviderFactory) Validate(headers http.Header) (*AnyscaleHeaders, error) {
var anyscaleHeaders AnyscaleHeaders
if err := utils.UnmarshalHeader(headers, &anyscaleHeaders); err != nil {
return nil, errors.New(err)
}

if err := utils.ValidateHeaders(anyscaleHeaders); err != nil {
return nil, err
}

return &anyscaleHeaders, nil
}

func (anyscale AnyscaleProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) {
anyscaleHeaders, err := anyscale.Validate(headers)
if err != nil {
return nil, err
}

anyscaleHeaders.APIKey = strings.Replace(anyscaleHeaders.APIKey, "Bearer ", "", 1)
openAIProvider := NewAnyscaleProvider(*anyscaleHeaders)
return openAIProvider, nil
}

type AnyscaleProvider struct {
Name string
Config base.ProviderConfig
AnyscaleHeaders
}

func NewAnyscaleProvider(headers AnyscaleHeaders) *AnyscaleProvider {
config := getAnyscaleConfig("https://api.endpoints.anyscale.com")

return &AnyscaleProvider{
Name: "Anyscale",
Config: config,
AnyscaleHeaders: headers,
}
}

func (anyscale AnyscaleProvider) GetName() string {
return anyscale.Name
}

func getAnyscaleConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
}
}
2 changes: 1 addition & 1 deletion mobius/internal/providers/azure/base.go
Expand Up @@ -50,7 +50,7 @@ func NewAzureProvider(headers AzureHeaders) *AzureProvider {
config := getAzureConfig()

return &AzureProvider{
Name: "Azure AI",
Name: "Azure",
Config: config,
AzureHeaders: headers,
}
Expand Down
4 changes: 2 additions & 2 deletions mobius/internal/providers/openai/base.go
Expand Up @@ -15,7 +15,7 @@ type OpenAIHeaders struct {
APIKey string `validate:"required" json:"Authorization" error:"API key is required"`
}

func (azf OpenAIProviderFactory) Validate(headers http.Header) (*OpenAIHeaders, error) {
func (oaif OpenAIProviderFactory) Validate(headers http.Header) (*OpenAIHeaders, error) {
var oaiHeaders OpenAIHeaders
if err := utils.UnmarshalHeader(headers, &oaiHeaders); err != nil {
return nil, errors.New(err)
Expand Down Expand Up @@ -49,7 +49,7 @@ func NewOpenAIProvider(headers OpenAIHeaders) *OpenAIProvider {
config := getOpenAIConfig("https://api.openai.com")

return &OpenAIProvider{
Name: "Open AI",
Name: "OpenAI",
Config: config,
OpenAIHeaders: headers,
}
Expand Down
2 changes: 2 additions & 0 deletions mobius/internal/providers/providers.go
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/internal/providers/anyscale"
"github.com/missingstudio/studio/backend/internal/providers/azure"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/providers/openai"
Expand All @@ -20,6 +21,7 @@ var providerFactories = make(map[string]ProviderFactory)
func init() {
providerFactories["openai"] = openai.OpenAIProviderFactory{}
providerFactories["azure"] = azure.AzureProviderFactory{}
providerFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
}

func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) {
Expand Down

0 comments on commit 9ef2bce

Please sign in to comment.