Skip to content

Commit

Permalink
mylogical: Unify GTIDSet handling
Browse files Browse the repository at this point in the history
This change eliminates the multiple types for handdling GTIDSet flavors into a
single wrapper type. It is intended to make a future update to support
automatically switching in and out of time-based backfill easier to implement.

This change addresses a few cases where errors in GTID processing were being
discarded, leading to incorrect calls to the logical.Events API. Specifically,
DDL-only events reported from MariaDB would cause multiple OnBegin calls
without paired calls to OnCommit.

The conn.lastStamp field is removed and is instead communicated over the
channel. This was necessary to ensure correct resynchronization.

There are also some other spot cleanups to eliminate code warnings in GoLand.
  • Loading branch information
bobvawter committed Aug 1, 2022
1 parent c8d5e1c commit becd684
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 243 deletions.
4 changes: 1 addition & 3 deletions internal/source/mylogical/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,17 @@ func (c *Config) Preflight() error {
c.password, _ = u.User.Password()
params := u.Query()
sslmode := params.Get("sslmode")
var tls *tls.Config

switch sslmode {
case "disable":
// tls configuration won't be set if we disable sslmode
case "require", "verify-ca", "verify-full":
tls, err = newClientTLSConfig(params, sslmode == "require", u.Hostname())
c.tlsConfig, err = newClientTLSConfig(params, sslmode == "require", u.Hostname())
if err != nil {
return err
}
default:
return errors.Errorf("invalid sslmode: %q", sslmode)
}
c.tlsConfig = tls
return nil
}
181 changes: 76 additions & 105 deletions internal/source/mylogical/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package mylogical

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand All @@ -41,12 +42,8 @@ import (
type conn struct {
// Columns, as ordered by the source database.
columns map[ident.Table][]types.ColData
// Key to set/retrieve state
consistentPointKey string
// Flavor is one of the mysql.MySQLFlavor or mysql.MariaDBFlavor constants
flavor string
// Last Stamp
lastStamp stamp.Stamp
// Map source ids to target tables.
relations map[uint64]ident.Table
// The configuration for opening replication connections.
Expand All @@ -67,38 +64,30 @@ const (

var _ logical.Dialect = (*conn)(nil)

func newStamp(flavor string) (stamp.Stamp, error) {
switch flavor {
case mysql.MySQLFlavor:
return newMySQLStamp(), nil
case mysql.MariaDBFlavor:
return newMariadbStamp(), nil
default:
return nil, errors.Errorf("Invalid flavor %s", flavor)
}
}

// Process implements logical.Dialect and receives a sequence of logical
// replication messages, or possibly a rollbackMessage.
func (c *conn) Process(
ctx context.Context, ch <-chan logical.Message, events logical.Events,
) error {
for {
// Perform context-aware read.
var msg logical.Message
select {
case msg = <-ch:
case <-ctx.Done():
return ctx.Err()
}
// This is the expected consistent point (i.e. transaction id) that
// we expect to see, given all previous messages on the wire. It is
// set, and reset, any time the upstream producer (re-)starts a read
// from the transaction log.
var streamCP *consistentPoint

for msg := range ch {
// Ensure that we resynchronize.
if logical.IsRollback(msg) {
if err := events.OnRollback(ctx, msg); err != nil {
return err
}
continue
}
var err error
// Resynchronize with the view of consumed transactions.
if nextStamp, ok := msg.(*consistentPoint); ok {
streamCP = nextStamp
continue
}
var ev, ok = msg.(replication.BinlogEvent)
if !ok {
return errors.Errorf("unexpected message %T", msg)
Expand All @@ -123,48 +112,61 @@ func (c *conn) Process(
// we expect a MariadbGTIDEvent with the GTID to begin the transaction
log.Tracef("processing %T", ev.Event)

EventProcessing:
switch e := ev.Event.(type) {
case *replication.XIDEvent:
// On commit should preserve the GTIDs so we can verify consistency,
// and restart the process from the last committed transaction.
log.Tracef("Commit")
err = events.OnCommit(ctx)
if err := events.OnCommit(ctx); err != nil {
return err
}

case *replication.GTIDEvent:
// A transaction is executed and committed on the source.
// This client transaction is assigned a GTID composed of the source's UUID
// and the smallest nonzero transaction sequence number not yet used on this server (GNO)
switch s := c.lastStamp.(type) {
case mySQLStamp:
u, _ := uuid.FromBytes(e.SID)
ns := fmt.Sprintf("%s:%d", u.String(), e.GNO)
a, err := mysql.ParseUUIDSet(ns)
if err == nil {
c.lastStamp = s.addMysqlGTIDSet(a)
}
default:
errors.Errorf("unexpected GTIDEvent for %T", s)
u, err := uuid.FromBytes(e.SID)
if err != nil {
return err
}
ns := fmt.Sprintf("%s:%d", u.String(), e.GNO)
toAdd, err := mysql.ParseUUIDSet(ns)
if err != nil {
return err
}
streamCP = streamCP.withMysqlGTIDSet(toAdd)

case *replication.MariadbGTIDEvent:
switch s := c.lastStamp.(type) {
case mariadbStamp:
a := e.GTID
c.lastStamp = s.addMariaGTIDSet(&a)
events.OnBegin(ctx, c.lastStamp)
default:
errors.Errorf("unexpected MariadbGTIDEvent for %T", s)
// We ignore events that won't have a terminating COMMIT
// events, e.g. schema changes.
// See flags section: https://mariadb.com/kb/en/gtid_event/
if e.IsStandalone() {
continue
}
var err error
streamCP, err = streamCP.withMariaGTIDSet(&e.GTID)
if err != nil {
return err
}
if err := events.OnBegin(ctx, streamCP); err != nil {
return err
}

case *replication.QueryEvent:
// Only supporting BEGIN
// DDL statement would also sent here.
log.Tracef("Query: %s %+v\n", e.Query, e.GSet)
if string(e.Query) == "BEGIN" {
err = events.OnBegin(ctx, c.lastStamp)
if bytes.Equal(e.Query, []byte("BEGIN")) {
if err := events.OnBegin(ctx, streamCP); err != nil {
return err
}
}

case *replication.TableMapEvent:
err = c.onRelation(e, events.GetTargetDB())
if err := c.onRelation(e); err != nil {
return err
}

case *replication.RowsEvent:
var operation mutationType
switch ev.Header.EventType {
Expand All @@ -175,51 +177,44 @@ func (c *conn) Process(
case replication.WRITE_ROWS_EVENTv0, replication.WRITE_ROWS_EVENTv1, replication.WRITE_ROWS_EVENTv2:
operation = insertMutation
default:
err = errors.Errorf("Operation not supported %s", ev.Header.EventType)
break EventProcessing
return errors.Errorf("Operation not supported %s", ev.Header.EventType)
}
mutationCount.With(prometheus.Labels{"type": operation.String()}).Inc()
err = c.onDataTuple(ctx, events, e, operation)
if err := c.onDataTuple(ctx, events, e, operation); err != nil {
return err
}

default:
err = errors.Errorf("unimplemented logical replication message %+v", e)
}
if err != nil {
return err
return errors.Errorf("unimplemented logical replication message %+v", e)
}
}
return nil
}

// ReadInto implements logical.Dialect, opens a replication connection,
// and writes supported events into the provided channel.
func (c *conn) ReadInto(ctx context.Context, ch chan<- logical.Message, state logical.State) error {
syncer := replication.NewBinlogSyncer(c.sourceConfig)
defer syncer.Close()
if state.GetConsistentPoint() == nil {
return errors.New("missing gtidset")
}
log.Tracef("ReadInto: %+v", state)
m, err := state.GetConsistentPoint().MarshalText()
if err != nil {
return errors.Wrap(err, "unable to parse gtidset")
}

gtidset, err := mysql.ParseGTIDSet(c.flavor, string(m))
if err != nil {
return errors.Wrap(err, "unable to parse gtidset")
cp := state.GetConsistentPoint()
if cp == nil {
return errors.New("missing gtidset")
}
streamer, err := syncer.StartSyncGTID(gtidset)

streamer, err := syncer.StartSyncGTID(cp.(*consistentPoint).AsGTIDSet())
if err != nil {
dialFailureCount.Inc()
return errors.WithStack(err)
return err
}
dialSuccessCount.Inc()

c.lastStamp, err = c.UnmarshalStamp([]byte(gtidset.String()))
if err != nil {
dialFailureCount.Inc()
return err
// Send the initial consistent point we're reading from.
select {
case ch <- cp.(*consistentPoint):
case <-ctx.Done():
return ctx.Err()
}

for ctx.Err() == nil {
ev, err := streamer.GetEvent(ctx)
if err != nil {
Expand Down Expand Up @@ -263,36 +258,14 @@ func (c *conn) ReadInto(ctx context.Context, ch chan<- logical.Message, state lo
return nil
}

// UnmarshalStamp decodes GTID Sets expressed as strings.
// Supports MySQL or MariaDB
// See https://dev.mysql.com/doc/refman/8.0/en/replication-gtids-concepts.html
// and https://mariadb.com/kb/en/gtid/
// Examples:
// MySQL: E11FA47-71CA-11E1-9E33-C80AA9429562:1-3:11:47-49
// MariaDB: 0-1-1
// UnmarshalStamp implements logical.Dialect. It delegates to
// consistentPoint.UnmarshalText.
func (c *conn) UnmarshalStamp(stamp []byte) (stamp.Stamp, error) {
log.Tracef("UnmarshalStamp %s", stamp)
s, err := mysql.ParseGTIDSet(c.flavor, string(stamp))
cp, err := newConsistentPoint(c.flavor)
if err != nil {
return nil, errors.Wrapf(err, "cannot unmarshal stamp %s", string(stamp))
}
switch c.flavor {
case mysql.MySQLFlavor:
ret, ok := s.(*mysql.MysqlGTIDSet)
if !ok {
return nil, errors.New("cannot unmarshal stamp " + string(stamp))
}
return mySQLStamp{gtidset: ret}, nil
case mysql.MariaDBFlavor:
ret, ok := s.(*mysql.MariadbGTIDSet)
if !ok {
return nil, errors.New("cannot unmarshal stamp " + string(stamp))
}
return mariadbStamp{gtidset: ret}, nil
default:
return nil, errors.New("invalid flavor")
return nil, err
}

return cp, cp.UnmarshalText(stamp)
}

func (c *conn) onDataTuple(
Expand Down Expand Up @@ -366,7 +339,7 @@ func (c *conn) onDataTuple(
// onRelation updates the source database namespace mappings.
// Columns names are only available if
// set global binlog_row_metadata = full;
func (c *conn) onRelation(msg *replication.TableMapEvent, targetDB ident.Ident) error {
func (c *conn) onRelation(msg *replication.TableMapEvent) error {
tbl := ident.NewTable(
ident.New(string(msg.Schema)),
ident.Public,
Expand Down Expand Up @@ -413,7 +386,7 @@ var (
// Based on the type of server it also verifies that the settings defined in the
// mySQLSystemSettings and mariaDBSystemSettings slices are correctly configured for the replication to work.
// It returns mysql.MariaDBFlavor or mysql.MySQLFlavor upon success.
func getFlavor(ctx context.Context, config *Config) (string, error) {
func getFlavor(config *Config) (string, error) {
addr := fmt.Sprintf("%s:%d", config.host, config.port)
c, err := client.Connect(addr, config.user, config.password, "", func(c *client.Conn) {
c.SetTLSConfig(config.tlsConfig)
Expand All @@ -434,15 +407,15 @@ func getFlavor(ctx context.Context, config *Config) (string, error) {
log.Infof("Version info: %s", version)
if strings.Contains(version, "mariadb") {
for _, v := range mariaDBSystemSettings {
err = checkSystemSetting(ctx, c, v[0], v[1])
err = checkSystemSetting(c, v[0], v[1])
if err != nil {
return "", err
}
}
return mysql.MariaDBFlavor, nil
} else if strings.Contains(version, "MySQL") {
for _, v := range mySQLSystemSettings {
err = checkSystemSetting(ctx, c, v[0], v[1])
err = checkSystemSetting(c, v[0], v[1])
if err != nil {
return "", err
}
Expand All @@ -453,9 +426,7 @@ func getFlavor(ctx context.Context, config *Config) (string, error) {
}
}

func checkSystemSetting(
ctx context.Context, c *client.Conn, variable string, expected string,
) error {
func checkSystemSetting(c *client.Conn, variable string, expected string) error {
res, err := c.Execute(fmt.Sprintf("select @@%s;", variable))
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions internal/source/mylogical/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
)

func Test_mySqlStamp_Less(t *testing.T) {
a := assert.New(t)
tests := []struct {
name string
this string
Expand Down Expand Up @@ -53,6 +52,7 @@ func Test_mySqlStamp_Less(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
this, err := c.UnmarshalStamp([]byte(tt.this))
if !a.NoError(err) {
return
Expand All @@ -68,7 +68,6 @@ func Test_mySqlStamp_Less(t *testing.T) {
}

func Test_mariadbStamp_Less(t *testing.T) {
a := assert.New(t)
tests := []struct {
name string
this string
Expand All @@ -90,6 +89,7 @@ func Test_mariadbStamp_Less(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
this, err := c.UnmarshalStamp([]byte(tt.this))
if !a.NoError(err) {
return
Expand Down
Loading

0 comments on commit becd684

Please sign in to comment.