Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Add roomserver tests (3/4) (#2447)
Browse files Browse the repository at this point in the history
* Add Room Aliases tests

* Add Rooms table test

* Move StateKeyTuplerSorter to the types package

* Add StateBlock tests
Some optimizations

* Add State Snapshot tests
Some optimization

* Return []int64 and convert to pq.Int64Array for postgres

* Move []types.EventNID back to rows.Next()

* Update tests, rename SelectRoomIDs
  • Loading branch information
S7evinK authored May 16, 2022
1 parent 6af3538 commit 05607d6
Show file tree
Hide file tree
Showing 22 changed files with 570 additions and 313 deletions.
6 changes: 3 additions & 3 deletions roomserver/storage/postgres/events_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions roomserver/storage/postgres/room_aliases_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}

func createRoomAliasesTable(db *sql.DB) error {
func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}

func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{}

return s, sqlutil.StatementList{
Expand Down Expand Up @@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")

var aliases []string
var alias string
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
return nil, err
}
Expand Down
16 changes: 8 additions & 8 deletions roomserver/storage/postgres/rooms_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt
}

func createRoomsTable(db *sql.DB) error {
func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}

func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{}

return s, sqlutil.StatementList{
Expand All @@ -117,16 +117,16 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}

func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
Expand Down Expand Up @@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
Expand All @@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
Expand All @@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
Expand Down
53 changes: 10 additions & 43 deletions roomserver/storage/postgres/state_block_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"context"
"database/sql"
"fmt"
"sort"

"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
Expand Down Expand Up @@ -71,12 +70,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}

func createStateBlockTable(db *sql.DB) error {
func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}

func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}

return s, sqlutil.StatementList{
Expand All @@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
var nids types.EventNIDs
for _, e := range entries {
nids = append(nids, e.EventNID)
nids := make(types.EventNIDs, entries.Len())
for i := range entries {
nids[i] = entries[i].EventNID
}
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
Expand All @@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(

results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
r := []types.EventNID{}
for _, e := range result {
r = append(r, types.EventNID(e))
r := make([]types.EventNID, len(result))
for x := range result {
r[x] = types.EventNID(result[x])
}
results[i] = r
}
Expand All @@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
}
return pq.Int64Array(nids)
}

type stateKeyTupleSorter []types.StateKeyTuple

func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}

// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}

type int64Sorter []int64

func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
86 changes: 0 additions & 86 deletions roomserver/storage/postgres/state_block_table_test.go

This file was deleted.

10 changes: 4 additions & 6 deletions roomserver/storage/postgres/state_snapshot_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}

func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}

func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}

return s, sqlutil.StatementList{
Expand All @@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)]
var id int64
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}

Expand All @@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
Expand Down
16 changes: 8 additions & 8 deletions roomserver/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
if err := createRoomsTable(db); err != nil {
if err := CreateRoomsTable(db); err != nil {
return err
}
if err := createStateBlockTable(db); err != nil {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
if err := createRoomAliasesTable(db); err != nil {
if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
Expand Down Expand Up @@ -128,23 +128,23 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
rooms, err := prepareRoomsTable(db)
rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
stateBlock, err := prepareStateBlockTable(db)
stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
prevEvents, err := PreparePrevEventsTable(db)
if err != nil {
return err
}
roomAliases, err := prepareRoomAliasesTable(db)
roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion roomserver/storage/shared/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin

// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx, nil)
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
}

// ForgetRoom sets a users room to forgotten
Expand Down
4 changes: 2 additions & 2 deletions roomserver/storage/sqlite3/events_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
for _, v := range eventNIDs {
Expand Down
Loading

0 comments on commit 05607d6

Please sign in to comment.