Skip to content

Commit

Permalink
implement BeforeAppendModelHook to set UpdatedAt (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef committed Sep 17, 2023
1 parent b2dbf77 commit 0822b30
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
56 changes: 54 additions & 2 deletions pkg/store/postgres/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ type SessionSchema struct {
User *UserSchema `bun:"rel:belongs-to,join:user_id=user_id,on_delete:cascade" yaml:"-"`
}

var _ bun.BeforeAppendModelHook = (*SessionSchema)(nil)

func (s *SessionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
s.UpdatedAt = time.Now()
}
return nil
}

// BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator
func (s *SessionSchema) BeforeCreateTable(
_ context.Context,
Expand All @@ -47,8 +56,6 @@ func (s *SessionSchema) BeforeCreateTable(
type MessageStoreSchema struct {
bun.BaseModel `bun:"table:message,alias:m" yaml:"-"`

// TODO: replace UUIDs with sortable ULIDs or UUIDv7s to avoid having to have both a UUID and an ID.
// see https://blog.daveallie.com/ulid-primary-keys
UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid"`
// ID is used only for sorting / slicing purposes as we can't sort by CreatedAt for messages created simultaneously
ID int64 `bun:",autoincrement" yaml:"id,omitempty"`
Expand All @@ -63,6 +70,15 @@ type MessageStoreSchema struct {
Session *SessionSchema `bun:"rel:belongs-to,join:session_id=session_id,on_delete:cascade" yaml:"-"`
}

var _ bun.BeforeAppendModelHook = (*MessageStoreSchema)(nil)

func (s *MessageStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
s.UpdatedAt = time.Now()
}
return nil
}

func (s *MessageStoreSchema) BeforeCreateTable(
_ context.Context,
_ *bun.CreateTableQuery,
Expand All @@ -86,6 +102,15 @@ type MessageVectorStoreSchema struct {
Message *MessageStoreSchema `bun:"rel:belongs-to,join:message_uuid=uuid,on_delete:cascade"`
}

var _ bun.BeforeAppendModelHook = (*MessageVectorStoreSchema)(nil)

func (s *MessageVectorStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
s.UpdatedAt = time.Now()
}
return nil
}

func (s *MessageVectorStoreSchema) BeforeCreateTable(
_ context.Context,
_ *bun.CreateTableQuery,
Expand All @@ -109,6 +134,15 @@ type SummaryStoreSchema struct {
Message *MessageStoreSchema `bun:"rel:belongs-to,join:summary_point_uuid=uuid,on_delete:cascade"`
}

var _ bun.BeforeAppendModelHook = (*SummaryStoreSchema)(nil)

func (s *SummaryStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
s.UpdatedAt = time.Now()
}
return nil
}

func (s *SummaryStoreSchema) BeforeCreateTable(
_ context.Context,
_ *bun.CreateTableQuery,
Expand All @@ -129,6 +163,15 @@ func (s *DocumentCollectionSchema) BeforeCreateTable(
return nil
}

var _ bun.BeforeAppendModelHook = (*DocumentCollectionSchema)(nil)

func (s *DocumentCollectionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
s.UpdatedAt = time.Now()
}
return nil
}

// DocumentSchemaTemplate represents the schema template for Document tables.
// MessageEmbedding is manually added when createDocumentTable is run in order to set the correct dimensions.
// This means the embedding is not returned when querying using bun.
Expand All @@ -152,6 +195,15 @@ type UserSchema struct {
Metadata map[string]interface{} `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"`
}

var _ bun.BeforeAppendModelHook = (*UserSchema)(nil)

func (u *UserSchema) BeforeAppendModel(_ context.Context, query bun.Query) error {
if _, ok := query.(*bun.UpdateQuery); ok {
u.UpdatedAt = time.Now()
}
return nil
}

// BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator
func (u *UserSchema) BeforeCreateTable(
_ context.Context,
Expand Down
46 changes: 43 additions & 3 deletions pkg/store/postgres/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package postgres

import (
"context"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/uptrace/bun"
)

func TestEnsurePostgresSchemaSetup(t *testing.T) {
Expand All @@ -26,13 +29,50 @@ func TestEnsurePostgresSchemaSetup(t *testing.T) {
}

func TestCreateDocumentTable(t *testing.T) {
ctx := context.Background()

collection := NewTestCollectionDAO(3)

tableName, err := generateDocumentTableName(&collection)
assert.NoError(t, err)

err = createDocumentTable(ctx, testDB, tableName, collection.EmbeddingDimensions)
err = createDocumentTable(testCtx, testDB, tableName, collection.EmbeddingDimensions)
assert.NoError(t, err)
}

func TestUpdatedAtIsSetAfterUpdate(t *testing.T) {
// Define a list of all schemas
schemas := []bun.BeforeAppendModelHook{
&SessionSchema{},
&MessageStoreSchema{},
&SummaryStoreSchema{},
&MessageVectorStoreSchema{},
&UserSchema{},
&DocumentCollectionSchema{},
}

// Iterate over all schemas
for _, schema := range schemas {
// Create a new instance of the schema
instance := reflect.New(reflect.TypeOf(schema).Elem()).Interface().(bun.BeforeAppendModelHook)

// Set the UpdatedAt field to a time far in the past
reflect.ValueOf(instance).
Elem().
FieldByName("UpdatedAt").
Set(reflect.ValueOf(time.Unix(0, 0)))

// Create a dummy UpdateQuery
updateQuery := &bun.UpdateQuery{}

// Call the BeforeAppendModel method, which should update the UpdatedAt field
err := instance.BeforeAppendModel(context.Background(), updateQuery)
assert.NoError(t, err)

// Check that the UpdatedAt field was updated
assert.True(
t,
reflect.ValueOf(instance).Elem().FieldByName("UpdatedAt").Interface().(time.Time).After(
time.Now().Add(-time.Minute),
),
)
}
}

0 comments on commit 0822b30

Please sign in to comment.