diff --git a/gateway/cmd/serve.go b/gateway/cmd/serve.go index 96073a2..62d0228 100644 --- a/gateway/cmd/serve.go +++ b/gateway/cmd/serve.go @@ -11,9 +11,9 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/missingstudio/studio/backend/config" "github.com/missingstudio/studio/backend/core/connection" + "github.com/missingstudio/studio/backend/core/prompt" "github.com/missingstudio/studio/backend/internal/api" "github.com/missingstudio/studio/backend/internal/ingester" - "github.com/missingstudio/studio/backend/internal/prompt" "github.com/missingstudio/studio/backend/internal/providers" "github.com/missingstudio/studio/backend/internal/ratelimiter" "github.com/missingstudio/studio/backend/internal/server" diff --git a/gateway/models/prompt.go b/gateway/core/prompt/prompt.go similarity index 96% rename from gateway/models/prompt.go rename to gateway/core/prompt/prompt.go index 5496077..d792d4d 100644 --- a/gateway/models/prompt.go +++ b/gateway/core/prompt/prompt.go @@ -1,4 +1,4 @@ -package models +package prompt import "github.com/google/uuid" diff --git a/gateway/internal/prompt/service.go b/gateway/core/prompt/service.go similarity index 56% rename from gateway/internal/prompt/service.go rename to gateway/core/prompt/service.go index a0e7e93..ad9cc46 100644 --- a/gateway/internal/prompt/service.go +++ b/gateway/core/prompt/service.go @@ -5,9 +5,16 @@ import ( "fmt" "github.com/google/uuid" - "github.com/missingstudio/studio/backend/models" ) +type Repository interface { + GetAll(context.Context) ([]Prompt, error) + Upsert(context.Context, Prompt) (Prompt, error) + GetByID(context.Context, uuid.UUID) (Prompt, error) + GetByName(context.Context, string) (Prompt, error) + DeleteByID(context.Context, uuid.UUID) error +} + var _ Repository = &Service{} type Service struct { @@ -24,7 +31,7 @@ func (s *Service) DeleteByID(ctx context.Context, promptID uuid.UUID) error { return s.promptRepo.DeleteByID(ctx, promptID) } -func (s *Service) GetAll(ctx context.Context) ([]models.Prompt, error) { +func (s *Service) GetAll(ctx context.Context) ([]Prompt, error) { prompts, err := s.promptRepo.GetAll(ctx) if err != nil { return nil, err @@ -32,28 +39,28 @@ func (s *Service) GetAll(ctx context.Context) ([]models.Prompt, error) { return prompts, nil } -func (s *Service) GetByID(ctx context.Context, promptID uuid.UUID) (models.Prompt, error) { +func (s *Service) GetByID(ctx context.Context, promptID uuid.UUID) (Prompt, error) { prompt, err := s.promptRepo.GetByID(ctx, promptID) if err != nil { - return models.Prompt{}, err + return Prompt{}, err } return prompt, err } -func (s *Service) GetByName(ctx context.Context, name string) (models.Prompt, error) { +func (s *Service) GetByName(ctx context.Context, name string) (Prompt, error) { prompt, err := s.promptRepo.GetByName(ctx, name) if err != nil { - return models.Prompt{}, err + return Prompt{}, err } return prompt, err } -func (s *Service) Upsert(ctx context.Context, c models.Prompt) (models.Prompt, error) { +func (s *Service) Upsert(ctx context.Context, c Prompt) (Prompt, error) { id, err := s.promptRepo.Upsert(ctx, c) if err != nil { - return models.Prompt{}, fmt.Errorf("failed to save prompt: %w", err) + return Prompt{}, fmt.Errorf("failed to save prompt: %w", err) } return id, err diff --git a/gateway/internal/api/deps.go b/gateway/internal/api/deps.go index bd3635f..7753542 100644 --- a/gateway/internal/api/deps.go +++ b/gateway/internal/api/deps.go @@ -4,8 +4,8 @@ import ( "log/slog" "github.com/missingstudio/studio/backend/core/connection" + "github.com/missingstudio/studio/backend/core/prompt" "github.com/missingstudio/studio/backend/internal/ingester" - "github.com/missingstudio/studio/backend/internal/prompt" "github.com/missingstudio/studio/backend/internal/providers" "github.com/missingstudio/studio/backend/internal/ratelimiter" ) diff --git a/gateway/internal/api/v1/prompts.go b/gateway/internal/api/v1/prompts.go index b50ceb4..8df9cc2 100644 --- a/gateway/internal/api/v1/prompts.go +++ b/gateway/internal/api/v1/prompts.go @@ -1,11 +1,12 @@ package v1 import ( + "bytes" "context" + "text/template" "connectrpc.com/connect" - "github.com/missingstudio/studio/backend/internal/prompt" - "github.com/missingstudio/studio/backend/models" + "github.com/missingstudio/studio/backend/core/prompt" "github.com/missingstudio/studio/common/errors" promptv1 "github.com/missingstudio/studio/protos/pkg/prompt/v1" "google.golang.org/protobuf/types/known/emptypb" @@ -37,7 +38,7 @@ func (s *V1Handler) ListPrompts(ctx context.Context, req *connect.Request[emptyp } func (s *V1Handler) CreatePrompt(ctx context.Context, req *connect.Request[promptv1.CreatePromptRequest]) (*connect.Response[promptv1.CreatePromptResponse], error) { - prompt := models.Prompt{ + prompt := prompt.Prompt{ Name: req.Msg.Name, Description: req.Msg.Description, Template: req.Msg.Template, @@ -92,13 +93,14 @@ func (s *V1Handler) GetPromptValue(ctx context.Context, req *connect.Request[pro return nil, errors.NewNotFound(err.Error()) } - prompt := prompt.NewPrompt(p.Template, req.Msg.Data.AsMap()) - value, err := prompt.Run() + var buf bytes.Buffer + tmpl := template.Must(template.New("prompt").Parse(p.Template)) + err = tmpl.Execute(&buf, req.Msg.Data.AsMap()) if err != nil { - return nil, errors.NewNotFound(err.Error()) + return nil, errors.New(err) } return connect.NewResponse(&promptv1.GetPromptValueResponse{ - Data: value, + Data: buf.String(), }), nil } diff --git a/gateway/internal/api/v1/v1.go b/gateway/internal/api/v1/v1.go index a07c630..e95376c 100644 --- a/gateway/internal/api/v1/v1.go +++ b/gateway/internal/api/v1/v1.go @@ -9,10 +9,10 @@ import ( "connectrpc.com/validate" "connectrpc.com/vanguard" "github.com/missingstudio/studio/backend/core/connection" + "github.com/missingstudio/studio/backend/core/prompt" "github.com/missingstudio/studio/backend/internal/api" "github.com/missingstudio/studio/backend/internal/ingester" "github.com/missingstudio/studio/backend/internal/interceptor" - "github.com/missingstudio/studio/backend/internal/prompt" "github.com/missingstudio/studio/backend/internal/providers" "github.com/missingstudio/studio/protos/pkg/llm/v1/llmv1connect" "github.com/missingstudio/studio/protos/pkg/prompt/v1/promptv1connect" diff --git a/gateway/internal/prompt/prompt.go b/gateway/internal/prompt/prompt.go deleted file mode 100644 index bc09ff6..0000000 --- a/gateway/internal/prompt/prompt.go +++ /dev/null @@ -1,39 +0,0 @@ -package prompt - -import ( - "bytes" - "context" - "html/template" - - "github.com/google/uuid" - "github.com/missingstudio/studio/backend/models" -) - -type Repository interface { - GetAll(context.Context) ([]models.Prompt, error) - Upsert(context.Context, models.Prompt) (models.Prompt, error) - GetByID(context.Context, uuid.UUID) (models.Prompt, error) - GetByName(context.Context, string) (models.Prompt, error) - DeleteByID(context.Context, uuid.UUID) error -} - -type Prompt struct { - tmpl *template.Template - data map[string]any -} - -func NewPrompt(text string, data map[string]any) *Prompt { - return &Prompt{ - tmpl: template.Must(template.New("prompt").Parse(text)), - data: data, - } -} - -func (p *Prompt) Run() (string, error) { - var buf bytes.Buffer - err := p.tmpl.Execute(&buf, p.data) - if err != nil { - return "", err - } - return buf.String(), nil -} diff --git a/gateway/internal/storage/postgres/prompt.go b/gateway/internal/storage/postgres/prompt.go index 1035b66..782b764 100644 --- a/gateway/internal/storage/postgres/prompt.go +++ b/gateway/internal/storage/postgres/prompt.go @@ -6,7 +6,7 @@ import ( "time" "github.com/google/uuid" - "github.com/missingstudio/studio/backend/models" + "github.com/missingstudio/studio/backend/core/prompt" ) type PromptDB struct { @@ -19,15 +19,15 @@ type PromptDB struct { CreatedAt time.Time `db:"created_at"` } -func (c PromptDB) ToPrompt() (models.Prompt, error) { +func (c PromptDB) ToPrompt() (prompt.Prompt, error) { var unmarshalledMetadata map[string]any if len(c.Metadata) > 0 { if err := json.Unmarshal(c.Metadata, &unmarshalledMetadata); err != nil { - return models.Prompt{}, fmt.Errorf("failed to unmarshal connection metadata(%s): %w", c.ID.String(), err) + return prompt.Prompt{}, fmt.Errorf("failed to unmarshal connection metadata(%s): %w", c.ID.String(), err) } } - return models.Prompt{ + return prompt.Prompt{ ID: c.ID, Name: c.Name, Description: c.Description, diff --git a/gateway/internal/storage/postgres/prompt_repository.go b/gateway/internal/storage/postgres/prompt_repository.go index f993fd9..1230a0c 100644 --- a/gateway/internal/storage/postgres/prompt_repository.go +++ b/gateway/internal/storage/postgres/prompt_repository.go @@ -9,8 +9,7 @@ import ( "github.com/doug-martin/goqu/v9" "github.com/google/uuid" - "github.com/missingstudio/studio/backend/internal/prompt" - "github.com/missingstudio/studio/backend/models" + "github.com/missingstudio/studio/backend/core/prompt" "github.com/missingstudio/studio/backend/pkg/database" ) @@ -26,10 +25,10 @@ func NewPromptRepository(dbc *database.Client) *PromptRepository { } } -func (c *PromptRepository) GetAll(ctx context.Context) ([]models.Prompt, error) { +func (c *PromptRepository) GetAll(ctx context.Context) ([]prompt.Prompt, error) { query, params, err := dialect.From(TABLE_PROMPTS).ToSQL() if err != nil { - return []models.Prompt{}, fmt.Errorf("%w: %s", queryErr, err) + return []prompt.Prompt{}, fmt.Errorf("%w: %s", queryErr, err) } var pms []PromptDB @@ -37,55 +36,55 @@ func (c *PromptRepository) GetAll(ctx context.Context) ([]models.Prompt, error) return c.dbc.SelectContext(ctx, &pms, query, params...) }); err != nil { if errors.Is(err, sql.ErrNoRows) { - return []models.Prompt{}, fmt.Errorf("%s", err) + return []prompt.Prompt{}, fmt.Errorf("%s", err) } - return []models.Prompt{}, fmt.Errorf("%w: %s", dbErr, err) + return []prompt.Prompt{}, fmt.Errorf("%w: %s", dbErr, err) } - var prompts []models.Prompt + var prompts []prompt.Prompt for _, c := range pms { - prompt, err := c.ToPrompt() + p, err := c.ToPrompt() if err != nil { - return []models.Prompt{}, fmt.Errorf("%w: %s", parseErr, err) + return []prompt.Prompt{}, fmt.Errorf("%w: %s", parseErr, err) } - prompts = append(prompts, prompt) + prompts = append(prompts, p) } return prompts, nil } -func (*PromptRepository) GetByID(ctx context.Context, connID uuid.UUID) (models.Prompt, error) { +func (*PromptRepository) GetByID(ctx context.Context, connID uuid.UUID) (prompt.Prompt, error) { panic("unimplemented") } -func (c *PromptRepository) GetByName(ctx context.Context, name string) (models.Prompt, error) { +func (c *PromptRepository) GetByName(ctx context.Context, name string) (prompt.Prompt, error) { query, params, err := dialect.From(TABLE_PROMPTS).Where(goqu.Ex{"name": name}).ToSQL() if err != nil { - return models.Prompt{}, err + return prompt.Prompt{}, err } - var prompt PromptDB + var pdb PromptDB if err = c.dbc.WithTimeout(ctx, TABLE_PROMPTS, "Get", func(ctx context.Context) error { - return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&prompt) + return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&pdb) }); err != nil { err = checkPostgresError(err) switch { case errors.Is(err, ErrDuplicateKey): - return models.Prompt{}, ErrConflict + return prompt.Prompt{}, ErrConflict default: - return models.Prompt{}, err + return prompt.Prompt{}, err } } - return prompt.ToPrompt() + return pdb.ToPrompt() } // Upsert implements prompt.Repository. -func (c *PromptRepository) Upsert(ctx context.Context, conn models.Prompt) (models.Prompt, error) { +func (c *PromptRepository) Upsert(ctx context.Context, conn prompt.Prompt) (prompt.Prompt, error) { marshaledMetadata, err := json.Marshal(conn.Metadata) if err != nil { - return models.Prompt{}, fmt.Errorf("namespace metadata: %w: %s", parseErr, err) + return prompt.Prompt{}, fmt.Errorf("namespace metadata: %w: %s", parseErr, err) } query, params, err := dialect.Insert(TABLE_PROMPTS).Rows( @@ -100,23 +99,23 @@ func (c *PromptRepository) Upsert(ctx context.Context, conn models.Prompt) (mode "updated_at": goqu.L("now()"), })).Returning(&PromptDB{}).ToSQL() if err != nil { - return models.Prompt{}, fmt.Errorf("%w: %s", queryErr, err) + return prompt.Prompt{}, fmt.Errorf("%w: %s", queryErr, err) } - var prompt PromptDB + var pdb PromptDB if err = c.dbc.WithTimeout(ctx, TABLE_PROMPTS, "Upsert", func(ctx context.Context) error { - return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&prompt) + return c.dbc.QueryRowxContext(ctx, query, params...).StructScan(&pdb) }); err != nil { err = checkPostgresError(err) switch { case errors.Is(err, ErrDuplicateKey): - return models.Prompt{}, ErrConflict + return prompt.Prompt{}, ErrConflict default: - return models.Prompt{}, err + return prompt.Prompt{}, err } } - return prompt.ToPrompt() + return pdb.ToPrompt() } func (*PromptRepository) DeleteByID(ctx context.Context, connID uuid.UUID) error {