Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improv: Migrate Message Embedding column only if schema vector width != config #189

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions pkg/store/postgres/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,24 +403,60 @@ func CreateSchema(
}
}

// check that the message embedding dimensions match the configured model
if err := checkMessageEmbeddingDims(ctx, appState, db); err != nil {
return fmt.Errorf("error checking message embedding dimensions: %w", err)
}

// apply migrations
if err := migrations.Migrate(ctx, db); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}

return nil
}

// checkMessageEmbeddingDims checks the dimensions of the message embedding column against the
// dimensions of the configured message embedding model. If they do not match, the column is dropped and
// recreated with the correct dimensions.
func checkMessageEmbeddingDims(ctx context.Context, appState *models.AppState, db *bun.DB) error {
model, err := llms.GetEmbeddingModel(appState, "message")
if err != nil {
return fmt.Errorf("error getting message embedding model: %w", err)
}
// we keep this at 1536 for legacy reasons, despite the default now being 384
if model.Dimensions != 1536 {
width, err := getEmbeddingColumnWidth(ctx, "message_embedding", db)
if err != nil {
return fmt.Errorf("error getting embedding column width: %w", err)
}

if width != model.Dimensions {
log.Warnf(
"message embedding dimensions are %d, expected %d.\n migrating message embedding column width to %d. this may result in loss of existing embedding vectors",
width,
model.Dimensions,
model.Dimensions,
)
err := MigrateMessageEmbeddingDims(ctx, db, model.Dimensions)
if err != nil {
return fmt.Errorf("error migrating message embedding dimensions: %w", err)
}
}
return nil
}

// apply migrations
if err := migrations.Migrate(ctx, db); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
// getEmbeddingColumnWidth returns the width of the embedding column in the provided table.
func getEmbeddingColumnWidth(ctx context.Context, tableName string, db *bun.DB) (int, error) {
var width int
err := db.NewSelect().
Table("pg_attribute").
ColumnExpr("atttypmod"). // vector width is stored in atttypmod
Where("attrelid = ?::regclass", tableName).
Where("attname = 'embedding'").
Scan(ctx, &width)
if err != nil {
return 0, fmt.Errorf("error getting embedding column width: %w", err)
}

return nil
return width, nil
}

// MigrateMessageEmbeddingDims drops the old embedding column and creates a new one with the
Expand Down
28 changes: 28 additions & 0 deletions pkg/store/postgres/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"time"

"github.com/getzep/zep/pkg/llms"
"github.com/stretchr/testify/assert"
"github.com/uptrace/bun"
)
Expand Down Expand Up @@ -76,3 +77,30 @@ func TestUpdatedAtIsSetAfterUpdate(t *testing.T) {
)
}
}

func TestCheckMessageEmbeddingDims(t *testing.T) {
// Create a new DB
CleanDB(t, testDB)
err := CreateSchema(testCtx, appState, testDB)
assert.NoError(t, err)

// Get the embedding model
model, err := llms.GetEmbeddingModel(appState, "message")
assert.NoError(t, err)

testWidth := model.Dimensions + 1
// Set the embedding column to a specific width
err = MigrateMessageEmbeddingDims(testCtx, testDB, testWidth)
assert.NoError(t, err)

width, err := getEmbeddingColumnWidth(testCtx, "message_embedding", testDB)
assert.NoError(t, err)

assert.Equal(t, width, testWidth)

// Clean the DB
CleanDB(t, testDB)
err = CreateSchema(testCtx, appState, testDB)
assert.NoError(t, err)

}
Loading