Skip to content

Commit

Permalink
fix for #52. specify columns on insert/upsert to avoid overwriting de…
Browse files Browse the repository at this point in the history
…leted_at
  • Loading branch information
danielchalef committed May 18, 2023
1 parent b18c6ac commit 4300ad2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pkg/memorystore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ func putSession(
session := PgSession{SessionID: sessionID, Metadata: metadata}
_, err := db.NewInsert().
Model(&session).
Column("uuid", "session_id", "created_at", "metadata").
On("CONFLICT (session_id) DO UPDATE").
Exec(ctx)
if err != nil {
Expand Down Expand Up @@ -504,8 +505,7 @@ func getSession(

// putMessages stores a new or updates existing messages for a session. Existing
// messages are determined by message UUID. Sessions are created if they do not
// exist. We also create new PgMessageVectorStore records for each new message.
// Embedding happens out of band.
// exist.
func putMessages(
ctx context.Context,
db *bun.DB,
Expand Down Expand Up @@ -533,7 +533,11 @@ func putMessages(
pgMessages[i].SessionID = sessionID
}

_, err = db.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx)
_, err = db.NewInsert().
Model(&pgMessages).
Column("id", "created_at", "uuid", "session_id", "role", "content", "token_count", "metadata").
On("CONFLICT (uuid) DO UPDATE").
Exec(ctx)
if err != nil {
return nil, NewStorageError("failed to save memories to store", err)
}
Expand Down
35 changes: 35 additions & 0 deletions pkg/memorystore/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,41 @@ func TestPutMessages(t *testing.T) {
// Verify the upserted messages in the database
verifyMessagesInDB(t, testCtx, testDB, sessionID, messages, resultMessages)
})

t.Run(
"upsert messages with updated TokenCount without overwriting DeletedAt",
func(t *testing.T) {
// Get messages with UUIDs
messages, err := getMessages(testCtx, testDB, sessionID, 12, 0)
assert.NoError(t, err, "getMessages should not return an error")

// Delete using deleteSession
err = deleteSession(testCtx, testDB, sessionID)
assert.NoError(t, err, "deleteSession should not return an error")

messagesOnceDeleted, err := getMessages(testCtx, testDB, sessionID, 12, 0)
assert.NoError(t, err, "getMessages should not return an error")

// confirm that no records were returned
assert.Equal(t, 0, len(messagesOnceDeleted), "getMessages should return 0 messages")

// Update original messages with TokenCount values
for i := range messages {
messages[i].TokenCount = i + 1
}

// Call putMessages function to upsert the messages
_, err = putMessages(testCtx, testDB, sessionID, messages)
assert.NoError(t, err, "putMessages should not return an error")

messagesInDB, err := getMessages(testCtx, testDB, sessionID, 12, 0)
assert.NoError(t, err, "getMessages should not return an error")

// len(messagesInDB) should be 0 since the session was deleted
assert.Equal(t, 0, len(messagesInDB), "getMessages should return 0 messages")
},
)

}

func verifyMessagesInDB(
Expand Down

0 comments on commit 4300ad2

Please sign in to comment.