Skip to content

Commit 2c34e8d

Browse files
committed
add detect database variant; vary backup locking on that basis
Signed-off-by: Avi Deitcher <avi@deitcher.net>
1 parent a97ed17 commit 2c34e8d

File tree

4 files changed

+209
-13
lines changed

4 files changed

+209
-13
lines changed

pkg/database/mysql/dump.go

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"strings"
2626
"text/template"
2727
"time"
28+
29+
dbutil "github.com/databacker/mysql-backup/pkg/util/database"
2830
)
2931

3032
/*
@@ -177,16 +179,11 @@ func (data *Data) Dump() error {
177179

178180
// Lock all tables before dumping if present
179181
if data.LockTables && (len(tables) > 0 || len(views) > 0) {
180-
var b bytes.Buffer
181-
b.WriteString("LOCK TABLES ")
182-
for index, table := range append(tables, views...) {
183-
if index != 0 {
184-
b.WriteString(",")
185-
}
186-
b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */")
182+
lockCommand, err := data.getBackupLockCommand(tables, views)
183+
if err != nil {
184+
return fmt.Errorf("failed to get lock command: %w", err)
187185
}
188-
189-
if _, err := data.Connection.Exec(b.String()); err != nil {
186+
if _, err := data.Connection.Exec(lockCommand); err != nil {
190187
return err
191188
}
192189

@@ -537,6 +534,36 @@ func (data *Data) getProceduresOrFunctionsCreateQueries(t string) ([]string, err
537534
return toGet, nil
538535
}
539536

537+
// getBackupLockCommand returns the SQL command to lock the tables for backup
538+
// It may vary depending on the database variant or version, so it is generated dynamically
539+
func (data *Data) getBackupLockCommand(tables, views []Table) (string, error) {
540+
dbVar, err := dbutil.DetectVariant(data.Connection)
541+
if err != nil {
542+
return "", fmt.Errorf("failed to determine database variant: %w", err)
543+
}
544+
var lockString string
545+
switch dbVar {
546+
case dbutil.VariantMariaDB:
547+
lockString = "LOCK TABLES"
548+
case dbutil.VariantMySQL:
549+
lockString = "LOCK TABLES"
550+
case dbutil.VariantPercona:
551+
// Percona just use the simple LOCK TABLES FOR BACKUP command
552+
return "LOCK TABLES FOR BACKUP", nil
553+
default:
554+
lockString = "LOCK TABLES"
555+
}
556+
var b bytes.Buffer
557+
b.WriteString(lockString + " ")
558+
for index, table := range append(tables, views...) {
559+
if index != 0 {
560+
b.WriteString(",")
561+
}
562+
b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */")
563+
}
564+
return b.String(), nil
565+
}
566+
540567
func (meta *metaData) updateMetadata(data *Data) (err error) {
541568
var serverVersion sql.NullString
542569
err = data.tx.QueryRow("SELECT version()").Scan(&serverVersion)

pkg/util/database/const.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package database
2+
3+
type Variant string
4+
5+
const (
6+
// VariantMariaDB is the MariaDB variant of MySQL.
7+
VariantMariaDB Variant = "mariadb"
8+
// VariantMySQL is the MySQL variant of MySQL.
9+
VariantMySQL Variant = "mysql"
10+
// VariantPercona is the Percona variant of MySQL.
11+
VariantPercona Variant = "percona"
12+
)

pkg/util/database/detect.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package database
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"strings"
7+
)
8+
9+
// DetectVariant returns the variant of the database, which can affect some commands.
10+
// It uses several heuristics to determine the variant based on the version and comment.
11+
// None of this is 100% reliable, but it should work for most cases.
12+
func DetectVariant(conn *sql.DB) (Variant, error) {
13+
// Check @@version and @@version_comment
14+
var version, comment string
15+
err := conn.QueryRow("SELECT @@version, @@version_comment").Scan(&version, &comment)
16+
if err != nil {
17+
return "", fmt.Errorf("failed to query version: %w", err)
18+
}
19+
20+
versionLower := strings.ToLower(version)
21+
commentLower := strings.ToLower(comment)
22+
23+
// Heuristic 1: version string or comment
24+
switch {
25+
case strings.Contains(versionLower, "mariadb") || strings.Contains(commentLower, "mariadb"):
26+
return VariantMariaDB, nil
27+
case strings.Contains(commentLower, "percona"):
28+
return VariantPercona, nil
29+
case strings.Contains(commentLower, "mysql"):
30+
return VariantMySQL, nil
31+
}
32+
33+
// Heuristic 2: Check for Aria engine (MariaDB)
34+
var dummy string
35+
err = conn.QueryRow("SELECT 1 FROM information_schema.engines WHERE engine = 'Aria' LIMIT 1").Scan(&dummy)
36+
if err == nil {
37+
return VariantMariaDB, nil
38+
}
39+
40+
// Heuristic 3: Percona plugins
41+
err = conn.QueryRow("SELECT 1 FROM information_schema.plugins WHERE plugin_name LIKE '%percona%' LIMIT 1").Scan(&dummy)
42+
if err == nil {
43+
return VariantPercona, nil
44+
}
45+
46+
return VariantMySQL, nil
47+
}

test/backup_test.go

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"bytes"
99
"compress/gzip"
1010
"context"
11+
"database/sql"
1112
"errors"
1213
"fmt"
1314
"io"
@@ -28,6 +29,8 @@ import (
2829
"github.com/databacker/mysql-backup/pkg/database"
2930
"github.com/databacker/mysql-backup/pkg/storage"
3031
"github.com/databacker/mysql-backup/pkg/storage/credentials"
32+
dbutil "github.com/databacker/mysql-backup/pkg/util/database"
33+
3134
"github.com/docker/docker/api/types"
3235
"github.com/docker/docker/api/types/container"
3336
"github.com/docker/docker/client"
@@ -47,6 +50,8 @@ const (
4750
mysqlRootPass = "root"
4851
smbImage = "mysqlbackup_smb_test:latest"
4952
mysqlImage = "mysql:8.2.0"
53+
mariaImage = "mariadb:11.8.2-noble"
54+
perconaImage = "percona:8.0.42-33"
5055
bucketName = "mybucket"
5156
)
5257

@@ -880,8 +885,57 @@ func populatePrePost(base string, targets []backupTarget) (err error) {
880885
return nil
881886
}
882887

888+
func startDatabase(dc *dockerContext, baseDir, image, name string) (containerPort, error) {
889+
resp, err := dc.cli.ImagePull(context.Background(), image, types.ImagePullOptions{})
890+
if err != nil {
891+
return containerPort{}, fmt.Errorf("failed to pull mysql image: %v", err)
892+
}
893+
io.Copy(os.Stdout, resp)
894+
resp.Close()
895+
896+
// start the mysql container; configure it for lots of debug logging, in case we need it
897+
mysqlConf := `
898+
[mysqld]
899+
log_error =/var/log/mysql/mysql_error.log
900+
general_log_file=/var/log/mysql/mysql.log
901+
general_log =1
902+
slow_query_log =1
903+
slow_query_log_file=/var/log/mysql/mysql_slow.log
904+
long_query_time =2
905+
log_queries_not_using_indexes = 1
906+
`
907+
if err := os.Mkdir(baseDir, 0o755); err != nil {
908+
return containerPort{}, fmt.Errorf("failed to create mysql base directory: %v", err)
909+
}
910+
confFile := filepath.Join(baseDir, "log.cnf")
911+
if err := os.WriteFile(confFile, []byte(mysqlConf), 0644); err != nil {
912+
return containerPort{}, fmt.Errorf("failed to write mysql config file: %v", err)
913+
}
914+
logDir := filepath.Join(baseDir, "mysql_logs")
915+
if err := os.Mkdir(logDir, 0755); err != nil {
916+
return containerPort{}, fmt.Errorf("failed to create mysql log directory: %v", err)
917+
}
918+
919+
// start mysql
920+
cid, port, err := dc.startContainer(
921+
image, name, "3306/tcp", []string{fmt.Sprintf("%s:/etc/mysql/conf.d/log.conf:ro", confFile), fmt.Sprintf("%s:/var/log/mysql", logDir)}, nil, []string{
922+
fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", mysqlRootPass),
923+
"MYSQL_DATABASE=tester",
924+
fmt.Sprintf("MYSQL_USER=%s", mysqlUser),
925+
fmt.Sprintf("MYSQL_PASSWORD=%s", mysqlPass),
926+
})
927+
if err != nil {
928+
return containerPort{}, fmt.Errorf("failed to start mysql container: %v", err)
929+
}
930+
return containerPort{name: name, id: cid, port: port}, nil
931+
}
932+
883933
func TestIntegration(t *testing.T) {
884934
syscall.Umask(0)
935+
dc, err := getDockerContext()
936+
if err != nil {
937+
t.Fatalf("failed to get docker client: %v", err)
938+
}
885939
t.Run("dump", func(t *testing.T) {
886940
var (
887941
err error
@@ -898,10 +952,6 @@ func TestIntegration(t *testing.T) {
898952
if err := os.Chmod(base, 0o777); err != nil {
899953
t.Fatalf("failed to chmod temp dir: %v", err)
900954
}
901-
dc, err := getDockerContext()
902-
if err != nil {
903-
t.Fatalf("failed to get docker client: %v", err)
904-
}
905955
backupFile := filepath.Join(base, "backup.sql")
906956
compactBackupFile := filepath.Join(base, "backup-compact.sql")
907957
if mysql, smb, s3, s3backend, err = setup(dc, base, backupFile, compactBackupFile); err != nil {
@@ -1033,4 +1083,64 @@ func TestIntegration(t *testing.T) {
10331083
})
10341084
})
10351085
})
1086+
t.Run("dbutil", func(t *testing.T) {
1087+
t.Run("detect", func(t *testing.T) {
1088+
// start all database variants
1089+
// wait for them to be ready
1090+
// then run the detect command on each of them
1091+
// then tear them down
1092+
1093+
// set up dirs
1094+
1095+
base := t.TempDir()
1096+
tests := []struct {
1097+
name string
1098+
image string
1099+
containerName string
1100+
variant dbutil.Variant
1101+
}{
1102+
{"mysql", mysqlImage, "mysql-detect", dbutil.VariantMySQL},
1103+
{"maria", mariaImage, "maria-detect", dbutil.VariantMariaDB},
1104+
{"percona", perconaImage, "maria-detect", dbutil.VariantMariaDB},
1105+
}
1106+
// tear down at the end
1107+
var cids []string
1108+
defer func() {
1109+
if err := teardown(dc, cids...); err != nil {
1110+
log.Errorf("failed to teardown test: %v", err)
1111+
}
1112+
}()
1113+
1114+
for _, tt := range tests {
1115+
t.Run(tt.name, func(t *testing.T) {
1116+
container, err := startDatabase(dc, filepath.Join(base, tt.name), tt.image, tt.containerName)
1117+
if err != nil {
1118+
t.Fatalf("failed to start mysql container: %v", err)
1119+
}
1120+
cids = append(cids, container.id)
1121+
if err = dc.waitForDBConnectionAndGrantPrivileges(container.id, mysqlRootUser, mysqlRootPass); err != nil {
1122+
return
1123+
}
1124+
dbconn := database.Connection{
1125+
User: mysqlRootUser,
1126+
Pass: mysqlRootPass,
1127+
Host: "localhost",
1128+
Port: container.port,
1129+
}
1130+
1131+
db, err := sql.Open("mysql", dbconn.MySQL())
1132+
if err != nil {
1133+
t.Fatalf("failed to open connection to database: %v", err)
1134+
}
1135+
v, err := dbutil.DetectVariant(db)
1136+
if err != nil {
1137+
t.Errorf("error detecting database variant: %v", err)
1138+
}
1139+
if v != tt.variant {
1140+
t.Errorf("expected database variant to be %s, got %s", tt.variant, v)
1141+
}
1142+
})
1143+
}
1144+
})
1145+
})
10361146
}

0 commit comments

Comments
 (0)