Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Summary Entity Recognition #251

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extractors:
messages:
summarizer:
enabled: true
entities:
enabled: true
embeddings:
enabled: true
dimensions: 384
Expand Down
5 changes: 3 additions & 2 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ type DocumentExtractorsConfig struct {
}

type SummarizerConfig struct {
Enabled bool `mapstructure:"enabled"`
Embeddings EmbeddingsConfig `mapstructure:"embeddings"`
Enabled bool `mapstructure:"enabled"`
Embeddings EmbeddingsConfig `mapstructure:"embeddings"`
Entities EntityExtractorConfig `mapstructure:"entities"`
}

type CustomPromptsConfig struct {
Expand Down
12 changes: 6 additions & 6 deletions pkg/llms/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ func GetEmbeddingModel(
appState *models.AppState,
documentType string,
) (*models.EmbeddingModel, error) {
var config config.EmbeddingsConfig
var cfg config.EmbeddingsConfig

switch documentType {
case "message":
config = appState.Config.Extractors.Messages.Embeddings
cfg = appState.Config.Extractors.Messages.Embeddings
case "summary":
config = appState.Config.Extractors.Messages.Summarizer.Embeddings
cfg = appState.Config.Extractors.Messages.Summarizer.Embeddings
case "document":
config = appState.Config.Extractors.Documents.Embeddings
cfg = appState.Config.Extractors.Documents.Embeddings
default:
return nil, errors.New("invalid document type")
}

return &models.EmbeddingModel{
Service: config.Service,
Dimensions: config.Dimensions,
Service: cfg.Service,
Dimensions: cfg.Dimensions,
}, nil
}
4 changes: 2 additions & 2 deletions pkg/llms/embeddings_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func embedTextsLocal(

url := appState.Config.NLP.ServerURL + endpoint

documents := make([]models.TextEmbedding, len(texts))
documents := make([]models.TextData, len(texts))
for i, text := range texts {
documents[i] = models.TextEmbedding{Text: text}
documents[i] = models.TextData{Text: text}
}
collection := models.TextEmbeddingCollection{
Embeddings: documents,
Expand Down
8 changes: 4 additions & 4 deletions pkg/models/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ type EmbeddingModel struct {
IsNormalized bool `json:"normalized"`
}

type TextEmbedding struct {
type TextData struct {
TextUUID uuid.UUID `json:"uuid,omitempty"` // MemoryStore's unique ID associated with this text.
Text string `json:"text"`
Embedding []float32 `json:"embedding,omitempty"`
Language string `json:"language"`
}

type TextEmbeddingCollection struct {
UUID uuid.UUID `json:"uuid,omitempty"`
Name string `json:"name,omitempty"`
Embeddings []TextEmbedding `json:"documents"`
UUID uuid.UUID `json:"uuid,omitempty"`
Name string `json:"name,omitempty"`
Embeddings []TextData `json:"documents"`
}
16 changes: 10 additions & 6 deletions pkg/models/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ type MessageStorer interface {
sessionID string,
messages []Message,
isPrivileged bool) error
// PutMessageEmbeddings stores a collection of TextEmbedding for a given sessionID.
// PutMessageEmbeddings stores a collection of TextData for a given sessionID.
PutMessageEmbeddings(ctx context.Context,
appState *AppState,
sessionID string,
embeddings []TextEmbedding) error
// GetMessageEmbeddings retrieves a collection of TextEmbedding for a given sessionID.
embeddings []TextData) error
// GetMessageEmbeddings retrieves a collection of TextData for a given sessionID.
GetMessageEmbeddings(ctx context.Context,
appState *AppState,
sessionID string) ([]TextEmbedding, error)
sessionID string) ([]TextData, error)
}

type MemoryStorer interface {
Expand Down Expand Up @@ -145,9 +145,13 @@ type SummaryStorer interface {
appState *AppState,
sessionID string,
summary *Summary) error
// PutSummaryEmbedding stores a TextEmbedding for a given sessionID and Summary UUID.
// UpdateSummaryMetadata updates the metadata for a given Summary. The Summary UUID must be set.
UpdateSummaryMetadata(ctx context.Context,
appState *AppState,
summary *Summary) error
// PutSummaryEmbedding stores a TextData for a given sessionID and Summary UUID.
PutSummaryEmbedding(ctx context.Context,
appState *AppState,
sessionID string,
embedding *TextEmbedding) error
embedding *TextData) error
}
19 changes: 16 additions & 3 deletions pkg/models/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,34 @@ import (
"github.com/google/uuid"
)

type TaskTopic string

const (
MessageSummarizerTopic TaskTopic = "message_summarizer"
MessageEmbedderTopic TaskTopic = "message_embedder"
MessageNerTopic TaskTopic = "message_ner"
MessageIntentTopic TaskTopic = "message_intent"
MessageTokenCountTopic TaskTopic = "message_token_count"
DocumentEmbedderTopic TaskTopic = "document_embedder"
MessageSummaryEmbedderTopic TaskTopic = "message_summary_embedder"
MessageSummaryNERTopic TaskTopic = "message_summary_ner"
)

type Task interface {
Execute(ctx context.Context, event *message.Message) error
HandleError(err error)
}

type TaskRouter interface {
Run(ctx context.Context) error
AddTask(ctx context.Context, name, taskType string, task Task)
AddTask(ctx context.Context, name string, taskType TaskTopic, task Task)
RunHandlers(ctx context.Context) error
IsRunning() bool
Close() error
}

type TaskPublisher interface {
Publish(taskType string, metadata map[string]string, payload any) error
Publish(taskType TaskTopic, metadata map[string]string, payload any) error
PublishMessage(metadata map[string]string, payload []MessageTask) error
Close() error
}
Expand All @@ -30,6 +43,6 @@ type MessageTask struct {
UUID uuid.UUID `json:"uuid"`
}

type MessageSummaryEmbeddingTask struct {
type MessageSummaryTask struct {
UUID uuid.UUID `json:"uuid"`
}
20 changes: 10 additions & 10 deletions pkg/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ func Create(appState *models.AppState) *http.Server {
}
}

// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
func setupRouter(appState *models.AppState) *chi.Mux {
maxRequestSize := appState.Config.Server.MaxRequestSize
if maxRequestSize == 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/postgres/documents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func TestDocumentCollectionUpdateDocuments(t *testing.T) {
t,
updatedDoc.Embedding,
returnedDoc.Embedding,
"Metadata mismatch for TextEmbedding %s",
"Metadata mismatch for TextData %s",
i,
)
}
Expand Down
86 changes: 54 additions & 32 deletions pkg/store/postgres/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,62 @@ func (pms *PostgresMemoryStore) GetSummaryList(
return summaries, nil
}

func (pms *PostgresMemoryStore) PutSummary(
ctx context.Context,
appState *models.AppState,
sessionID string,
summary *models.Summary,
) error {
retSummary, err := putSummary(ctx, pms.Client, sessionID, summary)
if err != nil {
return store.NewStorageError("failed to Create summary", err)
}

// Publish a message to the message summary embeddings topic
task := models.MessageSummaryTask{
UUID: retSummary.UUID,
}
err = appState.TaskPublisher.Publish(
models.MessageSummaryEmbedderTopic,
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryTask publish failed: %w", err)
}

err = appState.TaskPublisher.Publish(
models.MessageSummaryNERTopic,
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryTask publish failed: %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) UpdateSummaryMetadata(ctx context.Context,
_ *models.AppState,
summary *models.Summary) error {
_, err := updateSummaryMetadata(ctx, pms.Client, summary)
if err != nil {
return fmt.Errorf("failed to update summary metadata %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) PutSummaryEmbedding(
ctx context.Context,
_ *models.AppState,
sessionID string,
embedding *models.TextEmbedding,
embedding *models.TextData,
) error {
err := putSummaryEmbedding(ctx, pms.Client, sessionID, embedding)
if err != nil {
Expand Down Expand Up @@ -307,35 +358,6 @@ func (pms *PostgresMemoryStore) PutMemory(
return nil
}

func (pms *PostgresMemoryStore) PutSummary(
ctx context.Context,
appState *models.AppState,
sessionID string,
summary *models.Summary,
) error {
retSummary, err := putSummary(ctx, pms.Client, sessionID, summary)
if err != nil {
return store.NewStorageError("failed to Create summary", err)
}

// Publish a message to the message summary embeddings topic
task := models.MessageSummaryEmbeddingTask{
UUID: retSummary.UUID,
}
err = appState.TaskPublisher.Publish(
"message_summary_embedder",
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryEmbeddingTask publish failed: %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) PutMessageMetadata(
ctx context.Context,
_ *models.AppState,
Expand Down Expand Up @@ -371,7 +393,7 @@ func (pms *PostgresMemoryStore) Close() error {
func (pms *PostgresMemoryStore) PutMessageEmbeddings(ctx context.Context,
_ *models.AppState,
sessionID string,
embeddings []models.TextEmbedding,
embeddings []models.TextData,
) error {
if embeddings == nil {
return store.NewStorageError("nil embeddings received", nil)
Expand All @@ -391,7 +413,7 @@ func (pms *PostgresMemoryStore) PutMessageEmbeddings(ctx context.Context,
func (pms *PostgresMemoryStore) GetMessageEmbeddings(ctx context.Context,
_ *models.AppState,
sessionID string,
) ([]models.TextEmbedding, error) {
) ([]models.TextData, error) {
embeddings, err := getMessageEmbeddings(ctx, pms.Client, sessionID)
if err != nil {
return nil, store.NewStorageError("GetMessageEmbeddings failed to get embeddings", err)
Expand Down
21 changes: 12 additions & 9 deletions pkg/store/postgres/memorystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestPutMessages(t *testing.T) {
resultMessages, err := putMessages(testCtx, testDB, sessionID, messages)
assert.NoError(t, err, "putMessages should not return an error")

verifyMessagesInDB(t, messages, resultMessages)
verifyMessagesInDB(t, messages, resultMessages, false)
})

t.Run("upsert messages with updated TokenCount", func(t *testing.T) {
Expand All @@ -128,7 +128,7 @@ func TestPutMessages(t *testing.T) {
upsertedMessages, err := putMessages(testCtx, testDB, sessionID, insertedMessages)
assert.NoError(t, err, "putMessages should not return an error")

verifyMessagesInDB(t, insertedMessages, upsertedMessages)
verifyMessagesInDB(t, insertedMessages, upsertedMessages, true)
})

t.Run(
Expand Down Expand Up @@ -174,6 +174,7 @@ func verifyMessagesInDB(
t *testing.T,
expectedMessages,
resultMessages []models.Message,
verifyUpdatedAt bool,
) {
assert.Equal(
t,
Expand Down Expand Up @@ -214,12 +215,14 @@ func verifyMessagesInDB(
resultMessages[i].Metadata,
"Expected Metadata to be equal",
)
assert.Less(
t,
resultMessages[i].CreatedAt,
resultMessages[i].UpdatedAt,
"CreatedAt should be less than UpdatedAt",
)
if verifyUpdatedAt {
assert.Less(
t,
resultMessages[i].CreatedAt,
resultMessages[i].UpdatedAt,
"CreatedAt should be less than UpdatedAt",
)
}
}
}

Expand Down Expand Up @@ -444,7 +447,7 @@ func TestPutEmbeddingsLocal(t *testing.T) {
}

// Create embeddings
embeddings := []models.TextEmbedding{
embeddings := []models.TextData{
{
TextUUID: resultMessages[0].UUID,
Text: resultMessages[0].Content,
Expand Down
Loading
Loading