Skip to content

Commit

Permalink
refactor(gateway): move prompt into core package
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 29, 2024
1 parent 15a2ae4 commit 132b5a5
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 88 deletions.
2 changes: 1 addition & 1 deletion gateway/cmd/serve.go
Expand Up @@ -11,9 +11,9 @@ import (
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/api"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
"github.com/missingstudio/studio/backend/internal/server"
Expand Down
2 changes: 1 addition & 1 deletion gateway/models/prompt.go → gateway/core/prompt/prompt.go
@@ -1,4 +1,4 @@
package models
package prompt

import "github.com/google/uuid"

Expand Down
Expand Up @@ -5,9 +5,16 @@ import (
"fmt"

"github.com/google/uuid"
"github.com/missingstudio/studio/backend/models"
)

type Repository interface {
GetAll(context.Context) ([]Prompt, error)
Upsert(context.Context, Prompt) (Prompt, error)
GetByID(context.Context, uuid.UUID) (Prompt, error)
GetByName(context.Context, string) (Prompt, error)
DeleteByID(context.Context, uuid.UUID) error
}

var _ Repository = &Service{}

type Service struct {
Expand All @@ -24,36 +31,36 @@ func (s *Service) DeleteByID(ctx context.Context, promptID uuid.UUID) error {
return s.promptRepo.DeleteByID(ctx, promptID)
}

func (s *Service) GetAll(ctx context.Context) ([]models.Prompt, error) {
func (s *Service) GetAll(ctx context.Context) ([]Prompt, error) {
prompts, err := s.promptRepo.GetAll(ctx)
if err != nil {
return nil, err
}
return prompts, nil
}

func (s *Service) GetByID(ctx context.Context, promptID uuid.UUID) (models.Prompt, error) {
func (s *Service) GetByID(ctx context.Context, promptID uuid.UUID) (Prompt, error) {
prompt, err := s.promptRepo.GetByID(ctx, promptID)
if err != nil {
return models.Prompt{}, err
return Prompt{}, err
}

return prompt, err
}

func (s *Service) GetByName(ctx context.Context, name string) (models.Prompt, error) {
func (s *Service) GetByName(ctx context.Context, name string) (Prompt, error) {
prompt, err := s.promptRepo.GetByName(ctx, name)
if err != nil {
return models.Prompt{}, err
return Prompt{}, err
}

return prompt, err
}

func (s *Service) Upsert(ctx context.Context, c models.Prompt) (models.Prompt, error) {
func (s *Service) Upsert(ctx context.Context, c Prompt) (Prompt, error) {
id, err := s.promptRepo.Upsert(ctx, c)
if err != nil {
return models.Prompt{}, fmt.Errorf("failed to save prompt: %w", err)
return Prompt{}, fmt.Errorf("failed to save prompt: %w", err)
}

return id, err
Expand Down
2 changes: 1 addition & 1 deletion gateway/internal/api/deps.go
Expand Up @@ -4,8 +4,8 @@ import (
"log/slog"

"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
)
Expand Down
16 changes: 9 additions & 7 deletions gateway/internal/api/v1/prompts.go
@@ -1,11 +1,12 @@
package v1

import (
"bytes"
"context"
"text/template"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/common/errors"
promptv1 "github.com/missingstudio/studio/protos/pkg/prompt/v1"
"google.golang.org/protobuf/types/known/emptypb"
Expand Down Expand Up @@ -37,7 +38,7 @@ func (s *V1Handler) ListPrompts(ctx context.Context, req *connect.Request[emptyp
}

func (s *V1Handler) CreatePrompt(ctx context.Context, req *connect.Request[promptv1.CreatePromptRequest]) (*connect.Response[promptv1.CreatePromptResponse], error) {
prompt := models.Prompt{
prompt := prompt.Prompt{
Name: req.Msg.Name,
Description: req.Msg.Description,
Template: req.Msg.Template,
Expand Down Expand Up @@ -92,13 +93,14 @@ func (s *V1Handler) GetPromptValue(ctx context.Context, req *connect.Request[pro
return nil, errors.NewNotFound(err.Error())
}

prompt := prompt.NewPrompt(p.Template, req.Msg.Data.AsMap())
value, err := prompt.Run()
var buf bytes.Buffer
tmpl := template.Must(template.New("prompt").Parse(p.Template))
err = tmpl.Execute(&buf, req.Msg.Data.AsMap())
if err != nil {
return nil, errors.NewNotFound(err.Error())
return nil, errors.New(err)
}

return connect.NewResponse(&promptv1.GetPromptValueResponse{
Data: value,
Data: buf.String(),
}), nil
}
2 changes: 1 addition & 1 deletion gateway/internal/api/v1/v1.go
Expand Up @@ -9,10 +9,10 @@ import (
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/api"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/interceptor"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/protos/pkg/llm/v1/llmv1connect"
"github.com/missingstudio/studio/protos/pkg/prompt/v1/promptv1connect"
Expand Down
39 changes: 0 additions & 39 deletions gateway/internal/prompt/prompt.go

This file was deleted.

8 changes: 4 additions & 4 deletions gateway/internal/storage/postgres/prompt.go
Expand Up @@ -6,7 +6,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/core/prompt"
)

type PromptDB struct {
Expand All @@ -19,15 +19,15 @@ type PromptDB struct {
CreatedAt time.Time `db:"created_at"`
}

func (c PromptDB) ToPrompt() (models.Prompt, error) {
func (c PromptDB) ToPrompt() (prompt.Prompt, error) {
var unmarshalledMetadata map[string]any
if len(c.Metadata) > 0 {
if err := json.Unmarshal(c.Metadata, &unmarshalledMetadata); err != nil {
return models.Prompt{}, fmt.Errorf("failed to unmarshal connection metadata(%s): %w", c.ID.String(), err)
return prompt.Prompt{}, fmt.Errorf("failed to unmarshal connection metadata(%s): %w", c.ID.String(), err)
}
}

return models.Prompt{
return prompt.Prompt{
ID: c.ID,
Name: c.Name,
Description: c.Description,
Expand Down
51 changes: 25 additions & 26 deletions gateway/internal/storage/postgres/prompt_repository.go
Expand Up @@ -9,8 +9,7 @@ import (

"github.com/doug-martin/goqu/v9"
"github.com/google/uuid"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/pkg/database"
)

Expand All @@ -26,66 +25,66 @@ func NewPromptRepository(dbc *database.Client) *PromptRepository {
}
}

func (c *PromptRepository) GetAll(ctx context.Context) ([]models.Prompt, error) {
func (c *PromptRepository) GetAll(ctx context.Context) ([]prompt.Prompt, error) {
query, params, err := dialect.From(TABLE_PROMPTS).ToSQL()
if err != nil {
return []models.Prompt{}, fmt.Errorf("%w: %s", queryErr, err)
return []prompt.Prompt{}, fmt.Errorf("%w: %s", queryErr, err)
}

var pms []PromptDB
if err = c.dbc.WithTimeout(ctx, TABLE_PROMPTS, "List", func(ctx context.Context) error {
return c.dbc.SelectContext(ctx, &pms, query, params...)
}); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []models.Prompt{}, fmt.Errorf("%s", err)
return []prompt.Prompt{}, fmt.Errorf("%s", err)
}
return []models.Prompt{}, fmt.Errorf("%w: %s", dbErr, err)
return []prompt.Prompt{}, fmt.Errorf("%w: %s", dbErr, err)
}

var prompts []models.Prompt
var prompts []prompt.Prompt
for _, c := range pms {
prompt, err := c.ToPrompt()
p, err := c.ToPrompt()
if err != nil {
return []models.Prompt{}, fmt.Errorf("%w: %s", parseErr, err)
return []prompt.Prompt{}, fmt.Errorf("%w: %s", parseErr, err)
}

prompts = append(prompts, prompt)
prompts = append(prompts, p)
}

return prompts, nil
}

func (*PromptRepository) GetByID(ctx context.Context, connID uuid.UUID) (models.Prompt, error) {
func (*PromptRepository) GetByID(ctx context.Context, connID uuid.UUID) (prompt.Prompt, error) {
panic("unimplemented")
}

func (c *PromptRepository) GetByName(ctx context.Context, name string) (models.Prompt, error) {
func (c *PromptRepository) GetByName(ctx context.Context, name string) (prompt.Prompt, error) {
query, params, err := dialect.From(TABLE_PROMPTS).Where(goqu.Ex{"name": name}).ToSQL()
if err != nil {
return models.Prompt{}, err
return prompt.Prompt{}, err
}

var prompt PromptDB
var pdb PromptDB
if err = c.dbc.WithTimeout(ctx, TABLE_PROMPTS, "Get", func(ctx context.Context) error {
return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&prompt)
return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&pdb)
}); err != nil {
err = checkPostgresError(err)
switch {
case errors.Is(err, ErrDuplicateKey):
return models.Prompt{}, ErrConflict
return prompt.Prompt{}, ErrConflict
default:
return models.Prompt{}, err
return prompt.Prompt{}, err
}
}

return prompt.ToPrompt()
return pdb.ToPrompt()
}

// Upsert implements prompt.Repository.
func (c *PromptRepository) Upsert(ctx context.Context, conn models.Prompt) (models.Prompt, error) {
func (c *PromptRepository) Upsert(ctx context.Context, conn prompt.Prompt) (prompt.Prompt, error) {
marshaledMetadata, err := json.Marshal(conn.Metadata)
if err != nil {
return models.Prompt{}, fmt.Errorf("namespace metadata: %w: %s", parseErr, err)
return prompt.Prompt{}, fmt.Errorf("namespace metadata: %w: %s", parseErr, err)
}

query, params, err := dialect.Insert(TABLE_PROMPTS).Rows(
Expand All @@ -100,23 +99,23 @@ func (c *PromptRepository) Upsert(ctx context.Context, conn models.Prompt) (mode
"updated_at": goqu.L("now()"),
})).Returning(&PromptDB{}).ToSQL()
if err != nil {
return models.Prompt{}, fmt.Errorf("%w: %s", queryErr, err)
return prompt.Prompt{}, fmt.Errorf("%w: %s", queryErr, err)
}

var prompt PromptDB
var pdb PromptDB
if err = c.dbc.WithTimeout(ctx, TABLE_PROMPTS, "Upsert", func(ctx context.Context) error {
return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&prompt)
return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&pdb)
}); err != nil {
err = checkPostgresError(err)
switch {
case errors.Is(err, ErrDuplicateKey):
return models.Prompt{}, ErrConflict
return prompt.Prompt{}, ErrConflict
default:
return models.Prompt{}, err
return prompt.Prompt{}, err
}
}

return prompt.ToPrompt()
return pdb.ToPrompt()
}

func (*PromptRepository) DeleteByID(ctx context.Context, connID uuid.UUID) error {
Expand Down

0 comments on commit 132b5a5

Please sign in to comment.