Skip to content
This repository was archived by the owner on Jun 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type GraphDB interface {
DeleteEdge(ctx context.Context, user User, ID string) error
NodeEdits(ctx context.Context, ID string) ([]*model.NodeEdit, error)
EdgeEdits(ctx context.Context, ID string) ([]*model.EdgeEdit, error)
NodeMatchFuzzy(ctx context.Context, substring string) ([]*model.Node, error)
}

type UserDB interface {
Expand Down
15 changes: 15 additions & 0 deletions db/db_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 45 additions & 1 deletion db/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,31 @@ type PostgresDB struct {
}

func (pg *PostgresDB) init() (db.DB, error) {
return pg, pg.db.AutoMigrate(
// Auto-migrate the models
err := pg.db.AutoMigrate(
&Node{}, &Edge{}, &NodeEdit{}, &EdgeEdit{}, &AuthenticationToken{}, &User{}, &Role{},
)
if err != nil {
return nil, err
}
err = pg.db.Exec("CREATE EXTENSION IF NOT EXISTS pg_trgm;").Error
if err != nil {
return nil, err
}
// TODO(skep): only 'en' language is indexed right now! do it for all languages dynamically?!
err = pg.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_nodes_description_text_trgm
ON nodes USING GIN ((description->>'en') gin_trgm_ops);
`).Error
if err != nil {
return nil, err
}
// At least 0.2 is needed since the typo "aplpe" has similarity of 0.2 for "Apple".
err = pg.db.Exec(`SET pg_trgm.similarity_threshold=0.2;`).Error
if err != nil {
return nil, err
}
return pg, nil
}

func removeArangoPrefix(s string) string {
Expand Down Expand Up @@ -626,3 +648,25 @@ func (pg *PostgresDB) EdgeEdits(ctx context.Context, ID string) ([]*model.EdgeEd
lang := middleware.CtxGetLanguage(ctx)
return NewConvertToModel(lang).EdgeEdits(edits), nil
}

func (pg *PostgresDB) NodeMatchFuzzy(ctx context.Context, substring string) ([]*model.Node, error) {
nodes := []Node{}
limit := 50 // TODO: adjust the limit
substring = strings.ToLower(substring)
err := pg.db.WithContext(ctx).
Select("*, similarity(description->>'en', ?) as sim", substring). // 'similarity' is pg_trgm operator
Where("(description->>'en') % ?", substring). // % is the similarity operator of pg_trgm
Order("sim DESC").
Limit(limit).
Find(&nodes).Error
if err != nil {
return nil, err
}
var result []*model.Node
converter := NewConvertToModel(middleware.CtxGetLanguage(ctx))
for _, n := range nodes {
result = append(result, converter.Node(n))
}

return result, nil
}
215 changes: 215 additions & 0 deletions db/postgres/postgres_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"reflect"
"strconv"
"testing"
"time"

Expand All @@ -27,6 +28,103 @@ func TestPostgresDB_NewPostgresDB(t *testing.T) {
assert.NoError(err)
}

func TestPostgresDB_Init(t *testing.T) {
for _, test := range []struct {
Name string
DropIndex bool
DropExtension bool
ExpectedError bool
ExpectedIndex bool
ExpectedExtension bool
}{
{
Name: "Initial setup with no prior index or extension",
DropIndex: false,
DropExtension: false,
ExpectedError: false,
ExpectedIndex: true,
ExpectedExtension: true,
},
{
Name: "Re-initialize with existing index and extension",
DropIndex: false,
DropExtension: false,
ExpectedError: false,
ExpectedIndex: true,
ExpectedExtension: true,
},
{
Name: "Re-initialize after dropping index",
DropIndex: true,
DropExtension: false,
ExpectedError: false,
ExpectedIndex: true,
ExpectedExtension: true,
},
{
Name: "Re-initialize after dropping extension",
DropIndex: false,
DropExtension: true,
ExpectedError: false,
ExpectedIndex: true,
ExpectedExtension: true,
},
{
Name: "Re-initialize after dropping both index and extension",
DropIndex: true,
DropExtension: true,
ExpectedError: false,
ExpectedIndex: true,
ExpectedExtension: true,
},
} {
t.Run(test.Name, func(t *testing.T) {
// Set up the database
pg := setupDB(t)
assert := assert.New(t)
// Optionally drop the index
if test.DropIndex {
err := pg.db.Exec("DROP INDEX IF EXISTS idx_nodes_description_text_trgm;").Error
assert.NoError(err)
}
// Optionally drop the extension
if test.DropExtension {
err := pg.db.Exec("DROP EXTENSION IF EXISTS pg_trgm CASCADE;").Error
assert.NoError(err)
}
// Re-initialize the database
_, err := pg.init()
if test.ExpectedError {
assert.Error(err)
return
} else {
assert.NoError(err)
}
// Verify that the extension exists
var extCount int
err = pg.db.Raw("SELECT COUNT(*) FROM pg_extension WHERE extname = 'pg_trgm';").Scan(&extCount).Error
assert.NoError(err)
if test.ExpectedExtension {
assert.Equal(1, extCount, "Extension pg_trgm should exist after init")
} else {
assert.Equal(0, extCount, "Extension pg_trgm should not exist")
}
// Verify that the index exists
var idxCount int
err = pg.db.Raw(`
SELECT COUNT(*) FROM pg_indexes
WHERE indexname = 'idx_nodes_description_text_trgm';
`).Scan(&idxCount).Error
assert.NoError(err)
if test.ExpectedIndex {
assert.Equal(1, idxCount, "Index idx_nodes_description_text_trgm should exist after init")
} else {
assert.Equal(0, idxCount, "Index idx_nodes_description_text_trgm should not exist")
}
})
}
}

func TestPostgresDB_CreateNode(t *testing.T) {
for _, test := range []struct {
Name string
Expand Down Expand Up @@ -1266,6 +1364,123 @@ func TestPostgresDB_EdgeEdits(t *testing.T) {
}
}

func TestPostgresDB_NodeMatchFuzzy(t *testing.T) {
// Define test cases
for _, test := range []struct {
Name string
Substring string
ExpectedNodeIDs []uint
ExpectedNodeDescs []string
NodesToCreate []Node
ExpError bool
}{
{
Name: "Exact Match",
Substring: "apple",
NodesToCreate: []Node{
{Description: db.Text{"en": "Apple"}},
{Description: db.Text{"en": "Banana"}},
{Description: db.Text{"en": "Grape"}},
},
ExpectedNodeIDs: []uint{1},
ExpectedNodeDescs: []string{"Apple"},
},
{
Name: "resist simple sql-injection",
Substring: "apple'",
NodesToCreate: []Node{
{Description: db.Text{"en": "Apple"}},
{Description: db.Text{"en": "Banana"}},
{Description: db.Text{"en": "Grape"}},
},
ExpectedNodeIDs: []uint{1},
ExpectedNodeDescs: []string{"Apple"},
},
{
Name: "Case Insensitive Match",
Substring: "banana",
NodesToCreate: []Node{
{Description: db.Text{"en": "Apple"}},
{Description: db.Text{"en": "BANANA"}},
{Description: db.Text{"en": "Grape"}},
},
ExpectedNodeIDs: []uint{2},
ExpectedNodeDescs: []string{"BANANA"},
},
{
Name: "Partial Match",
Substring: "app",
NodesToCreate: []Node{
{Description: db.Text{"en": "Apple"}},
{Description: db.Text{"en": "Application"}},
{Description: db.Text{"en": "Banana"}},
},
ExpectedNodeIDs: []uint{1, 2},
ExpectedNodeDescs: []string{"Apple", "Application"},
},
{
Name: "Fuzzy Match with Typo",
Substring: "copmuter", // Typo for "computer"
NodesToCreate: []Node{
{Description: db.Text{"en": "Computer"}},
{Description: db.Text{"en": "Computer Programming"}},
{Description: db.Text{"en": "Lol"}},
},
ExpectedNodeIDs: []uint{1, 2},
ExpectedNodeDescs: []string{"Computer", "Computer Programming"},
},
{
Name: "No Match",
Substring: "orange",
NodesToCreate: []Node{
{Description: db.Text{"en": "Apple"}},
{Description: db.Text{"en": "Banana"}},
{Description: db.Text{"en": "Grape"}},
},
ExpectedNodeIDs: []uint{},
ExpectedNodeDescs: []string{},
},
{
Name: "Multiple Matches with Order",
Substring: "berry",
NodesToCreate: []Node{
{Description: db.Text{"en": "Blueberry"}},
{Description: db.Text{"en": "Strawberry"}},
{Description: db.Text{"en": "Raspberry"}},
{Description: db.Text{"en": "Blackberry"}},
},
ExpectedNodeIDs: []uint{1, 4, 3, 2},
ExpectedNodeDescs: []string{"Blueberry", "Blackberry", "Raspberry", "Strawberry"},
},
} {
t.Run(test.Name, func(t *testing.T) {
// Set up the database
pg := setupDB(t)
ctx := middleware.TestingCtxNewWithLanguage(context.Background(), "en")
assert := assert.New(t)
for _, node := range test.NodesToCreate {
assert.NoError(pg.db.Create(&node).Error)
}
nodes, err := pg.NodeMatchFuzzy(ctx, test.Substring)
if test.ExpError {
assert.Error(err)
return
}
assert.NoError(err)
returnedNodeIDs := []uint{}
returnedNodeDescs := []string{}
for _, node := range nodes {
id, err := strconv.ParseUint(node.ID, 10, 64)
assert.NoError(err)
returnedNodeIDs = append(returnedNodeIDs, uint(id))
returnedNodeDescs = append(returnedNodeDescs, node.Description)
}
assert.Equal(test.ExpectedNodeIDs, returnedNodeIDs, "Node IDs should match")
assert.Equal(test.ExpectedNodeDescs, returnedNodeDescs, "Node descriptions should match")
})
}
}

// func TestPostgresDB_(t *testing.T) {
// for _, test := range []struct {
// Name string
Expand Down
8 changes: 8 additions & 0 deletions db/postgres/postgres_shared_testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@ func TESTONLY_SetupAndCleanup(t *testing.T) *PostgresDB {
pgdb, err := NewPostgresDB(TESTONLY_Config)
assert.NoError(err)
pg := pgdb.(*PostgresDB)
t.Cleanup(func() {
sqlDB, err := pg.db.DB()
if err == nil {
sqlDB.Close()
}
})
pg.db.Exec(`DROP TABLE IF EXISTS authentication_tokens CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS users CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS edge_edits CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS edges CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS node_edits CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS nodes CASCADE`)
pg.db.Exec(`DROP TABLE IF EXISTS roles CASCADE`)
pg.db.Exec(`DROP INDEX IF EXISTS idx_nodes_description_text_trgm;`)
pg.db.Exec(`DROP EXTENSION IF EXISTS pg_trgm CASCADE;`)
pgdb, err = NewPostgresDB(TESTONLY_Config)
assert.NoError(err)
pg = pgdb.(*PostgresDB)
Expand Down
Loading
Loading