From b105bcbcf799c6d172f1eeafd853a05a25b448bb Mon Sep 17 00:00:00 2001 From: Eric Fritz Date: Tue, 28 Apr 2020 09:04:13 -0500 Subject: [PATCH] Add batch inserter util for SQLite (#10201) --- internal/sqliteutil/batch_inserter.go | 109 ++++++++++++ internal/sqliteutil/batch_inserter_test.go | 182 +++++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 internal/sqliteutil/batch_inserter.go create mode 100644 internal/sqliteutil/batch_inserter_test.go diff --git a/internal/sqliteutil/batch_inserter.go b/internal/sqliteutil/batch_inserter.go new file mode 100644 index 00000000000..20dff7cbec8 --- /dev/null +++ b/internal/sqliteutil/batch_inserter.go @@ -0,0 +1,109 @@ +package sqliteutil + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +// BatchInserter batches insertions to a single column in a SQLite database. +// +// The benchmark tests provided in this package show that 50% more rows can be +// inserted in the same time it takes for them to be inserted individually within +// a transaction. +// +// BenchmarkSQLiteInsertion-8 40417 29440 ns/op +// BenchmarkSQLiteInsertionInTransaction-8 214681 5542 ns/op +// BenchmarkSQLiteInsertionWithBatchInserter-8 324998 3701 ns/op +type BatchInserter struct { + db Execable + numColumns int + maxBatchSize int + batch []interface{} + queryPrefix string + queryPlaceholders []string +} + +// MaxNumSqliteParameters is the number of `?` placeholders that can be sent to SQLite without error. +const MaxNumSqliteParameters = 999 + +// Execable is the minimal common interface over sql.DB and sql.Tx required +// by BatchInserter. +type Execable interface { + // ExecContext executes a query without returning any rows. + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// NewBatchInserter creates a new batch inserter. +func NewBatchInserter(db Execable, tableName string, columnNames ...string) *BatchInserter { + numColumns := len(columnNames) + maxBatchSize := (MaxNumSqliteParameters / numColumns) * numColumns + + placeholders := make([]string, numColumns) + quotedColumnNames := make([]string, numColumns) + for i, columnName := range columnNames { + placeholders[i] = "?" + quotedColumnNames[i] = fmt.Sprintf(`"%s"`, columnName) + } + + queryPrefix := fmt.Sprintf(`INSERT INTO "%s" (%s) VALUES `, tableName, strings.Join(quotedColumnNames, ",")) + + queryPlaceholders := make([]string, maxBatchSize/numColumns) + for i := range queryPlaceholders { + queryPlaceholders[i] = fmt.Sprintf("(%s)", strings.Join(placeholders, ",")) + } + + return &BatchInserter{ + db: db, + numColumns: numColumns, + maxBatchSize: maxBatchSize, + batch: make([]interface{}, 0, maxBatchSize), + queryPrefix: queryPrefix, + queryPlaceholders: queryPlaceholders, + } +} + +// Inserter enqueues the values of a single row for insertion. The given values must match up +// with the columnNames given at construction of the inserter. +func (bi *BatchInserter) Insert(ctx context.Context, values ...interface{}) error { + if len(values) != bi.numColumns { + return fmt.Errorf("expected %d values, got %d", bi.numColumns, len(values)) + } + + bi.batch = append(bi.batch, values...) + + if len(bi.batch) >= bi.maxBatchSize { + // Flush full batch + return bi.Flush(ctx) + } + + return nil +} + +// Flush ensures that all queued rows are inserted. This method must be invoked at the end +// of insertion to ensure that all records are flushed to the underlying Execable. +func (bi *BatchInserter) Flush(ctx context.Context) error { + if batch := bi.pop(); len(batch) > 0 { + // Create a query with enough placeholders to match the current batch size. This should + // generally be the full queryPlaceholders slice, except for the last call to Flush which + // may be a partial batch. + query := bi.queryPrefix + strings.Join(bi.queryPlaceholders[:len(batch)/bi.numColumns], ",") + + if _, err := bi.db.ExecContext(ctx, query, batch...); err != nil { + return err + } + } + + return nil +} + +func (bi *BatchInserter) pop() (batch []interface{}) { + if len(bi.batch) < bi.maxBatchSize { + batch, bi.batch = bi.batch, bi.batch[:0] + return batch + } + + batch, bi.batch = bi.batch[:bi.maxBatchSize], bi.batch[bi.maxBatchSize:] + return batch +} diff --git a/internal/sqliteutil/batch_inserter_test.go b/internal/sqliteutil/batch_inserter_test.go new file mode 100644 index 00000000000..0d35eae9e9d --- /dev/null +++ b/internal/sqliteutil/batch_inserter_test.go @@ -0,0 +1,182 @@ +package sqliteutil + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/jmoiron/sqlx" +) + +func init() { + SetLocalLibpath() + MustRegisterSqlite3WithPcre() +} + +func TestBatchInserter(t *testing.T) { + ctx := context.Background() + + var expectedValues [][]interface{} + for i := 0; i < 1000; i++ { + expectedValues = append(expectedValues, []interface{}{i, i + 1, i + 2, i + 3, i + 4}) + } + + withTestDB(t, func(db *sqlx.DB) error { + inserter := NewBatchInserter(db, "test", "col1", "col2", "col3", "col4", "col5") + for _, values := range expectedValues { + if err := inserter.Insert(ctx, values...); err != nil { + return err + } + } + + if err := inserter.Flush(ctx); err != nil { + return err + } + + rows, err := db.Query("SELECT col1, col2, col3, col4, col5 from test") + if err != nil { + return err + } + defer rows.Close() + + var values [][]interface{} + for rows.Next() { + var v1, v2, v3, v4, v5 int + if err := rows.Scan(&v1, &v2, &v3, &v4, &v5); err != nil { + return err + } + + values = append(values, []interface{}{v1, v2, v3, v4, v5}) + } + + if diff := cmp.Diff(expectedValues, values); diff != "" { + t.Errorf("unexpected table contents (-want +got):\n%s", diff) + } + + return nil + }) +} + +func BenchmarkSQLiteInsertion(b *testing.B) { + var expectedValues [][]interface{} + for i := 0; i < b.N; i++ { + expectedValues = append(expectedValues, []interface{}{i, i + 1, i + 2, i + 3, i + 4}) + } + + withTestDB(b, func(db *sqlx.DB) error { + b.ResetTimer() + + for _, values := range expectedValues { + if _, err := db.Exec("INSERT INTO test (col1, col2, col3, col4, col5) VALUES (?, ?, ?, ?, ?)", values...); err != nil { + return err + } + } + + return nil + }) +} + +func BenchmarkSQLiteInsertionInTransaction(b *testing.B) { + ctx := context.Background() + + var expectedValues [][]interface{} + for i := 0; i < b.N; i++ { + expectedValues = append(expectedValues, []interface{}{i, i + 1, i + 2, i + 3, i + 4}) + } + + withTestDB(b, func(db *sqlx.DB) error { + b.ResetTimer() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + for _, values := range expectedValues { + if _, err := tx.Exec("INSERT INTO test (col1, col2, col3, col4, col5) VALUES (?, ?, ?, ?, ?)", values...); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil + }) +} + +func BenchmarkSQLiteInsertionWithBatchInserter(b *testing.B) { + ctx := context.Background() + + var expectedValues [][]interface{} + for i := 0; i < b.N; i++ { + expectedValues = append(expectedValues, []interface{}{i, i + 1, i + 2, i + 3, i + 4}) + } + + withTestDB(b, func(db *sqlx.DB) error { + b.ResetTimer() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + inserter := NewBatchInserter(tx, "test", "col1", "col2", "col3", "col4", "col5") + for _, values := range expectedValues { + if err := inserter.Insert(ctx, values...); err != nil { + return err + } + } + + if err := inserter.Flush(ctx); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil + }) +} + +func withTestDB(t testing.TB, test func(db *sqlx.DB) error) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("unexpected error creating temp directory: %s", err) + } + defer os.RemoveAll(tempDir) + + db, err := sqlx.Open("sqlite3_with_pcre", filepath.Join(tempDir, "batch.db")) + if err != nil { + t.Fatalf("unexpected error opening database: %s", err) + } + + createTableQuery := ` + CREATE TABLE test ( + id integer primary key not null, + col1 integer not null, + col2 integer not null, + col3 integer not null, + col4 integer not null, + col5 integer not null + ) + ` + _, err1 := db.Exec(createTableQuery) + _, err2 := db.Exec("PRAGMA synchronous = OFF") + _, err3 := db.Exec("PRAGMA journal_mode = OFF") + + for _, err := range []error{err1, err2, err3} { + if err != nil { + t.Fatalf("unexpected error setting up database: %s", err) + } + } + + if err := test(db); err != nil { + t.Fatalf("unexpected error running test: %s", err) + } +}