Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Commit

Permalink
Use and test the same MySQL connection that main.go uses (#1370)
Browse files Browse the repository at this point in the history
* Use mysql for tests

Test with the same database we use in prod.

- Each test gets its own database
- Drop test databases after they are done.

* Move sql.Open to a central place
* Set MySQL db flags
* Use testdb in directory storage
  • Loading branch information
gdbelvin committed Oct 25, 2019
1 parent 575344f commit 90e1d58
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 98 deletions.
19 changes: 5 additions & 14 deletions cmd/keytransparency-sequencer/main.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main


import ( import (
"context" "context"
"database/sql"
"flag" "flag"
"fmt" "fmt"
"net" "net"
Expand Down Expand Up @@ -44,10 +43,10 @@ import (
pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto" pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
dir "github.com/google/keytransparency/core/directory" dir "github.com/google/keytransparency/core/directory"
spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto" spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto"
ktsql "github.com/google/keytransparency/impl/sql"
etcdelect "github.com/google/trillian/util/election2/etcd" etcdelect "github.com/google/trillian/util/election2/etcd"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"


_ "github.com/go-sql-driver/mysql" // Set database engine.
_ "github.com/google/trillian/crypto/keys/der/proto" _ "github.com/google/trillian/crypto/keys/der/proto"
_ "github.com/google/trillian/merkle/coniks" // Register hasher _ "github.com/google/trillian/merkle/coniks" // Register hasher
_ "github.com/google/trillian/merkle/rfc6962" // Register hasher _ "github.com/google/trillian/merkle/rfc6962" // Register hasher
Expand All @@ -74,17 +73,6 @@ var (
batchSize = flag.Int("batch-size", 100, "Maximum number of mutations to process per map revision") batchSize = flag.Int("batch-size", 100, "Maximum number of mutations to process per map revision")
) )


func openDB() *sql.DB {
db, err := sql.Open("mysql", *serverDBPath)
if err != nil {
glog.Exitf("sql.Open(): %v", err)
}
if err := db.Ping(); err != nil {
glog.Exitf("db.Ping(): %v", err)
}
return db
}

// getElectionFactory returns an election factory based on flags, and a // getElectionFactory returns an election factory based on flags, and a
// function which releases the resources associated with the factory. // function which releases the resources associated with the factory.
func getElectionFactory() (election2.Factory, func()) { func getElectionFactory() (election2.Factory, func()) {
Expand Down Expand Up @@ -130,7 +118,10 @@ func main() {
} }


// Database tables // Database tables
sqldb := openDB() sqldb, err := ktsql.Open(*serverDBPath)
if err != nil {
glog.Exit(err)
}
defer sqldb.Close() defer sqldb.Close()


mutations, err := mutationstorage.New(sqldb) mutations, err := mutationstorage.New(sqldb)
Expand Down
19 changes: 5 additions & 14 deletions cmd/keytransparency-server/main.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main


import ( import (
"context" "context"
"database/sql"
"flag" "flag"
"net/http" "net/http"


Expand All @@ -37,11 +36,11 @@ import (
"github.com/google/keytransparency/impl/sql/mutationstorage" "github.com/google/keytransparency/impl/sql/mutationstorage"


pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto" pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
ktsql "github.com/google/keytransparency/impl/sql"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"


_ "github.com/go-sql-driver/mysql" // Set database engine.
_ "github.com/google/trillian/crypto/keys/der/proto" _ "github.com/google/trillian/crypto/keys/der/proto"
) )


Expand All @@ -58,23 +57,15 @@ var (
revisionPageSize = flag.Int("revision-page-size", 10, "Max number of revisions to return at once") revisionPageSize = flag.Int("revision-page-size", 10, "Max number of revisions to return at once")
) )


func openDB() *sql.DB {
db, err := sql.Open("mysql", *serverDBPath)
if err != nil {
glog.Exitf("sql.Open(): %v", err)
}
if err := db.Ping(); err != nil {
glog.Exitf("db.Ping(): %v", err)
}
return db
}

func main() { func main() {
flag.Parse() flag.Parse()
ctx := context.Background() ctx := context.Background()


// Open Resources. // Open Resources.
sqldb := openDB() sqldb, err := ktsql.Open(*serverDBPath)
if err != nil {
glog.Exit(err)
}
defer sqldb.Close() defer sqldb.Close()


creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
Expand Down
30 changes: 16 additions & 14 deletions core/integration/storagetest/batch.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ import (
spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto" spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto"
) )


// Batcher writes batch definitions to storage. // batchStorageFactory returns a new database object, and a function for cleaning it up.
type Batcher = sequencer.Batcher type batchStorageFactory func(ctx context.Context, t *testing.T, dirID string) (sequencer.Batcher, func(context.Context))


type BatchStorageFactory func(ctx context.Context, t *testing.T, dirID string) Batcher type BatchStorageTest func(ctx context.Context, t *testing.T, f batchStorageFactory)

type BatchStorageTest func(ctx context.Context, t *testing.T, f BatchStorageFactory)


// RunBatchStorageTests runs all the batch storage tests against the provided map storage implementation. // RunBatchStorageTests runs all the batch storage tests against the provided map storage implementation.
func RunBatchStorageTests(t *testing.T, factory BatchStorageFactory) { func RunBatchStorageTests(t *testing.T, factory batchStorageFactory) {
ctx := context.Background() ctx := context.Background()
b := &BatchTests{} b := &BatchTests{}
for name, f := range map[string]BatchStorageTest{ for name, f := range map[string]BatchStorageTest{
Expand All @@ -52,19 +50,21 @@ func RunBatchStorageTests(t *testing.T, factory BatchStorageFactory) {
// BatchTests is a suite of tests to run against // BatchTests is a suite of tests to run against
type BatchTests struct{} type BatchTests struct{}


func (*BatchTests) TestNotFound(ctx context.Context, t *testing.T, f BatchStorageFactory) { func (*BatchTests) TestNotFound(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "testnotfounddir" domainID := "testnotfounddir"
b := f(ctx, t, domainID) b, done := f(ctx, t, domainID)
defer done(ctx)
_, err := b.ReadBatch(ctx, domainID, 0) _, err := b.ReadBatch(ctx, domainID, 0)
st := status.Convert(err) st := status.Convert(err)
if got, want := st.Code(), codes.NotFound; got != want { if got, want := st.Code(), codes.NotFound; got != want {
t.Errorf("ReadBatch(): %v, want %v", err, want) t.Errorf("ReadBatch(): %v, want %v", err, want)
} }
} }


func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f BatchStorageFactory) { func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "writebatchtest" domainID := "writebatchtest"
b := f(ctx, t, domainID) b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
rev int64 rev int64
wantErr bool wantErr bool
Expand All @@ -86,9 +86,10 @@ func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f BatchStor
} }
} }


func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f BatchStorageFactory) { func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "readbatchtest" domainID := "readbatchtest"
b := f(ctx, t, domainID) b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
rev int64 rev int64
want *spb.MapMetadata want *spb.MapMetadata
Expand All @@ -115,9 +116,10 @@ func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f BatchStora
} }
} }


func (*BatchTests) TestHighestRev(ctx context.Context, t *testing.T, f BatchStorageFactory) { func (*BatchTests) TestHighestRev(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "writebatchtest" domainID := "writebatchtest"
b := f(ctx, t, domainID) b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
rev int64 rev int64
sources []*spb.MapMetadata_SourceSlice sources []*spb.MapMetadata_SourceSlice
Expand Down
12 changes: 7 additions & 5 deletions core/integration/storagetest/mutation_logs.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ import (
pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto" pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
) )


type MutationLogsFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) keyserver.MutationLogs // mutationLogsFactory returns a new database object, and a function for cleaning it up.
type mutationLogsFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) (keyserver.MutationLogs, func(context.Context))


// RunMutationLogsTests runs all the tests against the provided storage implementation. // RunMutationLogsTests runs all the tests against the provided storage implementation.
func RunMutationLogsTests(t *testing.T, factory MutationLogsFactory) { func RunMutationLogsTests(t *testing.T, factory mutationLogsFactory) {
ctx := context.Background() ctx := context.Background()
b := &mutationLogsTests{} b := &mutationLogsTests{}
for name, f := range map[string]func(ctx context.Context, t *testing.T, f MutationLogsFactory){ for name, f := range map[string]func(ctx context.Context, t *testing.T, f mutationLogsFactory){
// TODO(gbelvin): Discover test methods via reflection. // TODO(gbelvin): Discover test methods via reflection.
"TestReadLog": b.TestReadLog, "TestReadLog": b.TestReadLog,
} { } {
Expand All @@ -51,10 +52,11 @@ func mustMarshal(t *testing.T, p proto.Message) []byte {
} }


// TestReadLog ensures that reads happen in atomic units of batch size. // TestReadLog ensures that reads happen in atomic units of batch size.
func (mutationLogsTests) TestReadLog(ctx context.Context, t *testing.T, newForTest MutationLogsFactory) { func (mutationLogsTests) TestReadLog(ctx context.Context, t *testing.T, newForTest mutationLogsFactory) {
directoryID := "TestReadLog" directoryID := "TestReadLog"
logID := int64(5) // Any log ID. logID := int64(5) // Any log ID.
m := newForTest(ctx, t, directoryID, logID) m, done := newForTest(ctx, t, directoryID, logID)
defer done(ctx)
// Write ten batches, three entries each. // Write ten batches, three entries each.
for i := byte(0); i < 10; i++ { for i := byte(0); i < 10; i++ {
entry := &pb.EntryUpdate{Mutation: &pb.SignedEntry{Entry: mustMarshal(t, &pb.Entry{Index: []byte{i}})}} entry := &pb.EntryUpdate{Mutation: &pb.SignedEntry{Entry: mustMarshal(t, &pb.Entry{Index: []byte{i}})}}
Expand Down
17 changes: 10 additions & 7 deletions core/integration/storagetest/mutation_logs_admin.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )


type LogsAdminFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) adminserver.LogsAdmin // logAdminFactory returns a new database object, and a function for cleaning it up.
type logAdminFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) (adminserver.LogsAdmin, func(context.Context))


// RunLogsAdminTests runs all the admin tests against the provided storage implementation. // RunLogsAdminTests runs all the admin tests against the provided storage implementation.
func RunLogsAdminTests(t *testing.T, factory LogsAdminFactory) { func RunLogsAdminTests(t *testing.T, factory logAdminFactory) {
ctx := context.Background() ctx := context.Background()
b := &logsAdminTests{} b := &logsAdminTests{}
for name, f := range map[string]func(ctx context.Context, t *testing.T, f LogsAdminFactory){ for name, f := range map[string]func(ctx context.Context, t *testing.T, f logAdminFactory){
// TODO(gbelvin): Discover test methods via reflection. // TODO(gbelvin): Discover test methods via reflection.
"TestSetWritable": b.TestSetWritable, "TestSetWritable": b.TestSetWritable,
"TestListLogs": b.TestListLogs, "TestListLogs": b.TestListLogs,
Expand All @@ -41,15 +42,16 @@ func RunLogsAdminTests(t *testing.T, factory LogsAdminFactory) {


type logsAdminTests struct{} type logsAdminTests struct{}


func (logsAdminTests) TestSetWritable(ctx context.Context, t *testing.T, f LogsAdminFactory) { func (logsAdminTests) TestSetWritable(ctx context.Context, t *testing.T, f logAdminFactory) {
directoryID := "TestSetWritable" directoryID := "TestSetWritable"
m := f(ctx, t, directoryID, 1) m, done := f(ctx, t, directoryID, 1)
defer done(ctx)
if st := status.Convert(m.SetWritable(ctx, directoryID, 2, true)); st.Code() != codes.NotFound { if st := status.Convert(m.SetWritable(ctx, directoryID, 2, true)); st.Code() != codes.NotFound {
t.Errorf("SetWritable(non-existent logid): %v, want %v", st, codes.NotFound) t.Errorf("SetWritable(non-existent logid): %v, want %v", st, codes.NotFound)
} }
} }


func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f LogsAdminFactory) { func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f logAdminFactory) {
directoryID := "TestListLogs" directoryID := "TestListLogs"
for _, tc := range []struct { for _, tc := range []struct {
desc string desc string
Expand All @@ -64,7 +66,8 @@ func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f LogsAdmi
{desc: "multi", logIDs: []int64{1, 2, 3}, setWritable: map[int64]bool{1: true, 2: false}, wantLogIDs: []int64{1, 3}}, {desc: "multi", logIDs: []int64{1, 2, 3}, setWritable: map[int64]bool{1: true, 2: false}, wantLogIDs: []int64{1, 3}},
} { } {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
m := f(ctx, t, directoryID, tc.logIDs...) m, done := f(ctx, t, directoryID, tc.logIDs...)
defer done(ctx)
wantLogs := make(map[int64]bool) wantLogs := make(map[int64]bool)
for _, logID := range tc.wantLogIDs { for _, logID := range tc.wantLogIDs {
wantLogs[logID] = true wantLogs[logID] = true
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/kylelemons/godebug v1.1.0 github.com/kylelemons/godebug v1.1.0
github.com/lyft/protoc-gen-validate v0.1.0 // indirect github.com/lyft/protoc-gen-validate v0.1.0 // indirect
github.com/mattn/go-isatty v0.0.9 // indirect github.com/mattn/go-isatty v0.0.9 // indirect
github.com/mattn/go-sqlite3 v1.10.0
github.com/mwitkow/go-proto-validators v0.1.0 // indirect github.com/mwitkow/go-proto-validators v0.1.0 // indirect
github.com/prometheus/client_golang v1.1.0 github.com/prometheus/client_golang v1.1.0
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect
Expand Down
1 change: 0 additions & 1 deletion impl/integration/env.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import (


_ "github.com/google/trillian/merkle/coniks" // Register hasher _ "github.com/google/trillian/merkle/coniks" // Register hasher
_ "github.com/google/trillian/merkle/rfc6962" // Register hasher _ "github.com/google/trillian/merkle/rfc6962" // Register hasher
_ "github.com/mattn/go-sqlite3" // Use sqlite database for testing.
) )


var ( var (
Expand Down
31 changes: 13 additions & 18 deletions impl/sql/directory/storage_test.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -16,40 +16,35 @@ package directory


import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"


"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/keytransparency/core/directory" "github.com/google/keytransparency/core/directory"
tpb "github.com/google/trillian" "github.com/google/keytransparency/impl/sql/testdb"
"github.com/google/trillian/crypto/keyspb" "github.com/google/trillian/crypto/keyspb"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"


_ "github.com/mattn/go-sqlite3" tpb "github.com/google/trillian"
) )


func newStorage(t *testing.T) (s directory.Storage, close func()) { func newStorage(ctx context.Context, t *testing.T) (directory.Storage, func(context.Context)) {
t.Helper() t.Helper()
db, err := sql.Open("sqlite3", ":memory:") db, done := testdb.NewForTest(ctx, t)
if err != nil { s, err := NewStorage(db)
t.Fatalf("sql.Open(): %v", err)
}
closeFunc := func() { db.Close() }
s, err = NewStorage(db)
if err != nil { if err != nil {
closeFunc() done(ctx)
t.Fatalf("Failed to create adminstorage: %v", err) t.Fatalf("Failed to create adminstorage: %v", err)
} }
return s, closeFunc return s, done
} }


func TestList(t *testing.T) { func TestList(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s, closeF := newStorage(t) s, done := newStorage(ctx, t)
defer closeF() defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
directories []*directory.Directory directories []*directory.Directory
readDeleted bool readDeleted bool
Expand Down Expand Up @@ -105,8 +100,8 @@ func TestList(t *testing.T) {


func TestWriteReadDelete(t *testing.T) { func TestWriteReadDelete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s, closeF := newStorage(t) s, done := newStorage(ctx, t)
defer closeF() defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
desc string desc string
d directory.Directory d directory.Directory
Expand Down Expand Up @@ -249,8 +244,8 @@ func TestWriteReadDelete(t *testing.T) {


func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s, closeF := newStorage(t) s, done := newStorage(ctx, t)
defer closeF() defer done(ctx)
for _, tc := range []struct { for _, tc := range []struct {
directoryID string directoryID string
}{ }{
Expand Down
Loading

0 comments on commit 90e1d58

Please sign in to comment.