Skip to content

Commit

Permalink
refactor(gateway): move chat into core package
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 29, 2024
1 parent 132b5a5 commit 4c95111
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 34 deletions.
7 changes: 6 additions & 1 deletion gateway/cmd/serve.go
Expand Up @@ -69,7 +69,12 @@ func Serve(cfg *config.Config) error {
promptService := prompt.NewService(promptRepository)

providerService := providers.NewService()
deps := api.NewDeps(logger, ingester, rl, providerService, connectionService, promptService)
deps := api.NewDeps(
logger, ingester, rl,
providerService,
connectionService,
promptService,
)

if err := server.Serve(ctx, logger, cfg.App, deps); err != nil {
logger.Error("error starting server", "error", err)
Expand Down
2 changes: 1 addition & 1 deletion gateway/models/chat.go → gateway/core/chat/chat.go
@@ -1,4 +1,4 @@
package models
package chat

type ResponseFormat struct {
Type string `json:"type,omitempty"`
Expand Down
5 changes: 0 additions & 5 deletions gateway/core/connection/service.go
Expand Up @@ -27,12 +27,10 @@ func NewService(connectionRepo Repository) *Service {
}
}

// DeleteByID implements connection.Repository.
func (s *Service) DeleteByID(ctx context.Context, connID uuid.UUID) error {
return s.connectionRepo.DeleteByID(ctx, connID)
}

// GetAll implements connection.Repository.
func (s *Service) GetAll(ctx context.Context) ([]Connection, error) {
conns, err := s.connectionRepo.GetAll(ctx)
if err != nil {
Expand All @@ -41,7 +39,6 @@ func (s *Service) GetAll(ctx context.Context) ([]Connection, error) {
return conns, nil
}

// GetByID implements connection.Repository.
func (s *Service) GetByID(ctx context.Context, connID uuid.UUID) (Connection, error) {
conn, err := s.connectionRepo.GetByID(ctx, connID)
if err != nil {
Expand All @@ -51,7 +48,6 @@ func (s *Service) GetByID(ctx context.Context, connID uuid.UUID) (Connection, er
return conn, err
}

// GetByName implements connection.Repository.
func (s *Service) GetByName(ctx context.Context, name string) (Connection, error) {
conn, err := s.connectionRepo.GetByName(ctx, name)
if err != nil {
Expand All @@ -61,7 +57,6 @@ func (s *Service) GetByName(ctx context.Context, name string) (Connection, error
return conn, err
}

// Upsert implements connection.Repository.
func (s *Service) Upsert(ctx context.Context, c Connection) (Connection, error) {
id, err := s.connectionRepo.Upsert(ctx, c)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions gateway/internal/api/v1/chatcompletions.go
Expand Up @@ -7,12 +7,12 @@ import (
"time"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/internal/router"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
)
Expand Down Expand Up @@ -86,13 +86,13 @@ func (s *V1Handler) ChatCompletions(
return connect.NewResponse(chatCompletionResponseSchema), nil
}

func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionRequest) (*models.ChatCompletionRequest, error) {
func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionRequest) (*chat.ChatCompletionRequest, error) {
payload, err := json.Marshal(req)
if err != nil {
return nil, err
}

data := &models.ChatCompletionRequest{}
data := &chat.ChatCompletionRequest{}
err = json.Unmarshal(payload, data)
if err != nil {
return nil, err
Expand All @@ -101,7 +101,7 @@ func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionR
return data, nil
}

func (s *V1Handler) createChatCompletionResponseSchema(resp *models.ChatCompletionResponse) (*llmv1.ChatCompletionResponse, error) {
func (s *V1Handler) createChatCompletionResponseSchema(resp *chat.ChatCompletionResponse) (*llmv1.ChatCompletionResponse, error) {
payload, err := json.Marshal(resp)
if err != nil {
return nil, err
Expand All @@ -116,7 +116,7 @@ func (s *V1Handler) createChatCompletionResponseSchema(resp *models.ChatCompleti
return data, nil
}

func (s *V1Handler) sendMetrics(provider string, latency time.Duration, response *models.ChatCompletionResponse) {
func (s *V1Handler) sendMetrics(provider string, latency time.Duration, response *chat.ChatCompletionResponse) {
ingesterdata := make(map[string]any)
ingesterdata["provider"] = provider
ingesterdata["latency"] = latency
Expand Down
6 changes: 3 additions & 3 deletions gateway/internal/providers/anyscale/anyscale.go
Expand Up @@ -7,12 +7,12 @@ import (
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) {
func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
Expand All @@ -32,7 +32,7 @@ func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *m
return nil, err
}

data := &models.ChatCompletionResponse{}
data := &chat.ChatCompletionResponse{}
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/azure/azure.go
Expand Up @@ -4,11 +4,11 @@ import (
"context"
"errors"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/internal/providers/openai"
"github.com/missingstudio/studio/backend/models"
)

func (az *azureProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) {
func (az *azureProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
return nil, errors.New("Not yet implemented")
}

Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/base/base.go
Expand Up @@ -3,8 +3,8 @@ package base
import (
"context"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/models"
)

type ProviderConfig struct {
Expand All @@ -30,5 +30,5 @@ var ProviderRegistry = map[string]func(connection.Connection) IProvider{}

type ChatCompletionInterface interface {
IProvider
ChatCompletion(context.Context, *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error)
ChatCompletion(context.Context, *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error)
}
6 changes: 3 additions & 3 deletions gateway/internal/providers/deepinfra/deepinfra.go
Expand Up @@ -7,12 +7,12 @@ import (
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) {
func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
Expand All @@ -32,7 +32,7 @@ func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload
return nil, err
}

data := &models.ChatCompletionResponse{}
data := &chat.ChatCompletionResponse{}
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions gateway/internal/providers/openai/openai.go
Expand Up @@ -7,8 +7,8 @@ import (
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

Expand All @@ -28,7 +28,7 @@ var OpenAIModels = []string{
"gpt-3.5-turbo-instruct",
}

func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) {
func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
Expand All @@ -48,7 +48,7 @@ func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *models.C
return nil, err
}

data := &models.ChatCompletionResponse{}
data := &chat.ChatCompletionResponse{}
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions gateway/internal/providers/togetherai/togetherai.go
Expand Up @@ -7,12 +7,12 @@ import (
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/core/chat"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) {
func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
Expand All @@ -32,7 +32,7 @@ func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *model
return nil, err
}

data := &models.ChatCompletionResponse{}
data := &chat.ChatCompletionResponse{}
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err
}
Expand Down
5 changes: 0 additions & 5 deletions gateway/internal/storage/postgres/connection_repository.go
Expand Up @@ -27,12 +27,10 @@ func NewConnectionRepository(dbc *database.Client) *ConnectionRepository {
}
}

// DeleteByID implements connection.Repository.
func (*ConnectionRepository) DeleteByID(ctx context.Context, connID uuid.UUID) error {
panic("unimplemented")
}

// GetAll implements connection.Repository.
func (c *ConnectionRepository) GetAll(ctx context.Context) ([]connection.Connection, error) {
query, params, err := dialect.From(TABLE_CONNECTIONS).ToSQL()
if err != nil {
Expand Down Expand Up @@ -62,12 +60,10 @@ func (c *ConnectionRepository) GetAll(ctx context.Context) ([]connection.Connect
return connections, nil
}

// GetByID implements connection.Repository.
func (*ConnectionRepository) GetByID(ctx context.Context, connID uuid.UUID) (connection.Connection, error) {
panic("unimplemented")
}

// GetByID implements connection.Repository.
func (c *ConnectionRepository) GetByName(ctx context.Context, name string) (connection.Connection, error) {
query, params, err := dialect.From(TABLE_CONNECTIONS).Where(goqu.Ex{"name": name}).ToSQL()
if err != nil {
Expand All @@ -91,7 +87,6 @@ func (c *ConnectionRepository) GetByName(ctx context.Context, name string) (conn
return connDb.ToConnection()
}

// Upsert implements connection.Repository.
func (c *ConnectionRepository) Upsert(ctx context.Context, conn connection.Connection) (connection.Connection, error) {
marshaledConfig, err := json.Marshal(conn.Config)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion gateway/internal/storage/postgres/prompt_repository.go
Expand Up @@ -80,7 +80,6 @@ func (c *PromptRepository) GetByName(ctx context.Context, name string) (prompt.P
return pdb.ToPrompt()
}

// Upsert implements prompt.Repository.
func (c *PromptRepository) Upsert(ctx context.Context, conn prompt.Prompt) (prompt.Prompt, error) {
marshaledMetadata, err := json.Marshal(conn.Metadata)
if err != nil {
Expand Down

0 comments on commit 4c95111

Please sign in to comment.