Skip to content

Commit

Permalink
feat: add deepinfra provider support
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Feb 1, 2024
1 parent 9ef2bce commit c2dd771
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
67 changes: 67 additions & 0 deletions mobius/internal/providers/deepinfra/base.go
@@ -0,0 +1,67 @@
package deepinfra

import (
"net/http"
"strings"

"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/errors"
)

type DeepinfraProviderFactory struct{}

type DeepinfraHeaders struct {
APIKey string `validate:"required" json:"Authorization" error:"API key is required"`
}

func (deepinfra DeepinfraProviderFactory) Validate(headers http.Header) (*DeepinfraHeaders, error) {
var deepinfraHeaders DeepinfraHeaders
if err := utils.UnmarshalHeader(headers, &deepinfraHeaders); err != nil {
return nil, errors.New(err)
}

if err := utils.ValidateHeaders(deepinfraHeaders); err != nil {
return nil, err
}

return &deepinfraHeaders, nil
}

func (deepinfra DeepinfraProviderFactory) Create(headers http.Header) (base.ProviderInterface, error) {
deepinfraHeaders, err := deepinfra.Validate(headers)
if err != nil {
return nil, err
}

deepinfraHeaders.APIKey = strings.Replace(deepinfraHeaders.APIKey, "Bearer ", "", 1)
openAIProvider := NewDeepinfraProvider(*deepinfraHeaders)
return openAIProvider, nil
}

type DeepinfraProvider struct {
Name string
Config base.ProviderConfig
DeepinfraHeaders
}

func NewDeepinfraProvider(headers DeepinfraHeaders) *DeepinfraProvider {
config := getDeepinfraConfig("https://api.deepinfra.com/v1/openai")

return &DeepinfraProvider{
Name: "Deepinfra",
Config: config,
DeepinfraHeaders: headers,
}
}

func (deepinfra DeepinfraProvider) GetName() string {
return deepinfra.Name
}

func getDeepinfraConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/chat/completions",
}
}
39 changes: 39 additions & 0 deletions mobius/internal/providers/deepinfra/deepinfra.go
@@ -0,0 +1,39 @@
package deepinfra

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/pkg/requester"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (deepinfra *DeepinfraProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
payload, err := json.Marshal(cr)
if err != nil {
return nil, err
}

client := requester.NewHTTPClient()
requestURL := fmt.Sprintf("%s%s", deepinfra.Config.BaseURL, deepinfra.Config.ChatCompletions)
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))

req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", deepinfra.APIKey))

resp, err := client.Do(req)
if err != nil {
return nil, err
}

var data llmv1.CompletionResponse
err = json.Unmarshal(resp, &data)
if err != nil {
return nil, err
}

return &data, nil
}
2 changes: 2 additions & 0 deletions mobius/internal/providers/providers.go
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/missingstudio/studio/backend/internal/providers/anyscale"
"github.com/missingstudio/studio/backend/internal/providers/azure"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/providers/deepinfra"
"github.com/missingstudio/studio/backend/internal/providers/openai"
"github.com/missingstudio/studio/common/errors"
)
Expand All @@ -22,6 +23,7 @@ func init() {
providerFactories["openai"] = openai.OpenAIProviderFactory{}
providerFactories["azure"] = azure.AzureProviderFactory{}
providerFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
providerFactories["deepinfra"] = deepinfra.DeepinfraProviderFactory{}
}

func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) {
Expand Down

0 comments on commit c2dd771

Please sign in to comment.