Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow Django like template SQL #29

Merged
merged 4 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,66 @@ output
}
```

## Advanced Usage: Template SQL

The SQL to be executed is rendered by pongo2, a Django-syntax like template-engine, once.
Therefore, the SQL to be specified can use template notation.
In CLI, `--var key=value` flag, in Lambda, `vars` key can be specified as a string hash, and variables can be passed at runtime to dynamically generate SQL by template notation.
For example, if you specify `--var relation=users --var limit=5` and execute the following template SQL, (environment variable `ENV=dev` is set)

task_template.sql
```sql
CREATE DATABASE IF NOT EXISTS {{ must_env("ENV") }}_mysqlbatch;
USE {{ must_env("ENV") }}_mysqlbatch;
DROP TABLE IF EXISTS {{ var("relation","hoge") }};
CREATE TABLE {{ var("relation","hoge") }} (
id INTEGER auto_increment,
name VARCHAR(191),
age INTEGER,
PRIMARY KEY (`id`),
UNIQUE INDEX `name` (`name`)
);
INSERT IGNORE INTO {{ var("relation","hoge") }}(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com'));
INSERT IGNORE INTO {{ var("relation","hoge") }}(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM {{ var("relation","hoge") }};
{%- for i in %}
INSERT IGNORE INTO {{ var("relation","hoge") }}(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM {{ var("relation","hoge") }};
{%- endfor %}
SELECT * FROM {{ var("relation","hoge") }} WHERE age is NOT NULL LIMIT {{ must_var("limit") }};
```

The following SQL is executed.

```sql
CREATE DATABASE IF NOT EXISTS dev_mysqlbatch;
USE dev_mysqlbatch;
DROP TABLE IF EXISTS users;
CREATE TABLE users (
id INTEGER auto_increment,
name VARCHAR(191),
age INTEGER,
PRIMARY KEY (`id`),
UNIQUE INDEX `name` (`name`)
);
INSERT IGNORE INTO users(name) VALUES(CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com'));
INSERT IGNORE INTO users(name) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')) FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
INSERT IGNORE INTO users(name, age) SELECT (CONCAT(SUBSTRING(MD5(RAND()), 1, 40),'@example.com')),RAND() FROM users;
SELECT * FROM hoge WHERE age is NOT NULL LIMIT 5;
```

When executing with Lambda, the following JSON can be specified as the payload.

```json
{
"file": "./task_template.sql",
"vars": {
"relation": "users",
"limit": 5
}
}
```

## License

see [LICENSE](https://github.com/mashiike/mysqlbatch/blob/master/LICENSE) file.
Expand Down
38 changes: 27 additions & 11 deletions cmd/mysqlbatch/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ var (
func main() {
conf := mysqlbatch.NewDefaultConfig()
var (
vars flagx.StringSlice
versionFlag = flag.Bool("v", false, "show version info")
silentFlag = flag.Bool("s", false, "no output to console")
detailFlag = flag.Bool("d", false, "output deteil for execute sql, -s has priority")
enableBootstrapFlag = flag.Bool("enable-lambda-bootstrap", false, "if run on AWS Lambda, running as lambda bootstrap")
dumpRenderedSQLFlag = flag.Bool("dump-rendered-sql", false, "dump rendered sql")
)
flag.StringVar(&conf.DSN, "dsn", "", "dsn format as [mysql://]user:pass@tcp(host:port)/dbname (default \"\")")
flag.StringVar(&conf.User, "u", "root", "username (default root)")
Expand All @@ -43,6 +45,7 @@ func main() {
flag.StringVar(&conf.Host, "host", "", "")
flag.StringVar(&conf.Location, "location", "", "timezone of mysql database system")
flag.StringVar(&conf.PasswordSSMParameterName, "password-ssm-parameter-name", "", "pasword ssm parameter name")
flag.Var(&vars, "var", "set variable (format: key=value)")
flag.VisitAll(flagx.EnvToFlagWithPrefix("MYSQLBATCH_"))
flag.Parse()

Expand All @@ -52,6 +55,9 @@ func main() {
fmt.Printf("build date: %s\n", BuildDate)
return
}
if *dumpRenderedSQLFlag {
mysqlbatch.DefaultSQLDumper = os.Stderr
}
conf.Database = os.Getenv("MYSQLBATCH_DATABASE")
if flag.NArg() == 1 {
conf.Database = flag.Arg(0)
Expand Down Expand Up @@ -83,7 +89,16 @@ func main() {
})
}
}
if err := executer.ExecuteContext(ctx, os.Stdin); err != nil {
varsMap := make(map[string]string)
for _, v := range vars {
kv := strings.SplitN(v, "=", 2)
if len(kv) != 2 {
log.Printf("invalid var format: %s\n", v)
os.Exit(1)
}
varsMap[kv[0]] = kv[1]
}
if err := executer.ExecuteContext(ctx, os.Stdin, varsMap); err != nil {
log.Println(err)
os.Exit(1)
}
Expand All @@ -97,15 +112,16 @@ type handler struct {
}

type payload struct {
SQL string `json:"sql,omitempty"`
File string `json:"file,omitempty"`
DSN *string `json:"dsn,omitempty"`
User *string `json:"user,omitempty"`
Port *int `json:"port,omitempty"`
Host *string `json:"host,omitempty"`
Database *string `json:"database,omitempty"`
Location *string `json:"Location,omitempty"`
PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"`
SQL string `json:"sql,omitempty"`
File string `json:"file,omitempty"`
DSN *string `json:"dsn,omitempty"`
User *string `json:"user,omitempty"`
Port *int `json:"port,omitempty"`
Host *string `json:"host,omitempty"`
Database *string `json:"database,omitempty"`
Location *string `json:"Location,omitempty"`
PasswordSSMParameterName *string `json:"password_ssm_parameter_name,omitempty"`
Vars map[string]string `json:"vars,omitempty"`
}

type response struct {
Expand Down Expand Up @@ -173,7 +189,7 @@ func (h *handler) Invoke(ctx context.Context, p *payload) (*response, error) {
Query: query,
})
})
if err := executer.ExecuteContext(ctx, query); err != nil {
if err := executer.ExecuteContext(ctx, query, p.Vars); err != nil {
return nil, err
}
r := &response{
Expand Down
95 changes: 89 additions & 6 deletions executer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@ package mysqlbatch

import (
"bufio"
"bytes"
"context"
"database/sql"
"fmt"
"io"
"os"
"strings"
"sync"
"time"

"github.com/flosch/pongo2/v6"
_ "github.com/go-sql-driver/mysql"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"
)

var DefaultSQLDumper io.Writer = io.Discard

// Executer queries the DB. There is no parallelism
type Executer struct {
mu sync.Mutex
Expand Down Expand Up @@ -65,15 +71,15 @@ func (e *Executer) Close() error {
}

// Execute SQL
func (e *Executer) Execute(queryReader io.Reader) error {
return e.ExecuteContext(context.Background(), queryReader)
func (e *Executer) Execute(queryReader io.Reader, vars map[string]string) error {
return e.ExecuteContext(context.Background(), queryReader, vars)
}

// ExecuteContext SQL execute with context.Context
func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader) error {
func (e *Executer) ExecuteContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error {
e.mu.Lock()
defer e.mu.Unlock()
if err := e.executeContext(ctx, queryReader); err != nil {
if err := e.executeContext(ctx, queryReader, vars); err != nil {
return err
}
return e.updateLastExecuteTime(ctx)
Expand All @@ -87,8 +93,21 @@ func (e *Executer) updateLastExecuteTime(ctx context.Context) error {
return errors.Wrap(row.Scan(&e.lastExecuteTime), "scan db time")
}

func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader) error {
scanner := NewQueryScanner(queryReader)
func (e *Executer) executeContext(ctx context.Context, queryReader io.Reader, vars map[string]string) error {
bs, err := io.ReadAll(queryReader)
if err != nil {
return err
}
tpl, err := pongo2.FromBytes(bs)
if err != nil {
return errors.Wrap(err, "parse query template failed")
}
var buf bytes.Buffer
if err := tpl.ExecuteWriter(e.newPongo2Ctx(ctx, vars), &buf); err != nil {
return errors.Wrap(err, "execute query template failed")
}
reader := io.TeeReader(&buf, DefaultSQLDumper)
scanner := NewQueryScanner(reader)
for scanner.Scan() {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -205,6 +224,70 @@ func (e *Executer) SetTableSelectHook(hook func(query, table string)) {
}
}

func (e *Executer) newPongo2Ctx(ctx context.Context, vars map[string]string) pongo2.Context {
pongo2Ctx := pongo2.Context{
"var": func(key string, defaultValue string) string {
if v, ok := vars[key]; ok {
return v
}
return defaultValue
},
"must_var": func(key string) (string, error) {
if v, ok := vars[key]; ok {
return v, nil
}
return "", errors.Errorf("variable %s is not defined", key)
},
"env": func(key string, defaultValue string) string {
if v, ok := os.LookupEnv(key); ok {
return v
}
return defaultValue
},
"must_env": func(key string) (string, error) {
if v, ok := os.LookupEnv(key); ok {
return v, nil
}
return "", errors.Errorf("environment variable %s is not defined", key)
},
"range": func(args ...int) ([]int, error) {
if len(args) == 0 {
return nil, errors.New("range requires at least 1 argument, got 0")
}
if len(args) > 3 {
return nil, fmt.Errorf("range requires at most 3 arguments, got %d", len(args))
}
var start, end, step int
switch len(args) {
case 1:
start = 0
end = args[0]
step = 1
case 2:
start = args[0]
end = args[1]
step = 1
case 3:
start = args[0]
end = args[1]
step = args[2]
}
if step == 0 {
return nil, errors.New("range requires step != 0")
}
if (step > 0 && start > end) || (step < 0 && start < end) {
return nil, errors.New("range requires start <= end when step > 0, or start >= end when step < 0")
}
var result []int
for i := start; i < end; i += step {
result = append(result, i)
}
return result, nil
},
}
return pongo2Ctx
}

// QueryScanner separate string by ; and delete newline
type QueryScanner struct {
*bufio.Scanner
Expand Down
36 changes: 35 additions & 1 deletion executer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
_ "embed"
"fmt"
"log"
"os"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -94,6 +95,9 @@ func diffStr(a, b string) (string, bool) {
//go:embed testdata/test.sql
var testSQL []byte

//go:embed testdata/test_template.sql
var testTemplateSQL []byte

func TestExecuterExecute(t *testing.T) {
conf := mysqlbatch.NewDefaultConfig()
conf.Password = "mysqlbatch"
Expand All @@ -112,7 +116,37 @@ func TestExecuterExecute(t *testing.T) {
require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query)
require.Equal(t, 5, len(rows))
})
err = e.Execute(bytes.NewReader(testSQL))
err = e.Execute(bytes.NewReader(testSQL), nil)
require.NoError(t, err)
log.Println("LastExecuteTime:", e.LastExecuteTime())
require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute))
require.EqualValues(t, 1, count)
}

func TestExecuterExecute__WithVars(t *testing.T) {
os.Setenv("ENV", "test")
mysqlbatch.DefaultSQLDumper = os.Stderr
conf := mysqlbatch.NewDefaultConfig()
conf.Password = "mysqlbatch"
conf.Location = "Asia/Tokyo"
e, err := mysqlbatch.New(context.Background(), conf)
require.NoError(t, err)
defer e.Close()

e.SetTimeCheckQuery("SELECT NOW()")
e.SetTableSelectHook(func(query, table string) {
t.Log(query + "\n" + table + "\n")
})
var count int32
e.SetSelectHook(func(query string, columns []string, rows [][]string) {
atomic.AddInt32(&count, 1)
require.Equal(t, `SELECT * FROM users WHERE age is NOT NULL LIMIT 5`, query)
require.Equal(t, 5, len(rows))
})
err = e.Execute(bytes.NewReader(testTemplateSQL), map[string]string{
"relation": "users",
"limit": "5",
})
require.NoError(t, err)
log.Println("LastExecuteTime:", e.LastExecuteTime())
require.InDelta(t, time.Since(e.LastExecuteTime()), 0, float64(5*time.Minute))
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.18.25
github.com/aws/aws-sdk-go-v2/service/ssm v1.36.0
github.com/aws/smithy-go v1.13.5
github.com/flosch/pongo2/v6 v6.0.0
github.com/go-sql-driver/mysql v1.7.1
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c
github.com/olekukonko/tablewriter v0.0.5
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flosch/pongo2/v6 v6.0.0 h1:lsGru8IAzHgIAw6H2m4PCyleO58I40ow6apih0WprMU=
github.com/flosch/pongo2/v6 v6.0.0/go.mod h1:CuDpFm47R0uGGE7z13/tTlt1Y6zdxvr2RLT5LJhsHEU=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
Expand All @@ -44,8 +46,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c h1:jrKp5SY9Qt8lQmorJAksSYOIexZdkp7EREJgx4mX9XA=
github.com/ken39arg/go-flagx v0.0.0-20220608183922-7cf7c6c0093c/go.mod h1:DNbx2/OnOT5GtlYTUF2xr4GZSunGDP1Wk0WO3mmaKz0=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
Expand Down Expand Up @@ -74,8 +76,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
Expand Down
Loading
Loading