diff --git a/pkg/memorystore/postgres.go b/pkg/memorystore/postgres.go index 4a26884f..f35d9703 100644 --- a/pkg/memorystore/postgres.go +++ b/pkg/memorystore/postgres.go @@ -463,6 +463,7 @@ func putSession( session := PgSession{SessionID: sessionID, Metadata: metadata} _, err := db.NewInsert(). Model(&session). + Column("uuid", "session_id", "created_at", "metadata"). On("CONFLICT (session_id) DO UPDATE"). Exec(ctx) if err != nil { @@ -504,8 +505,7 @@ func getSession( // putMessages stores a new or updates existing messages for a session. Existing // messages are determined by message UUID. Sessions are created if they do not -// exist. We also create new PgMessageVectorStore records for each new message. -// Embedding happens out of band. +// exist. func putMessages( ctx context.Context, db *bun.DB, @@ -533,7 +533,11 @@ func putMessages( pgMessages[i].SessionID = sessionID } - _, err = db.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx) + _, err = db.NewInsert(). + Model(&pgMessages). + Column("id", "created_at", "uuid", "session_id", "role", "content", "token_count", "metadata"). + On("CONFLICT (uuid) DO UPDATE"). + Exec(ctx) if err != nil { return nil, NewStorageError("failed to save memories to store", err) } diff --git a/pkg/memorystore/postgres_test.go b/pkg/memorystore/postgres_test.go index 8ab9f4c7..4ab5487f 100644 --- a/pkg/memorystore/postgres_test.go +++ b/pkg/memorystore/postgres_test.go @@ -298,6 +298,41 @@ func TestPutMessages(t *testing.T) { // Verify the upserted messages in the database verifyMessagesInDB(t, testCtx, testDB, sessionID, messages, resultMessages) }) + + t.Run( + "upsert messages with updated TokenCount without overwriting DeletedAt", + func(t *testing.T) { + // Get messages with UUIDs + messages, err := getMessages(testCtx, testDB, sessionID, 12, 0) + assert.NoError(t, err, "getMessages should not return an error") + + // Delete using deleteSession + err = deleteSession(testCtx, testDB, sessionID) + assert.NoError(t, err, "deleteSession should not return an error") + + messagesOnceDeleted, err := getMessages(testCtx, testDB, sessionID, 12, 0) + assert.NoError(t, err, "getMessages should not return an error") + + // confirm that no records were returned + assert.Equal(t, 0, len(messagesOnceDeleted), "getMessages should return 0 messages") + + // Update original messages with TokenCount values + for i := range messages { + messages[i].TokenCount = i + 1 + } + + // Call putMessages function to upsert the messages + _, err = putMessages(testCtx, testDB, sessionID, messages) + assert.NoError(t, err, "putMessages should not return an error") + + messagesInDB, err := getMessages(testCtx, testDB, sessionID, 12, 0) + assert.NoError(t, err, "getMessages should not return an error") + + // len(messagesInDB) should be 0 since the session was deleted + assert.Equal(t, 0, len(messagesInDB), "getMessages should return 0 messages") + }, + ) + } func verifyMessagesInDB(