Skip to content

Commit

Permalink
Init on creation, better transaction creation
Browse files Browse the repository at this point in the history
  • Loading branch information
mhutchinson committed Mar 23, 2023
1 parent 8b3fbc5 commit 10da4f9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 46 deletions.
45 changes: 24 additions & 21 deletions distributor/cmd/internal/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ type LogInfo struct {
// NewDistributor returns a distributor that will accept checkpoints from
// the given witnesses, for the given logs, and persist its state in the
// database provided. Callers must call Init() on the returned distributor.
func NewDistributor(ws map[string]note.Verifier, ls map[string]LogInfo, db *sql.DB) *Distributor {
return &Distributor{
// `ws` is a map from witness ID (verifier key name) to the note verifier.
// `ls` is a map from log ID (github.com/transparency-dev/formats/log.ID) to log info.
func NewDistributor(ws map[string]note.Verifier, ls map[string]LogInfo, db *sql.DB) (*Distributor, error) {
d := &Distributor{
ws: ws,
ls: ls,
db: db,
}
return d, d.init()
}

// Distributor persists witnessed checkpoints and allows querying of them.
Expand All @@ -57,22 +60,6 @@ type Distributor struct {
db *sql.DB
}

// Init ensures that the database is in good order. This must be called before
// any other method on this object. It is safe to call on subsequent runs of
// the application as it is idempotent.
func (d *Distributor) Init() error {
if _, err := d.db.Exec(`CREATE TABLE IF NOT EXISTS chkpts (
logID BLOB,
witID BLOB,
treeSize INTEGER,
chkpt BLOB,
PRIMARY KEY (logID, witID)
)`); err != nil {
return err
}
return nil
}

// GetLogs returns a list of all log IDs the distributor is aware of, sorted
// by the ID.
func (d *Distributor) GetLogs(ctx context.Context) ([]string, error) {
Expand All @@ -90,7 +77,7 @@ func (d *Distributor) GetCheckpointN(ctx context.Context, logID string, n uint32
if !ok {
return nil, fmt.Errorf("unknown log ID %q", logID)
}
tx, err := d.db.BeginTx(ctx, nil)
tx, err := d.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %v", err)
}
Expand Down Expand Up @@ -135,7 +122,7 @@ func (d *Distributor) GetCheckpointN(ctx context.Context, logID string, n uint32

// GetCheckpointWitness gets the largest checkpoint for the log that was witnessed by the given witness.
func (d *Distributor) GetCheckpointWitness(ctx context.Context, logID, witID string) ([]byte, error) {
tx, err := d.db.BeginTx(ctx, nil)
tx, err := d.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %v", err)
}
Expand Down Expand Up @@ -165,7 +152,7 @@ func (d *Distributor) Distribute(ctx context.Context, logID, witID string, nextR
// This is a valid checkpoint for this log for this witness
// Now find the previous checkpoint if one exists.

tx, err := d.db.BeginTx(ctx, nil)
tx, err := d.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: false})
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
Expand Down Expand Up @@ -210,6 +197,22 @@ func (d *Distributor) Distribute(ctx context.Context, logID, witID string, nextR
return saveCheckpointFn()
}

// init ensures that the database is in good order. This must be called before
// any other method on this object. It is safe to call on subsequent runs of
// the application as it is idempotent.
func (d *Distributor) init() error {
if _, err := d.db.Exec(`CREATE TABLE IF NOT EXISTS chkpts (
logID BLOB,
witID BLOB,
treeSize INTEGER,
chkpt BLOB,
PRIMARY KEY (logID, witID)
)`); err != nil {
return err
}
return nil
}

// getLatestCheckpoint returns the latest checkpoint for the given log and witness pair.
// If no checkpoint is found then an error with status `codes.NotFound` will be returned,
// which allows callers to handle this case separately if needed.
Expand Down
60 changes: 35 additions & 25 deletions distributor/cmd/internal/distributor/distributor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,17 @@ func TestGetLogs(t *testing.T) {
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqlitedb, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open temporary in-memory DB: %v", err)
}
d := distributor.NewDistributor(ws, tC.logs, sqlitedb)
if err := d.Init(); err != nil {
t.Fatalf("Init(): %v", err)
d, err := distributor.NewDistributor(ws, tC.logs, sqlitedb)
if err != nil {
t.Fatalf("NewDistributor(): %v", err)
}
got, err := d.GetLogs(context.Background())
got, err := d.GetLogs(ctx)
if err != nil {
t.Errorf("GetLogs(): %v", err)
}
Expand Down Expand Up @@ -181,17 +183,19 @@ func TestDistributeLogAndWitnessMustMatchCheckpoint(t *testing.T) {
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqlitedb, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open temporary in-memory DB: %v", err)
}
d := distributor.NewDistributor(ws, ls, sqlitedb)
if err := d.Init(); err != nil {
t.Fatalf("Init(): %v", err)
d, err := distributor.NewDistributor(ws, ls, sqlitedb)
if err != nil {
t.Fatalf("NewDistributor(): %v", err)
}

logCP16 := tC.log.checkpoint(16, "16", tC.wit.signer)
err = d.Distribute(context.Background(), tC.reqLogID, tC.reqWitID, logCP16)
err = d.Distribute(ctx, tC.reqLogID, tC.reqWitID, logCP16)
if (err != nil) != tC.wantErr {
t.Errorf("unexpected error output (wantErr: %t): %v", tC.wantErr, err)
}
Expand Down Expand Up @@ -285,20 +289,22 @@ func TestDistributeEvolution(t *testing.T) {
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqlitedb, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open temporary in-memory DB: %v", err)
}
d := distributor.NewDistributor(ws, ls, sqlitedb)
if err := d.Init(); err != nil {
t.Fatalf("Init(): %v", err)
d, err := distributor.NewDistributor(ws, ls, sqlitedb)
if err != nil {
t.Fatalf("NewDistributor(): %v", err)
}
err = d.Distribute(context.Background(), "FooLog", "Whittle", logFoo.checkpoint(16, "16", witWhittle.signer))
err = d.Distribute(ctx, "FooLog", "Whittle", logFoo.checkpoint(16, "16", witWhittle.signer))
if err != nil {
t.Fatalf("Distribute(): %v", err)
}

err = d.Distribute(context.Background(), tC.log.Verifier.Name(), tC.wit.verifier.Name(), tC.log.checkpoint(tC.size, tC.hashSeed, tC.wit.signer))
err = d.Distribute(ctx, tC.log.Verifier.Name(), tC.wit.verifier.Name(), tC.log.checkpoint(tC.size, tC.hashSeed, tC.wit.signer))
if (err != nil) != tC.wantErr {
t.Errorf("unexpected error output (wantErr: %t): %v", tC.wantErr, err)
}
Expand Down Expand Up @@ -350,21 +356,23 @@ func TestGetCheckpointWitness(t *testing.T) {
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqlitedb, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open temporary in-memory DB: %v", err)
}
d := distributor.NewDistributor(ws, ls, sqlitedb)
if err := d.Init(); err != nil {
t.Fatalf("Init(): %v", err)
d, err := distributor.NewDistributor(ws, ls, sqlitedb)
if err != nil {
t.Fatalf("NewDistributor(): %v", err)
}
writeCP := logFoo.checkpoint(16, "16", witWhittle.signer)
err = d.Distribute(context.Background(), "FooLog", "Whittle", writeCP)
err = d.Distribute(ctx, "FooLog", "Whittle", writeCP)
if err != nil {
t.Fatalf("Distribute(): %v", err)
}

readCP, err := d.GetCheckpointWitness(context.Background(), tC.log.Verifier.Name(), tC.wit.verifier.Name())
readCP, err := d.GetCheckpointWitness(ctx, tC.log.Verifier.Name(), tC.wit.verifier.Name())
if (err != nil) != tC.wantErr {
t.Errorf("unexpected error output (wantErr: %t): %v", tC.wantErr, err)
}
Expand Down Expand Up @@ -477,26 +485,28 @@ func TestGetCheckpointN(t *testing.T) {
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqlitedb, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open temporary in-memory DB: %v", err)
}
d := distributor.NewDistributor(ws, ls, sqlitedb)
if err := d.Init(); err != nil {
t.Fatalf("Init(): %v", err)
d, err := distributor.NewDistributor(ws, ls, sqlitedb)
if err != nil {
t.Fatalf("NewDistributor(): %v", err)
}
if err := d.Distribute(context.Background(), "FooLog", "Whittle", logFoo.checkpoint(16, "16", witWhittle.signer)); err != nil {
if err := d.Distribute(ctx, "FooLog", "Whittle", logFoo.checkpoint(16, "16", witWhittle.signer)); err != nil {
t.Fatal(err)
}
if err := d.Distribute(context.Background(), "FooLog", "Waffle", logFoo.checkpoint(14, "14", witWaffle.signer)); err != nil {
if err := d.Distribute(ctx, "FooLog", "Waffle", logFoo.checkpoint(14, "14", witWaffle.signer)); err != nil {
t.Fatal(err)
}

if err := d.Distribute(context.Background(), tC.distLog.Verifier.Name(), tC.distWit.verifier.Name(), tC.distLog.checkpoint(tC.distSize, fmt.Sprintf("%d", tC.distSize), tC.distWit.signer)); err != nil {
if err := d.Distribute(ctx, tC.distLog.Verifier.Name(), tC.distWit.verifier.Name(), tC.distLog.checkpoint(tC.distSize, fmt.Sprintf("%d", tC.distSize), tC.distWit.signer)); err != nil {
t.Fatal(err)
}

cpRaw, err := d.GetCheckpointN(context.Background(), tC.reqLog, tC.reqN)
cpRaw, err := d.GetCheckpointN(ctx, tC.reqLog, tC.reqN)
if (err != nil) != tC.wantErr {
t.Fatalf("unexpected error output (wantErr: %t): %v", tC.wantErr, err)
}
Expand Down

0 comments on commit 10da4f9

Please sign in to comment.