Skip to content

Commit

Permalink
feat: add openai provider support for chat completion
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Jan 25, 2024
1 parent 0d9b713 commit 22f2ede
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 156 deletions.
21 changes: 12 additions & 9 deletions mobius/internal/connectrpc/mux.go
Expand Up @@ -5,13 +5,13 @@ import (
"fmt"
"log"
"net/http"
"time"

"connectrpc.com/connect"
"connectrpc.com/grpchealth"
"connectrpc.com/grpcreflect"
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/internal/providers"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
)
Expand Down Expand Up @@ -65,12 +65,15 @@ func (s *LLMServer) ChatCompletions(
) (*connect.Response[llmv1.CompletionResponse], error) {
log.Println("Request headers: ", req.Header())

res := connect.NewResponse(&llmv1.CompletionResponse{
Id: "1",
Object: "chat.compilation",
Created: uint64(time.Now().Unix()),
Model: "random",
Choices: []*llmv1.CompletionChoice{},
})
return res, nil
provider, err := providers.NewLLMProvider(req.Header())
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

data, err := provider.ChatCompilation(ctx, req.Msg)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

return connect.NewResponse(data), nil
}
48 changes: 48 additions & 0 deletions mobius/internal/providers/openai/openai.go
@@ -0,0 +1,48 @@
package openai

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"

llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

var OpenAIAPIURL = "https://api.openai.com/v1/chat/completions"

type OpenAI struct {
APIKey string
}

func (oai OpenAI) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
payload, err := json.Marshal(cr)
if err != nil {
return nil, err
}

client := &http.Client{}
req, _ := http.NewRequestWithContext(ctx, "POST", OpenAIAPIURL, bytes.NewReader(payload))
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+oai.APIKey)

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

var data llmv1.CompletionResponse
err = json.Unmarshal(body, &data)
if err != nil {
return nil, err
}

return &data, nil
}
27 changes: 27 additions & 0 deletions mobius/internal/providers/providers.go
@@ -0,0 +1,27 @@
package providers

import (
"context"
"errors"
"net/http"
"strings"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers/openai"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

type LLMProvider interface {
ChatCompilation(ctx context.Context, ra *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
}

func NewLLMProvider(headers http.Header) (LLMProvider, error) {
provider := headers.Get("x-ms-provider")
if provider == "" {
return nil, connect.NewError(connect.CodeNotFound, errors.New("provider not found"))
}
authHeader := headers.Get("Authorization")
accessToken := strings.Replace(authHeader, "Bearer ", "", 1)

return &openai.OpenAI{APIKey: accessToken}, nil
}
19 changes: 19 additions & 0 deletions mobius/pkg/utils/interceptor.go
@@ -0,0 +1,19 @@
package utils

import (
"context"

"connectrpc.com/connect"
)

func NewLogInterceptor() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
249 changes: 117 additions & 132 deletions protos/pkg/llm/service.pb.go

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions protos/proto/llm/service.proto
Expand Up @@ -30,7 +30,7 @@ message Role {

message ChatMessage {
// role of the message author. One of "system", "user", "assistant".
Role role = 1;
string role = 1;
// content of the message
string content = 2;
}
Expand Down Expand Up @@ -62,8 +62,7 @@ message CompletionChoice {
// index of the choice in the list of choices.
uint32 index = 1;
// message generated by the model.
repeated ChatMessage messages = 2;
FinishReason finish_reason = 3;
ChatMessage message = 2;
}

message Usage {
Expand Down
18 changes: 6 additions & 12 deletions protos/src/llm/service_pb.ts
Expand Up @@ -102,9 +102,9 @@ export class ChatMessage extends Message<ChatMessage> {
/**
* role of the message author. One of "system", "user", "assistant".
*
* @generated from field: llm.v1.Role role = 1;
* @generated from field: string role = 1;
*/
role?: Role;
role = "";

/**
* content of the message
Expand All @@ -121,7 +121,7 @@ export class ChatMessage extends Message<ChatMessage> {
static readonly runtime: typeof proto3 = proto3;
static readonly typeName = "llm.v1.ChatMessage";
static readonly fields: FieldList = proto3.util.newFieldList(() => [
{ no: 1, name: "role", kind: "message", T: Role },
{ no: 1, name: "role", kind: "scalar", T: 9 /* ScalarType.STRING */ },
{ no: 2, name: "content", kind: "scalar", T: 9 /* ScalarType.STRING */ },
]);

Expand Down Expand Up @@ -298,14 +298,9 @@ export class CompletionChoice extends Message<CompletionChoice> {
/**
* message generated by the model.
*
* @generated from field: repeated llm.v1.ChatMessage messages = 2;
*/
messages: ChatMessage[] = [];

/**
* @generated from field: llm.v1.FinishReason finish_reason = 3;
* @generated from field: llm.v1.ChatMessage message = 2;
*/
finishReason = FinishReason.NULL;
message?: ChatMessage;

constructor(data?: PartialMessage<CompletionChoice>) {
super();
Expand All @@ -316,8 +311,7 @@ export class CompletionChoice extends Message<CompletionChoice> {
static readonly typeName = "llm.v1.CompletionChoice";
static readonly fields: FieldList = proto3.util.newFieldList(() => [
{ no: 1, name: "index", kind: "scalar", T: 13 /* ScalarType.UINT32 */ },
{ no: 2, name: "messages", kind: "message", T: ChatMessage, repeated: true },
{ no: 3, name: "finish_reason", kind: "enum", T: proto3.getEnumType(FinishReason) },
{ no: 2, name: "message", kind: "message", T: ChatMessage },
]);

static fromBinary(bytes: Uint8Array, options?: Partial<BinaryReadOptions>): CompletionChoice {
Expand Down

0 comments on commit 22f2ede

Please sign in to comment.