From baa55e6e132bac2e05426aab6e00ecdb14d8d65a Mon Sep 17 00:00:00 2001 From: dnovitski <54758025+dnovitski@users.noreply.github.com> Date: Fri, 22 May 2026 04:26:01 +0200 Subject: [PATCH] feat: merge-DML batching optimization for binlog apply Add --is-merge-dml-event flag that batches and deduplicates binlog DML events before applying them to the ghost table, significantly reducing SQL round-trips during high-write migrations. When enabled and the unique key is memory-comparable (numeric columns): - Deduplicates DML events by unique key (latest event wins) - Reduces INSERT+DELETE sequences to DELETE (safe against row-copy races) - Batches INSERTs/UPDATEs as multi-row REPLACE INTO - Batches DELETEs as DELETE WHERE (pk) IN (...) - Skips events beyond migration range (not yet copied by row-copy) - Disables merge for tables with secondary unique indexes Safety: strict numeric type validation in formatNumericValue prevents SQL injection. Type detection uses exact base-type parsing (not substring). Uses BuildColumnsPreparedValues for proper per-column conversion tokens. Original implementation by shaohoukun in PR #1378, adapted to current master's builder-pattern API with correctness and security hardening. Co-authored-by: shaohk Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/command-line-flags.md | 15 ++ go/base/context.go | 2 + go/cmd/gh-ost/main.go | 1 + go/logic/applier.go | 344 +++++++++++++++++++++++++++++------ go/logic/applier_test.go | 370 ++++++++++++++++++++++++++++++++++++++ go/logic/inspect.go | 126 ++++++++++++- go/logic/migrator.go | 6 +- go/sql/builder.go | 14 +- go/sql/types.go | 30 ++++ 9 files changed, 843 insertions(+), 65 deletions(-) diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index 2012c8c4c..ae5e3e424 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -158,6 +158,21 @@ While the ongoing estimated number of rows is still heuristic, it's almost exact Without this parameter, migration is a _noop_: testing table creation and validity of migration, but not touching data. +### is-merge-dml-event + +When enabled, batched binlog DML events are merged in memory before applying them to the ghost table. Only effective when the migration unique key uses numeric column types (`int`, `bigint`, `decimal`, `float`, etc.). + +**Batching:** All DML events in a batch are grouped by type — inserts and updates become a single multi-row `REPLACE INTO`, deletes become a single `DELETE WHERE (pk) IN (...)`. + +**Deduplication:** Repeated changes to the same unique key within a batch collapse to the final state (last writer wins). + +**Range filtering:** When a binlog event's unique key value is beyond `MigrationIterationRangeMaxValues` but within `MigrationRangeMaxValues`, the event is skipped — that data will be synced by the row-copy chunk. Events beyond `MigrationRangeMaxValues` or below `MigrationIterationRangeMaxValues` are applied normally. + +Automatically disabled when: +- The chosen unique key contains non-numeric columns (TEXT, BLOB, JSON, etc.) +- The chosen unique key has nullable columns (NULL breaks comparison and dedup semantics) +- The table has multiple unique indexes (REPLACE semantics are unsafe with secondary unique constraints) + ### force-named-cut-over If given, a `cut-over` command must name the migrated table, or else ignored. diff --git a/go/base/context.go b/go/base/context.go index 617e5bb13..d765bd239 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -211,7 +211,9 @@ type MigrationContext struct { controlReplicasLagResult mysql.ReplicationLagResult TotalRowsCopied int64 TotalDMLEventsApplied int64 + TotalDMLEventsIgnored int64 DMLBatchSize int64 + IsMergeDMLEvents bool isThrottled bool throttleReason string throttleReasonHint ThrottleReasonHint diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 567137fd5..eec62cf3d 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -108,6 +108,7 @@ func main() { exponentialBackoffMaxInterval := flag.Int64("exponential-backoff-max-interval", 64, "Maximum number of seconds to wait between attempts when performing various operations with exponential backoff.") chunkSize := flag.Int64("chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 10-100,000)") dmlBatchSize := flag.Int64("dml-batch-size", 10, "batch size for DML events to apply in a single transaction (range 1-1000)") + flag.BoolVar(&migrationContext.IsMergeDMLEvents, "is-merge-dml-event", false, "Merge DML Binlog Event") defaultRetries := flag.Int64("default-retries", 60, "Default number of retries for various operations before panicking") flag.BoolVar(&migrationContext.PanicOnWarnings, "panic-on-warnings", false, "Panic when SQL warnings are encountered when copying a batch indicating data loss") cutOverLockTimeoutSeconds := flag.Int64("cut-over-lock-timeout-seconds", 3, "Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout) or attempting instant DDL") diff --git a/go/logic/applier.go b/go/logic/applier.go index b49e131b8..add78f33d 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -38,18 +38,24 @@ const ( var ErrNoCheckpointFound = errors.New("no checkpoint found in _ghk table") type dmlBuildResult struct { - query string - args []interface{} - rowsDelta int64 - err error + dml binlog.EventDML + query string + args []interface{} + uniqueKeyValues []interface{} + sharedColumnArgs []interface{} + rowsDelta int64 + err error } -func newDmlBuildResult(query string, args []interface{}, rowsDelta int64, err error) *dmlBuildResult { +func newDmlBuildResult(dml binlog.EventDML, query string, args []interface{}, uniqueKeyValues []interface{}, rowsDelta int64, sharedColumnArgs []interface{}, err error) *dmlBuildResult { return &dmlBuildResult{ - query: query, - args: args, - rowsDelta: rowsDelta, - err: err, + dml: dml, + query: query, + args: args, + uniqueKeyValues: uniqueKeyValues, + sharedColumnArgs: sharedColumnArgs, + rowsDelta: rowsDelta, + err: err, } } @@ -1492,39 +1498,260 @@ func (apl *Applier) updateModifiesUniqueKeyColumns(dmlEvent *binlog.BinlogDMLEve return "", false } +func (apl *Applier) extractUniqueKeyArgs(args []interface{}) []interface{} { + uniqueKeyArgs := make([]interface{}, 0, apl.migrationContext.UniqueKey.Columns.Len()) + for _, column := range apl.migrationContext.UniqueKey.Columns.Columns() { + tableOrdinal := apl.migrationContext.OriginalTableColumns.Ordinals[column.Name] + arg := column.ConvertArg(args[tableOrdinal]) + uniqueKeyArgs = append(uniqueKeyArgs, arg) + } + return uniqueKeyArgs +} + +func (apl *Applier) extractSharedColumnArgs(args []interface{}) []interface{} { + sharedArgs := make([]interface{}, 0, apl.migrationContext.SharedColumns.Len()) + for _, column := range apl.migrationContext.SharedColumns.Columns() { + tableOrdinal := apl.migrationContext.OriginalTableColumns.Ordinals[column.Name] + arg := column.ConvertArg(args[tableOrdinal]) + sharedArgs = append(sharedArgs, arg) + } + return sharedArgs +} + // buildDMLEventQuery creates a query to operate on the ghost table, based on an intercepted binlog // event entry on the original table. func (apl *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlBuildResult { switch dmlEvent.DML { case binlog.DeleteDML: - { - query, uniqueKeyArgs, err := apl.dmlDeleteQueryBuilder.BuildQuery(dmlEvent.WhereColumnValues.AbstractValues()) - return []*dmlBuildResult{newDmlBuildResult(query, uniqueKeyArgs, -1, err)} - } + query, uniqueKeyArgs, err := apl.dmlDeleteQueryBuilder.BuildQuery(dmlEvent.WhereColumnValues.AbstractValues()) + return []*dmlBuildResult{newDmlBuildResult(binlog.DeleteDML, query, uniqueKeyArgs, uniqueKeyArgs, -1, nil, err)} case binlog.InsertDML: - { - query, sharedArgs, err := apl.dmlInsertQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues()) - return []*dmlBuildResult{newDmlBuildResult(query, sharedArgs, 1, err)} - } + query, sharedArgs, err := apl.dmlInsertQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues()) + uniqueKeyArgs := apl.extractUniqueKeyArgs(dmlEvent.NewColumnValues.AbstractValues()) + return []*dmlBuildResult{newDmlBuildResult(binlog.InsertDML, query, sharedArgs, uniqueKeyArgs, 1, sharedArgs, err)} case binlog.UpdateDML: - { - if _, isModified := apl.updateModifiesUniqueKeyColumns(dmlEvent); isModified { - results := make([]*dmlBuildResult, 0, 2) - dmlEvent.DML = binlog.DeleteDML - results = append(results, apl.buildDMLEventQuery(dmlEvent)...) - dmlEvent.DML = binlog.InsertDML - results = append(results, apl.buildDMLEventQuery(dmlEvent)...) - return results - } - query, updateArgs, err := apl.dmlUpdateQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) - args := sqlutils.Args() - args = append(args, updateArgs...) - return []*dmlBuildResult{newDmlBuildResult(query, args, 0, err)} + if _, isModified := apl.updateModifiesUniqueKeyColumns(dmlEvent); isModified { + results := make([]*dmlBuildResult, 0, 2) + originalDML := dmlEvent.DML + dmlEvent.DML = binlog.DeleteDML + results = append(results, apl.buildDMLEventQuery(dmlEvent)...) + dmlEvent.DML = binlog.InsertDML + results = append(results, apl.buildDMLEventQuery(dmlEvent)...) + dmlEvent.DML = originalDML + return results } + query, updateArgs, err := apl.dmlUpdateQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + sharedArgs := apl.extractSharedColumnArgs(dmlEvent.NewColumnValues.AbstractValues()) + uniqueKeyArgs := apl.extractUniqueKeyArgs(dmlEvent.WhereColumnValues.AbstractValues()) + args := sqlutils.Args() + args = append(args, updateArgs...) + return []*dmlBuildResult{newDmlBuildResult(binlog.UpdateDML, query, args, uniqueKeyArgs, 0, sharedArgs, err)} } return []*dmlBuildResult{newDmlBuildResultError(fmt.Errorf("unknown dml event type: %+v", dmlEvent.DML))} } +func (apl *Applier) generateBatchedDeleteQuery(uniqueKeyValuesList [][]string) string { + if len(uniqueKeyValuesList) == 0 { + return "" + } + + databaseName := sql.EscapeName(apl.migrationContext.DatabaseName) + tableName := sql.EscapeName(apl.migrationContext.GetGhostTableName()) + uniqueKeyColumnNames := apl.migrationContext.UniqueKey.Columns.Names() + for i := range uniqueKeyColumnNames { + uniqueKeyColumnNames[i] = sql.EscapeName(uniqueKeyColumnNames[i]) + } + + valueClauses := make([]string, 0, len(uniqueKeyValuesList)) + for _, uniqueKeyValues := range uniqueKeyValuesList { + if len(uniqueKeyValues) == 0 { + continue + } + if len(uniqueKeyColumnNames) == 1 { + valueClauses = append(valueClauses, uniqueKeyValues[0]) + continue + } + valueClauses = append(valueClauses, fmt.Sprintf("(%s)", strings.Join(uniqueKeyValues, ", "))) + } + if len(valueClauses) == 0 { + return "" + } + + return fmt.Sprintf(`delete /* gh-ost %s.%s */ from %s.%s where (%s) in (%s)`, + databaseName, tableName, + databaseName, tableName, + strings.Join(uniqueKeyColumnNames, ", "), + strings.Join(valueClauses, ", "), + ) +} + +func (apl *Applier) generateBatchedReplaceQuery(sharedColumnArgsList [][]interface{}) (string, []interface{}) { + if len(sharedColumnArgsList) == 0 { + return "", nil + } + + databaseName := sql.EscapeName(apl.migrationContext.DatabaseName) + tableName := sql.EscapeName(apl.migrationContext.GetGhostTableName()) + mappedSharedColumnNames := apl.migrationContext.MappedSharedColumns.Names() + for i := range mappedSharedColumnNames { + mappedSharedColumnNames[i] = sql.EscapeName(mappedSharedColumnNames[i]) + } + preparedValues := sql.BuildColumnsPreparedValues(apl.migrationContext.MappedSharedColumns) + colCount := len(preparedValues) + singleRowClause := "(" + strings.Join(preparedValues, ", ") + ")" + + valuesClauses := make([]string, 0, len(sharedColumnArgsList)) + allArgs := make([]interface{}, 0, len(sharedColumnArgsList)*colCount) + for _, rowArgs := range sharedColumnArgsList { + valuesClauses = append(valuesClauses, singleRowClause) + allArgs = append(allArgs, rowArgs...) + } + + query := fmt.Sprintf(`replace /* gh-ost %s.%s */ into %s.%s (%s) values %s`, + databaseName, tableName, + databaseName, tableName, + strings.Join(mappedSharedColumnNames, ", "), + strings.Join(valuesClauses, ", "), + ) + return query, allArgs +} + +func (apl *Applier) isIgnoreOverMaxChunkRangeEvent(uniqueKeyArgs []interface{}) (bool, error) { + if apl.migrationContext.MigrationRangeMaxValues == nil { + return false, nil + } + for order, uniqueKeyCol := range apl.migrationContext.UniqueKey.Columns.Columns() { + if uniqueKeyCol.CompareValueFunc == nil { + return false, nil + } + cmp, err := uniqueKeyCol.CompareValueFunc(uniqueKeyArgs[order], apl.migrationContext.MigrationRangeMaxValues.StringColumn(order)) + if err != nil { + return false, err + } + if cmp > 0 { + return true, nil + } + if cmp < 0 { + return false, nil + } + } + if apl.migrationContext.MigrationIterationRangeMaxValues == nil { + return false, nil + } + for order, uniqueKeyCol := range apl.migrationContext.UniqueKey.Columns.Columns() { + if uniqueKeyCol.CompareValueFunc == nil { + return false, nil + } + cmp, err := uniqueKeyCol.CompareValueFunc(uniqueKeyArgs[order], apl.migrationContext.MigrationIterationRangeMaxValues.StringColumn(order)) + if err != nil { + return false, err + } + if cmp > 0 { + return true, nil + } + if cmp < 0 { + return false, nil + } + } + return false, nil +} + +//nolint:unparam // ignored count kept for API symmetry with buildDMLEventQueriesMerged +func (apl *Applier) buildDMLEventQueriesAll(dmlEvents [](*binlog.BinlogDMLEvent)) ([]*dmlBuildResult, int64, int64, error) { + buildResults := make([]*dmlBuildResult, 0, len(dmlEvents)) + + for _, dmlEvent := range dmlEvents { + for _, buildResult := range apl.buildDMLEventQuery(dmlEvent) { + if buildResult.err != nil { + return nil, 0, 0, buildResult.err + } + buildResults = append(buildResults, buildResult) + } + } + + return buildResults, int64(len(dmlEvents)), 0, nil +} + +func (apl *Applier) buildDMLEventQueriesMerged(dmlEvents [](*binlog.BinlogDMLEvent)) ([]*dmlBuildResult, int64, int64, error) { + type mergedEntry struct { + result *dmlBuildResult + formattedKeyValues []string + } + + dmlMap := make(map[string]*mergedEntry) + const keySeparator = "#gho#" + var appliedEvents int64 + var ignoredEvents int64 + + for _, dmlEvent := range dmlEvents { + results := apl.buildDMLEventQuery(dmlEvent) + ignored := false + for _, buildResult := range results { + if buildResult.err != nil { + return nil, 0, 0, buildResult.err + } + if skip, err := apl.isIgnoreOverMaxChunkRangeEvent(buildResult.uniqueKeyValues); err != nil { + return nil, 0, 0, err + } else if skip { + ignored = true + break + } + } + if ignored { + ignoredEvents++ + continue + } + appliedEvents++ + for _, buildResult := range results { + formattedValues, err := apl.migrationContext.UniqueKey.FormatValues(buildResult.uniqueKeyValues) + if err != nil { + return nil, 0, 0, err + } + mapKey := strings.Join(formattedValues, keySeparator) + if existing, ok := dmlMap[mapKey]; ok && existing != nil && existing.result != nil && existing.result.dml == binlog.InsertDML && buildResult.dml == binlog.DeleteDML { + // Row was INSERT'd then DELETE'd in this batch. If row-copy already + // copied this row to ghost, the ghost has a stale copy. Emit DELETE + // to ensure ghost converges to the correct state. + dmlMap[mapKey] = &mergedEntry{result: buildResult, formattedKeyValues: formattedValues} + continue + } + dmlMap[mapKey] = &mergedEntry{result: buildResult, formattedKeyValues: formattedValues} + } + } + + deleteValuesList := make([][]string, 0) + insertArgsList := make([][]interface{}, 0) + updateArgsList := make([][]interface{}, 0) + for _, entry := range dmlMap { + if entry == nil || entry.result == nil { + continue + } + switch entry.result.dml { + case binlog.DeleteDML: + deleteValuesList = append(deleteValuesList, entry.formattedKeyValues) + case binlog.InsertDML: + insertArgsList = append(insertArgsList, entry.result.sharedColumnArgs) + case binlog.UpdateDML: + updateArgsList = append(updateArgsList, entry.result.sharedColumnArgs) + default: + return nil, 0, 0, fmt.Errorf("unsupported dml event %s", entry.result.dml) + } + } + + buildResults := make([]*dmlBuildResult, 0, 3) + if query := apl.generateBatchedDeleteQuery(deleteValuesList); query != "" { + buildResults = append(buildResults, newDmlBuildResult(binlog.DeleteDML, query, nil, nil, -1, nil, nil)) + } + if query, args := apl.generateBatchedReplaceQuery(insertArgsList); query != "" { + buildResults = append(buildResults, newDmlBuildResult(binlog.InsertDML, query, args, nil, 1, nil, nil)) + } + if query, args := apl.generateBatchedReplaceQuery(updateArgsList); query != "" { + buildResults = append(buildResults, newDmlBuildResult(binlog.UpdateDML, query, args, nil, 0, nil, nil)) + } + + return buildResults, appliedEvents, ignoredEvents, nil +} + // executeBatchWithWarningChecking executes a batch of DML statements with SHOW WARNINGS // interleaved after each statement to detect warnings from any statement in the batch. // This is used when PanicOnWarnings is enabled to ensure warnings from middle statements @@ -1626,9 +1853,10 @@ func (apl *Applier) executeBatchWithWarningChecking(ctx context.Context, tx *gos return totalDelta, nil } -// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table -func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { +func (apl *Applier) applyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent), merged bool) error { var totalDelta int64 + var appliedEvents int64 + var ignoredEvents int64 ctx := context.Background() err := func() error { @@ -1653,16 +1881,25 @@ func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) e return err } - buildResults := make([]*dmlBuildResult, 0, len(dmlEvents)) - nArgs := 0 - for _, dmlEvent := range dmlEvents { - for _, buildResult := range apl.buildDMLEventQuery(dmlEvent) { - if buildResult.err != nil { - return rollback(buildResult.err) - } - nArgs += len(buildResult.args) - buildResults = append(buildResults, buildResult) + var buildResults []*dmlBuildResult + if merged { + buildResults, appliedEvents, ignoredEvents, err = apl.buildDMLEventQueriesMerged(dmlEvents) + } else { + buildResults, appliedEvents, ignoredEvents, err = apl.buildDMLEventQueriesAll(dmlEvents) + } + if err != nil { + return rollback(err) + } + if len(buildResults) == 0 { + if err := tx.Commit(); err != nil { + return err } + return nil + } + + nArgs := 0 + for _, buildResult := range buildResults { + nArgs += len(buildResult.args) } // When PanicOnWarnings is enabled, we need to check warnings after each statement @@ -1685,7 +1922,9 @@ func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) e for _, buildResult := range buildResults { for _, arg := range buildResult.args { nv := driver.NamedValue{Value: driver.Value(arg)} - nvc.CheckNamedValue(&nv) + if err := nvc.CheckNamedValue(&nv); err != nil { + return err + } multiArgs = append(multiArgs, nv) } @@ -1695,14 +1934,10 @@ func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) e res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) if err != nil { - err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) - return err + return fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) } mysqlRes := res.(drivermysql.Result) - - // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). - // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event for i, rowsAffected := range mysqlRes.AllRowsAffected() { totalDelta += buildResults[i].rowsDelta * rowsAffected } @@ -1723,15 +1958,24 @@ func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) e if err != nil { return apl.migrationContext.Log.Errore(err) } - // no error - atomic.AddInt64(&apl.migrationContext.TotalDMLEventsApplied, int64(len(dmlEvents))) + atomic.AddInt64(&apl.migrationContext.TotalDMLEventsApplied, appliedEvents) + atomic.AddInt64(&apl.migrationContext.TotalDMLEventsIgnored, ignoredEvents) if apl.migrationContext.CountTableRows { atomic.AddInt64(&apl.migrationContext.RowsDeltaEstimate, totalDelta) } - apl.migrationContext.Log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", len(dmlEvents)) + apl.migrationContext.Log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", appliedEvents) return nil } +// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table +func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { + return apl.applyDMLEventQueries(dmlEvents, false) +} + +func (apl *Applier) ApplyDMLEventQueriesMerged(dmlEvents [](*binlog.BinlogDMLEvent)) error { + return apl.applyDMLEventQueries(dmlEvents, true) +} + func (apl *Applier) Teardown() { apl.migrationContext.Log.Debugf("Tearing down...") apl.db.Close() diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 6d7ba42f4..e4bf6b6fd 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -1548,3 +1548,373 @@ func TestApplier(t *testing.T) { } suite.Run(t, new(ApplierTestSuite)) } + +func TestApplierGenerateBatchedReplaceQuery(t *testing.T) { + columns := sql.NewColumnList([]string{"id", "name", "value"}) + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "test" + migrationContext.OriginalTableColumns = columns + migrationContext.SharedColumns = columns + migrationContext.MappedSharedColumns = columns + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: t.Name(), + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + + t.Run("empty", func(t *testing.T) { + query, args := applier.generateBatchedReplaceQuery(nil) + require.Empty(t, query) + require.Nil(t, args) + }) + + t.Run("single row", func(t *testing.T) { + query, args := applier.generateBatchedReplaceQuery([][]interface{}{{1, "hello", 42}}) + require.Contains(t, query, "replace") + require.Contains(t, query, "`test`.`_test_gho`") + require.Contains(t, query, "(?, ?, ?)") + require.Len(t, args, 3) + require.Equal(t, 1, args[0]) + require.Equal(t, "hello", args[1]) + require.Equal(t, 42, args[2]) + }) + + t.Run("multiple rows", func(t *testing.T) { + query, args := applier.generateBatchedReplaceQuery([][]interface{}{ + {1, "a", 10}, + {2, "b", 20}, + {3, "c", 30}, + }) + require.Contains(t, query, "replace") + require.Contains(t, query, "(?, ?, ?), (?, ?, ?), (?, ?, ?)") + require.Len(t, args, 9) + }) +} + +func TestApplierGenerateBatchedDeleteQuery(t *testing.T) { + columns := sql.NewColumnList([]string{"id", "name"}) + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "test" + migrationContext.OriginalTableColumns = columns + migrationContext.SharedColumns = columns + migrationContext.MappedSharedColumns = columns + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: t.Name(), + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + + t.Run("empty", func(t *testing.T) { + query := applier.generateBatchedDeleteQuery(nil) + require.Empty(t, query) + }) + + t.Run("single key single value", func(t *testing.T) { + query := applier.generateBatchedDeleteQuery([][]string{{"123"}}) + require.Contains(t, query, "delete") + require.Contains(t, query, "`test`.`_test_gho`") + require.Contains(t, query, "(`id`) in (123)") + }) + + t.Run("single key multiple values", func(t *testing.T) { + query := applier.generateBatchedDeleteQuery([][]string{{"1"}, {"2"}, {"3"}}) + require.Contains(t, query, "(`id`) in (1, 2, 3)") + }) + + t.Run("composite key", func(t *testing.T) { + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: t.Name(), + Columns: *sql.NewColumnList([]string{"id", "tenant_id"}), + } + query := applier.generateBatchedDeleteQuery([][]string{{"1", "10"}, {"2", "20"}}) + require.Contains(t, query, "(`id`, `tenant_id`) in ((1, 10), (2, 20))") + }) +} + +func TestApplierIsIgnoreOverMaxChunkRangeEvent(t *testing.T) { + columns := sql.NewColumnList([]string{"id", "name"}) + columns.GetColumn("id").CompareValueFunc = func(a interface{}, b interface{}) (int, error) { + // Simple int comparison for testing + ai := a.(int) + bi, _ := fmt.Sscanf(fmt.Sprintf("%v", b), "%d", new(int)) + _ = bi + var bval int + fmt.Sscanf(fmt.Sprintf("%v", b), "%d", &bval) + if ai > bval { + return 1, nil + } + if ai < bval { + return -1, nil + } + return 0, nil + } + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "test" + migrationContext.OriginalTableColumns = columns + migrationContext.SharedColumns = columns + migrationContext.MappedSharedColumns = columns + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: t.Name(), + Columns: *sql.NewColumnList([]string{"id"}), + } + migrationContext.UniqueKey.Columns.GetColumn("id").CompareValueFunc = columns.GetColumn("id").CompareValueFunc + + applier := NewApplier(migrationContext) + + t.Run("nil range max - never ignore", func(t *testing.T) { + migrationContext.MigrationRangeMaxValues = nil + ignore, err := applier.isIgnoreOverMaxChunkRangeEvent([]interface{}{999}) + require.NoError(t, err) + require.False(t, ignore) + }) + + t.Run("within range - not ignored", func(t *testing.T) { + migrationContext.MigrationRangeMaxValues = sql.ToColumnValues([]interface{}{100}) + ignore, err := applier.isIgnoreOverMaxChunkRangeEvent([]interface{}{50}) + require.NoError(t, err) + require.False(t, ignore) + }) + + t.Run("beyond range - ignored", func(t *testing.T) { + migrationContext.MigrationRangeMaxValues = sql.ToColumnValues([]interface{}{100}) + migrationContext.MigrationIterationRangeMaxValues = sql.ToColumnValues([]interface{}{100}) + ignore, err := applier.isIgnoreOverMaxChunkRangeEvent([]interface{}{200}) + require.NoError(t, err) + require.True(t, ignore) + }) +} + +func TestApplierBuildDMLEventQueriesMerged(t *testing.T) { + columns := sql.NewColumnList([]string{"id", "name"}) + // Set up FormatValueFunc for numeric key + columns.GetColumn("id").FormatValueFunc = func(arg interface{}) (string, error) { + return fmt.Sprintf("%v", arg), nil + } + columns.GetColumn("id").CompareValueFunc = compareIntegralValues + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "test" + migrationContext.OriginalTableColumns = columns + migrationContext.SharedColumns = columns + migrationContext.MappedSharedColumns = columns + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: t.Name(), + Columns: *sql.NewColumnList([]string{"id"}), + IsMemoryComparable: true, + } + migrationContext.UniqueKey.Columns.GetColumn("id").FormatValueFunc = columns.GetColumn("id").FormatValueFunc + migrationContext.UniqueKey.Columns.GetColumn("id").CompareValueFunc = columns.GetColumn("id").CompareValueFunc + + applier := NewApplier(migrationContext) + applier.prepareQueries() + + t.Run("INSERT then DELETE same key emits DELETE", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "alice"})}, + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "alice"})}, + } + results, applied, ignored, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(2), applied) + require.Equal(t, int64(0), ignored) + // Should emit a batched DELETE, not cancel + require.Len(t, results, 1) + require.Equal(t, binlog.DeleteDML, results[0].dml) + }) + + t.Run("DELETE then INSERT same key emits REPLACE", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "alice"})}, + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "bob"})}, + } + results, applied, ignored, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(2), applied) + require.Equal(t, int64(0), ignored) + // Last write wins: INSERT overwrites DELETE → REPLACE + require.Len(t, results, 1) + require.Equal(t, binlog.InsertDML, results[0].dml) + }) + + t.Run("multiple UPDATEs same key emits last only", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.UpdateDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "v1"}), WhereColumnValues: sql.ToColumnValues([]interface{}{1, "v0"})}, + {DatabaseName: "test", DML: binlog.UpdateDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "v2"}), WhereColumnValues: sql.ToColumnValues([]interface{}{1, "v1"})}, + {DatabaseName: "test", DML: binlog.UpdateDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "v3"}), WhereColumnValues: sql.ToColumnValues([]interface{}{1, "v2"})}, + } + results, applied, _, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(3), applied) + // Only one REPLACE query with last values + require.Len(t, results, 1) + require.Equal(t, binlog.UpdateDML, results[0].dml) + }) + + t.Run("UPDATE then DELETE same key emits DELETE", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.UpdateDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "v1"}), WhereColumnValues: sql.ToColumnValues([]interface{}{1, "v0"})}, + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "v1"})}, + } + results, applied, _, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(2), applied) + require.Len(t, results, 1) + require.Equal(t, binlog.DeleteDML, results[0].dml) + }) + + t.Run("mixed keys deduplicate independently", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "alice"})}, + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{2, "bob"})}, + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "alice"})}, + } + results, applied, _, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(3), applied) + // Key 1: INSERT→DELETE = DELETE; Key 2: INSERT = REPLACE + // Should have both a DELETE batch and a REPLACE batch + hasDel, hasIns := false, false + for _, r := range results { + if r.dml == binlog.DeleteDML { + hasDel = true + } + if r.dml == binlog.InsertDML { + hasIns = true + } + } + require.True(t, hasDel, "expected a DELETE result for key 1") + require.True(t, hasIns, "expected an INSERT/REPLACE result for key 2") + }) + + t.Run("DELETE INSERT DELETE same key emits DELETE", func(t *testing.T) { + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "a"})}, + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{1, "b"})}, + {DatabaseName: "test", DML: binlog.DeleteDML, WhereColumnValues: sql.ToColumnValues([]interface{}{1, "b"})}, + } + results, applied, _, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(3), applied) + require.Len(t, results, 1) + require.Equal(t, binlog.DeleteDML, results[0].dml) + }) + + t.Run("ignored events beyond range", func(t *testing.T) { + migrationContext.MigrationRangeMaxValues = sql.ToColumnValues([]interface{}{10}) + migrationContext.MigrationIterationRangeMaxValues = sql.ToColumnValues([]interface{}{10}) + defer func() { + migrationContext.MigrationRangeMaxValues = nil + migrationContext.MigrationIterationRangeMaxValues = nil + }() + + events := []*binlog.BinlogDMLEvent{ + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{5, "ok"})}, + {DatabaseName: "test", DML: binlog.InsertDML, NewColumnValues: sql.ToColumnValues([]interface{}{99, "beyond"})}, + } + results, applied, ignored, err := applier.buildDMLEventQueriesMerged(events) + require.NoError(t, err) + require.Equal(t, int64(1), applied) + require.Equal(t, int64(1), ignored) + require.Len(t, results, 1) + require.Equal(t, binlog.InsertDML, results[0].dml) + }) + + t.Run("empty events returns empty results", func(t *testing.T) { + results, applied, ignored, err := applier.buildDMLEventQueriesMerged([]*binlog.BinlogDMLEvent{}) + require.NoError(t, err) + require.Equal(t, int64(0), applied) + require.Equal(t, int64(0), ignored) + require.Empty(t, results) + }) +} + +func TestInspectIsIntegerColumnType(t *testing.T) { + tests := []struct { + colType string + expected bool + }{ + {"int", true}, + {"integer", true}, + {"bigint", true}, + {"tinyint", true}, + {"smallint", true}, + {"mediumint", true}, + {"int(11)", true}, + {"bigint unsigned", true}, + {"tinyint(1)", true}, + {"point", false}, + {"multipoint", false}, + {"varchar(255)", false}, + {"text", false}, + {"decimal(10,2)", false}, + {"float", false}, + {"timestamp", false}, + } + for _, tt := range tests { + t.Run(tt.colType, func(t *testing.T) { + require.Equal(t, tt.expected, isIntegerColumnType(tt.colType), "isIntegerColumnType(%q)", tt.colType) + }) + } +} + +func TestInspectIsDecimalColumnType(t *testing.T) { + tests := []struct { + colType string + expected bool + }{ + {"decimal(10,2)", true}, + {"numeric(5,3)", true}, + {"float", true}, + {"double", true}, + {"float(7,4)", true}, + {"int", false}, + {"bigint", false}, + {"varchar(255)", false}, + {"point", false}, + {"timestamp", false}, + } + for _, tt := range tests { + t.Run(tt.colType, func(t *testing.T) { + require.Equal(t, tt.expected, isDecimalColumnType(tt.colType), "isDecimalColumnType(%q)", tt.colType) + }) + } +} + +func TestFormatNumericValue(t *testing.T) { + tests := []struct { + name string + arg interface{} + want string + wantErr bool + }{ + {"int", int(42), "42", false}, + {"int64", int64(-999), "-999", false}, + {"uint64 large", uint64(18446744073709551615), "18446744073709551615", false}, + {"float64", float64(3.14), "3.14", false}, + {"numeric string", "12345", "12345", false}, + {"decimal string", "99.99", "99.99", false}, + {"non-numeric string", "hello", "", true}, + {"sql injection string", "1 OR 1=1", "", true}, + {"nil", nil, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := formatNumericValue(tt.arg) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 96aadd672..fb496da32 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -10,7 +10,9 @@ import ( gosql "database/sql" "errors" "fmt" + "math/big" "reflect" + "strconv" "strings" "sync/atomic" "time" @@ -146,7 +148,11 @@ func (isp *Inspector) inspectOriginalAndGhostTables() (err error) { for i, sharedUniqueKey := range sharedUniqueKeys { isp.applyColumnTypes(isp.migrationContext.DatabaseName, isp.migrationContext.OriginalTableName, &sharedUniqueKey.Columns) uniqueKeyIsValid := true + isMemoryComparable := true for _, column := range sharedUniqueKey.Columns.Columns() { + if column.FormatValueFunc == nil { + isMemoryComparable = false + } switch column.Type { case sql.FloatColumnType: { @@ -164,6 +170,7 @@ func (isp *Inspector) inspectOriginalAndGhostTables() (err error) { } if uniqueKeyIsValid { isp.migrationContext.UniqueKey = sharedUniqueKeys[i] + isp.migrationContext.UniqueKey.IsMemoryComparable = isMemoryComparable break } } @@ -698,6 +705,100 @@ func (isp *Inspector) CountTableRows(ctx context.Context) error { } // applyColumnTypes +// isIntegerColumnType returns true for MySQL integer column types. +func isIntegerColumnType(lowerColumnType string) bool { + baseType := strings.Split(lowerColumnType, "(")[0] + baseType = strings.Fields(baseType)[0] + switch baseType { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": + return true + } + return false +} + +// isDecimalColumnType returns true for MySQL decimal/float column types. +func isDecimalColumnType(lowerColumnType string) bool { + baseType := strings.Split(lowerColumnType, "(")[0] + baseType = strings.Fields(baseType)[0] + switch baseType { + case "decimal", "numeric", "float", "double": + return true + } + return false +} + +func formatNumericValue(arg interface{}) (string, error) { + if arg == nil { + return "", fmt.Errorf("format numeric value: nil") + } + switch v := arg.(type) { + case int: + return strconv.FormatInt(int64(v), 10), nil + case int8: + return strconv.FormatInt(int64(v), 10), nil + case int16: + return strconv.FormatInt(int64(v), 10), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + case int64: + return strconv.FormatInt(v, 10), nil + case uint: + return strconv.FormatUint(uint64(v), 10), nil + case uint8: + return strconv.FormatUint(uint64(v), 10), nil + case uint16: + return strconv.FormatUint(uint64(v), 10), nil + case uint32: + return strconv.FormatUint(uint64(v), 10), nil + case uint64: + return strconv.FormatUint(v, 10), nil + case float32: + return strconv.FormatFloat(float64(v), 'g', -1, 32), nil + case float64: + return strconv.FormatFloat(v, 'g', -1, 64), nil + case string: + // Binlog may decode numeric columns as strings; validate it's numeric + if _, ok := new(big.Int).SetString(v, 10); ok { + return v, nil + } + if _, ok := new(big.Float).SetString(v); ok { + return v, nil + } + return "", fmt.Errorf("format numeric value: non-numeric string %q", v) + } + return "", fmt.Errorf("format numeric value: unsupported type %T", arg) +} + +func compareIntegralValues(a interface{}, b interface{}) (int, error) { + if a == nil || b == nil { + return 0, fmt.Errorf("compare integral values: nil argument") + } + left := new(big.Int) + if _, ok := left.SetString(fmt.Sprintf("%v", a), 10); !ok { + return 0, fmt.Errorf("compare integral values: cannot parse %v", a) + } + right := new(big.Int) + if _, ok := right.SetString(fmt.Sprintf("%v", b), 10); !ok { + return 0, fmt.Errorf("compare integral values: cannot parse %v", b) + } + return left.Cmp(right), nil +} + +func compareDecimalValues(a interface{}, b interface{}) (int, error) { + if a == nil || b == nil { + return 0, fmt.Errorf("compare decimal values: nil argument") + } + left, ok := new(big.Float).SetString(fmt.Sprintf("%v", a)) + if !ok { + return 0, fmt.Errorf("compare decimal values: cannot parse %v", a) + } + right, ok := new(big.Float).SetString(fmt.Sprintf("%v", b)) + if !ok { + return 0, fmt.Errorf("compare decimal values: cannot parse %v", b) + } + return left.Cmp(right), nil +} + func (isp *Inspector) applyColumnTypes(databaseName, tableName string, columnsLists ...*sql.ColumnList) error { query := ` select /* gh-ost */ * @@ -709,6 +810,7 @@ func (isp *Inspector) applyColumnTypes(databaseName, tableName string, columnsLi err := sqlutils.QueryRowsMap(isp.db, query, func(m sqlutils.RowMap) error { columnName := m.GetString("COLUMN_NAME") columnType := m.GetString("COLUMN_TYPE") + lowerColumnType := strings.ToLower(columnType) columnOctetLength := m.GetUint("CHARACTER_OCTET_LENGTH") isNullable := m.GetString("IS_NULLABLE") extra := m.GetString("EXTRA") @@ -722,29 +824,37 @@ func (isp *Inspector) applyColumnTypes(databaseName, tableName string, columnsLi column.Nullable = true } - if strings.Contains(columnType, "unsigned") { + if strings.Contains(lowerColumnType, "unsigned") { column.IsUnsigned = true } - if strings.Contains(columnType, "mediumint") { + if isIntegerColumnType(lowerColumnType) { + column.CompareValueFunc = compareIntegralValues + column.FormatValueFunc = formatNumericValue + } + if isDecimalColumnType(lowerColumnType) { + column.CompareValueFunc = compareDecimalValues + column.FormatValueFunc = formatNumericValue + } + if strings.Contains(lowerColumnType, "mediumint") { column.Type = sql.MediumIntColumnType } - if strings.Contains(columnType, "timestamp") { + if strings.Contains(lowerColumnType, "timestamp") { column.Type = sql.TimestampColumnType } - if strings.Contains(columnType, "datetime") { + if strings.Contains(lowerColumnType, "datetime") { column.Type = sql.DateTimeColumnType } - if strings.Contains(columnType, "json") { + if strings.Contains(lowerColumnType, "json") { column.Type = sql.JSONColumnType } - if strings.Contains(columnType, "float") { + if strings.Contains(lowerColumnType, "float") { column.Type = sql.FloatColumnType } - if strings.HasPrefix(columnType, "enum") { + if strings.HasPrefix(lowerColumnType, "enum") { column.Type = sql.EnumColumnType column.EnumValues = sql.ParseEnumValues(m.GetString("COLUMN_TYPE")) } - if strings.HasPrefix(columnType, "binary") { + if strings.HasPrefix(lowerColumnType, "binary") { column.Type = sql.BinaryColumnType column.BinaryOctetLength = columnOctetLength } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 90fa8c509..a07f18291 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1394,9 +1394,10 @@ func (mgtr *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { currentBinlogCoordinates := mgtr.eventsStreamer.GetCurrentBinlogCoordinates() - status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; Lag: %.2fs, HeartbeatLag: %.2fs, State: %s; ETA: %s", + status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Ignored: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; Lag: %.2fs, HeartbeatLag: %.2fs, State: %s; ETA: %s", totalRowsCopied, rowsEstimate, progressPct, atomic.LoadInt64(&mgtr.migrationContext.TotalDMLEventsApplied), + atomic.LoadInt64(&mgtr.migrationContext.TotalDMLEventsIgnored), len(mgtr.applyEventsQueue), cap(mgtr.applyEventsQueue), base.PrettifyDurationOutput(elapsedTime), base.PrettifyDurationOutput(mgtr.migrationContext.ElapsedRowCopyTime()), currentBinlogCoordinates.DisplayString(), @@ -1685,6 +1686,9 @@ func (mgtr *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { } // Create a task to apply the DML event; this will be execute by executeWriteFuncs() var applyEventFunc tableWriteFunc = func() error { + if mgtr.migrationContext.IsMergeDMLEvents && mgtr.migrationContext.UniqueKey != nil && mgtr.migrationContext.UniqueKey.IsMemoryComparable && !mgtr.migrationContext.UniqueKey.HasNullable && len(mgtr.migrationContext.OriginalTableUniqueKeys) <= 1 { + return mgtr.applier.ApplyDMLEventQueriesMerged(dmlEvents) + } return mgtr.applier.ApplyDMLEventQueries(dmlEvents) } if err := mgtr.retryOperation(applyEventFunc); err != nil { diff --git a/go/sql/builder.go b/go/sql/builder.go index 6e41eb4e1..cbf0dd93c 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -48,7 +48,9 @@ func TruncateColumnName(name string, limit int) string { return truncatedName } -func buildColumnsPreparedValues(columns *ColumnList) []string { +// BuildColumnsPreparedValues returns a slice of prepared-statement placeholder tokens, +// one per column, using each column's type-specific conversion function. +func BuildColumnsPreparedValues(columns *ColumnList) []string { values := make([]string, columns.Len()) for i, column := range columns.Columns() { var token string @@ -127,7 +129,7 @@ func NewCheckpointQueryBuilder(databaseName, tableName string, uniqueKeyColumns if uniqueKeyColumns.Len() == 0 { return nil, fmt.Errorf("got 0 columns in BuildSetCheckpointInsertQuery") } - values := buildColumnsPreparedValues(uniqueKeyColumns) + values := BuildColumnsPreparedValues(uniqueKeyColumns) minUniqueColNames := []string{} maxUniqueColNames := []string{} for _, name := range uniqueKeyColumns.Names() { @@ -256,7 +258,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, } func BuildRangePreparedComparison(columns *ColumnList, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := buildColumnsPreparedValues(columns) + values := BuildColumnsPreparedValues(columns) return BuildRangeComparison(columns.Names(), values, args, comparisonSign) } @@ -324,8 +326,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin } func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool, noWait bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns) - rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns) + rangeStartValues := BuildColumnsPreparedValues(uniqueKeyColumns) + rangeEndValues := BuildColumnsPreparedValues(uniqueKeyColumns) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait) } @@ -563,7 +565,7 @@ func NewDMLInsertQueryBuilder(databaseName, tableName string, tableColumns, shar for i := range mappedSharedColumnNames { mappedSharedColumnNames[i] = EscapeName(mappedSharedColumnNames[i]) } - preparedValues := buildColumnsPreparedValues(mappedSharedColumns) + preparedValues := BuildColumnsPreparedValues(mappedSharedColumns) stmt := fmt.Sprintf(` insert /* gh-ost %s.%s */ ignore diff --git a/go/sql/types.go b/go/sql/types.go index 3f7cd1b78..0186ab45b 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -55,6 +55,8 @@ type Column struct { CharacterSetName string Nullable bool MySQLType string + CompareValueFunc func(a interface{}, b interface{}) (int, error) + FormatValueFunc func(a interface{}) (string, error) } func (cl *Column) convertArg(arg interface{}) interface{} { @@ -126,6 +128,11 @@ func (cl *Column) convertArg(arg interface{}) interface{} { return arg } +// ConvertArg applies type-specific conversion to the given argument value. +func (cl *Column) ConvertArg(arg interface{}) interface{} { + return cl.convertArg(arg) +} + func NewColumns(names []string) []Column { result := make([]Column, len(names)) for i := range names { @@ -290,6 +297,9 @@ type UniqueKey struct { Columns ColumnList HasNullable bool IsAutoIncrement bool + // IsMemoryComparable indicates all columns in this key have numeric types + // that can be safely compared and formatted in-memory for merge-DML batching. + IsMemoryComparable bool } // IsPrimary checks if this unique key is primary @@ -309,6 +319,26 @@ func (uk *UniqueKey) String() string { return fmt.Sprintf("%s: %s; has nullable: %+v", description, uk.Columns.Names(), uk.HasNullable) } +// FormatValues formats the given argument values as string representations suitable +// for use as map keys and SQL literals in batched merge-DML operations. +func (uk *UniqueKey) FormatValues(args []interface{}) ([]string, error) { + if len(args) != uk.Columns.Len() { + return nil, fmt.Errorf("unique key args count mismatch: got %d, want %d", len(args), uk.Columns.Len()) + } + values := make([]string, 0, len(args)) + for i, column := range uk.Columns.Columns() { + if column.FormatValueFunc == nil { + return nil, fmt.Errorf("column %s does not support format value", column.Name) + } + val, err := column.FormatValueFunc(args[i]) + if err != nil { + return nil, err + } + values = append(values, val) + } + return values, nil +} + type ColumnValues struct { abstractValues []interface{} ValuesPointers []interface{}