Skip to content

Commit

Permalink
Merge pull request #29 from mashiike/feature/template-sql
Browse files Browse the repository at this point in the history
allow Go template SQL
  • Loading branch information
mashiike committed May 29, 2023
2 parents 7159a38 + 36dc420 commit 9af7f5e
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 20 deletions.
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,66 @@ output
}
```

## Advanced Usage: Template SQL

The SQL to be executed is rendered by pongo2, a Django-syntax like template-engine, once.
Therefore, the SQL to be specified can use template notation.
In CLI, `--var key=value` flag, in Lambda, `vars` key can be specified as a string hash, and variables can be passed at runtime to dynamically generate SQL by template notation.
For example, if you specify `--var relation=users --var limit=5` and execute the following template SQL, (environment variable `ENV=dev` is set)

task_template.sql
```sql
CREATE DATABASE IF NOT EXISTS {{ must_env("ENV") }}_mysqlbatch;
USE {{ must_env("ENV") }}_mysqlbatch;
DROP TABLE IF EXISTS {{ var("relation","hoge") }};
CREATE TABLE {{ var("relation","hoge") }} (
id INTEGER auto_increment,
name VARCHAR(191),
age INTEGER,
PRIMARY KEY (`id`),
UNIQUE INDEX `name` (`name`)
);
INSERT IGNORE INTO {{ var("relation","hoge") }}(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com'));
INSERT IGNORE INTO {{ var("relation","hoge") }}(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM {{ var("relation","hoge") }};
{%- for i in %}
INSERT IGNORE INTO {{ var("relation","hoge") }}(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM {{ var("relation","hoge") }};
{%- endfor %}
SELECT * FROM {{ var("relation","hoge") }} WHERE age is NOT NULL LIMIT {{ must_var("limit") }};
```

The following SQL is executed.

```sql
CREATE DATABASE IF NOT EXISTS dev_mysqlbatch;
USE dev_mysqlbatch;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER auto_increment,
name VARCHAR(191),
age INTEGER,
PRIMARY KEY (`id`),
UNIQUE INDEX `name` (`name`)
);
INSERT IGNORE INTO users(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com'));
INSERT IGNORE INTO users(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
SELECT * FROM hoge WHERE age is NOT NULL LIMIT 5;
```

When executing with Lambda, the following JSON can be specified as the payload.

```json
{
"file": "./task_template.sql",
"vars": {
"relation": "users",
"limit": 5
}
}
```

## License

see [LICENSE](https://github.com/mashiike/mysqlbatch/blob/master/LICENSE) file.
Expand Down
38 changes: 27 additions & 11 deletions cmd/mysqlbatch/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ var (
func main() {
conf := mysqlbatch.NewDefaultConfig()
var (
vars flagx.StringSlice
versionFlag = flag.Bool("v", false, "show version info")
silentFlag = flag.Bool("s", false, "no output to console")
detailFlag = flag.Bool("d", false, "output deteil for execute sql, -s has priority")
enableBootstrapFlag = flag.Bool("enable-lambda-bootstrap", false, "if run on AWS Lambda, running as lambda bootstrap")
dumpRenderedSQLFlag = flag.Bool("dump-rendered-sql", false, "dump rendered sql")
)
flag.StringVar(&conf.DSN, "dsn", "", "dsn format as [mysql://]user:pass@tcp(host:port)/dbname (default \"\")")
flag.StringVar(&conf.User, "u", "root", "username (default root)")
Expand All @@ -43,6 +45,7 @@ func main() {
flag.StringVar(&conf.Host, "host", "", "")
flag.StringVar(&conf.Location, "location", "", "timezone of mysql database system")
flag.StringVar(&conf.PasswordSSMParameterName, "password-ssm-parameter-name", "", "pasword ssm parameter name")
flag.Var(&vars, "var", "set variable (format: key=value)")
flag.VisitAll(flagx.EnvToFlagWithPrefix("MYSQLBATCH_"))
flag.Parse()

Expand All @@ -52,6 +55,9 @@ func main() {
fmt.Printf("build date: %s\n", BuildDate)
return
}
if *dumpRenderedSQLFlag {
mysqlbatch.DefaultSQLDumper = os.Stderr
}
conf.Database = os.Getenv("MYSQLBATCH_DATABASE")
if flag.NArg() == 1 {
conf.Database = flag.Arg(0)
Expand Down Expand Up @@ -83,7 +89,16 @@ func main() {
})
}
}
if err := executer.ExecuteContext(ctx, os.Stdin); err != nil {
varsMap := make(map[string]string)
for _, v := range vars {
kv := strings.SplitN(v, "=", 2)
if len(kv) != 2 {
log.Printf("invalid var format: %s\n", v)
os.Exit(1)
}
varsMap[kv[0]] = kv[1]
}
if err := executer.ExecuteContext(ctx, os.Stdin, varsMap); err != nil {
log.Println(err)
os.Exit(1)
}
Expand All @@ -97,15 +112,16 @@ type handler struct {
}

type payload struct {
SQL string `json:"sql,omitempty"`
File string `json:"file,omitempty"`
DSN *string `json:"dsn,omitempty"`
User *string `json:"user,omitempty"`
Port *int `json:"port,omitempty"`
Host *string `json:"host,omitempty"`
Database *string `json:"database,omitempty"`
Location *string `json:"Location,omitempty"`
PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"`
SQL string `json:"sql,omitempty"`
File string `json:"file,omitempty"`
DSN *string `json:"dsn,omitempty"`
User *string `json:"user,omitempty"`
Port *int `json:"port,omitempty"`
Host *string `json:"host,omitempty"`
Database *string `json:"database,omitempty"`
Location *string `json:"Location,omitempty"`
PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"`
Vars map[string]string `json:"vars,omitempty"`
}

type response struct {
Expand Down Expand Up @@ -173,7 +189,7 @@ func (h *handler) Invoke(ctx context.Context, p *payload) (*response, error) {
Query: query,
})
})
if err := executer.ExecuteContext(ctx, query); err != nil {
if err := executer.ExecuteContext(ctx, query, p.Vars); err != nil {
return nil, err
}
r := &response{
Expand Down
95 changes: 89 additions & 6 deletions executer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@ package mysqlbatch

import (
"bufio"
"bytes"
"context"
"database/sql"
"fmt"
"io"
"os"
"strings"
"sync"
"time"

"github.com/flosch/pongo2/v6"
_ "github.com/go-sql-driver/mysql"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"
)

var DefaultSQLDumper io.Writer = io.Discard

// Executer queries the DB. There is no parallelism
type Executer struct {
mu sync.Mutex
Expand Down Expand Up @@ -65,15 +71,15 @@ func (e *Executer) Close() error {
}

// Execute SQL
func (e *Executer) Execute(queryReader io.Reader) error {
return e.ExecuteContext(context.Background(), queryReader)
func (e *Executer) Execute(queryReader io.Reader, vars map[string]string) error {
return e.ExecuteContext(context.Background(), queryReader, vars)
}

// ExecuteContext SQL execute with context.Context
func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader) error {
func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error {
e.mu.Lock()
defer e.mu.Unlock()
if err := e.executeContext(ctx, queryReader); err != nil {
if err := e.executeContext(ctx, queryReader, vars); err != nil {
return err
}
return e.updateLastExecuteTime(ctx)
Expand All @@ -87,8 +93,21 @@ func (e *Executer) updateLastExecuteTime(ctx context.Context) error {
return errors.Wrap(row.Scan(&e.lastExecuteTime), "scan db time")
}

func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader) error {
scanner := NewQueryScanner(queryReader)
func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error {
bs, err := io.ReadAll(queryReader)
if err != nil {
return err
}
tpl, err := pongo2.FromBytes(bs)
if err != nil {
return errors.Wrap(err, "parse query template failed")
}
var buf bytes.Buffer
if err := tpl.ExecuteWriter(e.newPongo2Ctx(ctx, vars), &buf); err != nil {
return errors.Wrap(err, "execute query template failed")
}
reader := io.TeeReader(&buf, DefaultSQLDumper)
scanner := NewQueryScanner(reader)
for scanner.Scan() {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -205,6 +224,70 @@ func (e *Executer) SetTableSelectHook(hook func(query, table string)) {
}
}

func (e *Executer) newPongo2Ctx(ctx context.Context, vars map[string]string) pongo2.Context {
pongo2Ctx := pongo2.Context{
"var": func(key string, defaultValue string) string {
if v, ok := vars[key]; ok {
return v
}
return defaultValue
},
"must_var": func(key string) (string, error) {
if v, ok := vars[key]; ok {
return v, nil
}
return "", errors.Errorf("variable %s is not defined", key)
},
"env": func(key string, defaultValue string) string {
if v, ok := os.LookupEnv(key); ok {
return v
}
return defaultValue
},
"must_env": func(key string) (string, error) {
if v, ok := os.LookupEnv(key); ok {
return v, nil
}
return "", errors.Errorf("environment variable %s is not defined", key)
},
"range": func(args ...int) ([]int, error) {
if len(args) == 0 {
return nil, errors.New("range requires at least 1 argument, got 0")
}
if len(args) > 3 {
return nil, fmt.Errorf("range requires at most 3 arguments, got %d", len(args))
}
var start, end, step int
switch len(args) {
case 1:
start = 0
end = args[0]
step = 1
case 2:
start = args[0]
end = args[1]
step = 1
case 3:
start = args[0]
end = args[1]
step = args[2]
}
if step == 0 {
return nil, errors.New("range requires step != 0")
}
if (step > 0 && start > end) || (step < 0 && start < end) {
return nil, errors.New("range requires start <= end when step > 0, or start >= end when step < 0")
}
var result []int
for i := start; i < end; i += step {
result = append(result, i)
}
return result, nil
},
}
return pongo2Ctx
}

// QueryScanner separate string by ; and delete newline
type QueryScanner struct {
*bufio.Scanner
Expand Down
36 changes: 35 additions & 1 deletion executer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
_ "embed"
"fmt"
"log"
"os"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -94,6 +95,9 @@ func diffStr(a, b string) (string, bool) {
//go:embed testdata/test.sql
var testSQL []byte

//go:embed testdata/test_template.sql
var testTemplateSQL []byte

func TestExecuterExecute(t *testing.T) {
conf := mysqlbatch.NewDefaultConfig()
conf.Password = "mysqlbatch"
Expand All @@ -112,7 +116,37 @@ func TestExecuterExecute(t *testing.T) {
require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query)
require.Equal(t, 5, len(rows))
})
err = e.Execute(bytes.NewReader(testSQL))
err = e.Execute(bytes.NewReader(testSQL), nil)
require.NoError(t, err)
log.Println("LastExecuteTime:", e.LastExecuteTime())
require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute))
require.EqualValues(t, 1, count)
}

func TestExecuterExecute__WithVars(t *testing.T) {
os.Setenv("ENV", "test")
mysqlbatch.DefaultSQLDumper = os.Stderr
conf := mysqlbatch.NewDefaultConfig()
conf.Password = "mysqlbatch"
conf.Location = "Asia/Tokyo"
e, err := mysqlbatch.New(context.Background(), conf)
require.NoError(t, err)
defer e.Close()

e.SetTimeCheckQuery("SELECT NOW()")
e.SetTableSelectHook(func(query, table string) {
t.Log(query + "\n" + table + "\n")
})
var count int32
e.SetSelectHook(func(query string, columns []string, rows [][]string) {
atomic.AddInt32(&count, 1)
require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query)
require.Equal(t, 5, len(rows))
})
err = e.Execute(bytes.NewReader(testTemplateSQL), map[string]string{
"relation": "users",
"limit": "5",
})
require.NoError(t, err)
log.Println("LastExecuteTime:", e.LastExecuteTime())
require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute))
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.18.25
github.com/aws/aws-sdk-go-v2/service/ssm v1.36.4
github.com/aws/smithy-go v1.13.5
github.com/flosch/pongo2/v6 v6.0.0
github.com/go-sql-driver/mysql v1.7.1
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c
github.com/olekukonko/tablewriter v0.0.5
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flosch/pongo2/v6 v6.0.0 h1:lsGru8IAzHgIAw6H2m4PCyleO58I40ow6apih0WprMU=
github.com/flosch/pongo2/v6 v6.0.0/go.mod h1:CuDpFm47R0uGGE7z13/tTlt1Y6zdxvr2RLT5LJhsHEU=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
Expand All @@ -41,8 +43,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c h1:jrKp5SY9Qt8lQmorJAksSYOIexZdkp7EREJgx4mX9XA=
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c/go.mod h1:DNbx2/OnOT5GtlYTUF2xr4GZSunGDP1Wk0WO3mmaKz0=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
Expand Down Expand Up @@ -71,8 +73,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
Expand Down
Loading

0 comments on commit 9af7f5e

Please sign in to comment.