Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func main() {
panic(err)
}
// adjust batch size according to source db table
cfgCopy.BatchSize = src.AdjustBatchSizeAccordingToSourceDbTable()
cfgCopy.BatchSize = int64(src.AdjustBatchSizeAccordingToSourceDbTable())
w := worker.NewWorker(&cfgCopy, fmt.Sprintf("%s.%s", db, table), ig, src)
w.Run(ctx)
}
Expand Down
12 changes: 6 additions & 6 deletions config/conf_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
"sourceUser": "root",
"sourcePass": "123456",
"sourceDB": "mydb",
"sourceTable": "t1",
"sourceQuery": "select * from mydb.t1",
"sourceTable": "test_table",
"sourceQuery": "select * from mydb.test_table",
"sourceWhereCondition": "id > 0",
"sourceSplitKey": "id",
"sourceSplitTimeKey": "",
"timeSplitUnit": "minute",
"databendDSN": "http://databend:databend@localhost:8000",
"databendTable": "testSync.t1",
"batchSize": 2,
"databendDSN": "http://databend:databend@localhost:8009",
"databendTable": "testSync.test_table",
"batchSize": 20000,
"batchMaxInterval": 30,
"userStage": "~",
"deleteAfterSync": false,
"deleteAfterSync": true,
"maxThread": 10
}
111 changes: 95 additions & 16 deletions source/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package source

import (
"database/sql"
"database/sql/driver"
"fmt"
"log"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -40,25 +42,25 @@ func NewMysqlSource(cfg *config.Config) (*MysqlSource, error) {

// AdjustBatchSizeAccordingToSourceDbTable has a concept called s, s = (maxKey - minKey) / sourceTableRowCount
// if s == 1 it means the data is uniform in the table, if s is much bigger than 1, it means the data is not uniform in the table
func (s *MysqlSource) AdjustBatchSizeAccordingToSourceDbTable() int64 {
func (s *MysqlSource) AdjustBatchSizeAccordingToSourceDbTable() uint64 {
minSplitKey, maxSplitKey, err := s.GetMinMaxSplitKey()
if err != nil {
return s.cfg.BatchSize
return uint64(s.cfg.BatchSize)
}
sourceTableRowCount, err := s.GetSourceReadRowsCount()
if err != nil {
return s.cfg.BatchSize
return uint64(s.cfg.BatchSize)
}
rangeSize := maxSplitKey - minSplitKey + 1
switch {
case int64(sourceTableRowCount) <= s.cfg.BatchSize:
return rangeSize
case rangeSize/int64(sourceTableRowCount) >= 10:
return s.cfg.BatchSize * 5
case rangeSize/int64(sourceTableRowCount) >= 100:
return s.cfg.BatchSize * 20
case rangeSize/uint64(sourceTableRowCount) >= 10:
return uint64(s.cfg.BatchSize * 5)
case rangeSize/uint64(sourceTableRowCount) >= 100:
return uint64(s.cfg.BatchSize * 20)
default:
return s.cfg.BatchSize
return uint64(s.cfg.BatchSize)
}
}

Expand All @@ -74,28 +76,41 @@ func (s *MysqlSource) GetSourceReadRowsCount() (int, error) {
return rowCount, nil
}

func (s *MysqlSource) GetMinMaxSplitKey() (int64, int64, error) {
rows, err := s.db.Query(fmt.Sprintf("select min(%s), max(%s) from %s.%s WHERE %s", s.cfg.SourceSplitKey,
s.cfg.SourceSplitKey, s.cfg.SourceDB, s.cfg.SourceTable, s.cfg.SourceWhereCondition))
func (s *MysqlSource) GetMinMaxSplitKey() (uint64, uint64, error) {
query := fmt.Sprintf("SELECT MIN(%s), MAX(%s) FROM %s.%s WHERE %s",
s.cfg.SourceSplitKey, s.cfg.SourceSplitKey,
s.cfg.SourceDB, s.cfg.SourceTable, s.cfg.SourceWhereCondition)

rows, err := s.db.Query(query)
if err != nil {
return 0, 0, err
}
defer rows.Close()

var minSplitKey, maxSplitKey sql.NullInt64
var minSplitKey, maxSplitKey interface{}
for rows.Next() {
err = rows.Scan(&minSplitKey, &maxSplitKey)
if err != nil {
return 0, 0, err
}
}

// Check if minSplitKey and maxSplitKey are valid (not NULL)
if !minSplitKey.Valid || !maxSplitKey.Valid {
// 处理 NULL
if minSplitKey == nil || maxSplitKey == nil {
return 0, 0, nil
}

return minSplitKey.Int64, maxSplitKey.Int64, nil
min64, err := toUint64(minSplitKey)
if err != nil {
return 0, 0, fmt.Errorf("failed to convert min value: %w", err)
}

max64, err := toUint64(maxSplitKey)
if err != nil {
return 0, 0, fmt.Errorf("failed to convert max value: %w", err)
}

return min64, max64, nil
}

func (s *MysqlSource) GetMinMaxTimeSplitKey() (string, string, error) {
Expand All @@ -117,6 +132,7 @@ func (s *MysqlSource) GetMinMaxTimeSplitKey() (string, string, error) {
}

func (s *MysqlSource) DeleteAfterSync() error {
logrus.Infof("DeleteAfterSync: %v", s.cfg.DeleteAfterSync)
if !s.cfg.DeleteAfterSync {
return nil
}
Expand All @@ -126,6 +142,8 @@ func (s *MysqlSource) DeleteAfterSync() error {
return err
}

logrus.Infof("dbTables: %v", dbTables)

for db, tables := range dbTables {
for _, table := range tables {
count, err := s.GetSourceReadRowsCount()
Expand Down Expand Up @@ -188,7 +206,9 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
switch columnType.DatabaseTypeName() {
case "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT":
scanArgs[i] = new(sql.NullInt64)
case "UNSIGNED INT", "UNSIGNED TINYINT", "UNSIGNED MEDIUMINT", "UNSIGNED BIGINT":
case "UNSIGNED BIGINT":
scanArgs[i] = new(NullUint64)
case "UNSIGNED INT", "UNSIGNED TINYINT", "UNSIGNED MEDIUMINT":
scanArgs[i] = new(sql.NullInt64)
case "FLOAT", "DOUBLE":
scanArgs[i] = new(sql.NullFloat64)
Expand Down Expand Up @@ -244,6 +264,12 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
} else {
row[i] = nil
}
case *NullUint64:
if v.Valid {
row[i] = v.Uint64
} else {
row[i] = nil
}
case *sql.NullBool:
if v.Valid {
row[i] = v.Bool
Expand Down Expand Up @@ -375,5 +401,58 @@ func (s *MysqlSource) GetDbTablesAccordingToSourceDbTables() (map[string][]strin
allDbTables[db] = append(allDbTables[db], tables...)
}
}
if s.cfg.SourceDB != "" && s.cfg.SourceTable != "" {
allDbTables[s.cfg.SourceDB] = append(allDbTables[s.cfg.SourceDB], s.cfg.SourceTable)
}
return allDbTables, nil
}

// NullUint64 represents a uint64 that may be null.
type NullUint64 struct {
Uint64 uint64
Valid bool // Valid is true if Uint64 is not NULL
}

// Scan implements the Scanner interface.
func (n *NullUint64) Scan(value interface{}) error {
if value == nil {
n.Uint64, n.Valid = 0, false
return nil
}

n.Valid = true
switch v := value.(type) {
case uint64:
n.Uint64 = v
case int64:
if v < 0 {
// 处理溢出的情况
n.Uint64 = uint64(v)
} else {
n.Uint64 = uint64(v)
}
case []byte:
var err error
n.Uint64, err = strconv.ParseUint(string(v), 10, 64)
if err != nil {
return err
}
case string:
var err error
n.Uint64, err = strconv.ParseUint(v, 10, 64)
if err != nil {
return err
}
default:
return fmt.Errorf("cannot scan type %T into NullUint64", value)
}
return nil
}

// Value implements the driver Valuer interface.
func (n NullUint64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Uint64, nil
}
43 changes: 28 additions & 15 deletions source/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ type OracleSource struct {
statsRecorder *DatabendSourceStatsRecorder
}

func (p *OracleSource) AdjustBatchSizeAccordingToSourceDbTable() int64 {
func (p *OracleSource) AdjustBatchSizeAccordingToSourceDbTable() uint64 {
minSplitKey, maxSplitKey, err := p.GetMinMaxSplitKey()
if err != nil {
return p.cfg.BatchSize
return uint64(p.cfg.BatchSize)
}
sourceTableRowCount, err := p.GetSourceReadRowsCount()
if err != nil {
return p.cfg.BatchSize
return uint64(p.cfg.BatchSize)
}
rangeSize := maxSplitKey - minSplitKey + 1
switch {
case int64(sourceTableRowCount) <= p.cfg.BatchSize:
return rangeSize
case rangeSize/int64(sourceTableRowCount) >= 10:
return p.cfg.BatchSize * 5
case rangeSize/int64(sourceTableRowCount) >= 100:
return p.cfg.BatchSize * 20
case rangeSize/uint64(sourceTableRowCount) >= 10:
return uint64(p.cfg.BatchSize * 5)
case rangeSize/uint64(sourceTableRowCount) >= 100:
return uint64(p.cfg.BatchSize * 20)
default:
return p.cfg.BatchSize
return uint64(p.cfg.BatchSize)
}
}

Expand Down Expand Up @@ -111,32 +111,45 @@ func (p *OracleSource) GetSourceReadRowsCount() (int, error) {
return rowCount, nil
}

func (p *OracleSource) GetMinMaxSplitKey() (int64, int64, error) {
func (p *OracleSource) GetMinMaxSplitKey() (uint64, uint64, error) {
err := p.SwitchDatabase()
if err != nil {
return 0, 0, err
}
rows, err := p.db.Query(fmt.Sprintf("select COALESCE(min(%s),0), COALESCE(max(%s),0) from %s.%s WHERE %s",
p.cfg.SourceSplitKey, p.cfg.SourceSplitKey, p.cfg.SourceDB, p.cfg.SourceTable, p.cfg.SourceWhereCondition))

query := fmt.Sprintf("SELECT COALESCE(MIN(%s), 0), COALESCE(MAX(%s), 0) FROM %s.%s WHERE %s",
p.cfg.SourceSplitKey, p.cfg.SourceSplitKey,
p.cfg.SourceDB, p.cfg.SourceTable, p.cfg.SourceWhereCondition)

rows, err := p.db.Query(query)
if err != nil {
return 0, 0, err
}
defer rows.Close()

var minSplitKey, maxSplitKey sql.NullInt64
var minSplitKey, maxSplitKey interface{}
for rows.Next() {
err = rows.Scan(&minSplitKey, &maxSplitKey)
if err != nil {
return 0, 0, err
}
}

// Check if minSplitKey and maxSplitKey are valid (not NULL)
if !minSplitKey.Valid || !maxSplitKey.Valid {
if minSplitKey == nil || maxSplitKey == nil {
return 0, 0, nil
}

return minSplitKey.Int64, maxSplitKey.Int64, nil
min64, err := toUint64(minSplitKey)
if err != nil {
return 0, 0, fmt.Errorf("failed to convert min value: %w", err)
}

max64, err := toUint64(maxSplitKey)
if err != nil {
return 0, 0, fmt.Errorf("failed to convert max value: %w", err)
}

return min64, max64, nil
}

func (p *OracleSource) GetMinMaxTimeSplitKey() (string, string, error) {
Expand Down
Loading