diff --git a/README.md b/README.md index 9487db3..0fbe5c0 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,66 @@ output } ``` +## Advanced Usage: Template SQL + +The SQL to be executed is rendered by pongo2, a Django-syntax like template-engine, once. +Therefore, the SQL to be specified can use template notation. +In CLI, `--var key=value` flag, in Lambda, `vars` key can be specified as a string hash, and variables can be passed at runtime to dynamically generate SQL by template notation. +For example, if you specify `--var relation=users --var limit=5` and execute the following template SQL, (environment variable `ENV=dev` is set) + +task_template.sql +```sql +CREATE DATABASE IF NOT EXISTS {{ must_env("ENV") }}_mysqlbatch; +USE {{ must_env("ENV") }}_mysqlbatch; +DROP TABLE IF EXISTS {{ var("relation","hoge") }}; +CREATE TABLE {{ var("relation","hoge") }} ( + id INTEGER auto_increment, + name VARCHAR(191), + age INTEGER, + PRIMARY KEY (`id`), + UNIQUE INDEX `name` (`name`) +); +INSERT IGNORE INTO {{ var("relation","hoge") }}(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')); +INSERT IGNORE INTO {{ var("relation","hoge") }}(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM {{ var("relation","hoge") }}; +{%- for i in %} +INSERT IGNORE INTO {{ var("relation","hoge") }}(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM {{ var("relation","hoge") }}; +{%- endfor %} +SELECT * FROM {{ var("relation","hoge") }} WHERE age is NOT NULL LIMIT {{ must_var("limit") }}; +``` + +The following SQL is executed. + +```sql +CREATE DATABASE IF NOT EXISTS dev_mysqlbatch; +USE dev_mysqlbatch; +DROP TABLE IF EXISTS users; +CREATE TABLE users ( + id INTEGER auto_increment, + name VARCHAR(191), + age INTEGER, + PRIMARY KEY (`id`), + UNIQUE INDEX `name` (`name`) +); +INSERT IGNORE INTO users(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')); +INSERT IGNORE INTO users(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM users; +INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users; +INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users; +INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users; +SELECT * FROM hoge WHERE age is NOT NULL LIMIT 5; +``` + +When executing with Lambda, the following JSON can be specified as the payload. + +```json +{ + "file": "./task_template.sql", + "vars": { + "relation": "users", + "limit": 5 + } +} +``` + ## License see [LICENSE](https://github.com/mashiike/mysqlbatch/blob/master/LICENSE) file. diff --git a/cmd/mysqlbatch/main.go b/cmd/mysqlbatch/main.go index a5b6c66..c7ee18d 100644 --- a/cmd/mysqlbatch/main.go +++ b/cmd/mysqlbatch/main.go @@ -27,10 +27,12 @@ var ( func main() { conf := mysqlbatch.NewDefaultConfig() var ( + vars flagx.StringSlice versionFlag = flag.Bool("v", false, "show version info") silentFlag = flag.Bool("s", false, "no output to console") detailFlag = flag.Bool("d", false, "output deteil for execute sql, -s has priority") enableBootstrapFlag = flag.Bool("enable-lambda-bootstrap", false, "if run on AWS Lambda, running as lambda bootstrap") + dumpRenderedSQLFlag = flag.Bool("dump-rendered-sql", false, "dump rendered sql") ) flag.StringVar(&conf.DSN, "dsn", "", "dsn format as [mysql://]user:pass@tcp(host:port)/dbname (default \"\")") flag.StringVar(&conf.User, "u", "root", "username (default root)") @@ -43,6 +45,7 @@ func main() { flag.StringVar(&conf.Host, "host", "", "") flag.StringVar(&conf.Location, "location", "", "timezone of mysql database system") flag.StringVar(&conf.PasswordSSMParameterName, "password-ssm-parameter-name", "", "pasword ssm parameter name") + flag.Var(&vars, "var", "set variable (format: key=value)") flag.VisitAll(flagx.EnvToFlagWithPrefix("MYSQLBATCH_")) flag.Parse() @@ -52,6 +55,9 @@ func main() { fmt.Printf("build date: %s\n", BuildDate) return } + if *dumpRenderedSQLFlag { + mysqlbatch.DefaultSQLDumper = os.Stderr + } conf.Database = os.Getenv("MYSQLBATCH_DATABASE") if flag.NArg() == 1 { conf.Database = flag.Arg(0) @@ -83,7 +89,16 @@ func main() { }) } } - if err := executer.ExecuteContext(ctx, os.Stdin); err != nil { + varsMap := make(map[string]string) + for _, v := range vars { + kv := strings.SplitN(v, "=", 2) + if len(kv) != 2 { + log.Printf("invalid var format: %s\n", v) + os.Exit(1) + } + varsMap[kv[0]] = kv[1] + } + if err := executer.ExecuteContext(ctx, os.Stdin, varsMap); err != nil { log.Println(err) os.Exit(1) } @@ -97,15 +112,16 @@ type handler struct { } type payload struct { - SQL string `json:"sql,omitempty"` - File string `json:"file,omitempty"` - DSN *string `json:"dsn,omitempty"` - User *string `json:"user,omitempty"` - Port *int `json:"port,omitempty"` - Host *string `json:"host,omitempty"` - Database *string `json:"database,omitempty"` - Location *string `json:"Location,omitempty"` - PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"` + SQL string `json:"sql,omitempty"` + File string `json:"file,omitempty"` + DSN *string `json:"dsn,omitempty"` + User *string `json:"user,omitempty"` + Port *int `json:"port,omitempty"` + Host *string `json:"host,omitempty"` + Database *string `json:"database,omitempty"` + Location *string `json:"Location,omitempty"` + PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"` + Vars map[string]string `json:"vars,omitempty"` } type response struct { @@ -173,7 +189,7 @@ func (h *handler) Invoke(ctx context.Context, p *payload) (*response, error) { Query: query, }) }) - if err := executer.ExecuteContext(ctx, query); err != nil { + if err := executer.ExecuteContext(ctx, query, p.Vars); err != nil { return nil, err } r := &response{ diff --git a/executer.go b/executer.go index a8daa95..d667690 100644 --- a/executer.go +++ b/executer.go @@ -2,18 +2,24 @@ package mysqlbatch import ( "bufio" + "bytes" "context" "database/sql" + "fmt" "io" + "os" "strings" "sync" "time" + "github.com/flosch/pongo2/v6" _ "github.com/go-sql-driver/mysql" "github.com/olekukonko/tablewriter" "github.com/pkg/errors" ) +var DefaultSQLDumper io.Writer = io.Discard + // Executer queries the DB. There is no parallelism type Executer struct { mu sync.Mutex @@ -65,15 +71,15 @@ func (e *Executer) Close() error { } // Execute SQL -func (e *Executer) Execute(queryReader io.Reader) error { - return e.ExecuteContext(context.Background(), queryReader) +func (e *Executer) Execute(queryReader io.Reader, vars map[string]string) error { + return e.ExecuteContext(context.Background(), queryReader, vars) } // ExecuteContext SQL execute with context.Context -func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader) error { +func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error { e.mu.Lock() defer e.mu.Unlock() - if err := e.executeContext(ctx, queryReader); err != nil { + if err := e.executeContext(ctx, queryReader, vars); err != nil { return err } return e.updateLastExecuteTime(ctx) @@ -87,8 +93,21 @@ func (e *Executer) updateLastExecuteTime(ctx context.Context) error { return errors.Wrap(row.Scan(&e.lastExecuteTime), "scan db time") } -func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader) error { - scanner := NewQueryScanner(queryReader) +func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error { + bs, err := io.ReadAll(queryReader) + if err != nil { + return err + } + tpl, err := pongo2.FromBytes(bs) + if err != nil { + return errors.Wrap(err, "parse query template failed") + } + var buf bytes.Buffer + if err := tpl.ExecuteWriter(e.newPongo2Ctx(ctx, vars), &buf); err != nil { + return errors.Wrap(err, "execute query template failed") + } + reader := io.TeeReader(&buf, DefaultSQLDumper) + scanner := NewQueryScanner(reader) for scanner.Scan() { select { case <-ctx.Done(): @@ -205,6 +224,70 @@ func (e *Executer) SetTableSelectHook(hook func(query, table string)) { } } +func (e *Executer) newPongo2Ctx(ctx context.Context, vars map[string]string) pongo2.Context { + pongo2Ctx := pongo2.Context{ + "var": func(key string, defaultValue string) string { + if v, ok := vars[key]; ok { + return v + } + return defaultValue + }, + "must_var": func(key string) (string, error) { + if v, ok := vars[key]; ok { + return v, nil + } + return "", errors.Errorf("variable %s is not defined", key) + }, + "env": func(key string, defaultValue string) string { + if v, ok := os.LookupEnv(key); ok { + return v + } + return defaultValue + }, + "must_env": func(key string) (string, error) { + if v, ok := os.LookupEnv(key); ok { + return v, nil + } + return "", errors.Errorf("environment variable %s is not defined", key) + }, + "range": func(args ...int) ([]int, error) { + if len(args) == 0 { + return nil, errors.New("range requires at least 1 argument, got 0") + } + if len(args) > 3 { + return nil, fmt.Errorf("range requires at most 3 arguments, got %d", len(args)) + } + var start, end, step int + switch len(args) { + case 1: + start = 0 + end = args[0] + step = 1 + case 2: + start = args[0] + end = args[1] + step = 1 + case 3: + start = args[0] + end = args[1] + step = args[2] + } + if step == 0 { + return nil, errors.New("range requires step != 0") + } + if (step > 0 && start > end) || (step < 0 && start < end) { + return nil, errors.New("range requires start <= end when step > 0, or start >= end when step < 0") + } + var result []int + for i := start; i < end; i += step { + result = append(result, i) + } + return result, nil + }, + } + return pongo2Ctx +} + // QueryScanner separate string by ; and delete newline type QueryScanner struct { *bufio.Scanner diff --git a/executer_test.go b/executer_test.go index 6c78eae..8ce3918 100644 --- a/executer_test.go +++ b/executer_test.go @@ -6,6 +6,7 @@ import ( _ "embed" "fmt" "log" + "os" "strings" "sync/atomic" "testing" @@ -94,6 +95,9 @@ func diffStr(a, b string) (string, bool) { //go:embed testdata/test.sql var testSQL []byte +//go:embed testdata/test_template.sql +var testTemplateSQL []byte + func TestExecuterExecute(t *testing.T) { conf := mysqlbatch.NewDefaultConfig() conf.Password = "mysqlbatch" @@ -112,7 +116,37 @@ func TestExecuterExecute(t *testing.T) { require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query) require.Equal(t, 5, len(rows)) }) - err = e.Execute(bytes.NewReader(testSQL)) + err = e.Execute(bytes.NewReader(testSQL), nil) + require.NoError(t, err) + log.Println("LastExecuteTime:", e.LastExecuteTime()) + require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute)) + require.EqualValues(t, 1, count) +} + +func TestExecuterExecute__WithVars(t *testing.T) { + os.Setenv("ENV", "test") + mysqlbatch.DefaultSQLDumper = os.Stderr + conf := mysqlbatch.NewDefaultConfig() + conf.Password = "mysqlbatch" + conf.Location = "Asia/Tokyo" + e, err := mysqlbatch.New(context.Background(), conf) + require.NoError(t, err) + defer e.Close() + + e.SetTimeCheckQuery("SELECT NOW()") + e.SetTableSelectHook(func(query, table string) { + t.Log(query + "\n" + table + "\n") + }) + var count int32 + e.SetSelectHook(func(query string, columns []string, rows [][]string) { + atomic.AddInt32(&count, 1) + require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query) + require.Equal(t, 5, len(rows)) + }) + err = e.Execute(bytes.NewReader(testTemplateSQL), map[string]string{ + "relation": "users", + "limit": "5", + }) require.NoError(t, err) log.Println("LastExecuteTime:", e.LastExecuteTime()) require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute)) diff --git a/go.mod b/go.mod index 01dfb9a..8c0ae30 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.18.25 github.com/aws/aws-sdk-go-v2/service/ssm v1.36.0 github.com/aws/smithy-go v1.13.5 + github.com/flosch/pongo2/v6 v6.0.0 github.com/go-sql-driver/mysql v1.7.1 github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c github.com/olekukonko/tablewriter v0.0.5 diff --git a/go.sum b/go.sum index 2d787c1..756876e 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/flosch/pongo2/v6 v6.0.0 h1:lsGru8IAzHgIAw6H2m4PCyleO58I40ow6apih0WprMU= +github.com/flosch/pongo2/v6 v6.0.0/go.mod h1:CuDpFm47R0uGGE7z13/tTlt1Y6zdxvr2RLT5LJhsHEU= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= @@ -44,8 +46,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c h1:jrKp5SY9Qt8lQmorJAksSYOIexZdkp7EREJgx4mX9XA= github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c/go.mod h1:DNbx2/OnOT5GtlYTUF2xr4GZSunGDP1Wk0WO3mmaKz0= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -74,8 +76,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/testdata/test_template.sql b/testdata/test_template.sql new file mode 100644 index 0000000..5526cc1 --- /dev/null +++ b/testdata/test_template.sql @@ -0,0 +1,17 @@ +CREATE DATABASE IF NOT EXISTS {{ must_env("ENV") }}_mysqlbatch; +USE {{ must_env("ENV") }}_mysqlbatch; +DROP TABLE IF EXISTS {{ var("relation","hoge") }}; +CREATE TABLE {{ var("relation","hoge") }} ( + id INTEGER auto_increment, + name VARCHAR(191), + age INTEGER, + PRIMARY KEY (`id`), + UNIQUE INDEX `name` (`name`) +); +INSERT IGNORE INTO {{ var("relation","hoge") }}(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')); +INSERT IGNORE INTO {{ var("relation","hoge") }}(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM {{ var("relation","hoge") }}; +{%- for i in range(3) %} +INSERT IGNORE INTO {{ var("relation","hoge") }}(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM {{ var("relation","hoge") }}; +{%- endfor %} +SELECT * FROM {{ var("relation","hoge") }} WHERE age is NOT NULL LIMIT {{ must_var("limit") }}; +