From a0fcefbd1a0cf519ca17c3dfdc783e995ef9f6ca Mon Sep 17 00:00:00 2001 From: Jaz Date: Sun, 4 Jun 2023 10:13:48 -0700 Subject: [PATCH] Batch persist writes and emits --- cmd/bigsky/main.go | 2 +- events/dbpersist.go | 115 +++++++++-- events/dbpersist_test.go | 304 ++++++++++++++++++++++++++++ testing/integ_test.go | 60 +++--- testing/labelmaker_fakedata_test.go | 20 +- testing/pds_fakedata_test.go | 4 +- testing/utils.go | 116 +++++------ 7 files changed, 508 insertions(+), 113 deletions(-) create mode 100644 events/dbpersist_test.go diff --git a/cmd/bigsky/main.go b/cmd/bigsky/main.go index e6fa785ba..f04dd811c 100644 --- a/cmd/bigsky/main.go +++ b/cmd/bigsky/main.go @@ -173,7 +173,7 @@ func run(args []string) { repoman := repomgr.NewRepoManager(db, cstore, kmgr) - dbp, err := events.NewDbPersistence(db, cstore) + dbp, err := events.NewDbPersistence(db, cstore, nil) if err != nil { return fmt.Errorf("setting up db event persistence: %w", err) } diff --git a/events/dbpersist.go b/events/dbpersist.go index eaf0d9999..0402a4df1 100644 --- a/events/dbpersist.go +++ b/events/dbpersist.go @@ -18,6 +18,25 @@ import ( "gorm.io/gorm" ) +type PersistenceBatchItem struct { + Record *RepoEventRecord + Event *XRPCStreamEvent +} + +type BatchOptions struct { + MaxBatchSize int + MinBatchSize int + MaxTimeBetweenFlush time.Duration +} + +func DefaultBatchOptions() *BatchOptions { + return &BatchOptions{ + MaxBatchSize: 200, + MinBatchSize: 10, + MaxTimeBetweenFlush: 500 * time.Millisecond, + } +} + type DbPersistence struct { db *gorm.DB @@ -26,6 +45,10 @@ type DbPersistence struct { lk sync.Mutex broadcast func(*XRPCStreamEvent) + + batch []*PersistenceBatchItem + batchOptions BatchOptions + lastFlush time.Time } type RepoEventRecord struct { @@ -42,21 +65,95 @@ type RepoEventRecord struct { Ops []byte } -func NewDbPersistence(db *gorm.DB, cs *carstore.CarStore) (*DbPersistence, error) { +func NewDbPersistence(db *gorm.DB, cs *carstore.CarStore, batchOptions *BatchOptions) (*DbPersistence, error) { if err := db.AutoMigrate(&RepoEventRecord{}); err != nil { return nil, err } - return &DbPersistence{ - db: db, - cs: cs, - }, nil + if batchOptions == nil { + batchOptions = DefaultBatchOptions() + } + + p := DbPersistence{ + db: db, + cs: cs, + batchOptions: *batchOptions, + batch: []*PersistenceBatchItem{}, + } + + go func() { + for { + time.Sleep(100 * time.Millisecond) + p.lk.Lock() + if len(p.batch) > 0 && + (len(p.batch) >= p.batchOptions.MinBatchSize || + time.Since(p.lastFlush) >= p.batchOptions.MaxTimeBetweenFlush) { + p.lk.Unlock() + if err := p.FlushBatch(context.Background()); err != nil { + log.Errorf("failed to flush batch: %s", err) + } + } else { + p.lk.Unlock() + } + } + }() + + return &p, nil } func (p *DbPersistence) SetEventBroadcaster(brc func(*XRPCStreamEvent)) { p.broadcast = brc } +func (p *DbPersistence) FlushBatch(ctx context.Context) error { + p.lk.Lock() + defer p.lk.Unlock() + + records := make([]*RepoEventRecord, len(p.batch)) + for i, item := range p.batch { + records[i] = item.Record + } + + if err := p.db.CreateInBatches(records, 50).Error; err != nil { + return fmt.Errorf("failed to create records: %w", err) + } + + for i, item := range records { + e := p.batch[i].Event + e.RepoCommit.Seq = int64(item.Seq) + p.broadcast(e) + } + + p.batch = []*PersistenceBatchItem{} + p.lastFlush = time.Now() + + return nil +} + +func (p *DbPersistence) AddItemToBatch(ctx context.Context, rec *RepoEventRecord, evt *XRPCStreamEvent) error { + p.lk.Lock() + if p.batch == nil { + p.batch = []*PersistenceBatchItem{} + } + + if len(p.batch) >= p.batchOptions.MaxBatchSize { + p.lk.Unlock() + if err := p.FlushBatch(ctx); err != nil { + return fmt.Errorf("failed to flush batch at max size: %w", err) + } + p.lk.Lock() + } + + p.batch = append(p.batch, &PersistenceBatchItem{ + Record: rec, + Event: evt, + }) + + p.lk.Unlock() + + return nil +} + func (p *DbPersistence) Persist(ctx context.Context, e *XRPCStreamEvent) error { if e.RepoCommit == nil { return nil @@ -110,16 +207,10 @@ func (p *DbPersistence) Persist(ctx context.Context, e *XRPCStreamEvent) error { } rer.Ops = opsb - p.lk.Lock() - defer p.lk.Unlock() - if err := p.db.Create(&rer).Error; err != nil { + if err := p.AddItemToBatch(ctx, &rer, e); err != nil { return err } - e.RepoCommit.Seq = int64(rer.Seq) - - p.broadcast(e) - return nil } diff --git a/events/dbpersist_test.go b/events/dbpersist_test.go new file mode 100644 index 000000000..4a45a6430 --- /dev/null +++ b/events/dbpersist_test.go @@ -0,0 +1,304 @@ +package events_test + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "sync" + "testing" + "time" + + atproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/api/bsky" + "github.com/bluesky-social/indigo/carstore" + "github.com/bluesky-social/indigo/events" + lexutil "github.com/bluesky-social/indigo/lex/util" + "github.com/bluesky-social/indigo/models" + "github.com/bluesky-social/indigo/pds" + "github.com/bluesky-social/indigo/repo" + "github.com/bluesky-social/indigo/repomgr" + intTesting "github.com/bluesky-social/indigo/testing" + "github.com/bluesky-social/indigo/util" + "github.com/ipfs/go-cid" + "github.com/ipfs/go-log/v2" + car "github.com/ipld/go-car" + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func init() { + log.SetAllLoggers(log.LevelDebug) +} + +func BenchmarkDBPersist(b *testing.B) { + ctx := context.Background() + + db, _, cs, tempPath, err := setupDBs(b) + if err != nil { + b.Fatal(err) + } + + db.AutoMigrate(&pds.User{}) + db.AutoMigrate(&pds.Peering{}) + db.AutoMigrate(&models.ActorInfo{}) + + db.Create(&models.ActorInfo{ + Uid: 1, + Did: "did:example:123", + }) + + mgr := repomgr.NewRepoManager(db, cs, &util.FakeKeyManager{}) + + err = mgr.InitNewActor(ctx, 1, "alice", "did:example:123", "Alice", "", "") + if err != nil { + b.Fatal(err) + } + + _, cid, err := mgr.CreateRecord(ctx, 1, "app.bsky.feed.post", &bsky.FeedPost{ + Text: "hello world", + CreatedAt: time.Now().Format(util.ISO8601), + }) + if err != nil { + b.Fatal(err) + } + + defer os.RemoveAll(tempPath) + + // Initialize a DBPersister + dbp, err := events.NewDbPersistence(db, cs, nil) + if err != nil { + b.Fatal(err) + } + + // Create a bunch of events + evtman := events.NewEventManager(dbp) + + userRepoHead, err := mgr.GetRepoRoot(ctx, 1) + if err != nil { + b.Fatal(err) + } + + inEvts := make([]*events.XRPCStreamEvent, b.N) + for i := 0; i < b.N; i++ { + cidLink := lexutil.LexLink(cid) + headLink := lexutil.LexLink(userRepoHead) + inEvts[i] = &events.XRPCStreamEvent{ + RepoCommit: &atproto.SyncSubscribeRepos_Commit{ + Repo: "did:example:123", + Commit: headLink, + Ops: []*atproto.SyncSubscribeRepos_RepoOp{ + { + Action: "add", + Cid: &cidLink, + Path: "path1", + }, + }, + Time: time.Now().Format(util.ISO8601), + }, + } + } + + numRoutines := 5 + wg := sync.WaitGroup{} + + b.ResetTimer() + + errChan := make(chan error, numRoutines) + + // Add events in parallel + for i := 0; i < numRoutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < b.N; i++ { + err = evtman.AddEvent(ctx, inEvts[i]) + if err != nil { + errChan <- err + } + } + }() + } + + wg.Wait() + close(errChan) + + // Check for errors + for err := range errChan { + if err != nil { + b.Fatal(err) + } + } + + outEvtCount := 0 + expectedEvtCount := b.N * numRoutines + + // Flush manually + err = dbp.FlushBatch(ctx) + if err != nil { + b.Fatal(err) + } + + b.StopTimer() + + dbp.Playback(ctx, 0, func(evt *events.XRPCStreamEvent) error { + outEvtCount++ + return nil + }) + + if outEvtCount != expectedEvtCount { + b.Fatalf("expected %d events, got %d", expectedEvtCount, outEvtCount) + } +} + +func TestDBPersist(t *testing.T) { + if testing.Short() { + t.Skip("skipping BGS test in 'short' test mode") + } + assert := assert.New(t) + didr := intTesting.TestPLC(t) + p1 := intTesting.MustSetupPDS(t, "localhost:5155", ".tpds", didr) + p1.Run(t) + + b1 := intTesting.MustSetupBGS(t, "localhost:8231", didr) + b1.Run(t) + + p1.RequestScraping(t, b1) + + time.Sleep(time.Millisecond * 50) + + evts := b1.Events(t, -1) + defer evts.Cancel() + + bob := p1.MustNewUser(t, "bob.tpds") + alice := p1.MustNewUser(t, "alice.tpds") + + bp1 := bob.Post(t, "cats for cats") + ap1 := alice.Post(t, "no i like dogs") + + _ = bp1 + _ = ap1 + + fmt.Println("bob:", bob.DID()) + fmt.Println("alice:", alice.DID()) + + fmt.Println("event 1") + e1 := evts.Next() + assert.NotNil(e1.RepoCommit) + assert.Equal(e1.RepoCommit.Repo, bob.DID()) + + fmt.Println("event 2") + e2 := evts.Next() + assert.NotNil(e2.RepoCommit) + assert.Equal(e2.RepoCommit.Repo, alice.DID()) + + fmt.Println("event 3") + e3 := evts.Next() + assert.Equal(e3.RepoCommit.Repo, bob.DID()) + //assert.Equal(e3.RepoCommit.Ops[0].Kind, "createRecord") + + fmt.Println("event 4") + e4 := evts.Next() + assert.Equal(e4.RepoCommit.Repo, alice.DID()) + //assert.Equal(e4.RepoCommit.Ops[0].Kind, "createRecord") + + // playback + pbevts := b1.Events(t, 2) + defer pbevts.Cancel() + + fmt.Println("event 5") + pbe1 := pbevts.Next() + assert.Equal(*e3, *pbe1) +} + +func setupDBs(t testing.TB) (*gorm.DB, *gorm.DB, *carstore.CarStore, string, error) { + dir, err := os.MkdirTemp("", "integtest") + if err != nil { + return nil, nil, nil, "", err + } + + maindb, err := gorm.Open(sqlite.Open(filepath.Join(dir, "test.sqlite?cache=shared&mode=rwc"))) + if err != nil { + return nil, nil, nil, "", err + } + + tx := maindb.Exec("PRAGMA journal_mode=WAL;") + if tx.Error != nil { + return nil, nil, nil, "", tx.Error + } + + tx.Commit() + + cardb, err := gorm.Open(sqlite.Open(filepath.Join(dir, "car.sqlite"))) + if err != nil { + return nil, nil, nil, "", err + } + + cspath := filepath.Join(dir, "carstore") + if err := os.Mkdir(cspath, 0775); err != nil { + return nil, nil, nil, "", err + } + + cs, err := carstore.NewCarStore(cardb, cspath) + if err != nil { + return nil, nil, nil, "", err + } + + return maindb, cardb, cs, "", nil +} + +func randomFollows(t *testing.T, users []*intTesting.TestUser) { + for n := 0; n < 3; n++ { + for i, u := range users { + oi := rand.Intn(len(users)) + if i == oi { + continue + } + + u.Follow(t, users[oi].DID()) + } + } +} + +func socialSim(t *testing.T, users []*intTesting.TestUser, postiter, likeiter int) []*atproto.RepoStrongRef { + var posts []*atproto.RepoStrongRef + for i := 0; i < postiter; i++ { + for _, u := range users { + posts = append(posts, u.Post(t, intTesting.MakeRandomPost())) + } + } + + for i := 0; i < likeiter; i++ { + for _, u := range users { + u.Like(t, posts[rand.Intn(len(posts))]) + } + } + + return posts +} + +func commitFromSlice(t *testing.T, slice []byte, rcid cid.Cid) *repo.SignedCommit { + carr, err := car.NewCarReader(bytes.NewReader(slice)) + if err != nil { + t.Fatal(err) + } + + for { + blk, err := carr.Next() + if err != nil { + t.Fatal(err) + } + + if blk.Cid() == rcid { + + var sc repo.SignedCommit + if err := sc.UnmarshalCBOR(bytes.NewReader(blk.RawData())); err != nil { + t.Fatal(err) + } + return &sc + } + } +} diff --git a/testing/integ_test.go b/testing/integ_test.go index e022bf33d..e829d175e 100644 --- a/testing/integ_test.go +++ b/testing/integ_test.go @@ -25,11 +25,11 @@ func TestBGSBasic(t *testing.T) { t.Skip("skipping BGS test in 'short' test mode") } assert := assert.New(t) - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5155", ".tpds", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5155", ".tpds", didr) p1.Run(t) - b1 := mustSetupBGS(t, "localhost:8231", didr) + b1 := MustSetupBGS(t, "localhost:8231", didr) b1.Run(t) p1.RequestScraping(t, b1) @@ -37,7 +37,7 @@ func TestBGSBasic(t *testing.T) { time.Sleep(time.Millisecond * 50) evts := b1.Events(t, -1) - defer evts.cancel() + defer evts.Cancel() bob := p1.MustNewUser(t, "bob.tpds") alice := p1.MustNewUser(t, "alice.tpds") @@ -73,14 +73,14 @@ func TestBGSBasic(t *testing.T) { // playback pbevts := b1.Events(t, 2) - defer pbevts.cancel() + defer pbevts.Cancel() fmt.Println("event 5") pbe1 := pbevts.Next() assert.Equal(*e3, *pbe1) } -func randomFollows(t *testing.T, users []*testUser) { +func randomFollows(t *testing.T, users []*TestUser) { for n := 0; n < 3; n++ { for i, u := range users { oi := rand.Intn(len(users)) @@ -93,11 +93,11 @@ func randomFollows(t *testing.T, users []*testUser) { } } -func socialSim(t *testing.T, users []*testUser, postiter, likeiter int) []*atproto.RepoStrongRef { +func socialSim(t *testing.T, users []*TestUser, postiter, likeiter int) []*atproto.RepoStrongRef { var posts []*atproto.RepoStrongRef for i := 0; i < postiter; i++ { for _, u := range users { - posts = append(posts, u.Post(t, makeRandomPost())) + posts = append(posts, u.Post(t, MakeRandomPost())) } } @@ -118,20 +118,20 @@ func TestBGSMultiPDS(t *testing.T) { assert := assert.New(t) _ = assert - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5185", ".pdsuno", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5185", ".pdsuno", didr) p1.Run(t) - p2 := mustSetupPDS(t, "localhost:5186", ".pdsdos", didr) + p2 := MustSetupPDS(t, "localhost:5186", ".pdsdos", didr) p2.Run(t) - b1 := mustSetupBGS(t, "localhost:8281", didr) + b1 := MustSetupBGS(t, "localhost:8281", didr) b1.Run(t) p1.RequestScraping(t, b1) time.Sleep(time.Millisecond * 100) - var users []*testUser + var users []*TestUser for i := 0; i < 5; i++ { users = append(users, p1.MustNewUser(t, usernames[i]+".pdsuno")) } @@ -139,7 +139,7 @@ func TestBGSMultiPDS(t *testing.T) { randomFollows(t, users) socialSim(t, users, 10, 10) - var users2 []*testUser + var users2 []*TestUser for i := 0; i < 5; i++ { users2 = append(users2, p2.MustNewUser(t, usernames[i+5]+".pdsdos")) } @@ -182,24 +182,24 @@ func TestBGSMultiGap(t *testing.T) { //t.Skip("test too sleepy to run in CI for now") assert := assert.New(t) _ = assert - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5195", ".pdsuno", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5195", ".pdsuno", didr) p1.Run(t) - p2 := mustSetupPDS(t, "localhost:5196", ".pdsdos", didr) + p2 := MustSetupPDS(t, "localhost:5196", ".pdsdos", didr) p2.Run(t) - b1 := mustSetupBGS(t, "localhost:8291", didr) + b1 := MustSetupBGS(t, "localhost:8291", didr) b1.Run(t) p1.RequestScraping(t, b1) time.Sleep(time.Millisecond * 50) - users := []*testUser{p1.MustNewUser(t, usernames[0]+".pdsuno")} + users := []*TestUser{p1.MustNewUser(t, usernames[0]+".pdsuno")} socialSim(t, users, 10, 0) - users2 := []*testUser{p2.MustNewUser(t, usernames[1]+".pdsdos")} + users2 := []*TestUser{p2.MustNewUser(t, usernames[1]+".pdsdos")} p2posts := socialSim(t, users2, 10, 0) @@ -239,11 +239,11 @@ func TestHandleChange(t *testing.T) { //t.Skip("test too sleepy to run in CI for now") assert := assert.New(t) _ = assert - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5385", ".pdsuno", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5385", ".pdsuno", didr) p1.Run(t) - b1 := mustSetupBGS(t, "localhost:8391", didr) + b1 := MustSetupBGS(t, "localhost:8391", didr) b1.Run(t) p1.RequestScraping(t, b1) @@ -272,11 +272,11 @@ func TestBGSTakedown(t *testing.T) { assert := assert.New(t) _ = assert - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5151", ".tpds", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5151", ".tpds", didr) p1.Run(t) - b1 := mustSetupBGS(t, "localhost:3231", didr) + b1 := MustSetupBGS(t, "localhost:3231", didr) b1.Run(t) p1.RequestScraping(t, b1) @@ -323,11 +323,11 @@ func TestRebase(t *testing.T) { t.Skip("skipping BGS test in 'short' test mode") } assert := assert.New(t) - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:9155", ".tpds", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:9155", ".tpds", didr) p1.Run(t) - b1 := mustSetupBGS(t, "localhost:1531", didr) + b1 := MustSetupBGS(t, "localhost:1531", didr) b1.Run(t) p1.RequestScraping(t, b1) @@ -344,7 +344,7 @@ func TestRebase(t *testing.T) { time.Sleep(time.Millisecond * 100) evts1 := b1.Events(t, 0) - defer evts1.cancel() + defer evts1.Cancel() preRebaseEvts := evts1.WaitFor(5) fmt.Println(preRebaseEvts) diff --git a/testing/labelmaker_fakedata_test.go b/testing/labelmaker_fakedata_test.go index 1a26fb1c2..72bfaf83b 100644 --- a/testing/labelmaker_fakedata_test.go +++ b/testing/labelmaker_fakedata_test.go @@ -69,7 +69,7 @@ func testLabelMaker(t *testing.T) *labeler.Server { return lm } -func labelEvents(t *testing.T, lm *labeler.Server, since int64) *eventStream { +func labelEvents(t *testing.T, lm *labeler.Server, since int64) *EventStream { d := websocket.Dialer{} h := http.Header{} bgsHost := "localhost:1234" @@ -90,8 +90,8 @@ func labelEvents(t *testing.T, lm *labeler.Server, since int64) *eventStream { ctx, cancel := context.WithCancel(context.Background()) - es := &eventStream{ - cancel: cancel, + es := &EventStream{ + Cancel: cancel, } go func() { @@ -103,9 +103,9 @@ func labelEvents(t *testing.T, lm *labeler.Server, since int64) *eventStream { rsc := &events.RepoStreamCallbacks{ LabelLabels: func(evt *label.SubscribeLabels_Labels) error { fmt.Println("received event: ", evt.Seq) - es.lk.Lock() - es.events = append(es.events, &events.XRPCStreamEvent{LabelLabels: evt}) - es.lk.Unlock() + es.Lk.Lock() + es.Events = append(es.Events, &events.XRPCStreamEvent{LabelLabels: evt}) + es.Lk.Unlock() return nil }, } @@ -128,11 +128,11 @@ func TestLabelmakerBasic(t *testing.T) { assert := assert.New(t) _ = assert ctx := context.TODO() - didr := testPLC(t) - p1 := mustSetupPDS(t, "localhost:5115", ".tpds", didr) + didr := TestPLC(t) + p1 := MustSetupPDS(t, "localhost:5115", ".tpds", didr) p1.Run(t) - b1 := mustSetupBGS(t, "localhost:8322", didr) + b1 := MustSetupBGS(t, "localhost:8322", didr) b1.Run(t) p1.RequestScraping(t, b1) @@ -145,7 +145,7 @@ func TestLabelmakerBasic(t *testing.T) { time.Sleep(time.Millisecond * 50) evts := b1.Events(t, -1) - defer evts.cancel() + defer evts.Cancel() bob := p1.MustNewUser(t, "bob.tpds") alice := p1.MustNewUser(t, "alice.tpds") diff --git a/testing/pds_fakedata_test.go b/testing/pds_fakedata_test.go index 31c30ba7f..44ce2eb2b 100644 --- a/testing/pds_fakedata_test.go +++ b/testing/pds_fakedata_test.go @@ -59,8 +59,8 @@ func TestPDSFakedata(t *testing.T) { t.Skip("skipping PDS+fakedata test in 'short' test mode") } assert := assert.New(t) - plcc := testPLC(t) - pds := mustSetupPDS(t, "localhost:5159", ".test", plcc) + plcc := TestPLC(t) + pds := MustSetupPDS(t, "localhost:5159", ".test", plcc) pds.Run(t) time.Sleep(time.Millisecond * 50) diff --git a/testing/utils.go b/testing/utils.go index 9080dc750..c89f6ef2d 100644 --- a/testing/utils.go +++ b/testing/utils.go @@ -40,7 +40,7 @@ import ( "gorm.io/gorm" ) -type testPDS struct { +type TestPDS struct { dir string server *pds.Server plc *api.PLCServer @@ -50,7 +50,7 @@ type testPDS struct { shutdown func() } -func (tp *testPDS) Cleanup() { +func (tp *TestPDS) Cleanup() { if tp.shutdown != nil { tp.shutdown() } @@ -60,7 +60,7 @@ func (tp *testPDS) Cleanup() { } } -func mustSetupPDS(t *testing.T, host, suffix string, plc plc.PLCClient) *testPDS { +func MustSetupPDS(t *testing.T, host, suffix string, plc plc.PLCClient) *TestPDS { t.Helper() tpds, err := SetupPDS(host, suffix, plc) @@ -71,7 +71,7 @@ func mustSetupPDS(t *testing.T, host, suffix string, plc plc.PLCClient) *testPDS return tpds } -func SetupPDS(host, suffix string, plc plc.PLCClient) (*testPDS, error) { +func SetupPDS(host, suffix string, plc plc.PLCClient) (*TestPDS, error) { dir, err := os.MkdirTemp("", "integtest") if err != nil { return nil, err @@ -111,14 +111,14 @@ func SetupPDS(host, suffix string, plc plc.PLCClient) (*testPDS, error) { return nil, err } - return &testPDS{ + return &TestPDS{ dir: dir, server: srv, host: host, }, nil } -func (tp *testPDS) Run(t *testing.T) { +func (tp *TestPDS) Run(t *testing.T) { // TODO: rig this up so it t.Fatals if the RunAPI call fails immediately go func() { if err := tp.server.RunAPI(tp.host); err != nil { @@ -132,7 +132,7 @@ func (tp *testPDS) Run(t *testing.T) { } } -func (tp *testPDS) RequestScraping(t *testing.T, b *testBGS) { +func (tp *TestPDS) RequestScraping(t *testing.T, b *TestBGS) { t.Helper() c := &xrpc.Client{Host: "http://" + b.host} @@ -141,15 +141,15 @@ func (tp *testPDS) RequestScraping(t *testing.T, b *testBGS) { } } -type testUser struct { +type TestUser struct { handle string - pds *testPDS + pds *TestPDS did string client *xrpc.Client } -func (tp *testPDS) MustNewUser(t *testing.T, handle string) *testUser { +func (tp *TestPDS) MustNewUser(t *testing.T, handle string) *TestUser { t.Helper() u, err := tp.NewUser(handle) @@ -160,7 +160,7 @@ func (tp *testPDS) MustNewUser(t *testing.T, handle string) *testUser { return u } -func (tp *testPDS) NewUser(handle string) (*testUser, error) { +func (tp *TestPDS) NewUser(handle string) (*TestUser, error) { ctx := context.TODO() c := &xrpc.Client{ @@ -184,7 +184,7 @@ func (tp *testPDS) NewUser(handle string) (*testUser, error) { Did: out.Did, } - return &testUser{ + return &TestUser{ pds: tp, handle: out.Handle, client: c, @@ -192,7 +192,7 @@ func (tp *testPDS) NewUser(handle string) (*testUser, error) { }, nil } -func (u *testUser) Reply(t *testing.T, replyto, root *atproto.RepoStrongRef, body string) string { +func (u *TestUser) Reply(t *testing.T, replyto, root *atproto.RepoStrongRef, body string) string { t.Helper() ctx := context.TODO() @@ -215,11 +215,11 @@ func (u *testUser) Reply(t *testing.T, replyto, root *atproto.RepoStrongRef, bod return resp.Uri } -func (u *testUser) DID() string { +func (u *TestUser) DID() string { return u.did } -func (u *testUser) Post(t *testing.T, body string) *atproto.RepoStrongRef { +func (u *TestUser) Post(t *testing.T, body string) *atproto.RepoStrongRef { t.Helper() ctx := context.TODO() @@ -242,7 +242,7 @@ func (u *testUser) Post(t *testing.T, body string) *atproto.RepoStrongRef { } } -func (u *testUser) Like(t *testing.T, post *atproto.RepoStrongRef) { +func (u *TestUser) Like(t *testing.T, post *atproto.RepoStrongRef) { t.Helper() ctx := context.TODO() @@ -261,7 +261,7 @@ func (u *testUser) Like(t *testing.T, post *atproto.RepoStrongRef) { } -func (u *testUser) Follow(t *testing.T, did string) string { +func (u *TestUser) Follow(t *testing.T, did string) string { t.Helper() ctx := context.TODO() @@ -281,7 +281,7 @@ func (u *testUser) Follow(t *testing.T, did string) string { return resp.Uri } -func (u *testUser) GetFeed(t *testing.T) []*bsky.FeedDefs_FeedViewPost { +func (u *TestUser) GetFeed(t *testing.T) []*bsky.FeedDefs_FeedViewPost { t.Helper() ctx := context.TODO() @@ -293,7 +293,7 @@ func (u *testUser) GetFeed(t *testing.T) []*bsky.FeedDefs_FeedViewPost { return resp.Feed } -func (u *testUser) GetNotifs(t *testing.T) []*bsky.NotificationListNotifications_Notification { +func (u *TestUser) GetNotifs(t *testing.T) []*bsky.NotificationListNotifications_Notification { t.Helper() ctx := context.TODO() @@ -305,7 +305,7 @@ func (u *testUser) GetNotifs(t *testing.T) []*bsky.NotificationListNotifications return resp.Notifications } -func (u *testUser) ChangeHandle(t *testing.T, nhandle string) { +func (u *TestUser) ChangeHandle(t *testing.T, nhandle string) { t.Helper() ctx := context.TODO() @@ -316,7 +316,7 @@ func (u *testUser) ChangeHandle(t *testing.T, nhandle string) { } } -func (u *testUser) DoRebase(t *testing.T) { +func (u *TestUser) DoRebase(t *testing.T) { t.Helper() ctx := context.TODO() @@ -328,7 +328,7 @@ func (u *testUser) DoRebase(t *testing.T) { } } -func testPLC(t *testing.T) *plc.FakeDid { +func TestPLC(t *testing.T) *plc.FakeDid { // TODO: just do in memory... tdir, err := os.MkdirTemp("", "plcserv") if err != nil { @@ -342,12 +342,12 @@ func testPLC(t *testing.T) *plc.FakeDid { return plc.NewFakeDid(db) } -type testBGS struct { +type TestBGS struct { bgs *bgs.BGS host string } -func mustSetupBGS(t *testing.T, host string, didr plc.PLCClient) *testBGS { +func MustSetupBGS(t *testing.T, host string, didr plc.PLCClient) *TestBGS { tbgs, err := SetupBGS(host, didr) if err != nil { t.Fatal(err) @@ -356,7 +356,7 @@ func mustSetupBGS(t *testing.T, host string, didr plc.PLCClient) *testBGS { return tbgs } -func SetupBGS(host string, didr plc.PLCClient) (*testBGS, error) { +func SetupBGS(host string, didr plc.PLCClient) (*TestBGS, error) { dir, err := os.MkdirTemp("", "integtest") if err != nil { return nil, err @@ -389,7 +389,7 @@ func SetupBGS(host string, didr plc.PLCClient) (*testBGS, error) { notifman := notifs.NewNotificationManager(maindb, repoman.GetRecord) - dbpersist, err := events.NewDbPersistence(maindb, cs) + dbpersist, err := events.NewDbPersistence(maindb, cs, nil) if err != nil { return nil, err } @@ -412,13 +412,13 @@ func SetupBGS(host string, didr plc.PLCClient) (*testBGS, error) { return nil, err } - return &testBGS{ + return &TestBGS{ bgs: b, host: host, }, nil } -func (b *testBGS) Run(t *testing.T) { +func (b *TestBGS) Run(t *testing.T) { go func() { if err := b.bgs.Start(b.host); err != nil { fmt.Println(err) @@ -427,15 +427,15 @@ func (b *testBGS) Run(t *testing.T) { time.Sleep(time.Millisecond * 10) } -type eventStream struct { - lk sync.Mutex - events []*events.XRPCStreamEvent - cancel func() +type EventStream struct { + Lk sync.Mutex + Events []*events.XRPCStreamEvent + Cancel func() - cur int + Cur int } -func (b *testBGS) Events(t *testing.T, since int64) *eventStream { +func (b *TestBGS) Events(t *testing.T, since int64) *EventStream { d := websocket.Dialer{} h := http.Header{} @@ -455,8 +455,8 @@ func (b *testBGS) Events(t *testing.T, since int64) *eventStream { ctx, cancel := context.WithCancel(context.Background()) - es := &eventStream{ - cancel: cancel, + es := &EventStream{ + Cancel: cancel, } go func() { @@ -468,16 +468,16 @@ func (b *testBGS) Events(t *testing.T, since int64) *eventStream { rsc := &events.RepoStreamCallbacks{ RepoCommit: func(evt *atproto.SyncSubscribeRepos_Commit) error { fmt.Println("received event: ", evt.Seq, evt.Repo) - es.lk.Lock() - es.events = append(es.events, &events.XRPCStreamEvent{RepoCommit: evt}) - es.lk.Unlock() + es.Lk.Lock() + es.Events = append(es.Events, &events.XRPCStreamEvent{RepoCommit: evt}) + es.Lk.Unlock() return nil }, RepoHandle: func(evt *atproto.SyncSubscribeRepos_Handle) error { fmt.Println("received handle event: ", evt.Seq, evt.Did) - es.lk.Lock() - es.events = append(es.events, &events.XRPCStreamEvent{RepoHandle: evt}) - es.lk.Unlock() + es.Lk.Lock() + es.Events = append(es.Events, &events.XRPCStreamEvent{RepoHandle: evt}) + es.Lk.Unlock() return nil }, } @@ -489,31 +489,31 @@ func (b *testBGS) Events(t *testing.T, since int64) *eventStream { return es } -func (es *eventStream) Next() *events.XRPCStreamEvent { - defer es.lk.Unlock() +func (es *EventStream) Next() *events.XRPCStreamEvent { + defer es.Lk.Unlock() for { - es.lk.Lock() - if len(es.events) > es.cur { - es.cur++ - return es.events[es.cur-1] + es.Lk.Lock() + if len(es.Events) > es.Cur { + es.Cur++ + return es.Events[es.Cur-1] } - es.lk.Unlock() + es.Lk.Unlock() time.Sleep(time.Millisecond * 10) } } -func (es *eventStream) All() []*events.XRPCStreamEvent { - es.lk.Lock() - defer es.lk.Unlock() - out := make([]*events.XRPCStreamEvent, len(es.events)) - for i, e := range es.events { +func (es *EventStream) All() []*events.XRPCStreamEvent { + es.Lk.Lock() + defer es.Lk.Unlock() + out := make([]*events.XRPCStreamEvent, len(es.Events)) + for i, e := range es.Events { out[i] = e } return out } -func (es *eventStream) WaitFor(n int) []*events.XRPCStreamEvent { +func (es *EventStream) WaitFor(n int) []*events.XRPCStreamEvent { var out []*events.XRPCStreamEvent for i := 0; i < n; i++ { out = append(out, es.Next()) @@ -588,7 +588,7 @@ var words = []string{ "parrot", } -func makeRandomPost() string { +func MakeRandomPost() string { var out []string for i := 0; i < 20; i++ { out = append(out, words[mathrand.Intn(len(words))]) @@ -673,7 +673,7 @@ func RandFakeAtUri(collection, rkey string) string { return fmt.Sprintf("at://did:plc:%s/%s/%s", did, collection, rkey) } -func randAction() string { +func RandAction() string { v := mathrand.Intn(100) if v < 40 { return "post" @@ -696,7 +696,7 @@ func GenerateFakeRepo(r *repo.Repo, size int) (cid.Cid, error) { var root cid.Cid for i := 0; i < size; i++ { - switch randAction() { + switch RandAction() { case "post": _, _, err := r.CreateRecord(ctx, "app.bsky.feed.post", &bsky.FeedPost{ CreatedAt: time.Now().Format(bsutil.ISO8601),