Skip to content
Permalink
Browse files

Use and test the same MySQL connection that main.go uses (#1370)

* Use mysql for tests

Test with the same database we use in prod.

- Each test gets its own database
- Drop test databases after they are done.

* Move sql.Open to a central place
* Set MySQL db flags
* Use testdb in directory storage
  • Loading branch information
gdbelvin committed Oct 25, 2019
1 parent 575344f commit 90e1d58555c2e91a2f4207ab388c5f99af2abff8
@@ -16,7 +16,6 @@ package main

import (
"context"
"database/sql"
"flag"
"fmt"
"net"
@@ -44,10 +43,10 @@ import (
pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
dir "github.com/google/keytransparency/core/directory"
spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto"
ktsql "github.com/google/keytransparency/impl/sql"
etcdelect "github.com/google/trillian/util/election2/etcd"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"

_ "github.com/go-sql-driver/mysql" // Set database engine.
_ "github.com/google/trillian/crypto/keys/der/proto"
_ "github.com/google/trillian/merkle/coniks" // Register hasher
_ "github.com/google/trillian/merkle/rfc6962" // Register hasher
@@ -74,17 +73,6 @@ var (
batchSize = flag.Int("batch-size", 100, "Maximum number of mutations to process per map revision")
)

func openDB() *sql.DB {
db, err := sql.Open("mysql", *serverDBPath)
if err != nil {
glog.Exitf("sql.Open(): %v", err)
}
if err := db.Ping(); err != nil {
glog.Exitf("db.Ping(): %v", err)
}
return db
}

// getElectionFactory returns an election factory based on flags, and a
// function which releases the resources associated with the factory.
func getElectionFactory() (election2.Factory, func()) {
@@ -130,7 +118,10 @@ func main() {
}

// Database tables
sqldb := openDB()
sqldb, err := ktsql.Open(*serverDBPath)
if err != nil {
glog.Exit(err)
}
defer sqldb.Close()

mutations, err := mutationstorage.New(sqldb)
@@ -16,7 +16,6 @@ package main

import (
"context"
"database/sql"
"flag"
"net/http"

@@ -37,11 +36,11 @@ import (
"github.com/google/keytransparency/impl/sql/mutationstorage"

pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
ktsql "github.com/google/keytransparency/impl/sql"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"

_ "github.com/go-sql-driver/mysql" // Set database engine.
_ "github.com/google/trillian/crypto/keys/der/proto"
)

@@ -58,23 +57,15 @@ var (
revisionPageSize = flag.Int("revision-page-size", 10, "Max number of revisions to return at once")
)

func openDB() *sql.DB {
db, err := sql.Open("mysql", *serverDBPath)
if err != nil {
glog.Exitf("sql.Open(): %v", err)
}
if err := db.Ping(); err != nil {
glog.Exitf("db.Ping(): %v", err)
}
return db
}

func main() {
flag.Parse()
ctx := context.Background()

// Open Resources.
sqldb := openDB()
sqldb, err := ktsql.Open(*serverDBPath)
if err != nil {
glog.Exit(err)
}
defer sqldb.Close()

creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
@@ -27,15 +27,13 @@ import (
spb "github.com/google/keytransparency/core/sequencer/sequencer_go_proto"
)

// Batcher writes batch definitions to storage.
type Batcher = sequencer.Batcher
// batchStorageFactory returns a new database object, and a function for cleaning it up.
type batchStorageFactory func(ctx context.Context, t *testing.T, dirID string) (sequencer.Batcher, func(context.Context))

type BatchStorageFactory func(ctx context.Context, t *testing.T, dirID string) Batcher

type BatchStorageTest func(ctx context.Context, t *testing.T, f BatchStorageFactory)
type BatchStorageTest func(ctx context.Context, t *testing.T, f batchStorageFactory)

// RunBatchStorageTests runs all the batch storage tests against the provided map storage implementation.
func RunBatchStorageTests(t *testing.T, factory BatchStorageFactory) {
func RunBatchStorageTests(t *testing.T, factory batchStorageFactory) {
ctx := context.Background()
b := &BatchTests{}
for name, f := range map[string]BatchStorageTest{
@@ -52,19 +50,21 @@ func RunBatchStorageTests(t *testing.T, factory BatchStorageFactory) {
// BatchTests is a suite of tests to run against
type BatchTests struct{}

func (*BatchTests) TestNotFound(ctx context.Context, t *testing.T, f BatchStorageFactory) {
func (*BatchTests) TestNotFound(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "testnotfounddir"
b := f(ctx, t, domainID)
b, done := f(ctx, t, domainID)
defer done(ctx)
_, err := b.ReadBatch(ctx, domainID, 0)
st := status.Convert(err)
if got, want := st.Code(), codes.NotFound; got != want {
t.Errorf("ReadBatch(): %v, want %v", err, want)
}
}

func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f BatchStorageFactory) {
func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "writebatchtest"
b := f(ctx, t, domainID)
b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct {
rev int64
wantErr bool
@@ -86,9 +86,10 @@ func (*BatchTests) TestWriteBatch(ctx context.Context, t *testing.T, f BatchStor
}
}

func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f BatchStorageFactory) {
func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "readbatchtest"
b := f(ctx, t, domainID)
b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct {
rev int64
want *spb.MapMetadata
@@ -115,9 +116,10 @@ func (*BatchTests) TestReadBatch(ctx context.Context, t *testing.T, f BatchStora
}
}

func (*BatchTests) TestHighestRev(ctx context.Context, t *testing.T, f BatchStorageFactory) {
func (*BatchTests) TestHighestRev(ctx context.Context, t *testing.T, f batchStorageFactory) {
domainID := "writebatchtest"
b := f(ctx, t, domainID)
b, done := f(ctx, t, domainID)
defer done(ctx)
for _, tc := range []struct {
rev int64
sources []*spb.MapMetadata_SourceSlice
@@ -25,13 +25,14 @@ import (
pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto"
)

type MutationLogsFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) keyserver.MutationLogs
// mutationLogsFactory returns a new database object, and a function for cleaning it up.
type mutationLogsFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) (keyserver.MutationLogs, func(context.Context))

// RunMutationLogsTests runs all the tests against the provided storage implementation.
func RunMutationLogsTests(t *testing.T, factory MutationLogsFactory) {
func RunMutationLogsTests(t *testing.T, factory mutationLogsFactory) {
ctx := context.Background()
b := &mutationLogsTests{}
for name, f := range map[string]func(ctx context.Context, t *testing.T, f MutationLogsFactory){
for name, f := range map[string]func(ctx context.Context, t *testing.T, f mutationLogsFactory){
// TODO(gbelvin): Discover test methods via reflection.
"TestReadLog": b.TestReadLog,
} {
@@ -51,10 +52,11 @@ func mustMarshal(t *testing.T, p proto.Message) []byte {
}

// TestReadLog ensures that reads happen in atomic units of batch size.
func (mutationLogsTests) TestReadLog(ctx context.Context, t *testing.T, newForTest MutationLogsFactory) {
func (mutationLogsTests) TestReadLog(ctx context.Context, t *testing.T, newForTest mutationLogsFactory) {
directoryID := "TestReadLog"
logID := int64(5) // Any log ID.
m := newForTest(ctx, t, directoryID, logID)
m, done := newForTest(ctx, t, directoryID, logID)
defer done(ctx)
// Write ten batches, three entries each.
for i := byte(0); i < 10; i++ {
entry := &pb.EntryUpdate{Mutation: &pb.SignedEntry{Entry: mustMarshal(t, &pb.Entry{Index: []byte{i}})}}
@@ -24,13 +24,14 @@ import (
"google.golang.org/grpc/status"
)

type LogsAdminFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) adminserver.LogsAdmin
// logAdminFactory returns a new database object, and a function for cleaning it up.
type logAdminFactory func(ctx context.Context, t *testing.T, dirID string, logIDs ...int64) (adminserver.LogsAdmin, func(context.Context))

// RunLogsAdminTests runs all the admin tests against the provided storage implementation.
func RunLogsAdminTests(t *testing.T, factory LogsAdminFactory) {
func RunLogsAdminTests(t *testing.T, factory logAdminFactory) {
ctx := context.Background()
b := &logsAdminTests{}
for name, f := range map[string]func(ctx context.Context, t *testing.T, f LogsAdminFactory){
for name, f := range map[string]func(ctx context.Context, t *testing.T, f logAdminFactory){
// TODO(gbelvin): Discover test methods via reflection.
"TestSetWritable": b.TestSetWritable,
"TestListLogs": b.TestListLogs,
@@ -41,15 +42,16 @@ func RunLogsAdminTests(t *testing.T, factory LogsAdminFactory) {

type logsAdminTests struct{}

func (logsAdminTests) TestSetWritable(ctx context.Context, t *testing.T, f LogsAdminFactory) {
func (logsAdminTests) TestSetWritable(ctx context.Context, t *testing.T, f logAdminFactory) {
directoryID := "TestSetWritable"
m := f(ctx, t, directoryID, 1)
m, done := f(ctx, t, directoryID, 1)
defer done(ctx)
if st := status.Convert(m.SetWritable(ctx, directoryID, 2, true)); st.Code() != codes.NotFound {
t.Errorf("SetWritable(non-existent logid): %v, want %v", st, codes.NotFound)
}
}

func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f LogsAdminFactory) {
func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f logAdminFactory) {
directoryID := "TestListLogs"
for _, tc := range []struct {
desc string
@@ -64,7 +66,8 @@ func (logsAdminTests) TestListLogs(ctx context.Context, t *testing.T, f LogsAdmi
{desc: "multi", logIDs: []int64{1, 2, 3}, setWritable: map[int64]bool{1: true, 2: false}, wantLogIDs: []int64{1, 3}},
} {
t.Run(tc.desc, func(t *testing.T) {
m := f(ctx, t, directoryID, tc.logIDs...)
m, done := f(ctx, t, directoryID, tc.logIDs...)
defer done(ctx)
wantLogs := make(map[int64]bool)
for _, logID := range tc.wantLogIDs {
wantLogs[logID] = true
1 go.mod
@@ -28,7 +28,6 @@ require (
github.com/kylelemons/godebug v1.1.0
github.com/lyft/protoc-gen-validate v0.1.0 // indirect
github.com/mattn/go-isatty v0.0.9 // indirect
github.com/mattn/go-sqlite3 v1.10.0
github.com/mwitkow/go-proto-validators v0.1.0 // indirect
github.com/prometheus/client_golang v1.1.0
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect
@@ -53,7 +53,6 @@ import (

_ "github.com/google/trillian/merkle/coniks" // Register hasher
_ "github.com/google/trillian/merkle/rfc6962" // Register hasher
_ "github.com/mattn/go-sqlite3" // Use sqlite database for testing.
)

var (
@@ -16,40 +16,35 @@ package directory

import (
"context"
"database/sql"
"testing"
"time"

"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"github.com/google/keytransparency/core/directory"
tpb "github.com/google/trillian"
"github.com/google/keytransparency/impl/sql/testdb"
"github.com/google/trillian/crypto/keyspb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

_ "github.com/mattn/go-sqlite3"
tpb "github.com/google/trillian"
)

func newStorage(t *testing.T) (s directory.Storage, close func()) {
func newStorage(ctx context.Context, t *testing.T) (directory.Storage, func(context.Context)) {
t.Helper()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("sql.Open(): %v", err)
}
closeFunc := func() { db.Close() }
s, err = NewStorage(db)
db, done := testdb.NewForTest(ctx, t)
s, err := NewStorage(db)
if err != nil {
closeFunc()
done(ctx)
t.Fatalf("Failed to create adminstorage: %v", err)
}
return s, closeFunc
return s, done
}

func TestList(t *testing.T) {
ctx := context.Background()
s, closeF := newStorage(t)
defer closeF()
s, done := newStorage(ctx, t)
defer done(ctx)
for _, tc := range []struct {
directories []*directory.Directory
readDeleted bool
@@ -105,8 +100,8 @@ func TestList(t *testing.T) {

func TestWriteReadDelete(t *testing.T) {
ctx := context.Background()
s, closeF := newStorage(t)
defer closeF()
s, done := newStorage(ctx, t)
defer done(ctx)
for _, tc := range []struct {
desc string
d directory.Directory
@@ -249,8 +244,8 @@ func TestWriteReadDelete(t *testing.T) {

func TestDelete(t *testing.T) {
ctx := context.Background()
s, closeF := newStorage(t)
defer closeF()
s, done := newStorage(ctx, t)
defer done(ctx)
for _, tc := range []struct {
directoryID string
}{

0 comments on commit 90e1d58

Please sign in to comment.
You can’t perform that action at this time.