diff --git a/gateway/cmd/serve.go b/gateway/cmd/serve.go index 62d0228..5b1a266 100644 --- a/gateway/cmd/serve.go +++ b/gateway/cmd/serve.go @@ -69,7 +69,12 @@ func Serve(cfg *config.Config) error { promptService := prompt.NewService(promptRepository) providerService := providers.NewService() - deps := api.NewDeps(logger, ingester, rl, providerService, connectionService, promptService) + deps := api.NewDeps( + logger, ingester, rl, + providerService, + connectionService, + promptService, + ) if err := server.Serve(ctx, logger, cfg.App, deps); err != nil { logger.Error("error starting server", "error", err) diff --git a/gateway/models/chat.go b/gateway/core/chat/chat.go similarity index 99% rename from gateway/models/chat.go rename to gateway/core/chat/chat.go index 5613fef..8f79baf 100644 --- a/gateway/models/chat.go +++ b/gateway/core/chat/chat.go @@ -1,4 +1,4 @@ -package models +package chat type ResponseFormat struct { Type string `json:"type,omitempty"` diff --git a/gateway/core/connection/service.go b/gateway/core/connection/service.go index 3af4c78..b42c59e 100644 --- a/gateway/core/connection/service.go +++ b/gateway/core/connection/service.go @@ -27,12 +27,10 @@ func NewService(connectionRepo Repository) *Service { } } -// DeleteByID implements connection.Repository. func (s *Service) DeleteByID(ctx context.Context, connID uuid.UUID) error { return s.connectionRepo.DeleteByID(ctx, connID) } -// GetAll implements connection.Repository. func (s *Service) GetAll(ctx context.Context) ([]Connection, error) { conns, err := s.connectionRepo.GetAll(ctx) if err != nil { @@ -41,7 +39,6 @@ func (s *Service) GetAll(ctx context.Context) ([]Connection, error) { return conns, nil } -// GetByID implements connection.Repository. func (s *Service) GetByID(ctx context.Context, connID uuid.UUID) (Connection, error) { conn, err := s.connectionRepo.GetByID(ctx, connID) if err != nil { @@ -51,7 +48,6 @@ func (s *Service) GetByID(ctx context.Context, connID uuid.UUID) (Connection, er return conn, err } -// GetByName implements connection.Repository. func (s *Service) GetByName(ctx context.Context, name string) (Connection, error) { conn, err := s.connectionRepo.GetByName(ctx, name) if err != nil { @@ -61,7 +57,6 @@ func (s *Service) GetByName(ctx context.Context, name string) (Connection, error return conn, err } -// Upsert implements connection.Repository. func (s *Service) Upsert(ctx context.Context, c Connection) (Connection, error) { id, err := s.connectionRepo.Upsert(ctx, c) if err != nil { diff --git a/gateway/internal/api/v1/chatcompletions.go b/gateway/internal/api/v1/chatcompletions.go index a2135e9..6a6dc29 100644 --- a/gateway/internal/api/v1/chatcompletions.go +++ b/gateway/internal/api/v1/chatcompletions.go @@ -7,12 +7,12 @@ import ( "time" "connectrpc.com/connect" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" "github.com/missingstudio/studio/backend/internal/constants" "github.com/missingstudio/studio/backend/internal/providers" "github.com/missingstudio/studio/backend/internal/providers/base" "github.com/missingstudio/studio/backend/internal/router" - "github.com/missingstudio/studio/backend/models" "github.com/missingstudio/studio/common/errors" llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1" ) @@ -86,13 +86,13 @@ func (s *V1Handler) ChatCompletions( return connect.NewResponse(chatCompletionResponseSchema), nil } -func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionRequest) (*models.ChatCompletionRequest, error) { +func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionRequest) (*chat.ChatCompletionRequest, error) { payload, err := json.Marshal(req) if err != nil { return nil, err } - data := &models.ChatCompletionRequest{} + data := &chat.ChatCompletionRequest{} err = json.Unmarshal(payload, data) if err != nil { return nil, err @@ -101,7 +101,7 @@ func (s *V1Handler) createChatCompletionRequestSchema(req *llmv1.ChatCompletionR return data, nil } -func (s *V1Handler) createChatCompletionResponseSchema(resp *models.ChatCompletionResponse) (*llmv1.ChatCompletionResponse, error) { +func (s *V1Handler) createChatCompletionResponseSchema(resp *chat.ChatCompletionResponse) (*llmv1.ChatCompletionResponse, error) { payload, err := json.Marshal(resp) if err != nil { return nil, err @@ -116,7 +116,7 @@ func (s *V1Handler) createChatCompletionResponseSchema(resp *models.ChatCompleti return data, nil } -func (s *V1Handler) sendMetrics(provider string, latency time.Duration, response *models.ChatCompletionResponse) { +func (s *V1Handler) sendMetrics(provider string, latency time.Duration, response *chat.ChatCompletionResponse) { ingesterdata := make(map[string]any) ingesterdata["provider"] = provider ingesterdata["latency"] = latency diff --git a/gateway/internal/providers/anyscale/anyscale.go b/gateway/internal/providers/anyscale/anyscale.go index a307cbd..5d89c16 100644 --- a/gateway/internal/providers/anyscale/anyscale.go +++ b/gateway/internal/providers/anyscale/anyscale.go @@ -7,12 +7,12 @@ import ( "fmt" "net/http" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" - "github.com/missingstudio/studio/backend/models" "github.com/missingstudio/studio/backend/pkg/requester" ) -func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) { +func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) { client := requester.NewHTTPClient() rawPayload, err := json.Marshal(payload) @@ -32,7 +32,7 @@ func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *m return nil, err } - data := &models.ChatCompletionResponse{} + data := &chat.ChatCompletionResponse{} if err := json.NewDecoder(resp.Body).Decode(data); err != nil { return nil, err } diff --git a/gateway/internal/providers/azure/azure.go b/gateway/internal/providers/azure/azure.go index dc62a4d..6009fe2 100644 --- a/gateway/internal/providers/azure/azure.go +++ b/gateway/internal/providers/azure/azure.go @@ -4,11 +4,11 @@ import ( "context" "errors" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/internal/providers/openai" - "github.com/missingstudio/studio/backend/models" ) -func (az *azureProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) { +func (az *azureProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) { return nil, errors.New("Not yet implemented") } diff --git a/gateway/internal/providers/base/base.go b/gateway/internal/providers/base/base.go index 2470aaa..a0259b3 100644 --- a/gateway/internal/providers/base/base.go +++ b/gateway/internal/providers/base/base.go @@ -3,8 +3,8 @@ package base import ( "context" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" - "github.com/missingstudio/studio/backend/models" ) type ProviderConfig struct { @@ -30,5 +30,5 @@ var ProviderRegistry = map[string]func(connection.Connection) IProvider{} type ChatCompletionInterface interface { IProvider - ChatCompletion(context.Context, *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) + ChatCompletion(context.Context, *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) } diff --git a/gateway/internal/providers/deepinfra/deepinfra.go b/gateway/internal/providers/deepinfra/deepinfra.go index fb416e7..cbe76e6 100644 --- a/gateway/internal/providers/deepinfra/deepinfra.go +++ b/gateway/internal/providers/deepinfra/deepinfra.go @@ -7,12 +7,12 @@ import ( "fmt" "net/http" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" - "github.com/missingstudio/studio/backend/models" "github.com/missingstudio/studio/backend/pkg/requester" ) -func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) { +func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) { client := requester.NewHTTPClient() rawPayload, err := json.Marshal(payload) @@ -32,7 +32,7 @@ func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload return nil, err } - data := &models.ChatCompletionResponse{} + data := &chat.ChatCompletionResponse{} if err := json.NewDecoder(resp.Body).Decode(data); err != nil { return nil, err } diff --git a/gateway/internal/providers/openai/openai.go b/gateway/internal/providers/openai/openai.go index da6433a..93621c1 100644 --- a/gateway/internal/providers/openai/openai.go +++ b/gateway/internal/providers/openai/openai.go @@ -7,8 +7,8 @@ import ( "fmt" "net/http" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" - "github.com/missingstudio/studio/backend/models" "github.com/missingstudio/studio/backend/pkg/requester" ) @@ -28,7 +28,7 @@ var OpenAIModels = []string{ "gpt-3.5-turbo-instruct", } -func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) { +func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) { client := requester.NewHTTPClient() rawPayload, err := json.Marshal(payload) @@ -48,7 +48,7 @@ func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *models.C return nil, err } - data := &models.ChatCompletionResponse{} + data := &chat.ChatCompletionResponse{} if err := json.NewDecoder(resp.Body).Decode(data); err != nil { return nil, err } diff --git a/gateway/internal/providers/togetherai/togetherai.go b/gateway/internal/providers/togetherai/togetherai.go index 910d1f1..7fdcccb 100644 --- a/gateway/internal/providers/togetherai/togetherai.go +++ b/gateway/internal/providers/togetherai/togetherai.go @@ -7,12 +7,12 @@ import ( "fmt" "net/http" + "github.com/missingstudio/studio/backend/core/chat" "github.com/missingstudio/studio/backend/core/connection" - "github.com/missingstudio/studio/backend/models" "github.com/missingstudio/studio/backend/pkg/requester" ) -func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletionRequest) (*models.ChatCompletionResponse, error) { +func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) { client := requester.NewHTTPClient() rawPayload, err := json.Marshal(payload) @@ -32,7 +32,7 @@ func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *model return nil, err } - data := &models.ChatCompletionResponse{} + data := &chat.ChatCompletionResponse{} if err := json.NewDecoder(resp.Body).Decode(data); err != nil { return nil, err } diff --git a/gateway/internal/storage/postgres/connection_repository.go b/gateway/internal/storage/postgres/connection_repository.go index a8f3503..f5eb803 100644 --- a/gateway/internal/storage/postgres/connection_repository.go +++ b/gateway/internal/storage/postgres/connection_repository.go @@ -27,12 +27,10 @@ func NewConnectionRepository(dbc *database.Client) *ConnectionRepository { } } -// DeleteByID implements connection.Repository. func (*ConnectionRepository) DeleteByID(ctx context.Context, connID uuid.UUID) error { panic("unimplemented") } -// GetAll implements connection.Repository. func (c *ConnectionRepository) GetAll(ctx context.Context) ([]connection.Connection, error) { query, params, err := dialect.From(TABLE_CONNECTIONS).ToSQL() if err != nil { @@ -62,12 +60,10 @@ func (c *ConnectionRepository) GetAll(ctx context.Context) ([]connection.Connect return connections, nil } -// GetByID implements connection.Repository. func (*ConnectionRepository) GetByID(ctx context.Context, connID uuid.UUID) (connection.Connection, error) { panic("unimplemented") } -// GetByID implements connection.Repository. func (c *ConnectionRepository) GetByName(ctx context.Context, name string) (connection.Connection, error) { query, params, err := dialect.From(TABLE_CONNECTIONS).Where(goqu.Ex{"name": name}).ToSQL() if err != nil { @@ -91,7 +87,6 @@ func (c *ConnectionRepository) GetByName(ctx context.Context, name string) (conn return connDb.ToConnection() } -// Upsert implements connection.Repository. func (c *ConnectionRepository) Upsert(ctx context.Context, conn connection.Connection) (connection.Connection, error) { marshaledConfig, err := json.Marshal(conn.Config) if err != nil { diff --git a/gateway/internal/storage/postgres/prompt_repository.go b/gateway/internal/storage/postgres/prompt_repository.go index 1230a0c..ca8189d 100644 --- a/gateway/internal/storage/postgres/prompt_repository.go +++ b/gateway/internal/storage/postgres/prompt_repository.go @@ -80,7 +80,6 @@ func (c *PromptRepository) GetByName(ctx context.Context, name string) (prompt.P return pdb.ToPrompt() } -// Upsert implements prompt.Repository. func (c *PromptRepository) Upsert(ctx context.Context, conn prompt.Prompt) (prompt.Prompt, error) { marshaledMetadata, err := json.Marshal(conn.Metadata) if err != nil {