diff --git a/.circleci/config.yml b/.circleci/config.yml index 108ecfd8..1571843a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,28 +1,37 @@ # Golang CircleCI 2.0 configuration file # # Check https://circleci.com/docs/2.0/language-go/ for more details -version: 2 +version: 2.1 +orbs: + codecov: codecov/codecov@3.1.1 jobs: - build-postgres-and-mysql: + build_and_tests: docker: # specify the version - - image: circleci/golang:1.13 - - - image: circleci/postgres:10.8-alpine - environment: # environment variables for primary container + - image: circleci/golang:1.16 + - image: circleci/postgres:12 + environment: POSTGRES_USER: jet POSTGRES_PASSWORD: jet POSTGRES_DB: jetdb + PGPORT: 50901 - - image: circleci/mysql:8.0.16 - command: [--default-authentication-plugin=mysql_native_password] + - image: circleci/mysql:8.0.27 + command: [ --default-authentication-plugin=mysql_native_password ] environment: MYSQL_ROOT_PASSWORD: jet MYSQL_DATABASE: dvds MYSQL_USER: jet MYSQL_PASSWORD: jet + MYSQL_TCP_PORT: 50902 - working_directory: /go/src/github.com/go-jet/jet + - image: circleci/mariadb:10.3 + command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ] + environment: + MYSQL_ROOT_PASSWORD: jet + MYSQL_DATABASE: dvds + MYSQL_USER: jet + MYSQL_PASSWORD: jet environment: # environment variables for the build itself TEST_RESULTS: /tmp/test-results # path to where test results will be saved @@ -32,25 +41,22 @@ jobs: - run: name: Submodule init - command: | - git submodule init - git submodule update - cd ./tests/testdata && git fetch && git checkout master + command: cd tests && make checkout-testdata + - restore_cache: # restores saved cache if no changes are detected since last run + keys: + - go-mod-v4-{{ checksum "go.sum" }} - run: - name: Install dependencies - command: | - cd /go/src/github.com/go-jet/jet - go get github.com/jstemmer/go-junit-report - go build -o /home/circleci/.local/bin/jet ./cmd/jet/ + name: Install jet generator + command: cd tests && make install-jet-gen - run: name: Waiting for Postgres to be ready command: | for i in `seq 1 10`; do - nc -z localhost 5432 && echo Success && exit 0 + nc -z localhost 50901 && echo Success && exit 0 echo -n . sleep 1 done @@ -61,39 +67,71 @@ jobs: command: | for i in `seq 1 10`; do - nc -z 127.0.0.1 3306 && echo Success && exit 0 + nc -z 127.0.0.1 50902 && echo Success && exit 0 echo -n . sleep 1 done echo Failed waiting for MySQL && exit 1 + + - run: + name: Waiting for MariaDB to be ready + command: | + for i in `seq 1 10`; + do + nc -z 127.0.0.1 50903 && echo Success && exit 0 + echo -n . + sleep 1 + done + echo Failed waiting for MySQL && exit 1 + - run: name: Install MySQL CLI; command: | sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client - run: - name: Create MySQL user and databases + name: Create MySQL/MariaDB user and test databases command: | - mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" - mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" - mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample" - mysql -h 127.0.0.1 -u jet -pjet -e "create database dvds2" + mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" + mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" + mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample" + mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2" + + mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" + mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" + mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample" + mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2" - run: - name: Init Postgres database - command: | - cd tests - go run ./init/init.go -testsuite all - cd .. + name: Init databases + command: | + cd tests + go run ./init/init.go -testsuite all + # to create test results report + - run: + name: Install go-junit-report + command: go install github.com/jstemmer/go-junit-report@latest - run: mkdir -p $TEST_RESULTS + # this will run all tests and exclude test files from code coverage report - - run: MY_SQL_SOURCE=MySQL go test -v ./... -covermode=atomic -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/sqlite/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: | + go test -v ./... \ + -covermode=atomic \ + -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ + -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - - run: - name: Upload code coverage - command: bash <(curl -s https://codecov.io/bash) + # run mariaDB tests. No need to collect coverage, because coverage is already included with mysql tests + - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ + + - save_cache: + key: go-mod-v4-{{ checksum "go.sum" }} + paths: + - "/go/pkg/mod" + + - codecov/upload: + file: cover.out - store_artifacts: # Upload test summary for display in Artifacts: https://circleci.com/docs/2.0/artifacts/ path: /tmp/test-results @@ -101,69 +139,9 @@ jobs: - store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/ path: /tmp/test-results - build-mariadb: - docker: - # specify the version - - image: circleci/golang:1.13 - - - image: circleci/mariadb:10.3 - command: [--default-authentication-plugin=mysql_native_password] - environment: - MYSQL_ROOT_PASSWORD: jet - MYSQL_DATABASE: dvds - MYSQL_USER: jet - MYSQL_PASSWORD: jet - - working_directory: /go/src/github.com/go-jet/jet - - environment: # environment variables for the build itself - TEST_RESULTS: /tmp/test-results # path to where test results will be saved - - steps: - - checkout - - - run: - name: Submodule init - command: | - git submodule init - git submodule update - cd ./tests/testdata && git fetch && git checkout master - - - run: - name: Install dependencies - command: | - cd /go/src/github.com/go-jet/jet - go get github.com/jstemmer/go-junit-report - go build -o /home/circleci/.local/bin/jet ./cmd/jet/ - - - run: - name: Install MySQL CLI; - command: | - sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client - - - run: - name: Init MariaDB database - command: | - mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" - mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" - mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample" - mysql -h 127.0.0.1 -u jet -pjet -e "create database dvds2" - - - run: - name: Init MariaDB database - command: | - cd tests - go run ./init/init.go -testsuite MariaDB - cd .. - - - run: - name: Run MariaDB tests - command: | - MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ workflows: version: 2 build_and_test: jobs: - - build-postgres-and-mysql - - build-mariadb + - build_and_tests \ No newline at end of file diff --git a/.gitignore b/.gitignore index 153be121..4d7da458 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,6 @@ gen .gentestdata .tests/testdata/ -.gen \ No newline at end of file +.gen +.docker +.env \ No newline at end of file diff --git a/README.md b/README.md index 89ac87fe..b53dfe1a 100644 --- a/README.md +++ b/README.md @@ -60,28 +60,26 @@ Use the command bellow to add jet as a dependency into `go.mod` project: $ go get -u github.com/go-jet/jet/v2 ``` -Jet generator can be installed in the following ways: +Jet generator can be installed in one of the following ways: -1) Install jet generator to GOPATH/bin folder: +1) (Go1.16+) Install jet generator using go install: ```sh - cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet + go install github.com/go-jet/jet/v2/cmd/jet@latest ``` - *Make sure GOPATH/bin folder is added to the PATH environment variable.* -2) Install jet generator into specific folder: - +2) Install jet generator to GOPATH/bin folder: + ```sh + cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet + ``` + +3) Install jet generator into specific folder: ```sh git clone https://github.com/go-jet/jet.git cd jet && go build -o dir_path ./cmd/jet ``` - *Make sure `dir_path` folder is added to the PATH environment variable.* +*Make sure that the destination folder is added to the PATH environment variable.* + -3) (Go1.16+) Install jet generator using go install: - ```sh - go install github.com/go-jet/jet/v2/cmd/jet@latest - ``` - *Jet generator is installed to the directory named by the GOBIN environment variable, - which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.* ### Quick Start For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in diff --git a/cmd/jet/main.go b/cmd/jet/main.go index fbbccaae..7e8e2999 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -3,7 +3,14 @@ package main import ( "flag" "fmt" + "github.com/go-jet/jet/v2/generator/metadata" sqlitegen "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/mysql" + postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/sqlite" "os" "strings" @@ -27,34 +34,17 @@ var ( dbName string schemaName string + ignoreTables string + ignoreViews string + ignoreEnums string + destDir string ) func init() { - flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)") - - flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)") - flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") - flag.IntVar(&port, "port", 0, "Database port") - flag.StringVar(&user, "user", "", "Database user") - flag.StringVar(&password, "password", "", "The user’s password") - flag.StringVar(¶ms, "params", "", "Additional connection string parameters(optional)") - flag.StringVar(&dbName, "dbname", "", "Database name") - flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL and MariaDB)`) - flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL and MariaDB)`) - - flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") -} - -func main() { + flag.StringVar(&source, "source", "", "Database system name (postgres, mysql, mariadb or sqlite)") - flag.Usage = func() { - _, _ = fmt.Fprint(os.Stdout, ` -Jet generator 2.6.0 - -Usage: - -dsn string - Data source name. Unified format for connecting to database. + flag.StringVar(&dsn, "dsn", "", `Data source name. Unified format for connecting to database. PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING Example: postgresql://user:pass@localhost:5432/dbname @@ -63,65 +53,70 @@ Usage: mysql://jet:jet@tcp(localhost:3306)/dvds SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples Example: - file://path/to/database/file - -source string - Database system name (PostgreSQL, MySQL, MariaDB or SQLite) - -host string - Database host path (Example: localhost) - -port int - Database port - -user string - Database user - -password string - The user’s password - -dbname string - Database name - -params string - Additional connection string parameters(optional) - -schema string - Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite) - -sslmode string - Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite) - -path string - Destination dir for files generated. - -Example commands: - - $ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds -path=./gen + file://path/to/database/file`) + flag.StringVar(&host, "host", "", "Database host path. Used only if dsn is not set. (Example: localhost)") + flag.IntVar(&port, "port", 0, "Database port. Used only if dsn is not set.") + flag.StringVar(&user, "user", "", "Database user. Used only if dsn is not set.") + flag.StringVar(&password, "password", "", "The user’s password. Used only if dsn is not set.") + flag.StringVar(&dbName, "dbname", "", "Database name. Used only if dsn is not set.") + flag.StringVar(&schemaName, "schema", "public", `Database schema name. Used only if dsn is not set. (default "public")(PostgreSQL only)`) + flag.StringVar(¶ms, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.") + flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`) + flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`) + flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore`) + flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore`) + + flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") +} + +func main() { + + flag.Usage = func() { + fmt.Println("Jet generator 2.7.0") + fmt.Println() + fmt.Println("Usage:") + + order := []string{ + "source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode", + "path", + "ignore-tables", "ignore-views", "ignore-enums", + } + for _, name := range order { + flagEntry := flag.CommandLine.Lookup(name) + fmt.Printf(" -%s\n", flagEntry.Name) + fmt.Printf("\t%s\n", flagEntry.Usage) + } + + fmt.Println() + fmt.Println(`Example command: + $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds -path=./gen $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen - $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=dvds -path=./gen -`) + $ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen + $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen + `) } flag.Parse() - if dsn == "" { - // validations for separated connection flags. - if source == "" || host == "" || port == 0 || user == "" || dbName == "" { - printErrorAndExit("ERROR: required flag(s) missing") - } - } else { - if source == "" { - // try to get source from schema - source = detectSchema(dsn) - } - - // validations when dsn != "" - if source == "" { - printErrorAndExit("ERROR: required -source flag missing.") - } + if dsn == "" && (source == "" || host == "" || port == 0 || user == "" || dbName == "") { + printErrorAndExit("ERROR: required flag(s) missing") } + source := getSource() + ignoreTablesList := parseList(ignoreTables) + ignoreViewsList := parseList(ignoreViews) + ignoreEnumsList := parseList(ignoreEnums) + var err error - switch strings.ToLower(strings.TrimSpace(source)) { + switch source { case "postgresql", "postgres": if dsn != "" { err = postgresgen.GenerateDSN(dsn, schemaName, destDir) break } - genData := postgresgen.DBConnection{ + dbConn := postgresgen.DBConnection{ Host: host, Port: port, User: user, @@ -133,7 +128,11 @@ Example commands: SchemaName: schemaName, } - err = postgresgen.Generate(destDir, genData) + err = postgresgen.Generate( + destDir, + dbConn, + genTemplate(postgres2.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList), + ) case "mysql", "mysqlx", "mariadb": if dsn != "" { @@ -149,12 +148,24 @@ Example commands: DBName: dbName, } - err = mysqlgen.Generate(destDir, dbConn) + err = mysqlgen.Generate( + destDir, + dbConn, + genTemplate(mysql.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList), + ) case "sqlite": if dsn == "" { printErrorAndExit("ERROR: required -dsn flag missing.") } - err = sqlitegen.GenerateDSN(dsn, destDir) + err = sqlitegen.GenerateDSN( + dsn, + destDir, + genTemplate(sqlite.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList), + ) + + case "": + printErrorAndExit("ERROR: required -source or -dns flag missing.") + default: printErrorAndExit("ERROR: unknown data source " + source + ". Only postgres, mysql, mariadb and sqlite are supported.") } @@ -167,10 +178,19 @@ Example commands: func printErrorAndExit(error string) { fmt.Println("\n", error) + fmt.Println() flag.Usage() os.Exit(-2) } +func getSource() string { + if source != "" { + return strings.TrimSpace(strings.ToLower(source)) + } + + return detectSchema(dsn) +} + func detectSchema(dsn string) string { match := strings.SplitN(dsn, "://", 2) if len(match) < 2 { // not found @@ -183,5 +203,75 @@ func detectSchema(dsn string) string { return "sqlite" } - return match[0] + return strings.ToLower(match[0]) +} + +func parseList(list string) []string { + ret := strings.Split(list, ",") + + for i := 0; i < len(ret); i++ { + ret[i] = strings.ToLower(strings.TrimSpace(ret[i])) + } + + return ret +} + +func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []string, ignoreEnums []string) template.Template { + + shouldSkipTable := func(table metadata.Table) bool { + return utils.StringSliceContains(ignoreTables, strings.ToLower(table.Name)) + } + + shouldSkipView := func(view metadata.Table) bool { + return utils.StringSliceContains(ignoreViews, strings.ToLower(view.Name)) + } + + shouldSkipEnum := func(enum metadata.Enum) bool { + return utils.StringSliceContains(ignoreEnums, strings.ToLower(enum.Name)) + } + + return template.Default(dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + if shouldSkipTable(table) { + return template.TableModel{Skip: true} + } + return template.DefaultTableModel(table) + }). + UseView(func(view metadata.Table) template.ViewModel { + if shouldSkipView(view) { + return template.ViewModel{Skip: true} + } + return template.DefaultViewModel(view) + }). + UseEnum(func(enum metadata.Enum) template.EnumModel { + if shouldSkipEnum(enum) { + return template.EnumModel{Skip: true} + } + return template.DefaultEnumModel(enum) + }), + ). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + if shouldSkipTable(table) { + return template.TableSQLBuilder{Skip: true} + } + return template.DefaultTableSQLBuilder(table) + }). + UseView(func(table metadata.Table) template.ViewSQLBuilder { + if shouldSkipView(table) { + return template.ViewSQLBuilder{Skip: true} + } + return template.DefaultViewSQLBuilder(table) + }). + UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { + if shouldSkipEnum(enum) { + return template.EnumSQLBuilder{Skip: true} + } + return template.DefaultEnumSQLBuilder(enum) + }), + ) + }) } diff --git a/doc.go b/doc.go index f27c7d8e..44a3e893 100644 --- a/doc.go +++ b/doc.go @@ -1,77 +1,156 @@ /* -Package jet is a framework for writing type-safe SQL queries in Go, with ability to easily convert database query -result into desired arbitrary object structure. +Package jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder +with code generation and automatic query result data mapping. +Jet currently supports PostgreSQL, MySQL, MariaDB and SQLite. Future releases will add support for additional databases. Installation -Use the bellow command to add jet as a dependency into go.mod project: - $ go get github.com/go-jet/jet/v2 +Use the command bellow to add jet as a dependency into go.mod project: + $ go get -u github.com/go-jet/jet/v2 -Use the bellow command to add jet as a dependency into GOPATH project: - $ go get -u github.com/go-jet/jet +Jet generator can be installed in one of the following ways: -Install jet generator to GOPATH bin folder. This will allow generating jet files from the command line. - cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet + 1) (Go1.16+) Install jet generator using go install: + go install github.com/go-jet/jet/v2/cmd/jet@latest + + 2) Install jet generator to GOPATH/bin folder: + cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet + + 3) Install jet generator into specific folder: + git clone https://github.com/go-jet/jet.git + cd jet && go build -o dir_path ./cmd/jet + +Make sure that the destination folder is added to the PATH environment variable. -Make sure GOPATH bin folder is added to the PATH environment variable. Usage + Jet requires already defined database schema(with tables, enums etc), so that jet generator can generate SQL Builder and Model files. File generation is very fast, and can be added as every pre-build step. Sample command: - jet -source=PostgreSQL -host=localhost -port=5432 -user=jet -password=pass -dbname=jetdb -schema=dvds -path=./gen + jet -dsn=postgresql://user:pass@localhost:5432/jetdb -schema=dvds -path=./.gen -Then next step is to import generated SQL Builder and Model files and write SQL queries in Go: +Before we can write SQL queries in Go, we need to import generated SQL builder and model types: import . "some_path/.gen/jetdb/dvds/table" import "some_path/.gen/jetdb/dvds/model" -To write SQL queries for PostgreSQL import: - . "github.com/go-jet/jet/v2/postgres" +To write postgres SQL queries we import: + . "github.com/go-jet/jet/v2/postgres" // Dot import is used so that Go code resemble as much as native SQL. It is not mandatory. -To write SQL queries for MySQL and MariaDB import: - . "github.com/go-jet/jet/v2/mysql" -*Dot import is used so that Go code resemble as much as native SQL. Dot import is not mandatory. - -Write SQL: +Then we can write the SQL query: // sub-query - rRatingFilms := SELECT( - Film.FilmID, - Film.Title, - Film.Rating, - ). - FROM(Film). - WHERE(Film.Rating.EQ(enum.FilmRating.R)). - AsTable("rFilms") + rRatingFilms := + SELECT( + Film.FilmID, + Film.Title, + Film.Rating, + ).FROM( + Film, + ).WHERE( + Film.Rating.EQ(enum.FilmRating.R), + ).AsTable("rFilms") // export column from sub-query rFilmID := Film.FilmID.From(rRatingFilms) // main-query - query := SELECT( + stmt := + SELECT( Actor.AllColumns, FilmActor.AllColumns, rRatingFilms.AllColumns(), - ). - FROM( + ).FROM( rRatingFilms. - INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)). - INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID) - ). - ORDER_BY(rFilmID, Actor.ActorID) - -Store result into desired destination: + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID) + ).ORDER_BY( + rFilmID, + Actor.ActorID, + ) + +Now we can run the statement and store the result into desired destination: var dest []struct { model.Film Actors []model.Actor } - err := query.Query(db, &dest) - -Detail info about all features and use cases can be + err := stmt.Query(db, &dest) + +We can print a statement to see SQL query and arguments sent to postgres server: + fmt.Println(stmt.Sql()) + +Output: + SELECT "rFilms"."film.film_id" AS "film.film_id", + "rFilms"."film.title" AS "film.title", + "rFilms"."film.rating" AS "film.rating", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update" + FROM ( + SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.rating AS "film.rating" + FROM dvds.film + WHERE film.rating = 'R' + ) AS "rFilms" + INNER JOIN dvds.film_actor ON (film_actor.film_id = "rFilms"."film.film_id") + INNER JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id) + WHERE "rFilms"."film.film_id" < $1 + ORDER BY "rFilms"."film.film_id" ASC, actor.actor_id ASC; + [50] + +If we print destination as json, we'll get: + + [ + { + "FilmID": 8, + "Title": "Airport Pollock", + "Rating": "R", + "Actors": [ + { + "ActorID": 55, + "FirstName": "Fay", + "LastName": "Kilmer", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 96, + "FirstName": "Gene", + "LastName": "Willis", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + ... + ] + }, + { + "FilmID": 17, + "Title": "Alone Trip", + "Actors": [ + { + "ActorID": 3, + "FirstName": "Ed", + "LastName": "Chase", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 12, + "FirstName": "Karl", + "LastName": "Berry", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + ... + ... + ] + +Detail info about all statements, features and use cases can be found at project wiki page - https://github.com/go-jet/jet/wiki. */ package jet diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index a409eb76..5847be48 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -20,7 +20,7 @@ WHERE table_schema = ? and table_type = ?; ` var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) throw.OnError(err) for i := range tables { @@ -32,15 +32,14 @@ WHERE table_schema = ? and table_type = ?; func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { query := ` -WITH primaryKeys AS ( - SELECT k.column_name - FROM information_schema.table_constraints t - JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) - WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' -) SELECT COLUMN_NAME AS "column.Name", IS_NULLABLE = "YES" AS "column.IsNullable", - (EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey", + (EXISTS( + SELECT 1 + FROM information_schema.table_constraints t + JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) + WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' AND k.column_name = columns.column_name + )) AS "column.IsPrimaryKey", IF (COLUMN_TYPE = 'tinyint(1)', 'boolean', IF (DATA_TYPE='enum', @@ -54,7 +53,7 @@ WHERE table_schema = ? AND table_name = ? ORDER BY ordinal_position; ` var columns []metadata.Column - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) throw.OnError(err) return columns @@ -73,7 +72,7 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; Values string } - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) throw.OnError(err) var ret []metadata.Enum diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index e2fb9698..93e6ffb5 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -19,7 +19,7 @@ WHERE table_schema = $1 and table_type = $2; ` var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) throw.OnError(err) for i := range tables { @@ -58,7 +58,7 @@ where table_schema = $1 and table_name = $2 order by ordinal_position; ` var columns []metadata.Column - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) throw.OnError(err) return columns @@ -76,7 +76,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;` var result []metadata.Enum - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) throw.OnError(err) return result diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index e1d5e4d1..c11f2103 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -28,7 +28,7 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) throw.OnError(err) for i := range tables { @@ -47,7 +47,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t Pk int32 } - err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) throw.OnError(err) var columns []metadata.Column diff --git a/internal/jet/alias.go b/internal/jet/alias.go index 57f55cde..8693b13c 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -13,7 +13,12 @@ func newAlias(expression Expression, aliasName string) Projection { } func (a *alias) fromImpl(subQuery SelectTable) Projection { - column := NewColumnImpl(a.alias, "", nil) + // if alias is in the form "table.column", we break it into two parts so that ProjectionList.As(newAlias) can + // overwrite tableName with a new alias. This method is called only for exporting aliased custom columns. + // Generated columns have default aliasing. + tableName, columnName := extractTableAndColumnName(a.alias) + + column := NewColumnImpl(columnName, tableName, nil) column.subQuery = subQuery return &column diff --git a/internal/jet/bool_expression_test.go b/internal/jet/bool_expression_test.go index 765cdab5..c6cdbbb9 100644 --- a/internal/jet/bool_expression_test.go +++ b/internal/jet/bool_expression_test.go @@ -6,7 +6,6 @@ import ( func TestBoolExpressionEQ(t *testing.T) { assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") - assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator") } func TestBoolExpressionNOT_EQ(t *testing.T) { diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 446a5451..ce99a6b5 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -18,8 +18,9 @@ type ClauseWithProjections interface { // ClauseSelect struct type ClauseSelect struct { - Distinct bool - ProjectionList []Projection + Distinct bool + DistinctOnColumns []ColumnExpression + ProjectionList []Projection } // Projections returns list of projections for select clause @@ -36,6 +37,12 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o out.WriteString("DISTINCT") } + if len(s.DistinctOnColumns) > 0 { + out.WriteString("ON (") + SerializeColumnExpressions(s.DistinctOnColumns, statementType, out) + out.WriteByte(')') + } + if len(s.ProjectionList) == 0 { panic("jet: SELECT clause has to have at least one projection") } @@ -45,6 +52,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseFrom struct type ClauseFrom struct { + Name string Tables []Serializer } @@ -54,7 +62,11 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, opt return } out.NewLine() - out.WriteString("FROM") + if f.Name != "" { + out.WriteString(f.Name) + } else { + out.WriteString("FROM") + } out.IncreaseIdent() for i, table := range f.Tables { diff --git a/internal/jet/date_expression.go b/internal/jet/date_expression.go index 27c20351..560688af 100644 --- a/internal/jet/date_expression.go +++ b/internal/jet/date_expression.go @@ -13,6 +13,8 @@ type DateExpression interface { LT_EQ(rhs DateExpression) BoolExpression GT(rhs DateExpression) BoolExpression GT_EQ(rhs DateExpression) BoolExpression + BETWEEN(min, max DateExpression) BoolExpression + NOT_BETWEEN(min, max DateExpression) BoolExpression ADD(rhs Interval) TimestampExpression SUB(rhs Interval) TimestampExpression @@ -54,6 +56,14 @@ func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { return GtEq(d.parent, rhs) } +func (d *dateInterfaceImpl) BETWEEN(min, max DateExpression) BoolExpression { + return NewBetweenOperatorExpression(d.parent, min, max, false) +} + +func (d *dateInterfaceImpl) NOT_BETWEEN(min, max DateExpression) BoolExpression { + return NewBetweenOperatorExpression(d.parent, min, max, true) +} + func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression { return TimestampExp(Add(d.parent, rhs)) } diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 14e6a9b5..e657f30a 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -1,5 +1,7 @@ package jet +import "fmt" + // Expression is common interface for all expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. type Expression interface { @@ -33,7 +35,8 @@ type ExpressionInterfaceImpl struct { } func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { - return e.Parent + panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s", + subQuery.Alias(), serializeToDefaultDebugString(e.Parent))) } // IS_NULL tests expression whether it is a NULL value. @@ -93,7 +96,7 @@ type binaryOperatorExpression struct { } // NewBinaryOperatorExpression creates new binaryOperatorExpression -func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { +func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression { binaryExpression := &binaryOperatorExpression{ lhs: lhs, rhs: rhs, @@ -106,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression - return binaryExpression + return complexExpr(binaryExpression) } func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if c.lhs == nil { - panic("jet: lhs is nil for '" + c.operator + "' operator") - } - if c.rhs == nil { - panic("jet: rhs is nil for '" + c.operator + "' operator") - } - - wrap := !contains(options, NoWrap) - - if wrap { - out.WriteString("(") - } - if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) serializeOverrideFunc(statement, out, FallTrough(options)...) @@ -131,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu out.WriteString(c.operator) c.rhs.serialize(statement, out, FallTrough(options)...) } - - if wrap { - out.WriteString(")") - } } // A prefix operator Expression @@ -145,27 +131,19 @@ type prefixExpression struct { operator string } -func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression { +func newPrefixOperatorExpression(expression Expression, operator string) Expression { prefixExpression := &prefixExpression{ expression: expression, operator: operator, } prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression - return prefixExpression + return complexExpr(prefixExpression) } func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") out.WriteString(p.operator) - - if p.expression == nil { - panic("jet: nil prefix expression in prefix operator " + p.operator) - } - p.expression.serialize(statement, out, FallTrough(options)...) - - out.WriteString(")") } // A postfix operator Expression @@ -188,11 +166,77 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf } func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if p.expression == nil { - panic("jet: nil prefix expression in postfix operator " + p.operator) + p.expression.serialize(statement, out, FallTrough(options)...) + out.WriteString(p.operator) +} + +type betweenOperatorExpression struct { + ExpressionInterfaceImpl + + expression Expression + notBetween bool + min Expression + max Expression +} + +// NewBetweenOperatorExpression creates new BETWEEN operator expression +func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression { + newBetweenOperator := &betweenOperatorExpression{ + expression: expression, + notBetween: notBetween, + min: min, + max: max, } + newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator + + return BoolExp(complexExpr(newBetweenOperator)) +} + +func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { p.expression.serialize(statement, out, FallTrough(options)...) + if p.notBetween { + out.WriteString("NOT") + } + out.WriteString("BETWEEN") + p.min.serialize(statement, out, FallTrough(options)...) + out.WriteString("AND") + p.max.serialize(statement, out, FallTrough(options)...) +} - out.WriteString(p.operator) +type complexExpression struct { + ExpressionInterfaceImpl + expressions Expression +} + +func complexExpr(expressions Expression) Expression { + complexExpression := &complexExpression{expressions: expressions} + complexExpression.ExpressionInterfaceImpl.Parent = complexExpression + + return complexExpression +} + +func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if !contains(options, NoWrap) { + out.WriteString("(") + } + + s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper + + if !contains(options, NoWrap) { + out.WriteString(")") + } +} + +type skipParenthesisWrap struct { + Expression +} + +func skipWrap(expression Expression) Expression { + return &skipParenthesisWrap{expression} +} + +// since the expression is a function parameter, there is no need to wrap it in parentheses +func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + s.Expression.serialize(statement, out, append(options, NoWrap)...) } diff --git a/internal/jet/expression_test.go b/internal/jet/expression_test.go index 7b7bed6d..74e6a059 100644 --- a/internal/jet/expression_test.go +++ b/internal/jet/expression_test.go @@ -4,10 +4,6 @@ import ( "testing" ) -func TestInvalidExpression(t *testing.T) { - assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`) -} - func TestExpressionIS_NULL(t *testing.T) { assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index 20c8d3f3..3fb30fed 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -14,6 +14,8 @@ type FloatExpression interface { LT_EQ(rhs FloatExpression) BoolExpression GT(rhs FloatExpression) BoolExpression GT_EQ(rhs FloatExpression) BoolExpression + BETWEEN(min, max FloatExpression) BoolExpression + NOT_BETWEEN(min, max FloatExpression) BoolExpression ADD(rhs NumericExpression) FloatExpression SUB(rhs NumericExpression) FloatExpression @@ -60,6 +62,14 @@ func (n *floatInterfaceImpl) LT_EQ(rhs FloatExpression) BoolExpression { return LtEq(n.parent, rhs) } +func (n *floatInterfaceImpl) BETWEEN(min, max FloatExpression) BoolExpression { + return NewBetweenOperatorExpression(n.parent, min, max, false) +} + +func (n *floatInterfaceImpl) NOT_BETWEEN(min, max FloatExpression) BoolExpression { + return NewBetweenOperatorExpression(n.parent, min, max, true) +} + func (n *floatInterfaceImpl) ADD(rhs NumericExpression) FloatExpression { return FloatExp(Add(n.parent, rhs)) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 9a647e96..3e40edfe 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression { // ----------------- Aggregate functions -------------------// // AVG is aggregate function used to calculate avg value from numeric expression -func AVG(numericExpression NumericExpression) floatWindowExpression { +func AVG(numericExpression Expression) floatWindowExpression { return NewFloatWindowFunc("AVG", numericExpression) } @@ -594,7 +594,7 @@ type funcExpressionImpl struct { func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ name: name, - expressions: expressions, + expressions: parameters(expressions), } if parent != nil { @@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } +func parameters(expressions []Expression) []Expression { + var ret []Expression + + for _, expression := range expressions { + if _, isStatement := expression.(Statement); isStatement { + ret = append(ret, expression) + } else { + ret = append(ret, skipWrap(expression)) + } + } + + return ret +} + // NewFloatWindowFunc creates new float function with name and expressions func newWindowFunc(name string, expressions ...Expression) windowExpression { - newFun := NewFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) newFun.ExpressionInterfaceImpl.Parent = windowExpr @@ -698,12 +711,12 @@ type integerFunc struct { } func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { - floatFunc := &integerFunc{} + intFunc := &integerFunc{} - floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) - floatFunc.integerInterfaceImpl.parent = floatFunc + intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc) + intFunc.integerInterfaceImpl.parent = intFunc - return floatFunc + return intFunc } // NewFloatWindowFunc creates new float function with name and expressions @@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { return timestampzFunc } -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. func Func(name string, expressions ...Expression) Expression { return NewFunc(name, expressions, nil) } diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index ff2a0a0b..32d15e04 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -5,46 +5,29 @@ type IntegerExpression interface { Expression numericExpression - // Check if expression is equal to rhs EQ(rhs IntegerExpression) BoolExpression - // Check if expression is not equal to rhs NOT_EQ(rhs IntegerExpression) BoolExpression - // Check if expression is distinct from rhs IS_DISTINCT_FROM(rhs IntegerExpression) BoolExpression - // Check if expression is not distinct from rhs IS_NOT_DISTINCT_FROM(rhs IntegerExpression) BoolExpression - // Check if expression is less then rhs LT(rhs IntegerExpression) BoolExpression - // Check if expression is less then equal rhs LT_EQ(rhs IntegerExpression) BoolExpression - // Check if expression is greater then rhs GT(rhs IntegerExpression) BoolExpression - // Check if expression is greater then equal rhs GT_EQ(rhs IntegerExpression) BoolExpression + BETWEEN(min, max IntegerExpression) BoolExpression + NOT_BETWEEN(min, max IntegerExpression) BoolExpression - // expression + rhs ADD(rhs IntegerExpression) IntegerExpression - // expression - rhs SUB(rhs IntegerExpression) IntegerExpression - // expression * rhs MUL(rhs IntegerExpression) IntegerExpression - // expression / rhs DIV(rhs IntegerExpression) IntegerExpression - // expression % rhs MOD(rhs IntegerExpression) IntegerExpression - // expression ^ rhs POW(rhs IntegerExpression) IntegerExpression - // expression & rhs BIT_AND(rhs IntegerExpression) IntegerExpression - // expression | rhs BIT_OR(rhs IntegerExpression) IntegerExpression - // expression # rhs BIT_XOR(rhs IntegerExpression) IntegerExpression - // expression << rhs BIT_SHIFT_LEFT(shift IntegerExpression) IntegerExpression - // expression >> rhs BIT_SHIFT_RIGHT(shift IntegerExpression) IntegerExpression } @@ -85,6 +68,14 @@ func (i *integerInterfaceImpl) LT_EQ(rhs IntegerExpression) BoolExpression { return LtEq(i.parent, rhs) } +func (i *integerInterfaceImpl) BETWEEN(min, max IntegerExpression) BoolExpression { + return NewBetweenOperatorExpression(i.parent, min, max, false) +} + +func (i *integerInterfaceImpl) NOT_BETWEEN(min, max IntegerExpression) BoolExpression { + return NewBetweenOperatorExpression(i.parent, min, max, true) +} + func (i *integerInterfaceImpl) ADD(rhs IntegerExpression) IntegerExpression { return IntExp(Add(i.parent, rhs)) } diff --git a/internal/jet/integer_expression_test.go b/internal/jet/integer_expression_test.go index 79205454..a20981bc 100644 --- a/internal/jet/integer_expression_test.go +++ b/internal/jet/integer_expression_test.go @@ -99,3 +99,9 @@ func TestIntExpressionIntExp(t *testing.T) { assertClauseSerialize(t, IntExp(table1ColFloat.ADD(table2ColFloat)).ADD(Int(11)), "((table1.col_float + table2.col_float) + $1)", int64(11)) } + +func TestIntExpressionBetween(t *testing.T) { + assertClauseSerialize(t, table1ColInt.BETWEEN(Int(1), table1Col3), "(table1.col_int BETWEEN $1 AND table1.col3)", int64(1)) + assertClauseSerialize(t, table1ColInt.BETWEEN(Int(1), table1Col3).AND(table1ColBool), + "((table1.col_int BETWEEN $1 AND table1.col3) AND table1.col_bool)", int64(1)) +} diff --git a/internal/jet/logger.go b/internal/jet/logger.go index c900fc05..c177d9cf 100644 --- a/internal/jet/logger.go +++ b/internal/jet/logger.go @@ -1,6 +1,11 @@ package jet -import "context" +import ( + "context" + "runtime" + "strings" + "time" +) // PrintableStatement is a statement which sql query can be logged type PrintableStatement interface { @@ -8,7 +13,7 @@ type PrintableStatement interface { DebugSql() (query string) } -// LoggerFunc is a definition of a function user can implement to support automatic statement logging. +// LoggerFunc is a function user can implement to support automatic statement logging. type LoggerFunc func(ctx context.Context, statement PrintableStatement) var logger LoggerFunc @@ -17,3 +22,60 @@ var logger LoggerFunc func SetLoggerFunc(loggerFunc LoggerFunc) { logger = loggerFunc } + +func callLogger(ctx context.Context, statement Statement) { + if logger != nil { + logger(ctx, statement) + } +} + +// QueryInfo contains information about executed query +type QueryInfo struct { + Statement PrintableStatement + // Depending on how the statement is executed, RowsProcessed is: + // - Number of rows returned for Query() and QueryContext() methods + // - RowsAffected() for Exec() and ExecContext() methods + // - Always 0 for Rows() method. + RowsProcessed int64 + Duration time.Duration + Err error +} + +// QueryLoggerFunc is a function user can implement to retrieve more information about statement executed. +type QueryLoggerFunc func(ctx context.Context, info QueryInfo) + +var queryLoggerFunc QueryLoggerFunc + +// SetQueryLogger sets automatic query logging function. +func SetQueryLogger(loggerFunc QueryLoggerFunc) { + queryLoggerFunc = loggerFunc +} + +func callQueryLoggerFunc(ctx context.Context, info QueryInfo) { + if queryLoggerFunc != nil { + queryLoggerFunc(ctx, info) + } +} + +// Caller returns information about statement caller +func (q QueryInfo) Caller() (file string, line int, function string) { + skip := 4 + // depending on execution type (Query, QueryContext, Exec, ...) looped once or twice + for { + var pc uintptr + var ok bool + + pc, file, line, ok = runtime.Caller(skip) + if !ok { + return + } + + funcDetails := runtime.FuncForPC(pc) + if !strings.Contains(funcDetails.Name(), "github.com/go-jet/jet/v2/internal") { + function = funcDetails.Name() + return + } + + skip++ + } +} diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 19173a62..b73a4518 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -173,3 +173,8 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o out.WriteString("END)") } + +// DISTINCT operator can be used to return distinct values of expr +func DISTINCT(expr Expression) Expression { + return newPrefixOperatorExpression(expr, "DISTINCT") +} diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go new file mode 100644 index 00000000..8ce5d1e1 --- /dev/null +++ b/internal/jet/order_set_aggregate_functions.go @@ -0,0 +1,60 @@ +package jet + +// MODE computes the most frequent value of the aggregated argument +func MODE() *OrderSetAggregateFunc { + return newOrderSetAggregateFunction("MODE", nil) +} + +// PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of +// aggregated argument values. This will interpolate between adjacent input items if needed. +func PERCENTILE_CONT(fraction FloatExpression) *OrderSetAggregateFunc { + return newOrderSetAggregateFunction("PERCENTILE_CONT", fraction) +} + +// PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position +// in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type. +func PERCENTILE_DISC(fraction FloatExpression) *OrderSetAggregateFunc { + return newOrderSetAggregateFunction("PERCENTILE_DISC", fraction) +} + +// OrderSetAggregateFunc implementation of order set aggregate function +type OrderSetAggregateFunc struct { + name string + fraction FloatExpression + orderBy Window +} + +func newOrderSetAggregateFunction(name string, fraction FloatExpression) *OrderSetAggregateFunc { + return &OrderSetAggregateFunc{ + name: name, + fraction: fraction, + } +} + +// WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values +func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression { + p.orderBy = ORDER_BY(orderBy) + return newOrderSetAggregateFuncExpression(*p) +} + +func newOrderSetAggregateFuncExpression(aggFunc OrderSetAggregateFunc) *orderSetAggregateFuncExpression { + ret := &orderSetAggregateFuncExpression{ + OrderSetAggregateFunc: aggFunc, + } + + ret.ExpressionInterfaceImpl.Parent = ret + + return ret +} + +type orderSetAggregateFuncExpression struct { + ExpressionInterfaceImpl + OrderSetAggregateFunc +} + +func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString(p.name) + WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + out.WriteString("WITHIN GROUP") + p.orderBy.serialize(statement, out) +} diff --git a/internal/jet/projection.go b/internal/jet/projection.go index b85702fe..1b1c625e 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -1,5 +1,7 @@ package jet +import "strings" + // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection interface { serializeForProjection(statement StatementType, out *SQLBuilder) @@ -14,16 +16,68 @@ func SerializeForProjection(projection Projection, statementType StatementType, // ProjectionList is a redefined type, so that ProjectionList can be used as a Projection. type ProjectionList []Projection -func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection { +func (pl ProjectionList) fromImpl(subQuery SelectTable) Projection { newProjectionList := ProjectionList{} - for _, projection := range cl { + for _, projection := range pl { newProjectionList = append(newProjectionList, projection.fromImpl(subQuery)) } return newProjectionList } -func (cl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) { - SerializeProjectionList(statement, cl, out) +func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) { + SerializeProjectionList(statement, pl, out) +} + +// As will create new projection list where each column is wrapped with a new table alias. +// tableAlias should be in the form 'name' or 'name.*'. +// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will +// have a column wrapped in alias 'Musician.Name'. +func (pl ProjectionList) As(tableAlias string) ProjectionList { + tableAlias = strings.TrimRight(tableAlias, ".*") + + newProjectionList := ProjectionList{} + + for _, projection := range pl { + switch p := projection.(type) { + case ProjectionList: + newProjectionList = append(newProjectionList, p.As(tableAlias)) + case ColumnExpression: + newProjectionList = append(newProjectionList, newAlias(p, tableAlias+"."+p.Name())) + case *alias: + newAlias := *p + _, columnName := extractTableAndColumnName(newAlias.alias) + newAlias.alias = tableAlias + "." + columnName + newProjectionList = append(newProjectionList, &newAlias) + } + } + + return newProjectionList +} + +// Except will create new projection list in which columns contained in excluded column names are removed +func (pl ProjectionList) Except(toExclude ...Column) ProjectionList { + excludedColumnList := UnwidColumnList(toExclude) + excludedColumnNames := map[string]bool{} + + for _, excludedColumn := range excludedColumnList { + excludedColumnNames[excludedColumn.Name()] = true + } + + var ret ProjectionList + + for _, projection := range pl { + switch p := projection.(type) { + case ProjectionList: + ret = append(ret, p.Except(toExclude...)) + case ColumnExpression: + if excludedColumnNames[p.Name()] { + continue + } + ret = append(ret, p) + } + } + + return ret } diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go new file mode 100644 index 00000000..7728e15a --- /dev/null +++ b/internal/jet/projection_test.go @@ -0,0 +1,46 @@ +package jet + +import "testing" + +func TestProjectionAs(t *testing.T) { + projectionList := ProjectionList{ + table1Col3, + SUM(table1ColInt).AS("sum"), + SUM(table1ColInt).AS("table.sum"), + ProjectionList{ + table1ColBool, + AVG(table1ColInt).AS("avg"), + AVG(table1ColInt).AS("t.avg"), + }, + } + + aliasedProjectionList := projectionList.As("new_alias.*") + + assertProjectionSerialize(t, aliasedProjectionList, + `table1.col3 AS "new_alias.col3", +SUM(table1.col_int) AS "new_alias.sum", +SUM(table1.col_int) AS "new_alias.sum", +table1.col_bool AS "new_alias.col_bool", +AVG(table1.col_int) AS "new_alias.avg", +AVG(table1.col_int) AS "new_alias.avg"`) + + subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) + + assertProjectionSerialize(t, subQueryProjections, + `"subQuery"."table1.col3" AS "table1.col3", +"subQuery".sum AS "sum", +"subQuery"."table.sum" AS "table.sum", +"subQuery"."table1.col_bool" AS "table1.col_bool", +"subQuery".avg AS "avg", +"subQuery"."t.avg" AS "t.avg"`) + + aliasedSubQueryProjectionList := subQueryProjections.(ProjectionList).As("subAlias") + + assertProjectionSerialize(t, aliasedSubQueryProjectionList, + `"subQuery"."table1.col3" AS "subAlias.col3", +"subQuery".sum AS "subAlias.sum", +"subQuery"."table.sum" AS "subAlias.sum", +"subQuery"."table1.col_bool" AS "subAlias.col_bool", +"subQuery".avg AS "subAlias.avg", +"subQuery"."t.avg" AS "subAlias.avg"`) +} diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 541992f9..c25fba36 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -2,38 +2,41 @@ package jet // SelectTable is interface for SELECT sub-queries type SelectTable interface { - Serializer + SerializerHasProjections Alias() string AllColumns() ProjectionList } type selectTableImpl struct { - selectStmt SerializerStatement - alias string + Statement SerializerHasProjections + alias string } // NewSelectTable func -func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl { - selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} +func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl { + selectTable := selectTableImpl{ + Statement: selectStmt, + alias: alias, + } + return selectTable } +func (s selectTableImpl) projections() ProjectionList { + return s.Statement.projections() +} + func (s selectTableImpl) Alias() string { return s.alias } func (s selectTableImpl) AllColumns() ProjectionList { - statementWithProjections, ok := s.selectStmt.(HasProjections) - if !ok { - return ProjectionList{} - } - - projectionList := statementWithProjections.projections().fromImpl(s) + projectionList := s.projections().fromImpl(s) return projectionList.(ProjectionList) } func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - s.selectStmt.serialize(statement, out) + s.Statement.serialize(statement, out) out.WriteString("AS") out.WriteIdentifier(s.alias) @@ -52,7 +55,7 @@ func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("LATERAL") - s.selectStmt.serialize(statement, out) + s.Statement.serialize(statement, out) out.WriteString("AS") out.WriteIdentifier(s.alias) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 1d050459..b2058017 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "github.com/go-jet/jet/v2/qrm" + "time" ) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) @@ -21,9 +22,9 @@ type Statement interface { // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error - //Exec executes statement over db connection/transaction without returning any rows. + // Exec executes statement over db connection/transaction without returning any rows. Exec(db qrm.DB) (sql.Result, error) - //Exec executes statement with context over db connection/transaction without returning any rows. + // ExecContext executes statement with context over db connection/transaction without returning any rows. ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) // Rows executes statements over db connection/transaction and returns rows Rows(ctx context.Context, db qrm.DB) (*Rows, error) @@ -51,6 +52,12 @@ type HasProjections interface { projections() ProjectionList } +// SerializerHasProjections interface is combination of Serializer and HasProjections interface +type SerializerHasProjections interface { + Serializer + HasProjections +} + // serializerStatementInterfaceImpl struct type serializerStatementInterfaceImpl struct { dialect Dialect @@ -78,12 +85,7 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { } func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { - query, args := s.Sql() - ctx := context.Background() - - callLogger(ctx, s) - - return qrm.Query(ctx, db, query, args, destination) + return s.QueryContext(context.Background(), db, destination) } func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { @@ -91,15 +93,25 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db callLogger(ctx, s) - return qrm.Query(ctx, db, query, args, destination) -} + var rowsProcessed int64 + var err error -func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { - query, args := s.Sql() + duration := duration(func() { + rowsProcessed, err = qrm.Query(ctx, db, query, args, destination) + }) + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + RowsProcessed: rowsProcessed, + Duration: duration, + Err: err, + }) - callLogger(context.Background(), s) + return err +} - return db.Exec(query, args...) +func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { + return s.ExecContext(context.Background(), db) } func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { @@ -107,7 +119,24 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q callLogger(ctx, s) - return db.ExecContext(ctx, query, args...) + duration := duration(func() { + res, err = db.ExecContext(ctx, query, args...) + }) + + var rowsAffected int64 + + if err == nil { + rowsAffected, _ = res.RowsAffected() + } + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + RowsProcessed: rowsAffected, + Duration: duration, + Err: err, + }) + + return res, err } func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { @@ -115,7 +144,18 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) callLogger(ctx, s) - rows, err := db.QueryContext(ctx, query, args...) + var rows *sql.Rows + var err error + + duration := duration(func() { + rows, err = db.QueryContext(ctx, query, args...) + }) + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + Duration: duration, + Err: err, + }) if err != nil { return nil, err @@ -124,10 +164,12 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) return &Rows{rows}, nil } -func callLogger(ctx context.Context, statement Statement) { - if logger != nil { - logger(ctx, statement) - } +func duration(f func()) time.Duration { + start := time.Now() + + f() + + return time.Now().Sub(start) } // ExpressionStatement interfacess @@ -200,7 +242,7 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti } for _, clause := range s.Clauses { - clause.Serialize(statement, out, FallTrough(options)...) + clause.Serialize(s.statementType, out, FallTrough(options)...) } if contains(options, Ident) { diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 3c568961..4e7efa62 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -13,6 +13,8 @@ type StringExpression interface { LT_EQ(rhs StringExpression) BoolExpression GT(rhs StringExpression) BoolExpression GT_EQ(rhs StringExpression) BoolExpression + BETWEEN(min, max StringExpression) BoolExpression + NOT_BETWEEN(min, max StringExpression) BoolExpression CONCAT(rhs Expression) StringExpression @@ -59,6 +61,14 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { return LtEq(s.parent, rhs) } +func (s *stringInterfaceImpl) BETWEEN(min, max StringExpression) BoolExpression { + return NewBetweenOperatorExpression(s.parent, min, max, false) +} + +func (s *stringInterfaceImpl) NOT_BETWEEN(min, max StringExpression) BoolExpression { + return NewBetweenOperatorExpression(s.parent, min, max, true) +} + func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator) } diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index 4fd7047c..efd146f2 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -13,6 +13,8 @@ type TimeExpression interface { LT_EQ(rhs TimeExpression) BoolExpression GT(rhs TimeExpression) BoolExpression GT_EQ(rhs TimeExpression) BoolExpression + BETWEEN(min, max TimeExpression) BoolExpression + NOT_BETWEEN(min, max TimeExpression) BoolExpression ADD(rhs Interval) TimeExpression SUB(rhs Interval) TimeExpression @@ -54,6 +56,14 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { return GtEq(t.parent, rhs) } +func (t *timeInterfaceImpl) BETWEEN(min, max TimeExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, false) +} + +func (t *timeInterfaceImpl) NOT_BETWEEN(min, max TimeExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, true) +} + func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression { return TimeExp(Add(t.parent, rhs)) } diff --git a/internal/jet/timestamp_expression.go b/internal/jet/timestamp_expression.go index f4cdd0b0..1013ce17 100644 --- a/internal/jet/timestamp_expression.go +++ b/internal/jet/timestamp_expression.go @@ -13,6 +13,8 @@ type TimestampExpression interface { LT_EQ(rhs TimestampExpression) BoolExpression GT(rhs TimestampExpression) BoolExpression GT_EQ(rhs TimestampExpression) BoolExpression + BETWEEN(min, max TimestampExpression) BoolExpression + NOT_BETWEEN(min, max TimestampExpression) BoolExpression ADD(rhs Interval) TimestampExpression SUB(rhs Interval) TimestampExpression @@ -54,6 +56,14 @@ func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression { return GtEq(t.parent, rhs) } +func (t *timestampInterfaceImpl) BETWEEN(min, max TimestampExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, false) +} + +func (t *timestampInterfaceImpl) NOT_BETWEEN(min, max TimestampExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, true) +} + func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression { return TimestampExp(Add(t.parent, rhs)) } diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index 0112a3c3..b8fe1035 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -13,6 +13,8 @@ type TimestampzExpression interface { LT_EQ(rhs TimestampzExpression) BoolExpression GT(rhs TimestampzExpression) BoolExpression GT_EQ(rhs TimestampzExpression) BoolExpression + BETWEEN(min, max TimestampzExpression) BoolExpression + NOT_BETWEEN(min, max TimestampzExpression) BoolExpression ADD(rhs Interval) TimestampzExpression SUB(rhs Interval) TimestampzExpression @@ -54,6 +56,14 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression return GtEq(t.parent, rhs) } +func (t *timestampzInterfaceImpl) BETWEEN(min, max TimestampzExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, false) +} + +func (t *timestampzInterfaceImpl) NOT_BETWEEN(min, max TimestampzExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, true) +} + func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression { return TimestampzExp(Add(t.parent, rhs)) } diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index d36ec809..8896dcec 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -13,6 +13,8 @@ type TimezExpression interface { LT_EQ(rhs TimezExpression) BoolExpression GT(rhs TimezExpression) BoolExpression GT_EQ(rhs TimezExpression) BoolExpression + BETWEEN(min, max TimezExpression) BoolExpression + NOT_BETWEEN(min, max TimezExpression) BoolExpression ADD(rhs Interval) TimezExpression SUB(rhs Interval) TimezExpression @@ -54,6 +56,14 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { return GtEq(t.parent, rhs) } +func (t *timezInterfaceImpl) BETWEEN(min, max TimezExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, false) +} + +func (t *timezInterfaceImpl) NOT_BETWEEN(min, max TimezExpression) BoolExpression { + return NewBetweenOperatorExpression(t.parent, min, max, true) +} + func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression { return TimezExp(Add(t.parent, rhs)) } diff --git a/internal/jet/utils.go b/internal/jet/utils.go index eab44030..524c2c53 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -3,6 +3,7 @@ package jet import ( "github.com/go-jet/jet/v2/internal/utils" "reflect" + "strings" ) // SerializeClauseList func @@ -33,7 +34,9 @@ func serializeExpressionList( out.WriteString(separator) } - expression.serialize(statement, out, options...) + if expression != nil { + expression.serialize(statement, out, options...) + } } } @@ -68,8 +71,8 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } -// SerializeColumnExpressionNames func -func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType, +// SerializeColumnExpressions func +func SerializeColumnExpressions(columns []ColumnExpression, statementType StatementType, out *SQLBuilder, options ...SerializeOption) { for i, col := range columns { if i > 0 { @@ -84,6 +87,21 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, statementType St } } +// SerializeColumnExpressionNames func +func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) { + for i, col := range columns { + if i > 0 { + out.WriteString(", ") + } + + if col == nil { + panic("jet: nil column in columns list") + } + + out.WriteIdentifier(col.Name()) + } +} + // ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer @@ -229,3 +247,22 @@ func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Exp return defaultExpression } + +func extractTableAndColumnName(alias string) (tableName string, columnName string) { + parts := strings.Split(alias, ".") + + if len(parts) >= 2 { + tableName = parts[0] + columnName = parts[1] + } else { + columnName = parts[0] + } + + return +} + +func serializeToDefaultDebugString(expr Serializer) string { + out := SQLBuilder{Dialect: defaultDialect, Debug: true} + expr.serialize(SelectStatementType, &out) + return out.Buff.String() +} diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index ab570679..783fa274 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -1,9 +1,12 @@ package jet +import "fmt" + // WITH function creates new with statement from list of common table expressions for specified dialect -func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement { +func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(statement Statement) Statement { newWithImpl := &withImpl{ - ctes: cte, + recursive: recursive, + ctes: cte, serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ dialect: dialect, statementType: WithStatementType, @@ -23,7 +26,8 @@ func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statemen type withImpl struct { serializerStatementInterfaceImpl - ctes []CommonTableExpressionDefinition + recursive bool + ctes []*CommonTableExpression primaryStatement SerializerStatement } @@ -31,6 +35,10 @@ func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options .. out.NewLine() out.WriteString("WITH") + if w.recursive { + out.WriteString("RECURSIVE") + } + for i, cte := range w.ctes { if i > 0 { out.WriteString(",") @@ -48,35 +56,55 @@ func (w withImpl) projections() ProjectionList { // CommonTableExpression contains information about a CTE. type CommonTableExpression struct { selectTableImpl + + NotMaterialized bool + Columns []ColumnExpression } // CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - return CommonTableExpression{ - selectTableImpl: selectTableImpl{ - selectStmt: nil, - alias: name, - }, +func CTE(name string, columns ...ColumnExpression) CommonTableExpression { + cte := CommonTableExpression{ + selectTableImpl: NewSelectTable(nil, name), + Columns: columns, } + + for _, column := range cte.Columns { + column.setSubQuery(cte) + } + + return cte } func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteIdentifier(c.alias) -} + if statement == WithStatementType { // serialize CTE definition + out.WriteIdentifier(c.alias) + if len(c.Columns) > 0 { + out.WriteByte('(') + SerializeColumnExpressionNames(c.Columns, out) + out.WriteByte(')') + } + out.WriteString("AS") -// AS returns sets definition for a CTE -func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition { - c.selectStmt = statement - return CommonTableExpressionDefinition{cte: c} -} + if c.NotMaterialized { + out.WriteString("NOT MATERIALIZED") + } + + if c.Statement == nil { + panic(fmt.Sprintf("jet: '%s' CTE is not defined", c.alias)) + } -// CommonTableExpressionDefinition contains implementation details of CTE -type CommonTableExpressionDefinition struct { - cte *CommonTableExpression + c.Statement.serialize(statement, out, FallTrough(options)...) + + } else { // serialize CTE in FROM clause + out.WriteIdentifier(c.alias) + } } -func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteIdentifier(c.cte.alias) - out.WriteString("AS") - c.cte.selectStmt.serialize(statement, out, FallTrough(options)...) +// AllColumns returns list of all projections in the CTE +func (c CommonTableExpression) AllColumns() ProjectionList { + if len(c.Columns) > 0 { + return ColumnListToProjectionList(c.Columns) + } + + return c.selectTableImpl.AllColumns() } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index c1419aa0..cac4a623 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -56,6 +56,12 @@ func PrintJson(v interface{}) { fmt.Println(string(jsonText)) } +// ToJSON converts v into json string +func ToJSON(v interface{}) string { + jsonText, _ := json.MarshalIndent(v, "", "\t") + return string(jsonText) +} + // AssertJSON check if data json output is the same as expectedJSON func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { jsonData, err := json.MarshalIndent(data, "", "\t") diff --git a/mysql/delete_statement.go b/mysql/delete_statement.go index 7f0fe6f9..0d39cde5 100644 --- a/mysql/delete_statement.go +++ b/mysql/delete_statement.go @@ -6,6 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type DeleteStatement interface { Statement + USING(tables ...ReadableTable) DeleteStatement WHERE(expression BoolExpression) DeleteStatement ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement LIMIT(limit int64) DeleteStatement @@ -15,6 +16,7 @@ type deleteStatementImpl struct { jet.SerializerStatement Delete jet.ClauseStatementBegin + Using jet.ClauseFrom Where jet.ClauseWhere OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit @@ -22,10 +24,15 @@ type deleteStatementImpl struct { func newDeleteStatement(table Table) DeleteStatement { newDelete := &deleteStatementImpl{} - newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete, - &newDelete.Where, &newDelete.OrderBy, &newDelete.Limit) + newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, + &newDelete.Delete, + &newDelete.Using, + &newDelete.Where, + &newDelete.OrderBy, + &newDelete.Limit) newDelete.Delete.Name = "DELETE FROM" + newDelete.Using.Name = "USING" newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) newDelete.Where.Mandatory = true newDelete.Limit.Count = -1 @@ -33,6 +40,11 @@ func newDeleteStatement(table Table) DeleteStatement { return newDelete } +func (d *deleteStatementImpl) USING(tables ...ReadableTable) DeleteStatement { + d.Using.Tables = readableTablesToSerializerList(tables) + return d +} + func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { d.Where.Condition = expression return d diff --git a/mysql/expressions.go b/mysql/expressions.go index b5857197..5ab9c1f7 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -87,7 +87,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/mysql/operators.go b/mysql/operators.go index 55855118..9591a920 100644 --- a/mysql/operators.go +++ b/mysql/operators.go @@ -7,3 +7,6 @@ var NOT = jet.NOT // BIT_NOT inverts every bit in integer expression result var BIT_NOT = jet.BIT_NOT + +// DISTINCT operator can be used to return distinct values of expr +var DISTINCT = jet.DISTINCT diff --git a/mysql/select_statement.go b/mysql/select_statement.go index ffb8054f..1c3a88a1 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -58,7 +58,7 @@ type SelectStatement interface { AsTable(alias string) SelectTable } -//SELECT creates new SelectStatement with list of projections +// SELECT creates new SelectStatement with list of projections func SELECT(projection Projection, projections ...Projection) SelectStatement { return newSelectStatement(nil, append([]Projection{projection}, projections...)) } @@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { } func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { - s.From.Tables = nil - for _, table := range tables { - s.From.Tables = append(s.From.Tables, table) - } + s.From.Tables = readableTablesToSerializerList(tables) return s } @@ -189,3 +186,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer { return jet.FixedLiteral(offset) } + +func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { + var ret []jet.Serializer + for _, table := range tables { + ret = append(ret, table) + } + return ret +} diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index 630f8006..37827d5f 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) { ))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE (NOT (EXISTS ( +WHERE NOT (EXISTS ( SELECT table2.col_int AS "table2.col_int" FROM db.table2 WHERE table1.col_int = table2.col_int - ))); + )); `) } diff --git a/mysql/select_table.go b/mysql/select_table.go index af9de27d..ad221934 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/mysql/types.go b/mysql/types.go index c82962fb..6e29b67f 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -24,4 +24,11 @@ type OrderByClause = jet.OrderByClause type GroupByClause = jet.GroupByClause // SetLogger sets automatic statement logging +// Deprecated: use SetQueryLogger instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLogger sets automatic query logging function. +var SetQueryLogger = jet.SetQueryLogger + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/mysql/with_statement.go b/mysql/with_statement.go index 1afcf26a..ca608cbf 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -2,25 +2,65 @@ package mysql import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) +} + +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/postgres/clause.go b/postgres/clause.go index 3a23fd07..0953e268 100644 --- a/postgres/clause.go +++ b/postgres/clause.go @@ -52,7 +52,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S out.WriteString("ON CONFLICT") if len(o.indexExpressions) > 0 { out.WriteString("(") - jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName) out.WriteString(")") } diff --git a/postgres/clause_test.go b/postgres/clause_test.go index 5602505f..28be3152 100644 --- a/postgres/clause_test.go +++ b/postgres/clause_test.go @@ -29,7 +29,7 @@ ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) ) assertClauseSerialize(t, onConflict, ` ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE - SET col_bool = $1, + SET col_bool = $1::boolean, col_int = $2 WHERE table2.col_float > $3`) } diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go index 2bfbd8c4..e4ecc49b 100644 --- a/postgres/delete_statement.go +++ b/postgres/delete_statement.go @@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet" type DeleteStatement interface { jet.SerializerStatement + USING(tables ...ReadableTable) DeleteStatement WHERE(expression BoolExpression) DeleteStatement - RETURNING(projections ...jet.Projection) DeleteStatement } @@ -15,22 +15,32 @@ type deleteStatementImpl struct { jet.SerializerStatement Delete jet.ClauseStatementBegin + Using jet.ClauseFrom Where jet.ClauseWhere Returning jet.ClauseReturning } func newDeleteStatement(table WritableTable) DeleteStatement { newDelete := &deleteStatementImpl{} - newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete, - &newDelete.Where, &newDelete.Returning) + newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, + &newDelete.Delete, + &newDelete.Using, + &newDelete.Where, + &newDelete.Returning) newDelete.Delete.Name = "DELETE FROM" newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) + newDelete.Using.Name = "USING" newDelete.Where.Mandatory = true return newDelete } +func (d *deleteStatementImpl) USING(tables ...ReadableTable) DeleteStatement { + d.Using.Tables = readableTablesToSerializerList(tables) + return d +} + func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { d.Where.Condition = expression return d diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index d98d8f3a..45ed7396 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -33,7 +33,7 @@ func TestExists(t *testing.T) { ).EQ(Bool(true)), `((EXISTS ( SELECT $1 -)) = $2)`, int64(1), true) +)) = $2::boolean)`, int64(1), true) assertProjectionSerialize(t, EXISTS( SELECT(Int(1)), diff --git a/postgres/expressions.go b/postgres/expressions.go index 8d0be889..c5c20653 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -100,7 +100,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/postgres/functions.go b/postgres/functions.go index 34f8370e..a20d1e11 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -336,3 +336,25 @@ func explicitLiteralCast(expresion Expression) jet.Expression { return expresion } + +// MODE computes the most frequent value of the aggregated argument +var MODE = jet.MODE + +// PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of +// aggregated argument values. This will interpolate between adjacent input items if needed. +func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc { + return jet.PERCENTILE_CONT(castFloatLiteral(fraction)) +} + +// PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position +// in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type. +func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc { + return jet.PERCENTILE_DISC(castFloatLiteral(fraction)) +} + +func castFloatLiteral(fraction FloatExpression) FloatExpression { + if _, ok := fraction.(jet.LiteralExpression); ok { + return CAST(fraction).AS_DOUBLE() // to make postgres aware of the type + } + return fraction +} diff --git a/postgres/functions_test.go b/postgres/functions_test.go new file mode 100644 index 00000000..4190f703 --- /dev/null +++ b/postgres/functions_test.go @@ -0,0 +1,12 @@ +package postgres + +import "testing" + +func TestROW(t *testing.T) { + assertSerialize(t, ROW(SELECT(Int(1))), `ROW(( + SELECT $1 +))`) + assertSerialize(t, ROW(Int(1), SELECT(Int(2)), Float(11.11)), `ROW($1, ( + SELECT $2 +), $3)`) +} diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index ad687b5f..3ec333e0 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -165,7 +165,7 @@ VALUES ('one', 'two'), ('1', '2'), ('theta', 'beta') ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE - SET col_bool = TRUE, + SET col_bool = TRUE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 @@ -191,7 +191,7 @@ INSERT INTO db.table1 (col1, col_bool) VALUES ('one', 'two'), ('1', '2') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE - SET col_bool = FALSE, + SET col_bool = FALSE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go index 6f5ab586..68d33bc1 100644 --- a/postgres/interval_expression.go +++ b/postgres/interval_expression.go @@ -41,6 +41,8 @@ type IntervalExpression interface { LT_EQ(rhs IntervalExpression) BoolExpression GT(rhs IntervalExpression) BoolExpression GT_EQ(rhs IntervalExpression) BoolExpression + BETWEEN(min, max IntervalExpression) BoolExpression + NOT_BETWEEN(min, max IntervalExpression) BoolExpression ADD(rhs IntervalExpression) IntervalExpression SUB(rhs IntervalExpression) IntervalExpression @@ -87,6 +89,14 @@ func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression { return jet.GtEq(i.parent, rhs) } +func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression { + return jet.NewBetweenOperatorExpression(i.parent, min, max, false) +} + +func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression { + return jet.NewBetweenOperatorExpression(i.parent, min, max, true) +} + func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression { return IntervalExp(jet.Add(i.parent, rhs)) } diff --git a/postgres/interval_expression_test.go b/postgres/interval_expression_test.go index f2fa9fe5..8d6e6474 100644 --- a/postgres/interval_expression_test.go +++ b/postgres/interval_expression_test.go @@ -67,7 +67,7 @@ func TestIntervalExpressionMethods(t *testing.T) { assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)), - "((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false) + "((table1.col_interval = INTERVAL '11 MINUTE') = $1::boolean)", false) assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)") assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)") assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)") diff --git a/postgres/literal.go b/postgres/literal.go index 8ee32352..e46b874d 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -6,35 +6,53 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -// Bool creates new bool literal expression -var Bool = jet.Bool +// Bool is boolean literal constructor +func Bool(value bool) BoolExpression { + return CAST(jet.Bool(value)).AS_BOOL() +} // Int is constructor for 64 bit signed integer expressions literals. var Int = jet.Int // Int8 is constructor for 8 bit signed integer expressions literals. -var Int8 = jet.Int8 +func Int8(value int8) IntegerExpression { + return CAST(jet.Int8(value)).AS_SMALLINT() +} // Int16 is constructor for 16 bit signed integer expressions literals. -var Int16 = jet.Int16 +func Int16(value int16) IntegerExpression { + return CAST(jet.Int16(value)).AS_SMALLINT() +} // Int32 is constructor for 32 bit signed integer expressions literals. -var Int32 = jet.Int32 +func Int32(value int32) IntegerExpression { + return CAST(jet.Int32(value)).AS_INTEGER() +} // Int64 is constructor for 64 bit signed integer expressions literals. -var Int64 = jet.Int +func Int64(value int64) IntegerExpression { + return CAST(jet.Int(value)).AS_BIGINT() +} // Uint8 is constructor for 8 bit unsigned integer expressions literals. -var Uint8 = jet.Uint8 +func Uint8(value uint8) IntegerExpression { + return CAST(jet.Uint8(value)).AS_SMALLINT() +} // Uint16 is constructor for 16 bit unsigned integer expressions literals. -var Uint16 = jet.Uint16 +func Uint16(value uint16) IntegerExpression { + return CAST(jet.Uint16(value)).AS_INTEGER() +} // Uint32 is constructor for 32 bit unsigned integer expressions literals. -var Uint32 = jet.Uint32 +func Uint32(value uint32) IntegerExpression { + return CAST(jet.Uint32(value)).AS_BIGINT() +} // Uint64 is constructor for 64 bit unsigned integer expressions literals. -var Uint64 = jet.Uint64 +func Uint64(value uint64) IntegerExpression { + return CAST(jet.Uint64(value)).AS_BIGINT() +} // Float creates new float literal expression var Float = jet.Float diff --git a/postgres/literal_test.go b/postgres/literal_test.go index 52a15a0a..f95e4867 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -7,7 +7,7 @@ import ( ) func TestBool(t *testing.T) { - assertSerialize(t, Bool(false), `$1`, false) + assertSerialize(t, Bool(false), `$1::boolean`, false) } func TestInt(t *testing.T) { @@ -16,42 +16,42 @@ func TestInt(t *testing.T) { func TestInt8(t *testing.T) { val := int8(math.MinInt8) - assertSerialize(t, Int8(val), `$1`, val) + assertSerialize(t, Int8(val), `$1::smallint`, val) } func TestInt16(t *testing.T) { val := int16(math.MinInt16) - assertSerialize(t, Int16(val), `$1`, val) + assertSerialize(t, Int16(val), `$1::smallint`, val) } func TestInt32(t *testing.T) { val := int32(math.MinInt32) - assertSerialize(t, Int32(val), `$1`, val) + assertSerialize(t, Int32(val), `$1::integer`, val) } func TestInt64(t *testing.T) { val := int64(math.MinInt64) - assertSerialize(t, Int64(val), `$1`, val) + assertSerialize(t, Int64(val), `$1::bigint`, val) } func TestUint8(t *testing.T) { val := uint8(math.MaxUint8) - assertSerialize(t, Uint8(val), `$1`, val) + assertSerialize(t, Uint8(val), `$1::smallint`, val) } func TestUint16(t *testing.T) { val := uint16(math.MaxUint16) - assertSerialize(t, Uint16(val), `$1`, val) + assertSerialize(t, Uint16(val), `$1::integer`, val) } func TestUint32(t *testing.T) { val := uint32(math.MaxUint32) - assertSerialize(t, Uint32(val), `$1`, val) + assertSerialize(t, Uint32(val), `$1::bigint`, val) } func TestUint64(t *testing.T) { val := uint64(math.MaxUint64) - assertSerialize(t, Uint64(val), `$1`, val) + assertSerialize(t, Uint64(val), `$1::bigint`, val) } func TestFloat(t *testing.T) { diff --git a/postgres/operators.go b/postgres/operators.go index 04c8b236..dcce3e08 100644 --- a/postgres/operators.go +++ b/postgres/operators.go @@ -7,3 +7,6 @@ var NOT = jet.NOT // BIT_NOT inverts every bit in integer expression result var BIT_NOT = jet.BIT_NOT + +// DISTINCT operator can be used to return distinct values of expr +var DISTINCT = jet.DISTINCT diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 8fb9cb6d..ff553fb6 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -44,7 +44,7 @@ type SelectStatement interface { jet.HasProjections Expression - DISTINCT() SelectStatement + DISTINCT(on ...jet.ColumnExpression) SelectStatement FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...GroupByClause) SelectStatement @@ -104,16 +104,14 @@ type selectStatementImpl struct { For jet.ClauseFor } -func (s *selectStatementImpl) DISTINCT() SelectStatement { +func (s *selectStatementImpl) DISTINCT(on ...jet.ColumnExpression) SelectStatement { s.Select.Distinct = true + s.Select.DistinctOnColumns = on return s } func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { - s.From.Tables = nil - for _, table := range tables { - s.From.Tables = append(s.From.Tables, table) - } + s.From.Tables = readableTablesToSerializerList(tables) return s } @@ -182,3 +180,11 @@ func toJetFrameOffset(offset int64) jet.Serializer { } return jet.FixedLiteral(offset) } + +func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { + var ret []jet.Serializer + for _, table := range tables { + ret = append(ret, table) + } + return ret +} diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index c3af03bf..b487f90e 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -23,7 +23,7 @@ func TestSelectLiterals(t *testing.T) { assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` SELECT $1, $2, - $3 + $3::boolean FROM db.table1; `, int64(1), 2.2, false) } @@ -59,7 +59,7 @@ func TestSelectWhere(t *testing.T) { assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE $1; +WHERE $1::boolean; `, true) assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` SELECT table1.col_int AS "table1.col_int" @@ -80,7 +80,7 @@ func TestSelectHaving(t *testing.T) { assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` SELECT table3.col_int AS "table3.col_int" FROM db.table3 -HAVING table1.col_bool = $1; +HAVING table1.col_bool = $1::boolean; `, true) } diff --git a/postgres/select_table.go b/postgres/select_table.go index e11b7cde..f3d680db 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// SelectTable is interface for MySQL sub-queries +// SelectTable is interface for postgres sub-queries type SelectTable interface { readableTable jet.SelectTable @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/postgres/types.go b/postgres/types.go index 6fed21b6..0f3ef706 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause // GroupByClause interface to use as input for GROUP_BY type GroupByClause = jet.GroupByClause -// SetLogger sets automatic statement logging +// SetLogger sets automatic statement logging function +// Deprecated: use SetQueryLogger instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLogger sets automatic query logging function. +var SetQueryLogger = jet.SetQueryLogger + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 58c5ba40..c13ffc6d 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -11,8 +11,9 @@ type UpdateStatement interface { SET(value interface{}, values ...interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement + FROM(tables ...ReadableTable) UpdateStatement WHERE(expression BoolExpression) UpdateStatement - RETURNING(projections ...jet.Projection) UpdateStatement + RETURNING(projections ...Projection) UpdateStatement } type updateStatementImpl struct { @@ -21,6 +22,7 @@ type updateStatementImpl struct { Update jet.ClauseUpdate Set clauseSet SetNew jet.SetClauseNew + From jet.ClauseFrom Where jet.ClauseWhere Returning jet.ClauseReturning } @@ -31,6 +33,7 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme &update.Update, &update.Set, &update.SetNew, + &update.From, &update.Where, &update.Returning) @@ -61,6 +64,11 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { return u } +func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement { + u.From.Tables = readableTablesToSerializerList(tables) + return u +} + func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { u.Where.Condition = expression return u diff --git a/postgres/with_statement.go b/postgres/with_statement.go index 1795b3dc..698d6e3d 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -2,25 +2,73 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) +} + +// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +// AS_NOT_MATERIALIZED is used to define not materialized CTE query +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.NotMaterialized = true + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/qrm/qrm.go b/qrm/qrm.go index 45024023..3731c687 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set") // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. -func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { +func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") @@ -26,11 +26,11 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr destinationPtrType := reflect.TypeOf(destPtr) if destinationPtrType.Elem().Kind() == reflect.Slice { - _, err := queryToSlice(ctx, db, query, args, destPtr) + rowsProcessed, err := queryToSlice(ctx, db, query, args, destPtr) if err != nil { - return fmt.Errorf("jet: %w", err) + return rowsProcessed, fmt.Errorf("jet: %w", err) } - return nil + return rowsProcessed, nil } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() @@ -38,16 +38,16 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { - return fmt.Errorf("jet: %w", err) + return rowsProcessed, fmt.Errorf("jet: %w", err) } if rowsProcessed == 0 { - return ErrNoRows + return 0, ErrNoRows } // edge case when row result set contains only NULLs. if tempSliceValue.Len() == 0 { - return nil + return rowsProcessed, nil } structValue := reflect.ValueOf(destPtr).Elem() @@ -56,7 +56,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr if structValue.Type().AssignableTo(firstTempStruct.Type()) { structValue.Set(tempSliceValue.Index(0).Elem()) } - return nil + return rowsProcessed, nil } else { panic("jet: destination has to be a pointer to slice or pointer to struct") } @@ -87,7 +87,7 @@ func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() - _, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil) + _, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil) if err != nil { return fmt.Errorf("failed to map a row, %w", err) @@ -136,35 +136,32 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, err = rows.Scan(scanContext.row...) if err != nil { - return + return scanContext.rowNum, err } scanContext.rowNum++ - _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) + _, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) if err != nil { - return + return scanContext.rowNum, err } } err = rows.Close() if err != nil { - return - } - - err = rows.Err() - - if err != nil { - return + return scanContext.rowNum, err } - rowsProcessed = scanContext.rowNum - - return + return scanContext.rowNum, rows.Err() } -func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { +func mapRowToSlice( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + slicePtrValue reflect.Value, + field *reflect.StructField) (updated bool, err error) { sliceElemType := getSliceElemType(slicePtrValue) @@ -184,12 +181,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl if ok { structPtrValue := getSliceElemPtrAt(slicePtrValue, index) - return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true) + return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true) } destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) - updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field) + updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field) if err != nil { return @@ -228,10 +225,25 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value return } -func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { +func mapRowToStruct( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, // to prevent circular dependency scan + structPtrValue reflect.Value, + parentField *reflect.StructField, + onlySlices ...bool, // small optimization, not to assign to already assigned struct fields +) (updated bool, err error) { + mapOnlySlices := len(onlySlices) > 0 structType := structPtrValue.Type().Elem() + if typesVisited.contains(&structType) { + return false, nil + } + + typesVisited.push(&structType) + defer typesVisited.pop() + typeInf := scanContext.getTypeInfo(structType, parentField) structValue := structPtrValue.Elem() @@ -248,7 +260,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re if fieldMap.complexType { var changed bool - changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) + changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field) if err != nil { return @@ -295,7 +307,12 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } -func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { +func mapRowToDestinationValue( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + dest reflect.Value, + structField *reflect.StructField) (updated bool, err error) { var destPtrValue reflect.Value @@ -309,7 +326,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re } } - updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) + updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField) if err != nil { return @@ -322,16 +339,21 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } -func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { +func mapRowToDestinationPtr( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + destPtrValue reflect.Value, + structField *reflect.StructField) (updated bool, err error) { utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") destValueKind := destPtrValue.Elem().Kind() if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField) } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField) } else { panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index dbc4b877..61feb759 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -132,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect. return s.constructGroupKey(groupKeyInfo) } - groupKeyInfo := s.getGroupKeyInfo(structType, structField) + groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack()) s.groupKeyInfoCache[mapKey] = groupKeyInfo @@ -144,7 +144,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return fmt.Sprintf("|ROW:%d|", s.rowNum) } - groupKeys := []string{} + var groupKeys []string for _, index := range groupKeyInfo.indexes { cellValue := s.rowElem(index) @@ -153,7 +153,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { groupKeys = append(groupKeys, subKey) } - subTypesGroupKeys := []string{} + var subTypesGroupKeys []string for _, subType := range groupKeyInfo.subTypes { subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) } @@ -161,9 +161,20 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" } -func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { +func (s *scanContext) getGroupKeyInfo( + structType reflect.Type, + parentField *reflect.StructField, + typeVisited *typeStack) groupKeyInfo { + ret := groupKeyInfo{typeName: structType.Name()} + if typeVisited.contains(&structType) { + return ret + } + + typeVisited.push(&structType) + defer typeVisited.pop() + typeName := getTypeName(structType, parentField) primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField) @@ -176,7 +187,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl continue } - subType := s.getGroupKeyInfo(fieldType, &field) + subType := s.getGroupKeyInfo(fieldType, &field, typeVisited) if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { ret.subTypes = append(ret.subTypes, subType) diff --git a/qrm/type_stack.go b/qrm/type_stack.go new file mode 100644 index 00000000..235c06ea --- /dev/null +++ b/qrm/type_stack.go @@ -0,0 +1,40 @@ +package qrm + +import "reflect" + +type typeStack []*reflect.Type + +func newTypeStack() *typeStack { + stack := make(typeStack, 0, 20) + return &stack +} + +func (s *typeStack) isEmpty() bool { + return len(*s) == 0 +} + +func (s *typeStack) push(t *reflect.Type) { + *s = append(*s, t) +} + +func (s *typeStack) pop() bool { + if s.isEmpty() { + return false + } + *s = (*s)[:len(*s)-1] + return true +} + +func (s *typeStack) contains(t *reflect.Type) bool { + if s.isEmpty() { + return false + } + + for _, typ := range *s { + if *typ == *t { + return true + } + } + + return false +} diff --git a/sqlite/delete_statement.go b/sqlite/delete_statement.go index dee85c06..e9c06106 100644 --- a/sqlite/delete_statement.go +++ b/sqlite/delete_statement.go @@ -9,7 +9,7 @@ type DeleteStatement interface { WHERE(expression BoolExpression) DeleteStatement ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement LIMIT(limit int64) DeleteStatement - RETURNING(projections ...jet.Projection) DeleteStatement + RETURNING(projections ...Projection) DeleteStatement } type deleteStatementImpl struct { diff --git a/sqlite/expressions.go b/sqlite/expressions.go index d1d47374..8d457354 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -90,7 +90,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/sqlite/insert_statement.go b/sqlite/insert_statement.go index 3912cc32..1db89b10 100644 --- a/sqlite/insert_statement.go +++ b/sqlite/insert_statement.go @@ -24,7 +24,6 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement { newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, &newInsert.Insert, &newInsert.ValuesQuery, - &newInsert.OnDuplicateKey, &newInsert.DefaultValues, &newInsert.OnConflict, &newInsert.Returning, @@ -40,12 +39,11 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement { type insertStatementImpl struct { jet.SerializerStatement - Insert jet.ClauseInsert - ValuesQuery jet.ClauseValuesQuery - OnDuplicateKey onDuplicateKeyUpdateClause - DefaultValues jet.ClauseOptional - OnConflict onConflictClause - Returning jet.ClauseReturning + Insert jet.ClauseInsert + ValuesQuery jet.ClauseValuesQuery + DefaultValues jet.ClauseOptional + OnConflict onConflictClause + Returning jet.ClauseReturning } func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { @@ -65,11 +63,6 @@ func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement { return is } -func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement { - is.OnDuplicateKey = assigments - return is -} - func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { is.ValuesQuery.Query = selectStatement return is @@ -85,29 +78,6 @@ func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertSt return is } -type onDuplicateKeyUpdateClause []jet.ColumnAssigment - -// Serialize for SetClause -func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(s) == 0 { - return - } - out.NewLine() - out.WriteString("ON DUPLICATE KEY UPDATE") - out.IncreaseIdent(24) - - for i, assigment := range s { - if i > 0 { - out.WriteString(",") - out.NewLine() - } - - jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...) - } - - out.DecreaseIdent(24) -} - func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict { is.OnConflict = onConflictClause{ insertStatement: is, diff --git a/sqlite/on_conflict_clause.go b/sqlite/on_conflict_clause.go index d131b9ea..1e2ec8f0 100644 --- a/sqlite/on_conflict_clause.go +++ b/sqlite/on_conflict_clause.go @@ -45,7 +45,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S out.WriteString("ON CONFLICT") if len(o.indexExpressions) > 0 { out.WriteString("(") - jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName) out.WriteString(")") } diff --git a/sqlite/operators.go b/sqlite/operators.go index 8ebecbf4..58d9cd99 100644 --- a/sqlite/operators.go +++ b/sqlite/operators.go @@ -7,3 +7,6 @@ var NOT = jet.NOT // BIT_NOT inverts every bit in integer expression result var BIT_NOT = jet.BIT_NOT + +// DISTINCT operator can be used to return distinct values of expr +var DISTINCT = jet.DISTINCT diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index 4406dcd3..b5a75669 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { } func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { - s.From.Tables = nil - for _, table := range tables { - s.From.Tables = append(s.From.Tables, table) - } + s.From.Tables = readableTablesToSerializerList(tables) return s } @@ -184,3 +181,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer { return jet.FixedLiteral(offset) } + +func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { + var ret []jet.Serializer + for _, table := range tables { + ret = append(ret, table) + } + return ret +} diff --git a/sqlite/select_statement_test.go b/sqlite/select_statement_test.go index 0ba76f0f..a42fe06d 100644 --- a/sqlite/select_statement_test.go +++ b/sqlite/select_statement_test.go @@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) { ))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE (NOT (EXISTS ( +WHERE NOT (EXISTS ( SELECT table2.col_int AS "table2.col_int" FROM db.table2 WHERE table1.col_int = table2.col_int - ))); + )); `) } diff --git a/sqlite/select_table.go b/sqlite/select_table.go index 4117e064..9ac7f720 100644 --- a/sqlite/select_table.go +++ b/sqlite/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/sqlite/types.go b/sqlite/types.go index 755be1d8..e06a3238 100644 --- a/sqlite/types.go +++ b/sqlite/types.go @@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause // GroupByClause interface to use as input for GROUP_BY type GroupByClause = jet.GroupByClause -// SetLogger sets automatic statement logging +// SetLogger sets automatic statement logging. +// Deprecated: use SetQueryLogger instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLogger sets automatic query logging function. +var SetQueryLogger = jet.SetQueryLogger + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/sqlite/update_statement.go b/sqlite/update_statement.go index 53cf72d1..c28819af 100644 --- a/sqlite/update_statement.go +++ b/sqlite/update_statement.go @@ -9,14 +9,16 @@ type UpdateStatement interface { SET(value interface{}, values ...interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement + FROM(tables ...ReadableTable) UpdateStatement WHERE(expression BoolExpression) UpdateStatement - RETURNING(projections ...jet.Projection) UpdateStatement + RETURNING(projections ...Projection) UpdateStatement } type updateStatementImpl struct { jet.SerializerStatement Update jet.ClauseUpdate + From jet.ClauseFrom Set jet.SetClause SetNew jet.SetClauseNew Where jet.ClauseWhere @@ -29,6 +31,7 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { &update.Update, &update.Set, &update.SetNew, + &update.From, &update.Where, &update.Returning) @@ -59,12 +62,17 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { return u } +func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement { + u.From.Tables = readableTablesToSerializerList(tables) + return u +} + func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { u.Where.Condition = expression return u } -func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { +func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement { u.Returning.ProjectionList = projections return u } diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go index 7940dcd5..5375fffc 100644 --- a/sqlite/with_statement.go +++ b/sqlite/with_statement.go @@ -2,25 +2,73 @@ package sqlite import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) +} + +// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +// AS_NOT_MATERIALIZED is used to define not materialized CTE query +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.NotMaterialized = true + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/tests/Makefile b/tests/Makefile new file mode 100644 index 00000000..632c3d21 --- /dev/null +++ b/tests/Makefile @@ -0,0 +1,62 @@ + + +setup: checkout-testdata docker-compose-up + +# checkout-testdata will checkout testdata from separate repository into git submodule. +checkout-testdata: + git submodule init + git submodule update + cd ./testdata && git fetch && git checkout master && git pull + +# docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize +# database with testdata retrieved in previous step. +# On the first run this action might take couple of minutes. Docker temp data are stored in .docker directory. +docker-compose-up: + docker-compose up + +init-all: + go run ./init/init.go -testsuite all + +init-postgres: + go run ./init/init.go -testsuite postgres + +init-mysql: + go run ./init/init.go -testsuite mysql + +init-mariadb: + go run ./init/init.go -testsuite mariadb + +init-sqlite: + go run ./init/init.go -testsuite sqlite + +# jet-gen will call generator on each of the test databases to generate sql builder and model files need to run the tests. +jet-gen-all: install-jet-gen jet-gen-postgres jet-gen-mysql jet-gen-mariadb jet-gen-sqlite + +install-jet-gen: + go build -o ${GOPATH}/bin/jet ../cmd/jet/ + +jet-gen-postgres: + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ + +jet-gen-mysql: + jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/dvds" -path=./.gentestdata/mysql + jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/dvds2" -path=./.gentestdata/mysql + jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/test_sample" -path=./.gentestdata/mysql + +jet-gen-mariadb: + jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/dvds" -path=./.gentestdata/mysql + jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/dvds2" -path=./.gentestdata/mysql + jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/test_sample" -path=./.gentestdata/mysql + +jet-gen-sqlite: + jet -source=sqlite -dsn="./testdata/init/sqlite/chinook.db" -schema=dvds -path=./.gentestdata/sqlite/chinook + jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila + jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample + + +# docker-compose-cleanup will stop and remove test containers, volumes, and images. +cleanup: + docker-compose down --volumes diff --git a/tests/Readme.md b/tests/Readme.md new file mode 100644 index 00000000..097e5f30 --- /dev/null +++ b/tests/Readme.md @@ -0,0 +1,29 @@ + +# Integration tests + +This folder contains integration tests intended to test jet generator, statements and query result mapping with a running database. + +## How to run tests? + +Before we can run tests, we need to set up and initialize test databases. +To simplify the process there is a Makefile with a list of helper commands. +```shell +# We first need to checkout testdata from separate repository into git submodule, +# then download docker image for each of the databases listed in docker-compose.yaml file, and +# finally run and initialize databases with downloaded test data. +# Note that on the first run this command might take a couple of minutes. +make setup + +# When databases are ready, we can generate sql builder and model types for each of the test databases +make jet-gen-all +``` + +Then we can run the tests the usual way: +```shell +go test -v ./... +``` + +To removes test containers, volumes, and images: +```shell +make cleanup +``` \ No newline at end of file diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index ef89c1b6..bbf73f99 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -8,7 +8,7 @@ import ( // Postgres test database connection parameters const ( PgHost = "localhost" - PgPort = 5432 + PgPort = 50901 PgUser = "jet" PgPassword = "jet" PgDBName = "jetdb" @@ -19,14 +19,25 @@ var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbn // MySQL test database connection parameters const ( - MySqLHost = "localhost" - MySQLPort = 3306 + MySqLHost = "127.0.0.1" + MySQLPort = 50902 MySQLUser = "jet" MySQLPassword = "jet" + + MariaDBHost = "127.0.0.1" + MariaDBPort = 50903 + MariaDBUser = "jet" + MariaDBPassword = "jet" ) -// MySQLConnectionString is MySQL driver connection string to test database -var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort) +// MySQLConnectionString is MySQL connection string for test database +func MySQLConnectionString(isMariaDB bool, dbName string) string { + if isMariaDB { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MariaDBUser, MariaDBPassword, MariaDBHost, MariaDBPort, dbName) + } + + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MySQLUser, MySQLPassword, MySqLHost, MySQLPort, dbName) +} // sqllite var ( diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml new file mode 100644 index 00000000..2e913f13 --- /dev/null +++ b/tests/docker-compose.yaml @@ -0,0 +1,39 @@ +version: '3' +services: + postgres: + image: postgres:14.1 + restart: always + environment: + - POSTGRES_USER=jet + - POSTGRES_PASSWORD=jet + - POSTGRES_DB=jetdb + ports: + - '50901:5432' + volumes: + - ./testdata/init/postgres:/docker-entrypoint-initdb.d + + mysql: + image: mysql:8.0.27 + command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] + restart: always + environment: + MYSQL_ROOT_PASSWORD: jet + MYSQL_USER: jet + MYSQL_PASSWORD: jet + ports: + - '50902:3306' + volumes: + - ./testdata/init/mysql:/docker-entrypoint-initdb.d + + mariadb: + image: mariadb:10.3.32 + command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] + restart: always + environment: + MYSQL_ROOT_PASSWORD: jet + MYSQL_USER: jet + MYSQL_PASSWORD: jet + ports: + - '50903:3306' + volumes: + - ./testdata/init/mysql:/docker-entrypoint-initdb.d diff --git a/tests/init/Readme.md b/tests/init/Readme.md new file mode 100644 index 00000000..65e61d9c --- /dev/null +++ b/tests/init/Readme.md @@ -0,0 +1,6 @@ + +The `init` command can be used to initialize test databases on the local host machine, if needed. +Update [dbconfig](../dbconfig/dbconfig.go) with your local database parameters. + +The recommended way to initialize test databases is by a docker container. +See tests [Readme.md](../Readme.md). \ No newline at end of file diff --git a/tests/init/init.go b/tests/init/init.go index aa04fb5f..c1c842ab 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -4,6 +4,7 @@ import ( "database/sql" "flag" "fmt" + "github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/tests/internal/utils/repo" "io/ioutil" @@ -11,7 +12,6 @@ import ( "os/exec" "strings" - "github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" @@ -39,7 +39,7 @@ func main() { } if testSuite == "mysql" || testSuite == "mariadb" { - initMySQLDB() + initMySQLDB(testSuite == "mariadb") return } @@ -48,8 +48,9 @@ func main() { return } - initMySQLDB() initPostgresDB() + initMySQLDB(false) + initMySQLDB(true) initSQLiteDB() } @@ -62,7 +63,7 @@ func initSQLiteDB() { throw.OnError(err) } -func initMySQLDB() { +func initMySQLDB(isMariaDB bool) { mySQLDBs := []string{ "dvds", @@ -71,8 +72,20 @@ func initMySQLDB() { } for _, dbName := range mySQLDBs { - cmdLine := fmt.Sprintf("mysql -h 127.0.0.1 -u %s -p%s %s < %s", - dbconfig.MySQLUser, dbconfig.MySQLPassword, dbName, "./testdata/init/mysql/"+dbName+".sql") + host := dbconfig.MySqLHost + port := dbconfig.MySQLPort + user := dbconfig.MySQLUser + pass := dbconfig.MySQLPassword + + if isMariaDB { + host = dbconfig.MariaDBHost + port = dbconfig.MariaDBPort + user = dbconfig.MariaDBUser + pass = dbconfig.MariaDBPassword + } + + cmdLine := fmt.Sprintf("mysql -h %s -P %d -u %s -p%s %s < %s", host, port, user, pass, dbName, + "./testdata/init/mysql/"+dbName+".sql") fmt.Println(cmdLine) @@ -85,10 +98,10 @@ func initMySQLDB() { throw.OnError(err) err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ - Host: dbconfig.MySqLHost, - Port: dbconfig.MySQLPort, - User: dbconfig.MySQLUser, - Password: dbconfig.MySQLPassword, + Host: host, + Port: port, + User: user, + Password: pass, DBName: dbName, }) @@ -99,7 +112,7 @@ func initMySQLDB() { func initPostgresDB() { db, err := sql.Open("postgres", dbconfig.PostgresConnectString) if err != nil { - panic("Failed to connect to test db") + panic("Failed to connect to test db: " + err.Error()) } defer func() { err := db.Close() diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2132d7a1..428a0e64 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -31,10 +31,6 @@ func TestAllTypes(t *testing.T) { require.Equal(t, len(dest), 2) - if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert - return - } - //testutils.PrintJson(dest) testutils.AssertJSON(t, dest, allTypesJson) } @@ -49,10 +45,6 @@ func TestAllTypesViewSelect(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 2) - if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert - return - } - testutils.AssertJSON(t, dest, allTypesJson) } @@ -224,6 +216,8 @@ func TestFloatOperators(t *testing.T) { AllTypes.Numeric.LT(Float(34.56)).AS("lt2"), AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), + AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"), + AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"), TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"), TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"), @@ -252,11 +246,9 @@ func TestFloatOperators(t *testing.T) { TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), ).LIMIT(2) - queryStr, _ := query.Sql() - - //fmt.Println(queryStr) + // fmt.Println(query.Sql()) - require.Equal(t, queryStr, strings.Replace(` + testutils.AssertStatementSql(t, query, strings.Replace(` SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", (all_types.'decimal' = ?) AS "eq2", (all_types.'real' = ?) AS "eq3", @@ -270,22 +262,24 @@ SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", (all_types.'numeric' < ?) AS "lt2", (all_types.'numeric' > ?) AS "gt1", (all_types.'numeric' > ?) AS "gt2", - TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1", - TRUNCATE((all_types.'decimal' + ?), ?) AS "add2", - TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1", - TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2", - TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1", - TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2", - TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1", - TRUNCATE((all_types.'decimal' / ?), ?) AS "div2", - TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1", - TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2", + (all_types.'numeric' BETWEEN ? AND all_types.'decimal') AS "between", + (all_types.'numeric' NOT BETWEEN (all_types.'decimal' * ?) AND ?) AS "not_between", + TRUNCATE(all_types.'decimal' + all_types.'decimal', ?) AS "add1", + TRUNCATE(all_types.'decimal' + ?, ?) AS "add2", + TRUNCATE(all_types.'decimal' - all_types.decimal_ptr, ?) AS "sub1", + TRUNCATE(all_types.'decimal' - ?, ?) AS "sub2", + TRUNCATE(all_types.'decimal' * all_types.decimal_ptr, ?) AS "mul1", + TRUNCATE(all_types.'decimal' * ?, ?) AS "mul2", + TRUNCATE(all_types.'decimal' / all_types.decimal_ptr, ?) AS "div1", + TRUNCATE(all_types.'decimal' / ?, ?) AS "div2", + TRUNCATE(all_types.'decimal' % all_types.decimal_ptr, ?) AS "mod1", + TRUNCATE(all_types.'decimal' % ?, ?) AS "mod2", TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1", TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2", TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs", TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power", TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt", - TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt", + TRUNCATE(POWER(all_types.'decimal', ? / ?), ?) AS "cbrt", CEIL(all_types.'real') AS "ceil", FLOOR(all_types.'real') AS "floor", ROUND(all_types.'decimal') AS "round1", @@ -316,61 +310,48 @@ func TestIntegerOperators(t *testing.T) { AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), AllTypes.BigInt.EQ(Int(12)).AS("eq2"), - AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), - AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), - AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), AllTypes.BigInt.LT(Int(65)).AS("lt2"), - AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), - AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), AllTypes.BigInt.GT(Int(65)).AS("gt2"), - AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"), + AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(Int(11)).AS("add2"), - AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), AllTypes.BigInt.SUB(Int(11)).AS("sub2"), - AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), AllTypes.BigInt.MUL(Int(11)).AS("mul2"), - AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), AllTypes.BigInt.DIV(Int(11)).AS("div2"), - AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), AllTypes.BigInt.MOD(Int(11)).AS("mod2"), - AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), AllTypes.SmallInt.POW(Int(6)).AS("pow2"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), - AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"), AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"), - AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), - BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), - AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"), @@ -379,9 +360,9 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` + testutils.AssertStatementSql(t, query, strings.ReplaceAll(` SELECT all_types.big_int AS "all_types.big_int", all_types.big_int_ptr AS "all_types.big_int_ptr", all_types.small_int AS "all_types.small_int", @@ -402,6 +383,8 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.big_int > ?) AS "gt2", (all_types.big_int >= all_types.big_int_ptr) AS "gte1", (all_types.big_int >= ?) AS "gte2", + (all_types.''integer'' BETWEEN ? AND ?) AS "between", + (all_types.''integer'' NOT BETWEEN ? AND ?) AS "not_between", (all_types.big_int + all_types.big_int) AS "add1", (all_types.big_int + ?) AS "add2", (all_types.big_int - all_types.big_int) AS "sub1", @@ -412,7 +395,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.big_int DIV ?) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", (all_types.big_int % ?) AS "mod2", - POW(all_types.small_int, (all_types.small_int DIV ?)) AS "pow1", + POW(all_types.small_int, all_types.small_int DIV ?) AS "pow1", POW(all_types.small_int, ?) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2", @@ -428,10 +411,10 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int >> ?) AS "bit shift right 2", ABS(all_types.big_int) AS "abs", SQRT(ABS(all_types.big_int)) AS "sqrt", - POWER(ABS(all_types.big_int), (? / ?)) AS "cbrt" + POWER(ABS(all_types.big_int), ? / ?) AS "cbrt" FROM test_sample.all_types LIMIT ?; -`) +`, "''", "`")) var dest []struct { common.AllTypesIntegerExpResult `alias:"."` @@ -461,6 +444,8 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.LT(String("Text")), AllTypes.Text.LT_EQ(AllTypes.VarCharPtr), AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.BETWEEN(String("min"), String("max")), + AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.CONCAT(String("text2")), AllTypes.Text.CONCAT(Int(11)), AllTypes.Text.LIKE(String("abc")), @@ -528,24 +513,21 @@ func TestTimeExpressions(t *testing.T) { AllTypes.TimePtr.NOT_EQ(AllTypes.Time), AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)), - AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)), - AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)), AllTypes.Time.LT(AllTypes.Time), AllTypes.Time.LT(Time(17, 46, 6)), - AllTypes.Time.LT_EQ(AllTypes.Time), AllTypes.Time.LT_EQ(Time(16, 56, 56)), - AllTypes.Time.GT(AllTypes.Time), AllTypes.Time.GT(Time(15, 16, 46)), - AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(14, 26, 36)), + AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), AllTypes.TimePtr), + AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, AllTypes.Time.ADD(INTERVAL(2, HOUR))), AllTypes.Time.ADD(INTERVAL(10, MINUTE)), AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)), @@ -583,6 +565,8 @@ SELECT CAST('20:34:58' AS TIME), all_types.time > CAST('15:16:46' AS TIME), all_types.time >= all_types.time, all_types.time >= CAST('14:26:36' AS TIME), + all_types.time BETWEEN CAST('11:00:30.0000001' AS TIME) AND all_types.time_ptr, + all_types.time NOT BETWEEN all_types.time_ptr AND (all_types.time + INTERVAL 2 HOUR), all_types.time + INTERVAL 10 MINUTE, all_types.time + INTERVAL all_types.''integer'' MINUTE, all_types.time + INTERVAL 3 HOUR, @@ -594,7 +578,7 @@ SELECT CAST('20:34:58' AS TIME), CURRENT_TIME(3) FROM test_sample.all_types; `, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", - "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") + "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36", "11:00:30.0000001") dest := []struct{}{} err := query.Query(db, &dest) @@ -608,27 +592,23 @@ func TestDateExpressions(t *testing.T) { AllTypes.Date.EQ(AllTypes.Date), AllTypes.Date.EQ(Date(2019, 6, 6)), - AllTypes.DatePtr.NOT_EQ(AllTypes.Date), AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)), - AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date), AllTypes.Date.IS_DISTINCT_FROM(Date(2019, 2, 6)), - AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date), AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)), AllTypes.Date.LT(AllTypes.Date), AllTypes.Date.LT(Date(2019, 4, 6)), - AllTypes.Date.LT_EQ(AllTypes.Date), AllTypes.Date.LT_EQ(Date(2019, 5, 5)), - AllTypes.Date.GT(AllTypes.Date), AllTypes.Date.GT(Date(2019, 1, 4)), - AllTypes.Date.GT_EQ(AllTypes.Date), AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + AllTypes.Date.BETWEEN(Date(2000, 2, 2), AllTypes.DatePtr), + AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, Date(2000, 2, 2)), AllTypes.Date.ADD(INTERVAL("10:20.000100", MINUTE_MICROSECOND)), AllTypes.Date.ADD(INTERVALe(AllTypes.BigInt, MINUTE)), @@ -661,6 +641,8 @@ SELECT CAST('2009-11-17' AS DATE), all_types.date > CAST('2019-01-04' AS DATE), all_types.date >= all_types.date, all_types.date >= CAST('2019-02-03' AS DATE), + all_types.date BETWEEN CAST('2000-02-02' AS DATE) AND all_types.date_ptr, + all_types.date NOT BETWEEN all_types.date_ptr AND CAST('2000-02-02' AS DATE), all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND, all_types.date + INTERVAL all_types.big_int MINUTE, all_types.date + INTERVAL 15 HOUR, @@ -684,27 +666,23 @@ func TestDateTimeExpressions(t *testing.T) { query := AllTypes.SELECT( AllTypes.DateTime.EQ(AllTypes.DateTime), AllTypes.DateTime.EQ(dateTime), - AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime), AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)), - AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime), AllTypes.DateTime.IS_DISTINCT_FROM(dateTime), - AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime), AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime), AllTypes.DateTime.LT(AllTypes.DateTime), AllTypes.DateTime.LT(dateTime), - AllTypes.DateTime.LT_EQ(AllTypes.DateTime), AllTypes.DateTime.LT_EQ(dateTime), - AllTypes.DateTime.GT(AllTypes.DateTime), AllTypes.DateTime.GT(dateTime), - AllTypes.DateTime.GT_EQ(AllTypes.DateTime), AllTypes.DateTime.GT_EQ(dateTime), + AllTypes.DateTime.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), + AllTypes.DateTime.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), @@ -718,7 +696,7 @@ func TestDateTimeExpressions(t *testing.T) { NOW(1), ) - //Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, ` SELECT all_types.date_time = all_types.date_time, @@ -737,6 +715,8 @@ SELECT all_types.date_time = all_types.date_time, all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME), all_types.date_time >= all_types.date_time, all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME), + all_types.date_time BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr, + all_types.date_time NOT BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr, all_types.date_time + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, all_types.date_time + INTERVAL all_types.big_int HOUR, all_types.date_time + INTERVAL 2 HOUR, @@ -761,27 +741,23 @@ func TestTimestampExpressions(t *testing.T) { query := AllTypes.SELECT( AllTypes.Timestamp.EQ(AllTypes.Timestamp), AllTypes.Timestamp.EQ(timestamp), - AllTypes.TimestampPtr.NOT_EQ(AllTypes.Timestamp), AllTypes.TimestampPtr.NOT_EQ(Timestamp(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)), - AllTypes.Timestamp.IS_DISTINCT_FROM(AllTypes.Timestamp), AllTypes.Timestamp.IS_DISTINCT_FROM(timestamp), - AllTypes.Timestamp.IS_NOT_DISTINCT_FROM(AllTypes.Timestamp), AllTypes.Timestamp.IS_NOT_DISTINCT_FROM(timestamp), AllTypes.Timestamp.LT(AllTypes.Timestamp), AllTypes.Timestamp.LT(timestamp), - AllTypes.Timestamp.LT_EQ(AllTypes.Timestamp), AllTypes.Timestamp.LT_EQ(timestamp), - AllTypes.Timestamp.GT(AllTypes.Timestamp), AllTypes.Timestamp.GT(timestamp), - AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp), AllTypes.Timestamp.GT_EQ(timestamp), + AllTypes.Timestamp.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), + AllTypes.Timestamp.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), AllTypes.Timestamp.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), AllTypes.Timestamp.ADD(INTERVALe(AllTypes.BigInt, HOUR)), @@ -814,6 +790,8 @@ SELECT all_types.timestamp = all_types.timestamp, all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'), all_types.timestamp >= all_types.timestamp, all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'), + all_types.timestamp BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr, + all_types.timestamp NOT BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr, all_types.timestamp + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, all_types.timestamp + INTERVAL all_types.big_int HOUR, all_types.timestamp + INTERVAL 2 HOUR, diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index cb7673c9..709ce1a2 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table" "github.com/stretchr/testify/require" @@ -91,3 +92,29 @@ func initForDeleteTest(t *testing.T) { testutils.AssertExec(t, stmt, db, 2) } + +func TestDeleteWithUsing(t *testing.T) { + tx := beginTx(t) + defer tx.Rollback() + + stmt := table.Rental.DELETE(). + USING( + table.Rental. + INNER_JOIN(table.Staff, table.Rental.StaffID.EQ(table.Staff.StaffID)), + table.Actor, + ). + WHERE( + table.Staff.StaffID.NOT_EQ(Int(2)). + AND(table.Rental.RentalID.LT(Int(100))), + ) + + testutils.AssertStatementSql(t, stmt, ` +DELETE FROM dvds.rental +USING dvds.rental + INNER JOIN dvds.staff ON (rental.staff_id = staff.staff_id), + dvds.actor +WHERE (staff.staff_id != ?) AND (rental.rental_id < ?); +`) + + testutils.AssertExec(t, stmt, tx) +} diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go index e915e0fc..6af636b7 100644 --- a/tests/mysql/generator_template_test.go +++ b/tests/mysql/generator_template_test.go @@ -25,18 +25,30 @@ var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view") var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum") var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go") -var dbConnection = mysql2.DBConnection{ - Host: dbconfig.MySqLHost, - Port: dbconfig.MySQLPort, - User: dbconfig.MySQLUser, - Password: dbconfig.MySQLPassword, - DBName: "dvds", +func dbConnection(dbName string) mysql2.DBConnection { + if sourceIsMariaDB() { + return mysql2.DBConnection{ + Host: dbconfig.MariaDBHost, + Port: dbconfig.MariaDBPort, + User: dbconfig.MariaDBUser, + Password: dbconfig.MariaDBPassword, + DBName: dbName, + } + } + + return mysql2.DBConnection{ + Host: dbconfig.MySqLHost, + Port: dbconfig.MySQLPort, + User: dbconfig.MySQLUser, + Password: dbconfig.MySQLPassword, + DBName: dbName, + } } func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") @@ -54,7 +66,7 @@ func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -75,7 +87,7 @@ func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -98,7 +110,7 @@ func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -116,7 +128,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -137,7 +149,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -175,7 +187,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -203,7 +215,7 @@ func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -236,7 +248,7 @@ func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -277,7 +289,7 @@ func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -318,7 +330,7 @@ func TestGeneratorTemplate_Model_AddTags(t *testing.T) { func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). @@ -361,7 +373,7 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, - dbConnection, + dbConnection("dvds"), template.Default(postgres2.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 033f699d..a414df32 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -1,10 +1,10 @@ package mysql import ( - "fmt" "io/ioutil" "os" "os/exec" + "strconv" "testing" "github.com/go-jet/jet/v2/generator/mysql" @@ -19,13 +19,7 @@ const genTestDir3 = "./.gentestdata3/mysql" func TestGenerator(t *testing.T) { for i := 0; i < 3; i++ { - err := mysql.Generate(genTestDir3, mysql.DBConnection{ - Host: dbconfig.MySqLHost, - Port: dbconfig.MySQLPort, - User: dbconfig.MySQLUser, - Password: dbconfig.MySQLPassword, - DBName: "dvds", - }) + err := mysql.Generate(genTestDir3, dbConnection("dvds")) require.NoError(t, err) @@ -33,17 +27,11 @@ func TestGenerator(t *testing.T) { } for i := 0; i < 3; i++ { - dsn := fmt.Sprintf("%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", - dbconfig.MySQLUser, - dbconfig.MySQLPassword, - dbconfig.MySqLHost, - dbconfig.MySQLPort, - "dvds", - ) + dsn := dbconfig.MySQLConnectionString(sourceIsMariaDB(), "dvds") + err := mysql.GenerateDSN(dsn, genTestDir3) require.NoError(t, err) - assertGeneratedFiles(t) } @@ -55,8 +43,27 @@ func TestCmdGenerator(t *testing.T) { err := os.RemoveAll(genTestDir3) require.NoError(t, err) - cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306", - "-user=jet", "-password=jet", "-path="+genTestDir3) + var cmd *exec.Cmd + + if sourceIsMariaDB() { + cmd = exec.Command("jet", + "-source=MariaDB", + "-dbname=dvds", + "-host="+dbconfig.MariaDBHost, + "-port="+strconv.Itoa(dbconfig.MariaDBPort), + "-user="+dbconfig.MariaDBUser, + "-password="+dbconfig.MariaDBPassword, + "-path="+genTestDir3) + } else { + cmd = exec.Command("jet", + "-source=MySQL", + "-dbname=dvds", + "-host="+dbconfig.MySqLHost, + "-port="+strconv.Itoa(dbconfig.MySQLPort), + "-user="+dbconfig.MySQLUser, + "-password="+dbconfig.MySQLPassword, + "-path="+genTestDir3) + } cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -70,13 +77,7 @@ func TestCmdGenerator(t *testing.T) { require.NoError(t, err) // check that generation via DSN works - dsn := fmt.Sprintf("mysql://%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", - dbconfig.MySQLUser, - dbconfig.MySQLPassword, - dbconfig.MySqLHost, - dbconfig.MySQLPort, - "dvds", - ) + dsn := "mysql://" + dbconfig.MySQLConnectionString(sourceIsMariaDB(), "dvds") cmd = exec.Command("jet", "-dsn="+dsn, "-path="+genTestDir3) cmd.Stderr = os.Stderr @@ -84,9 +85,48 @@ func TestCmdGenerator(t *testing.T) { err = cmd.Run() require.NoError(t, err) +} - err = os.RemoveAll(genTestDirRoot) +func TestIgnoreTablesViewsEnums(t *testing.T) { + cmd := exec.Command("jet", + "-source=MySQL", + "-dbname=dvds", + "-host="+dbconfig.MySqLHost, + "-port="+strconv.Itoa(dbconfig.MySQLPort), + "-user="+dbconfig.MySQLUser, + "-password="+dbconfig.MySQLPassword, + "-ignore-tables=actor,ADDRESS,Category, city ,country,staff,store,rental", + "-ignore-views=actor_info,CUSTomER_LIST, film_list", + "-ignore-enums=film_list_rating,film_rating", + "-path="+genTestDir3) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err := cmd.Run() + require.NoError(t, err) + + tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "customer.go", "film.go", "film_actor.go", + "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go") + + viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "sales_by_store.go", "staff_list.go") + + enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, enumFiles, "nicer_but_slower_film_list_rating.go") + + modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "nicer_but_slower_film_list_rating.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "sales_by_store.go", "staff_list.go") } func assertGeneratedFiles(t *testing.T) { diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e2be933a..e04580d4 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/stretchr/testify/require" "math/rand" + "runtime" "time" _ "github.com/go-sql-driver/mysql" @@ -36,7 +37,7 @@ func TestMain(m *testing.M) { defer profile.Start().Stop() var err error - db, err = sql.Open("mysql", dbconfig.MySQLConnectionString) + db, err = sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), "")) if err != nil { panic("Failed to connect to test db" + err.Error()) } @@ -51,11 +52,21 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo jetmysql.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + jetmysql.SetQueryLogger(func(ctx context.Context, info jetmysql.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) } func requireLogged(t *testing.T, statement postgres.Statement) { @@ -65,8 +76,29 @@ func requireLogged(t *testing.T, statement postgres.Statement) { require.Equal(t, loggedDebugSQL, statement.DebugSql()) } +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) +} + func skipForMariaDB(t *testing.T) { if sourceIsMariaDB() { t.SkipNow() } } + +func beginTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 5a88acf0..39f0e431 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -38,6 +38,7 @@ WHERE actor.actor_id = ?; testutils.AssertDeepEqual(t, actor, actor2) requireLogged(t, query) + requireQueryLogged(t, query, 1) } var actor2 = model.Actor{ @@ -60,9 +61,9 @@ SELECT actor.actor_id AS "actor.actor_id", FROM dvds.actor ORDER BY actor.actor_id; `) - dest := []model.Actor{} + var dest []model.Actor - err := query.Query(db, &dest) + err := query.QueryContext(context.Background(), db, &dest) require.NoError(t, err) @@ -73,6 +74,7 @@ ORDER BY actor.actor_id; //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") requireLogged(t, query) + requireQueryLogged(t, query, 200) } func TestSelectGroupByHaving(t *testing.T) { @@ -153,6 +155,68 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; requireLogged(t, query) } +func TestAggregateFunctionDistinct(t *testing.T) { + stmt := SELECT( + Payment.CustomerID, + + COUNT(DISTINCT(Payment.Amount)).AS("distinct.count"), + SUM(DISTINCT(Payment.Amount)).AS("distinct.sum"), + AVG(DISTINCT(Payment.Amount)).AS("distinct.avg"), + MIN(DISTINCT(Payment.PaymentDate)).AS("distinct.first_payment_date"), + MAX(DISTINCT(Payment.PaymentDate)).AS("distinct.last_payment_date"), + ).FROM( + Payment, + ).WHERE( + Payment.CustomerID.EQ(Int(1)), + ).GROUP_BY( + Payment.CustomerID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT payment.customer_id AS "payment.customer_id", + COUNT(DISTINCT payment.amount) AS "distinct.count", + SUM(DISTINCT payment.amount) AS "distinct.sum", + AVG(DISTINCT payment.amount) AS "distinct.avg", + MIN(DISTINCT payment.payment_date) AS "distinct.first_payment_date", + MAX(DISTINCT payment.payment_date) AS "distinct.last_payment_date" +FROM dvds.payment +WHERE payment.customer_id = 1 +GROUP BY payment.customer_id; +`) + + type Distinct struct { + model.Payment + + Count int64 + Sum float64 + Avg float64 + FirstPaymentDate time.Time + LastPaymentDate time.Time + } + + var dest Distinct + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "PaymentID": 0, + "CustomerID": 1, + "StaffID": 0, + "RentalID": null, + "Amount": 0, + "PaymentDate": "0001-01-01T00:00:00Z", + "LastUpdate": "0001-01-01T00:00:00Z", + "Count": 8, + "Sum": 38.92, + "Avg": 4.865, + "FirstPaymentDate": "2005-05-25T11:30:37Z", + "LastPaymentDate": "2005-08-22T20:03:46Z" +} +`) + +} + func TestSubQuery(t *testing.T) { rRatingFilms := Film. @@ -389,8 +453,6 @@ LIMIT ?; ). LIMIT(1000) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, expectedSQL, int64(1000)) var dest []struct { @@ -414,12 +476,7 @@ LIMIT ?; err := query.Query(db, &dest) require.NoError(t, err) - //require.Equal(t, len(dest), 1) - //require.Equal(t, len(dest[0].Films), 10) - //require.Equal(t, len(dest[0].Films[0].Actors), 10) - //testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json") - testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/lang_film_actor_inventory_rental.json") } } diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index dc289245..ba628a1b 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -261,15 +261,22 @@ func TestUpdateExecContext(t *testing.T) { } func TestUpdateWithJoin(t *testing.T) { - query := table.Staff. - INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). + tx := beginTx(t) + defer tx.Rollback() + + statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). UPDATE(table.Staff.LastName). - SET(String("New name")). + SET(String("New staff name")). WHERE(table.Staff.StaffID.EQ(Int(1))) - //fmt.Println(query.DebugSql()) + testutils.AssertStatementSql(t, statement, ` +UPDATE dvds.staff +INNER JOIN dvds.address ON (address.address_id = staff.address_id) +SET last_name = ? +WHERE staff.staff_id = ?; +`, "New staff name", int64(1)) - _, err := query.Exec(db) + _, err := statement.Exec(tx) require.NoError(t, err) } diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index ddc1d123..cc8dfd6f 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -149,7 +149,26 @@ func TestWITH_And_DELETE(t *testing.T) { ), ) - //fmt.Println(stmt.DebugSql()) + // fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +WITH payments_to_delete AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM dvds.payment + WHERE payment.amount < 0.5 +) +DELETE FROM dvds.payment +WHERE payment.payment_id IN ( + SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_delete + ); +`, "''", "`")) tx, err := db.Begin() require.NoError(t, err) @@ -157,3 +176,119 @@ func TestWITH_And_DELETE(t *testing.T) { testutils.AssertExec(t, stmt, tx, 24) } + +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + // fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS ( + ( + SELECT ?, + ?, + ? + ) + UNION ALL + ( + SELECT fibonacci1.n1 + ?, + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci1.''fibN1'' + fibonacci1.''nextFibN1'' + FROM fibonacci1 + WHERE fibonacci1.n1 < ? + ) +),fibonacci2 AS ( + ( + SELECT ? AS "n2", + ? AS "fibN2", + ? AS "nextFibN2" + ) + UNION ALL + ( + SELECT fibonacci2.n2 + ?, + fibonacci2.''nextFibN2'' AS "nextFibN2", + fibonacci2.''fibN2'' + fibonacci2.''nextFibN2'' + FROM fibonacci2 + WHERE fibonacci2.n2 < ? + ) +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1.''fibN1'' AS "fibN1", + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2.''fibN2'' AS "fibN2", + fibonacci2.''nextFibN2'' AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = ?; +`, "''", "`")) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 82ac82bf..405ec9ec 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -225,7 +225,7 @@ func TestExpressionOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Integer.IS_NULL().AS("result.is_null"), AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), - AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), + AllTypes.SmallIntPtr.IN(Int8(11), Int8(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), Raw("CURRENT_USER").AS("result.raw"), @@ -233,14 +233,16 @@ func TestExpressionOperators(t *testing.T) { Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"), - AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) + //fmt.Println(query.Sql()) + testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", - (all_types.small_int_ptr IN ($1, $2)) AS "result.in", + (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", (all_types.small_int_ptr IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types @@ -248,14 +250,14 @@ SELECT all_types.integer IS NULL AS "result.is_null", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", - (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", + (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types )) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; -`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) +`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -359,6 +361,8 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.LT(String("Text")), AllTypes.Text.LT_EQ(AllTypes.VarChar), AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.BETWEEN(String("min"), String("max")), + AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.CONCAT(String("text2")), AllTypes.Text.CONCAT(Int(11)), AllTypes.Text.LIKE(String("abc")), @@ -450,13 +454,13 @@ func TestBoolOperators(t *testing.T) { testutils.AssertStatementSql(t, query, ` SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", - (all_types.boolean = $1) AS "EQ2", + (all_types.boolean = $1::boolean) AS "EQ2", (all_types.boolean != all_types.boolean_ptr) AS "NEq1", - (all_types.boolean != $2) AS "NEq2", + (all_types.boolean != $2::boolean) AS "NEq2", (all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1", - (all_types.boolean IS DISTINCT FROM $3) AS "distinct2", + (all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2", (all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1", - (all_types.boolean IS NOT DISTINCT FROM $4) AS "NOTDISTINCT2", + (all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2", all_types.boolean IS TRUE AS "ISTRUE", all_types.boolean IS NOT TRUE AS "isnottrue", all_types.boolean IS FALSE AS "is_False", @@ -511,24 +515,26 @@ func TestFloatOperators(t *testing.T) { AllTypes.Numeric.LT(Float(34.56)).AS("lt2"), AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), - - TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"), - TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"), - TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Int(2)).AS("sub1"), - TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int(2)).AS("sub2"), - TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int(2)).AS("mul1"), - TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int(2)).AS("mul2"), - TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int(2)).AS("div1"), - TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int(2)).AS("div2"), - TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int(2)).AS("mod1"), - TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int(2)).AS("mod2"), - TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int(2)).AS("pow1"), - TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int(2)).AS("pow2"), - - TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), - TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), - TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"), + AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"), + + TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Uint8(2)).AS("add1"), + TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int8(2)).AS("add2"), + TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Uint16(2)).AS("sub1"), + TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int16(2)).AS("sub2"), + TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int16(2)).AS("mul1"), + TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int32(2)).AS("mul2"), + TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int32(2)).AS("div1"), + TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int8(2)).AS("div2"), + TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int8(2)).AS("mod1"), + TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int8(2)).AS("mod2"), + TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int8(2)).AS("pow1"), + TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int8(2)).AS("pow2"), + + TRUNC(ABSf(AllTypes.Decimal), Int8(2)).AS("abs"), + TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int8(2)).AS("power"), + TRUNC(SQRT(AllTypes.Decimal), Int16(2)).AS("sqrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int8(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -536,12 +542,12 @@ func TestFloatOperators(t *testing.T) { ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"), SIGN(AllTypes.Real).AS("sign"), - TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), + TRUNC(AllTypes.Decimal, Int32(1)).AS("trunc"), ).LIMIT(2) - queryStr, _ := query.Sql() + //fmt.Println(query.Sql()) - require.Equal(t, queryStr, ` + testutils.AssertStatementSql(t, query, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = $1) AS "eq2", (all_types.real = $2) AS "eq3", @@ -555,30 +561,32 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.numeric < $8) AS "lt2", (all_types.numeric > $9) AS "gt1", (all_types.numeric > $10) AS "gt2", - TRUNC((all_types.decimal + all_types.decimal), $11) AS "add1", - TRUNC((all_types.decimal + $12), $13) AS "add2", - TRUNC((all_types.decimal - all_types.decimal_ptr), $14) AS "sub1", - TRUNC((all_types.decimal - $15), $16) AS "sub2", - TRUNC((all_types.decimal * all_types.decimal_ptr), $17) AS "mul1", - TRUNC((all_types.decimal * $18), $19) AS "mul2", - TRUNC((all_types.decimal / all_types.decimal_ptr), $20) AS "div1", - TRUNC((all_types.decimal / $21), $22) AS "div2", - TRUNC((all_types.decimal % all_types.decimal_ptr), $23) AS "mod1", - TRUNC((all_types.decimal % $24), $25) AS "mod2", - TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26) AS "pow1", - TRUNC(POW(all_types.decimal, $27), $28) AS "pow2", - TRUNC(ABS(all_types.decimal), $29) AS "abs", - TRUNC(POWER(all_types.decimal, $30), $31) AS "power", - TRUNC(SQRT(all_types.decimal), $32) AS "sqrt", - TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt", + (all_types.numeric BETWEEN $11 AND all_types.decimal) AS "between", + (all_types.numeric NOT BETWEEN (all_types.decimal * $12) AND $13) AS "not_between", + TRUNC(all_types.decimal + all_types.decimal, $14::smallint) AS "add1", + TRUNC(all_types.decimal + $15, $16::smallint) AS "add2", + TRUNC(all_types.decimal - all_types.decimal_ptr, $17::integer) AS "sub1", + TRUNC(all_types.decimal - $18, $19::smallint) AS "sub2", + TRUNC(all_types.decimal * all_types.decimal_ptr, $20::smallint) AS "mul1", + TRUNC(all_types.decimal * $21, $22::integer) AS "mul2", + TRUNC(all_types.decimal / all_types.decimal_ptr, $23::integer) AS "div1", + TRUNC(all_types.decimal / $24, $25::smallint) AS "div2", + TRUNC(all_types.decimal % all_types.decimal_ptr, $26::smallint) AS "mod1", + TRUNC(all_types.decimal % $27, $28::smallint) AS "mod2", + TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $29::smallint) AS "pow1", + TRUNC(POW(all_types.decimal, $30), $31::smallint) AS "pow2", + TRUNC(ABS(all_types.decimal), $32::smallint) AS "abs", + TRUNC(POWER(all_types.decimal, $33), $34::smallint) AS "power", + TRUNC(SQRT(all_types.decimal), $35::smallint) AS "sqrt", + TRUNC(CBRT(all_types.decimal)::decimal, $36::smallint) AS "cbrt", CEIL(all_types.real) AS "ceil", FLOOR(all_types.real) AS "floor", ROUND(all_types.decimal) AS "round1", ROUND(all_types.decimal, all_types.integer) AS "round2", SIGN(all_types.real) AS "sign", - TRUNC(all_types.decimal, $34) AS "trunc" + TRUNC(all_types.decimal, $37::integer) AS "trunc" FROM test_sample.all_types -LIMIT $35; +LIMIT $38; `) var dest []struct { @@ -590,6 +598,7 @@ LIMIT $35; require.NoError(t, err) //testutils.PrintJson(dest) + // testutils.SaveJSONFile(dest, "./testdata/results/common/float_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") } @@ -602,62 +611,50 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallIntPtr, AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), - AllTypes.BigInt.EQ(Int(12)).AS("eq2"), - + AllTypes.BigInt.EQ(Int64(12)).AS("eq2"), AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), - AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), - + AllTypes.BigInt.NOT_EQ(Int64(12)).AS("neq2"), AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), - AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), - + AllTypes.BigInt.IS_DISTINCT_FROM(Int32(12)).AS("distinct2"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), - AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int32(12)).AS("not distinct2"), + AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"), + AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), - AllTypes.BigInt.LT(Int(65)).AS("lt2"), - + AllTypes.BigInt.LT(Uint8(65)).AS("lt2"), AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), - AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), - + AllTypes.BigInt.LT_EQ(Uint16(65)).AS("lte2"), AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), - AllTypes.BigInt.GT(Int(65)).AS("gt2"), - + AllTypes.BigInt.GT(Uint32(65)).AS("gt2"), AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), - AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + AllTypes.BigInt.GT_EQ(Uint64(65)).AS("gte2"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(Int(11)).AS("add2"), - AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), - AllTypes.BigInt.SUB(Int(11)).AS("sub2"), - + AllTypes.BigInt.SUB(Int8(11)).AS("sub2"), AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), - AllTypes.BigInt.MUL(Int(11)).AS("mul2"), - + AllTypes.BigInt.MUL(Int16(11)).AS("mul2"), AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), - AllTypes.BigInt.DIV(Int(11)).AS("div2"), - + AllTypes.BigInt.DIV(Int32(11)).AS("div2"), AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), - AllTypes.BigInt.MOD(Int(11)).AS("mod2"), - - AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), - AllTypes.SmallInt.POW(Int(6)).AS("pow2"), + AllTypes.BigInt.MOD(Int64(11)).AS("mod2"), + AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int8(3))).AS("pow1"), + AllTypes.SmallInt.POW(Int8(6)).AS("pow2"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), - AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"), AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"), - AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-11)).AS("bit_not_2"), - AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), + AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), - AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"), @@ -666,7 +663,7 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, ` SELECT all_types.big_int AS "all_types.big_int", @@ -674,50 +671,52 @@ SELECT all_types.big_int AS "all_types.big_int", all_types.small_int AS "all_types.small_int", all_types.small_int_ptr AS "all_types.small_int_ptr", (all_types.big_int = all_types.big_int) AS "eq1", - (all_types.big_int = $1) AS "eq2", + (all_types.big_int = $1::bigint) AS "eq2", (all_types.big_int != all_types.big_int_ptr) AS "neq1", - (all_types.big_int != $2) AS "neq2", + (all_types.big_int != $2::bigint) AS "neq2", (all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1", - (all_types.big_int IS DISTINCT FROM $3) AS "distinct2", + (all_types.big_int IS DISTINCT FROM $3::integer) AS "distinct2", (all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1", - (all_types.big_int IS NOT DISTINCT FROM $4) AS "not distinct2", + (all_types.big_int IS NOT DISTINCT FROM $4::integer) AS "not distinct2", + (all_types.integer BETWEEN $5 AND $6) AS "between", + (all_types.integer NOT BETWEEN $7 AND $8) AS "not_between", (all_types.big_int < all_types.big_int_ptr) AS "lt1", - (all_types.big_int < $5) AS "lt2", + (all_types.big_int < $9::smallint) AS "lt2", (all_types.big_int <= all_types.big_int_ptr) AS "lte1", - (all_types.big_int <= $6) AS "lte2", + (all_types.big_int <= $10::integer) AS "lte2", (all_types.big_int > all_types.big_int_ptr) AS "gt1", - (all_types.big_int > $7) AS "gt2", + (all_types.big_int > $11::bigint) AS "gt2", (all_types.big_int >= all_types.big_int_ptr) AS "gte1", - (all_types.big_int >= $8) AS "gte2", + (all_types.big_int >= $12::bigint) AS "gte2", (all_types.big_int + all_types.big_int) AS "add1", - (all_types.big_int + $9) AS "add2", + (all_types.big_int + $13) AS "add2", (all_types.big_int - all_types.big_int) AS "sub1", - (all_types.big_int - $10) AS "sub2", + (all_types.big_int - $14::smallint) AS "sub2", (all_types.big_int * all_types.big_int) AS "mul1", - (all_types.big_int * $11) AS "mul2", + (all_types.big_int * $15::smallint) AS "mul2", (all_types.big_int / all_types.big_int) AS "div1", - (all_types.big_int / $12) AS "div2", + (all_types.big_int / $16::integer) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", - (all_types.big_int % $13) AS "mod2", - POW(all_types.small_int, (all_types.small_int / $14)) AS "pow1", - POW(all_types.small_int, $15) AS "pow2", + (all_types.big_int % $17::bigint) AS "mod2", + POW(all_types.small_int, all_types.small_int / $18::smallint) AS "pow1", + POW(all_types.small_int, $19::smallint) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2", (all_types.small_int | all_types.small_int) AS "bit or 1", - (all_types.small_int | $16) AS "bit or 2", + (all_types.small_int | $20) AS "bit or 2", (all_types.small_int # all_types.small_int) AS "bit xor 1", - (all_types.small_int # $17) AS "bit xor 2", - (~ ($18 * all_types.small_int)) AS "bit_not_1", + (all_types.small_int # $21) AS "bit xor 2", + (~ ($22 * all_types.small_int)) AS "bit_not_1", (~ -11) AS "bit_not_2", - (all_types.small_int << (all_types.small_int / $19)) AS "bit shift left 1", - (all_types.small_int << $20) AS "bit shift left 2", - (all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1", - (all_types.small_int >> $22) AS "bit shift right 2", + (all_types.small_int << (all_types.small_int / $23::smallint)) AS "bit shift left 1", + (all_types.small_int << $24) AS "bit shift left 2", + (all_types.small_int >> (all_types.small_int / $25)) AS "bit shift right 1", + (all_types.small_int >> $26) AS "bit shift right 2", ABS(all_types.big_int) AS "abs", SQRT(ABS(all_types.big_int)) AS "sqrt", CBRT(ABS(all_types.big_int)) AS "cbrt" FROM test_sample.all_types -LIMIT $23; +LIMIT $27; `) var dest []struct { @@ -728,7 +727,7 @@ LIMIT $23; require.NoError(t, err) - //testutils.SaveJsonFile("./testdata/common/int_operators.json", dest) + //testutils.SaveJSONFile(dest, "./testdata/results/common/int_operators.json") //testutils.PrintJson(dest) testutils.AssertJSONFile(t, dest, "./testdata/results/common/int_operators.json") } @@ -759,21 +758,18 @@ func TestTimeExpression(t *testing.T) { AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_DISTINCT_FROM(Time(23, 6, 6, 100)), - AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(23, 6, 6, 200)), - AllTypes.Time.LT(AllTypes.Time), AllTypes.Time.LT(Time(23, 6, 6, 22)), - AllTypes.Time.LT_EQ(AllTypes.Time), AllTypes.Time.LT_EQ(Time(23, 6, 6, 33)), - AllTypes.Time.GT(AllTypes.Time), AllTypes.Time.GT(Time(23, 6, 6, 0)), - AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)), + AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), TimeT(time.Now())), + AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, AllTypes.Time.ADD(INTERVAL(2, HOUR))), AllTypes.Date.ADD(INTERVAL(1, HOUR)), AllTypes.Date.SUB(INTERVAL(1, MINUTE)), @@ -781,12 +777,20 @@ func TestTimeExpression(t *testing.T) { AllTypes.Time.SUB(INTERVAL(1, MINUTE)), AllTypes.Timez.ADD(INTERVAL(1, HOUR)), AllTypes.Timez.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timez.BETWEEN(TimezT(time.Now()), AllTypes.TimezPtr), + AllTypes.Timez.NOT_BETWEEN(AllTypes.Timez, TimezT(time.Now())), AllTypes.Timestamp.ADD(INTERVAL(1, HOUR)), AllTypes.Timestamp.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestamp.BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())), + AllTypes.Timestamp.NOT_BETWEEN(TimestampT(time.Now()), AllTypes.TimestampPtr), AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)), AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestamp.BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())), + AllTypes.Timestamp.NOT_BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())), AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()), + AllTypes.Date.BETWEEN(Date(2000, 2, 2), DateT(time.Now())), + AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, DateT(time.Now().Add(20*time.Hour))), CURRENT_DATE(), CURRENT_TIME(), @@ -847,6 +851,8 @@ func TestInterval(t *testing.T) { AllTypes.Interval.LT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), AllTypes.Interval.GT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), AllTypes.Interval.GT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), + AllTypes.Interval.BETWEEN(INTERVAL(1, HOUR), INTERVAL(2, HOUR)), + AllTypes.Interval.NOT_BETWEEN(AllTypes.IntervalPtr, INTERVALd(30*time.Second)), AllTypes.Interval.ADD(AllTypes.IntervalPtr).EQ(INTERVALd(17*time.Second)), AllTypes.Interval.SUB(AllTypes.IntervalPtr).EQ(INTERVAL(100, MICROSECOND)), AllTypes.IntervalPtr.MUL(Int(11)).EQ(AllTypes.Interval), diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 553f920a..684abc2c 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC; testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) requireLogged(t, stmt) + requireQueryLogged(t, stmt, 347) } func TestJoinEverything(t *testing.T) { @@ -101,12 +102,341 @@ func TestJoinEverything(t *testing.T) { } } - err := stmt.Query(db, &dest) + testutils.AssertStatementSql(t, stmt, ` +SELECT "Artist"."ArtistId" AS "Artist.ArtistId", + "Artist"."Name" AS "Artist.Name", + "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId", + "Track"."TrackId" AS "Track.TrackId", + "Track"."Name" AS "Track.Name", + "Track"."AlbumId" AS "Track.AlbumId", + "Track"."MediaTypeId" AS "Track.MediaTypeId", + "Track"."GenreId" AS "Track.GenreId", + "Track"."Composer" AS "Track.Composer", + "Track"."Milliseconds" AS "Track.Milliseconds", + "Track"."Bytes" AS "Track.Bytes", + "Track"."UnitPrice" AS "Track.UnitPrice", + "Genre"."GenreId" AS "Genre.GenreId", + "Genre"."Name" AS "Genre.Name", + "MediaType"."MediaTypeId" AS "MediaType.MediaTypeId", + "MediaType"."Name" AS "MediaType.Name", + "PlaylistTrack"."PlaylistId" AS "PlaylistTrack.PlaylistId", + "PlaylistTrack"."TrackId" AS "PlaylistTrack.TrackId", + "Playlist"."PlaylistId" AS "Playlist.PlaylistId", + "Playlist"."Name" AS "Playlist.Name", + "Invoice"."InvoiceId" AS "Invoice.InvoiceId", + "Invoice"."CustomerId" AS "Invoice.CustomerId", + "Invoice"."InvoiceDate" AS "Invoice.InvoiceDate", + "Invoice"."BillingAddress" AS "Invoice.BillingAddress", + "Invoice"."BillingCity" AS "Invoice.BillingCity", + "Invoice"."BillingState" AS "Invoice.BillingState", + "Invoice"."BillingCountry" AS "Invoice.BillingCountry", + "Invoice"."BillingPostalCode" AS "Invoice.BillingPostalCode", + "Invoice"."Total" AS "Invoice.Total", + "Customer"."CustomerId" AS "Customer.CustomerId", + "Customer"."FirstName" AS "Customer.FirstName", + "Customer"."LastName" AS "Customer.LastName", + "Customer"."Company" AS "Customer.Company", + "Customer"."Address" AS "Customer.Address", + "Customer"."City" AS "Customer.City", + "Customer"."State" AS "Customer.State", + "Customer"."Country" AS "Customer.Country", + "Customer"."PostalCode" AS "Customer.PostalCode", + "Customer"."Phone" AS "Customer.Phone", + "Customer"."Fax" AS "Customer.Fax", + "Customer"."Email" AS "Customer.Email", + "Customer"."SupportRepId" AS "Customer.SupportRepId", + "Employee"."EmployeeId" AS "Employee.EmployeeId", + "Employee"."LastName" AS "Employee.LastName", + "Employee"."FirstName" AS "Employee.FirstName", + "Employee"."Title" AS "Employee.Title", + "Employee"."ReportsTo" AS "Employee.ReportsTo", + "Employee"."BirthDate" AS "Employee.BirthDate", + "Employee"."HireDate" AS "Employee.HireDate", + "Employee"."Address" AS "Employee.Address", + "Employee"."City" AS "Employee.City", + "Employee"."State" AS "Employee.State", + "Employee"."Country" AS "Employee.Country", + "Employee"."PostalCode" AS "Employee.PostalCode", + "Employee"."Phone" AS "Employee.Phone", + "Employee"."Fax" AS "Employee.Fax", + "Employee"."Email" AS "Employee.Email", + "Manager"."EmployeeId" AS "Manager.EmployeeId", + "Manager"."LastName" AS "Manager.LastName", + "Manager"."FirstName" AS "Manager.FirstName", + "Manager"."Title" AS "Manager.Title", + "Manager"."ReportsTo" AS "Manager.ReportsTo", + "Manager"."BirthDate" AS "Manager.BirthDate", + "Manager"."HireDate" AS "Manager.HireDate", + "Manager"."Address" AS "Manager.Address", + "Manager"."City" AS "Manager.City", + "Manager"."State" AS "Manager.State", + "Manager"."Country" AS "Manager.Country", + "Manager"."PostalCode" AS "Manager.PostalCode", + "Manager"."Phone" AS "Manager.Phone", + "Manager"."Fax" AS "Manager.Fax", + "Manager"."Email" AS "Manager.Email" +FROM chinook."Artist" + LEFT JOIN chinook."Album" ON ("Artist"."ArtistId" = "Album"."ArtistId") + LEFT JOIN chinook."Track" ON ("Track"."AlbumId" = "Album"."AlbumId") + LEFT JOIN chinook."Genre" ON ("Genre"."GenreId" = "Track"."GenreId") + LEFT JOIN chinook."MediaType" ON ("MediaType"."MediaTypeId" = "Track"."MediaTypeId") + LEFT JOIN chinook."PlaylistTrack" ON ("PlaylistTrack"."TrackId" = "Track"."TrackId") + LEFT JOIN chinook."Playlist" ON ("Playlist"."PlaylistId" = "PlaylistTrack"."PlaylistId") + LEFT JOIN chinook."InvoiceLine" ON ("InvoiceLine"."TrackId" = "Track"."TrackId") + LEFT JOIN chinook."Invoice" ON ("Invoice"."InvoiceId" = "InvoiceLine"."InvoiceId") + LEFT JOIN chinook."Customer" ON ("Customer"."CustomerId" = "Invoice"."CustomerId") + LEFT JOIN chinook."Employee" ON ("Employee"."EmployeeId" = "Customer"."SupportRepId") + LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo") +ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId"; +`) + + err := stmt.QueryContext(context.Background(), db, &dest) require.NoError(t, err) require.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") requireLogged(t, stmt) + requireQueryLogged(t, stmt, 9423) +} + +// default column aliases from sub-CTEs are bubbled up to the main query, +// cte name does not affect default column alias in main query +func TestSubQueryColumnAliasBubbling(t *testing.T) { + subQuery1 := SELECT( + Artist.AllColumns, + String("custom_column_1").AS("custom_column_1"), + ).FROM( + Artist, + ).ORDER_BY( + Artist.ArtistId.ASC(), + ).AsTable("subQuery1") + + subQuery2 := SELECT( + subQuery1.AllColumns(), + String("custom_column_2").AS("custom_column_2"), + ).FROM( + subQuery1, + ).AsTable("subQuery2") + + mainQuery := SELECT( + subQuery2.AllColumns(), // columns will have the same alias as in the sub-query + subQuery2.AllColumns().As("artist2.*"), // all column aliases will be changed to artist2.* + subQuery2.AllColumns().Except(Artist.Name).As("artist3.*"), + subQuery2.AllColumns().Except( + Artist.MutableColumns, + StringColumn("custom_column_1").From(subQuery2), // custom_column_1 appears with the same alias in subQuery2 + StringColumn("custom_column_2").From(subQuery2), + ).As("artist4.*"), + ).FROM( + subQuery2, + ) + + // fmt.Println(mainQuery.Sql()) + + testutils.AssertStatementSql(t, mainQuery, ` +SELECT "subQuery2"."Artist.ArtistId" AS "Artist.ArtistId", + "subQuery2"."Artist.Name" AS "Artist.Name", + "subQuery2".custom_column_1 AS "custom_column_1", + "subQuery2".custom_column_2 AS "custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist2.ArtistId", + "subQuery2"."Artist.Name" AS "artist2.Name", + "subQuery2".custom_column_1 AS "artist2.custom_column_1", + "subQuery2".custom_column_2 AS "artist2.custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist3.ArtistId", + "subQuery2".custom_column_1 AS "artist3.custom_column_1", + "subQuery2".custom_column_2 AS "artist3.custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist4.ArtistId" +FROM ( + SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId", + "subQuery1"."Artist.Name" AS "Artist.Name", + "subQuery1".custom_column_1 AS "custom_column_1", + $1 AS "custom_column_2" + FROM ( + SELECT "Artist"."ArtistId" AS "Artist.ArtistId", + "Artist"."Name" AS "Artist.Name", + $2 AS "custom_column_1" + FROM chinook."Artist" + ORDER BY "Artist"."ArtistId" ASC + ) AS "subQuery1" + ) AS "subQuery2"; +`) + var dest []struct { + // subQuery2.AllColumns() + Artist1 struct { + model.Artist + + CustomColumn1 string + CustomColumn2 string + } + + // subQuery2.AllColumns().As("artist2.*") + Artist2 struct { + model.Artist `alias:"artist2.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist2.*"` + + // subQuery2.AllColumns().Except(Artist.Name).As("artist3.*") + Artist3 struct { + model.Artist `alias:"artist3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist3.*"` + + // subQuery2.AllColumns().Except(...).As("artist4.*") + Artist4 struct { + model.Artist `alias:"artist4.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist4.*"` + } + + err := mainQuery.Query(db, &dest) + require.NoError(t, err) + + // Artist1 + require.Len(t, dest, 275) + require.Equal(t, dest[0].Artist1.Artist, model.Artist{ + ArtistId: 1, + Name: testutils.StringPtr("AC/DC"), + }) + require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2") + + // Artist2 + require.Equal(t, testutils.ToJSON(dest[0].Artist1), testutils.ToJSON(dest[0].Artist2)) + + // Artist3 + require.Equal(t, dest[0].Artist3.ArtistId, int32(1)) + require.Nil(t, dest[0].Artist3.Name) + require.Equal(t, dest[0].Artist3.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Artist3.CustomColumn2, "custom_column_2") + + // Artist4 + require.Equal(t, dest[0].Artist3.Artist, dest[0].Artist4.Artist) + require.Equal(t, dest[0].Artist4.CustomColumn1, "") + require.Equal(t, dest[0].Artist4.CustomColumn2, "") +} + +func TestUnAliasedNamesPanicError(t *testing.T) { + subQuery1 := SELECT( + Artist.AllColumns, + Artist.Name.CONCAT(String("-musician")), //alias missing + ).FROM( + Artist, + ).ORDER_BY( + Artist.ArtistId.ASC(), + ).AsTable("subQuery1") + + require.Panics(t, func() { + SELECT( + subQuery1.AllColumns(), // panic, column not aliased + ).FROM( + subQuery1, + ) + }, "jet: can't export unaliased expression subQuery: subQuery1, expression: (\"Artist\".\"Name\" || '-musician')") +} + +func TestProjectionListReAliasing(t *testing.T) { + projectionList := ProjectionList{ + Track.GenreId, + SUM(Track.Milliseconds).AS("duration"), + MAX(Track.Milliseconds).AS("duration.max"), + } + + stmt := SELECT( + projectionList.As("genre_info"), + ).FROM( + Track, + ).WHERE( + Track.GenreId.LT(Int(5)), + ).GROUP_BY( + Track.GenreId, + ).ORDER_BY( + Track.GenreId, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "Track"."GenreId" AS "genre_info.GenreId", + SUM("Track"."Milliseconds") AS "genre_info.duration", + MAX("Track"."Milliseconds") AS "genre_info.max" +FROM chinook."Track" +WHERE "Track"."GenreId" < 5 +GROUP BY "Track"."GenreId" +ORDER BY "Track"."GenreId"; +`) + + type GenreInfo struct { + GenreID string + Duration int64 + Max int64 + } + + var dest []GenreInfo + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + expectedSQL := ` +[ + { + "GenreID": "1", + "Duration": 368231326, + "Max": 1612329 + }, + { + "GenreID": "2", + "Duration": 37928199, + "Max": 907520 + }, + { + "GenreID": "3", + "Duration": 115846292, + "Max": 816509 + }, + { + "GenreID": "4", + "Duration": 77805478, + "Max": 558602 + } +] +` + testutils.AssertJSON(t, dest, expectedSQL) + + subQuery := stmt.AsTable("subQuery") + + mainStmt := SELECT( + subQuery.AllColumns().As("genre_information.*"), + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, mainStmt, ` +SELECT "subQuery"."genre_info.GenreId" AS "genre_information.GenreId", + "subQuery"."genre_info.duration" AS "genre_information.duration", + "subQuery"."genre_info.max" AS "genre_information.max" +FROM ( + SELECT "Track"."GenreId" AS "genre_info.GenreId", + SUM("Track"."Milliseconds") AS "genre_info.duration", + MAX("Track"."Milliseconds") AS "genre_info.max" + FROM chinook."Track" + WHERE "Track"."GenreId" < 5 + GROUP BY "Track"."GenreId" + ORDER BY "Track"."GenreId" + ) AS "subQuery"; +`) + + type GenreInformation GenreInfo + var newDest []GenreInformation + + err = mainStmt.Query(db, &newDest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, expectedSQL) } func TestSelfJoin(t *testing.T) { @@ -413,3 +743,53 @@ var album347 = model.Album{ Title: "Koyaanisqatsi (Soundtrack from the Motion Picture)", ArtistId: 275, } + +func TestAggregateFunc(t *testing.T) { + stmt := SELECT( + PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), + PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), + PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")). + WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"), + + PERCENTILE_CONT(Float(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), + PERCENTILE_CONT(Float(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_int"), + + MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"), + ).FROM( + Invoice, + ).GROUP_BY( + Invoice.Total, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT PERCENTILE_DISC ($1::double precision) WITHIN GROUP (ORDER BY "Invoice"."InvoiceId") AS "percentile_disc_1", + PERCENTILE_DISC ("Invoice"."Total" / $2) WITHIN GROUP (ORDER BY "Invoice"."InvoiceDate" ASC) AS "percentile_disc_2", + PERCENTILE_DISC ((select array_agg(s) from generate_series(0, 1, 0.2) as s)) WITHIN GROUP (ORDER BY "Invoice"."BillingAddress" DESC) AS "percentile_disc_3", + PERCENTILE_CONT ($3::double precision) WITHIN GROUP (ORDER BY "Invoice"."Total") AS "percentile_cont_1", + PERCENTILE_CONT ($4::double precision) WITHIN GROUP (ORDER BY INTERVAL '1 HOUR' DESC) AS "percentile_cont_int", + MODE () WITHIN GROUP (ORDER BY "Invoice"."BillingPostalCode" DESC) AS "mode_1" +FROM chinook."Invoice" +GROUP BY "Invoice"."Total"; +`, 0.1, 100.0, 0.3, 0.2) + + var dest struct { + PercentileDisc1 string + PercentileDisc2 string + PercentileDisc3 string + PercentileCont1 string + Mode1 string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "PercentileDisc1": "41", + "PercentileDisc2": "2009-01-19T00:00:00Z", + "PercentileDisc3": "{\"Via Degli Scipioni, 43\",\"Qe 7 Bloco G\",\"Berger Stra�e 10\",\"696 Osborne Street\",\"2211 W Berry Street\",\"1033 N Park Ave\"}", + "PercentileCont1": "0.99", + "Mode1": "X1A 1N6" +} +`) +} diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index 5115a3f7..47637e1e 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -4,6 +4,8 @@ import ( "context" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/require" @@ -23,7 +25,14 @@ WHERE link.name IN ('Gmail', 'Outlook'); WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") - AssertExec(t, deleteStmt, 2) + + res, err := deleteStmt.ExecContext(context.Background(), db) + + require.NoError(t, err) + rows, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, rows, int64(2)) + requireQueryLogged(t, deleteStmt, int64(2)) } func TestDeleteWithWhereAndReturning(t *testing.T) { @@ -103,3 +112,72 @@ func TestDeleteExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") requireLogged(t, deleteStmt) } + +func TestDeleteFrom(t *testing.T) { + tx := beginTx(t) + defer tx.Rollback() + + stmt := table.Rental.DELETE(). + USING( + table.Staff. + INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)), + table.Actor, + ). + WHERE( + table.Staff.StaffID.EQ(table.Rental.StaffID). + AND(table.Staff.StaffID.EQ(Int(2))). + AND(table.Rental.RentalID.LT(Int(10))), + ). + RETURNING( + table.Rental.AllColumns, + table.Store.AllColumns, + ) + + testutils.AssertStatementSql(t, stmt, ` +DELETE FROM dvds.rental +USING dvds.staff + INNER JOIN dvds.store ON (store.store_id = staff.staff_id), + dvds.actor +WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $1)) AND (rental.rental_id < $2) +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update", + store.store_id AS "store.store_id", + store.manager_staff_id AS "store.manager_staff_id", + store.address_id AS "store.address_id", + store.last_update AS "store.last_update"; +`) + + var dest []struct { + Rental model2.Rental + Store model2.Store + } + + err := stmt.Query(tx, &dest) + + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` +{ + "Rental": { + "RentalID": 4, + "RentalDate": "2005-05-24T23:04:41Z", + "InventoryID": 2452, + "CustomerID": 333, + "ReturnDate": "2005-06-03T01:43:41Z", + "StaffID": 2, + "LastUpdate": "2006-02-16T02:30:53Z" + }, + "Store": { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z" + } +} +`) +} diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 85dd01e4..e8f59c77 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -27,7 +27,7 @@ var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", var dbConnection = postgres.DBConnection{ Host: dbconfig.PgHost, - Port: 5432, + Port: dbconfig.PgPort, User: dbconfig.PgUser, Password: dbconfig.PgPassword, DBName: dbconfig.PgDBName, diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 77c8aee1..852745bd 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -7,6 +7,7 @@ import ( "os/exec" "path/filepath" "reflect" + "strconv" "testing" "github.com/go-jet/jet/v2/generator/postgres" @@ -52,8 +53,13 @@ func TestCmdGenerator(t *testing.T) { err := os.RemoveAll(genTestDir2) require.NoError(t, err) - cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432", - "-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2) + cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", + "-port="+strconv.Itoa(dbconfig.PgPort), + "-user=jet", + "-password=jet", + "-schema=dvds", + "-path="+genTestDir2) + cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -86,6 +92,59 @@ func TestCmdGenerator(t *testing.T) { require.NoError(t, err) } +func TestGeneratorIgnoreTables(t *testing.T) { + err := os.RemoveAll(genTestDir2) + require.NoError(t, err) + + cmd := exec.Command("jet", + "-source=PostgreSQL", + "-host=localhost", + "-port="+strconv.Itoa(dbconfig.PgPort), + "-user=jet", + "-password=jet", + "-dbname=jetdb", + "-schema=dvds", + "-ignore-tables=actor,ADDRESS,country, Film , cITY,", + "-ignore-views=Actor_info, FILM_LIST ,staff_list", + "-ignore-enums=mpaa_rating", + "-path="+genTestDir2) + + fmt.Println(cmd.Args) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + // Table SQL Builder files + tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "category.go", + "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go") + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go") + + // Enums SQL Builder files + _, err = ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") + require.Error(t, err, "open ./.gentestdata2/jetdb/dvds/enum: no such file or directory") + + modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "category.go", + "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go") +} + func TestGenerator(t *testing.T) { for i := 0; i < 3; i++ { diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 4e8aade8..aa05e0fa 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/v2/tests/internal/utils/repo" "math/rand" "os" + "runtime" "testing" "time" @@ -59,11 +60,21 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo postgres.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + postgres.SetQueryLogger(func(ctx context.Context, info postgres.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) } func requireLogged(t *testing.T, statement postgres.Statement) { @@ -73,6 +84,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) { require.Equal(t, loggedDebugSQL, statement.DebugSql()) } +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) +} + func skipForPgxDriver(t *testing.T) { if isPgxDriver() { t.SkipNow() @@ -87,3 +113,9 @@ func isPgxDriver() bool { return false } + +func beginTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go index a193258c..4bbf90c5 100644 --- a/tests/postgres/raw_statements_test.go +++ b/tests/postgres/raw_statements_test.go @@ -9,7 +9,7 @@ import ( "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" - model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/postgres" ) diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index ce3cc46b..61b7becc 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -62,7 +62,8 @@ func TestScanToValidDestination(t *testing.T) { t.Run("global query function scan", func(t *testing.T) { queryStr, args := oneInventoryQuery.Sql() dest := []struct{}{} - err := qrm.Query(nil, db, queryStr, args, &dest) + rowProcessed, err := qrm.Query(nil, db, queryStr, args, &dest) + require.Equal(t, rowProcessed, int64(1)) require.NoError(t, err) }) @@ -782,6 +783,7 @@ func TestRowsScan(t *testing.T) { require.NoError(t, err) requireLogged(t, stmt) + requireQueryLogged(t, stmt, 0) } func TestScanNumericToFloat(t *testing.T) { diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 96359296..b3d3e63d 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -48,6 +48,63 @@ WHERE actor.actor_id = 2; requireLogged(t, query) } +func TestSelectDistinctOn(t *testing.T) { + + stmt := SELECT( + Rental.StaffID, + Rental.CustomerID, + Rental.RentalID, + ).DISTINCT( + Rental.StaffID, + Rental.CustomerID, + ).FROM( + Rental, + ).WHERE( + Rental.CustomerID.LT(Int(2)), + ).ORDER_BY( + Rental.StaffID.ASC(), + Rental.CustomerID.ASC(), + Rental.RentalID.ASC(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT DISTINCT ON (rental.staff_id, rental.customer_id) rental.staff_id AS "rental.staff_id", + rental.customer_id AS "rental.customer_id", + rental.rental_id AS "rental.rental_id" +FROM dvds.rental +WHERE rental.customer_id < 2 +ORDER BY rental.staff_id ASC, rental.customer_id ASC, rental.rental_id ASC; +`) + + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertJSON(t, dest, ` +[ + { + "RentalID": 573, + "RentalDate": "0001-01-01T00:00:00Z", + "InventoryID": 0, + "CustomerID": 1, + "ReturnDate": null, + "StaffID": 1, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 76, + "RentalDate": "0001-01-01T00:00:00Z", + "InventoryID": 0, + "CustomerID": 1, + "ReturnDate": null, + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +] +`) +} + func TestClassicSelect(t *testing.T) { expectedSQL := ` SELECT payment.payment_id AS "payment.payment_id", @@ -814,10 +871,10 @@ ORDER BY f1.film_id ASC; type F1 model.Film type F2 model.Film - theSameLengthFilms := []struct { + var theSameLengthFilms []struct { F1 F1 F2 F2 - }{} + } err := query.Query(db, &theSameLengthFilms) @@ -858,68 +915,124 @@ LIMIT 1000; Title2 string Length int16 } - films := []thesameLengthFilms{} + var films []thesameLengthFilms err := query.Query(db, &films) require.NoError(t, err) - //spew.Dump(films) - require.Equal(t, len(films), 1000) testutils.AssertDeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) } func TestSubQuery(t *testing.T) { - expectedQuery := ` -SELECT actor.actor_id AS "actor.actor_id", - actor.first_name AS "actor.first_name", - actor.last_name AS "actor.last_name", - actor.last_update AS "actor.last_update", - film_actor.actor_id AS "film_actor.actor_id", - film_actor.film_id AS "film_actor.film_id", - film_actor.last_update AS "film_actor.last_update", - "rFilms"."film.film_id" AS "film.film_id", - "rFilms"."film.title" AS "film.title", - "rFilms"."film.rating" AS "film.rating" -FROM dvds.actor - INNER JOIN dvds.film_actor ON (actor.actor_id = film_actor.film_id) - INNER JOIN ( - SELECT film.film_id AS "film.film_id", - film.title AS "film.title", - film.rating AS "film.rating" - FROM dvds.film - WHERE film.rating = 'R' - ) AS "rFilms" ON (film_actor.film_id = "rFilms"."film.film_id"); -` - - rRatingFilms := Film. + rRatingFilms := SELECT( Film.FilmID, Film.Title, Film.Rating, - ). - WHERE(Film.Rating.EQ(enum.MpaaRating.R)). - AsTable("rFilms") + ).FROM( + Film, + ).WHERE( + Film.Rating.EQ(enum.MpaaRating.R), + ).AsTable("rFilms") rFilmID := Film.FilmID.From(rRatingFilms) - query := Actor. - INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.FilmID)). - INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmID)). + stmt := SELECT( + rRatingFilms.AllColumns(), Actor.AllColumns, FilmActor.AllColumns, - rRatingFilms.AllColumns(), + ).FROM( + rRatingFilms. + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)). + INNER_JOIN(Actor, FilmActor.ActorID.EQ(Actor.ActorID)), + ).WHERE( + rFilmID.LT(Int(50)), + ).ORDER_BY( + rFilmID.ASC(), + Actor.ActorID.ASC(), ) - testutils.AssertDebugStatementSql(t, query, expectedQuery) + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "rFilms"."film.film_id" AS "film.film_id", + "rFilms"."film.title" AS "film.title", + "rFilms"."film.rating" AS "film.rating", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update" +FROM ( + SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.rating AS "film.rating" + FROM dvds.film + WHERE film.rating = 'R' + ) AS "rFilms" + INNER JOIN dvds.film_actor ON (film_actor.film_id = "rFilms"."film.film_id") + INNER JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id) +WHERE "rFilms"."film.film_id" < 50 +ORDER BY "rFilms"."film.film_id" ASC, actor.actor_id ASC; +`) - dest := []model.Actor{} + var dest []struct { + model.Film - err := query.Query(db, &dest) + Actors []model.Actor + } + err := stmt.Query(db, &dest) require.NoError(t, err) + require.Len(t, dest, 10) + + testutils.AssertJSON(t, dest[0], ` +{ + "FilmID": 8, + "Title": "Airport Pollock", + "Description": null, + "ReleaseYear": null, + "LanguageID": 0, + "RentalDuration": 0, + "RentalRate": 0, + "Length": null, + "ReplacementCost": 0, + "Rating": "R", + "LastUpdate": "0001-01-01T00:00:00Z", + "SpecialFeatures": null, + "Fulltext": "", + "Actors": [ + { + "ActorID": 55, + "FirstName": "Fay", + "LastName": "Kilmer", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 96, + "FirstName": "Gene", + "LastName": "Willis", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 110, + "FirstName": "Susan", + "LastName": "Davis", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 138, + "FirstName": "Lucille", + "LastName": "Dee", + "LastUpdate": "2013-05-26T14:47:57.62Z" + } + ] +} +`) + } func TestSelectFunctions(t *testing.T) { @@ -1078,6 +1191,66 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") } +func TestAggregateFunctionDistinct(t *testing.T) { + stmt := SELECT( + Payment.CustomerID, + + COUNT(DISTINCT(Payment.Amount)).AS("distinct.count"), + SUM(DISTINCT(Payment.Amount)).AS("distinct.sum"), + AVG(DISTINCT(Payment.Amount)).AS("distinct.avg"), + MIN(DISTINCT(Payment.PaymentDate)).AS("distinct.first_payment_date"), + MAX(DISTINCT(Payment.PaymentDate)).AS("distinct.last_payment_date"), + ).FROM( + Payment, + ).WHERE( + Payment.CustomerID.EQ(Int(1)), + ).GROUP_BY( + Payment.CustomerID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT payment.customer_id AS "payment.customer_id", + COUNT(DISTINCT payment.amount) AS "distinct.count", + SUM(DISTINCT payment.amount) AS "distinct.sum", + AVG(DISTINCT payment.amount) AS "distinct.avg", + MIN(DISTINCT payment.payment_date) AS "distinct.first_payment_date", + MAX(DISTINCT payment.payment_date) AS "distinct.last_payment_date" +FROM dvds.payment +WHERE payment.customer_id = 1 +GROUP BY payment.customer_id; +`) + + type Distinct struct { + model.Payment + + Count int64 + Sum float64 + Avg float64 + FirstPaymentDate time.Time + LastPaymentDate time.Time + } + + var dest Distinct + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "PaymentID": 0, + "CustomerID": 1, + "StaffID": 0, + "RentalID": 0, + "Amount": 0, + "PaymentDate": "0001-01-01T00:00:00Z", + "Count": 8, + "Sum": 38.92, + "Avg": 4.865, + "FirstPaymentDate": "2007-02-14T23:22:38.996577Z", + "LastPaymentDate": "2007-04-30T01:10:44.996577Z" +} +`) +} + func TestSelectGroupBy2(t *testing.T) { expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", @@ -1887,7 +2060,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer -WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); +WHERE ($1::boolean AND (customer.customer_id = $2)) AND (customer.activebool = $3::boolean); `, true, int64(1), true) dest := []model.Customer{} @@ -2056,3 +2229,353 @@ FROM dvds.address; require.Len(t, dest, 603) }) } + +type FilmWrap struct { + model.Film + + Actors []ActorWrap +} + +type ActorWrap struct { + model.Actor + + Films []FilmWrap +} + +func TestRecursionScanNxM(t *testing.T) { + + stmt := SELECT( + Actor.AllColumns, + Film.AllColumns, + ).FROM( + Actor. + INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)). + INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)), + ).ORDER_BY( + Actor.ActorID, + Film.FilmID, + ).LIMIT(100) + + t.Run("film->actors", func(t *testing.T) { + var films []FilmWrap + err := stmt.Query(db, &films) + + require.NoError(t, err) + require.Len(t, films, 95) + testutils.AssertJSON(t, films[:2], ` +[ + { + "FilmID": 1, + "Title": "Academy Dinosaur", + "Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 0.99, + "Length": 86, + "ReplacementCost": 20.99, + "Rating": "PG", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", + "Actors": [ + { + "ActorID": 1, + "FirstName": "Penelope", + "LastName": "Guiness", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + } + ] + }, + { + "FilmID": 23, + "Title": "Anaconda Confessions", + "Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 3, + "RentalRate": 0.99, + "Length": 92, + "ReplacementCost": 9.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", + "Actors": [ + { + "ActorID": 1, + "FirstName": "Penelope", + "LastName": "Guiness", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + }, + { + "ActorID": 4, + "FirstName": "Jennifer", + "LastName": "Davis", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + } + ] + } +] +`) + + }) + + t.Run("actors->films", func(t *testing.T) { + var actors []ActorWrap + + err := stmt.Query(db, &actors) + + require.NoError(t, err) + require.Equal(t, len(actors), 5) + require.Equal(t, actors[0].ActorID, int32(1)) + require.Equal(t, actors[0].FirstName, "Penelope") + require.Len(t, actors[0].Films, 19) + testutils.AssertJSON(t, actors[0].Films[:2], ` +[ + { + "FilmID": 1, + "Title": "Academy Dinosaur", + "Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 0.99, + "Length": 86, + "ReplacementCost": 20.99, + "Rating": "PG", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", + "Actors": null + }, + { + "FilmID": 23, + "Title": "Anaconda Confessions", + "Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 3, + "RentalRate": 0.99, + "Length": 92, + "ReplacementCost": 9.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", + "Actors": null + } +] +`) + }) +} + +type StoreWrap struct { + model.Store + + Staffs []StaffWrap +} + +type StaffWrap struct { + model.Staff + + Store StoreWrap +} + +func TestRecursionScanNx1(t *testing.T) { + stmt := SELECT( + Store.AllColumns, + Staff.AllColumns, + ).FROM( + Store. + INNER_JOIN(Staff, Staff.StoreID.EQ(Store.StoreID)), + ).ORDER_BY( + Store.StoreID, + Staff.StaffID, + ) + + t.Run("store->staff", func(t *testing.T) { + var stores []StoreWrap + + err := stmt.Query(db, &stores) + + require.NoError(t, err) + require.Len(t, stores, 2) + + testutils.AssertJSON(t, stores, ` +[ + { + "StoreID": 1, + "ManagerStaffID": 1, + "AddressID": 1, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": [ + { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=", + "Store": { + "StoreID": 0, + "ManagerStaffID": 0, + "AddressID": 0, + "LastUpdate": "0001-01-01T00:00:00Z", + "Staffs": null + } + } + ] + }, + { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": [ + { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null, + "Store": { + "StoreID": 0, + "ManagerStaffID": 0, + "AddressID": 0, + "LastUpdate": "0001-01-01T00:00:00Z", + "Staffs": null + } + } + ] + } +] +`) + }) + + t.Run("staff->store", func(t *testing.T) { + + var staffs []StaffWrap + + err := stmt.Query(db, &staffs) + require.NoError(t, err) + + testutils.AssertJSON(t, staffs, ` +[ + { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=", + "Store": { + "StoreID": 1, + "ManagerStaffID": 1, + "AddressID": 1, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": null + } + }, + { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null, + "Store": { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": null + } + } +] +`) + }) +} + +// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, +// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. +// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. +func TestLiteralTypeDeduction(t *testing.T) { + stmt := SELECT( + SUM( + CASE().WHEN(Staff.Active.IS_TRUE()). + THEN(Int8(6)). // if Int8 and Int32 are replaced with Int, + ELSE(Int32(-1)), // execution of this statement will return an error + ).AS("num_passed"), + ).FROM(Staff) + + testutils.AssertStatementSql(t, stmt, ` +SELECT SUM((CASE WHEN staff.active IS TRUE THEN $1::smallint ELSE $2::integer END)) AS "num_passed" +FROM dvds.staff; +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + +func GET_FILM_COUNT(lenFrom, lenTo IntegerExpression) IntegerExpression { + return IntExp(Func("dvds.get_film_count", lenFrom, lenTo)) +} + +func TestCustomFunctionCall(t *testing.T) { + stmt := SELECT( + GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT dvds.get_film_count(100, 120) AS "film_count"; +`) + + var dest struct { + FilmCount int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.FilmCount, 165) + + stmt2 := SELECT( + Raw("dvds.get_film_count(#1, #2)", RawArgs{"#1": 100, "#2": 120}).AS("film_count"), + ) + + err = stmt2.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.FilmCount, 165) + + stmt3 := RawStatement(` + SELECT dvds.get_film_count(#1, #2) AS "film_count";`, RawArgs{"#1": 100, "#2": 120}, + ) + + err = stmt3.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.FilmCount, 165) +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 5ec44a11..6cde276b 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -4,6 +4,8 @@ import ( "context" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/require" @@ -264,11 +266,13 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - AssertExec(t, stmt, 1) + _, err := stmt.Exec(db) + require.NoError(t, err) + requireQueryLogged(t, stmt, 1) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { @@ -291,7 +295,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) @@ -371,6 +375,77 @@ func TestUpdateExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } +func TestUpdateFrom(t *testing.T) { + tx := beginTx(t) + defer tx.Rollback() + + stmt := table.Rental.UPDATE(). + SET( + table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)), + ). + FROM( + table.Staff. + INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)), + table.Actor, + ). + WHERE( + table.Staff.StaffID.EQ(table.Rental.StaffID). + AND(table.Staff.StaffID.EQ(Int(2))). + AND(table.Rental.RentalID.LT(Int(10))), + ). + RETURNING( + table.Rental.AllColumns.Except(table.Rental.LastUpdate), + table.Store.AllColumns.Except(table.Store.LastUpdate), + ) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE dvds.rental +SET rental_date = $1::timestamp without time zone +FROM dvds.staff + INNER JOIN dvds.store ON (store.store_id = staff.staff_id), + dvds.actor +WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $2)) AND (rental.rental_id < $3) +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + store.store_id AS "store.store_id", + store.manager_staff_id AS "store.manager_staff_id", + store.address_id AS "store.address_id"; +`) + + var dest []struct { + Rental model2.Rental + Store model2.Store + } + + err := stmt.Query(tx, &dest) + + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` +{ + "Rental": { + "RentalID": 4, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2452, + "CustomerID": 333, + "ReturnDate": "2005-06-03T01:43:41Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + "Store": { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +} +`) +} + func setupLinkTableForUpdateTest(t *testing.T) { cleanUpLinkTable(t) diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 8a16fd4c..0b47e9aa 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -1,6 +1,8 @@ package postgres import ( + "context" + "fmt" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" @@ -143,7 +145,7 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10) require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10) - //fmt.Println(stmt.Sql()) + // fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( @@ -217,5 +219,650 @@ FROM log_discontinued; err = stmt.Query(tx, &resp) require.NoError(t, err) +} + +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH RECURSIVE fibonacci1 (n1, "fibN1", "nextFibN1") AS ( + ( + SELECT $1::integer, + $2::integer, + $3::integer + ) + UNION ALL + ( + SELECT fibonacci1.n1 + $4, + fibonacci1."nextFibN1" AS "nextFibN1", + fibonacci1."fibN1" + fibonacci1."nextFibN1" + FROM fibonacci1 + WHERE fibonacci1.n1 < $5 + ) +),fibonacci2 AS ( + ( + SELECT $6::integer AS "n2", + $7::integer AS "fibN2", + $8::integer AS "nextFibN2" + ) + UNION ALL + ( + SELECT fibonacci2.n2 + $9, + fibonacci2."nextFibN2" AS "nextFibN2", + fibonacci2."fibN2" + fibonacci2."nextFibN2" + FROM fibonacci2 + WHERE fibonacci2.n2 < $10 + ) +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1."fibN1" AS "fibN1", + fibonacci1."nextFibN1" AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2."fibN2" AS "fibN2", + fibonacci2."nextFibN2" AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = $11; +`) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) +} + +// default column aliases from sub-queries are bubbled up to the main query, +// cte name does not affect default column alias in main query +func TestCTEColumnAliasBubbling(t *testing.T) { + cte1 := CTE("cte1") + cte2 := CTE("cte2") + + stmt := WITH( + cte1.AS( + SELECT( + Territories.AllColumns, + String("custom_column_1").AS("custom_column_1"), + ).FROM( + Territories, + ).ORDER_BY( + Territories.TerritoryID.ASC(), + ), + ), + cte2.AS( + SELECT( + cte1.AllColumns(), + String("custom_column_2").AS("custom_column_2"), + ).FROM( + cte1, + ), + ), + )( + SELECT( + cte2.AllColumns(), // columns will have the same alias as in CTEs + cte2.AllColumns().As("territories2.*"), // all column aliases will be changed to territories2.* + cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*"), + cte2.AllColumns(). + Except( + Territories.MutableColumns, + StringColumn("custom_column_1").From(cte2), // custom_column_1 appears with the same alias in cte2 + StringColumn("custom_column_2").From(cte2), + ).As("territories4.*"), + ).FROM( + cte2, + ), + ) + + // fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH cte1 AS ( + SELECT territories.territory_id AS "territories.territory_id", + territories.territory_description AS "territories.territory_description", + territories.region_id AS "territories.region_id", + $1 AS "custom_column_1" + FROM northwind.territories + ORDER BY territories.territory_id ASC +),cte2 AS ( + SELECT cte1."territories.territory_id" AS "territories.territory_id", + cte1."territories.territory_description" AS "territories.territory_description", + cte1."territories.region_id" AS "territories.region_id", + cte1.custom_column_1 AS "custom_column_1", + $2 AS "custom_column_2" + FROM cte1 +) +SELECT cte2."territories.territory_id" AS "territories.territory_id", + cte2."territories.territory_description" AS "territories.territory_description", + cte2."territories.region_id" AS "territories.region_id", + cte2.custom_column_1 AS "custom_column_1", + cte2.custom_column_2 AS "custom_column_2", + cte2."territories.territory_id" AS "territories2.territory_id", + cte2."territories.territory_description" AS "territories2.territory_description", + cte2."territories.region_id" AS "territories2.region_id", + cte2.custom_column_1 AS "territories2.custom_column_1", + cte2.custom_column_2 AS "territories2.custom_column_2", + cte2."territories.territory_id" AS "territories3.territory_id", + cte2.custom_column_1 AS "territories3.custom_column_1", + cte2.custom_column_2 AS "territories3.custom_column_2", + cte2."territories.territory_id" AS "territories4.territory_id" +FROM cte2; +`) + + var dest []struct { + // cte2.AllColumns() + Territories1 struct { + model.Territories + + CustomColumn1 string + CustomColumn2 string + } + + // cte2.AllColumns().As("territories2.*") + Territories2 struct { + model.Territories `alias:"territories2.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories2.*"` + + // cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*") + Territories3 struct { + model.Territories `alias:"territories3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories3.*"` + + // cte2.AllColumns() ... .As("territories4.*") + Territories4 struct { + model.Territories `alias:"territories3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories4.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 53) + require.Equal(t, dest[0].Territories1.Territories, model.Territories{ + TerritoryID: "01581", + TerritoryDescription: "Westboro", + RegionID: 1, + }) + require.Equal(t, dest[0].Territories1.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Territories1.CustomColumn2, "custom_column_2") + + // Territories2 + require.Equal(t, testutils.ToJSON(dest[0].Territories1), testutils.ToJSON(dest[0].Territories2)) + + // Territories3 + require.Equal(t, dest[0].Territories3.TerritoryID, dest[0].Territories1.TerritoryID) + require.Equal(t, dest[0].Territories3.RegionID, int16(0)) + require.Equal(t, dest[0].Territories3.TerritoryDescription, "") + require.Equal(t, dest[0].Territories1.CustomColumn1, dest[0].Territories3.CustomColumn1) + require.Equal(t, dest[0].Territories1.CustomColumn2, dest[0].Territories3.CustomColumn2) + + // Territories4 + require.Equal(t, dest[0].Territories3.Territories, dest[0].Territories4.Territories) + require.Equal(t, dest[0].Territories4.CustomColumn1, "") + require.Equal(t, dest[0].Territories4.CustomColumn2, "") +} + +func TestRecursiveWithStatement(t *testing.T) { + + subordinates := CTE("subordinates") + + stmt := WITH_RECURSIVE( + subordinates.AS( + SELECT( + Employees.AllColumns, + ).FROM( + Employees, + ).WHERE( + Employees.EmployeeID.EQ(Int(2)), + ).UNION( + SELECT( + Employees.AllColumns, + ).FROM( + Employees. + INNER_JOIN(subordinates, Employees.EmployeeID.From(subordinates).EQ(Employees.ReportsTo)), + ), + ), + ), + )( + SELECT( + subordinates.AllColumns(), + ).FROM( + subordinates, + ), + ) + + //fmt.Println(stmt.DebugSql()) + + type EmployeeWrap struct { + model.Employees + + Subordinates []*EmployeeWrap + } + + type employeeID = int16 + employeeMap := make(map[employeeID]*EmployeeWrap) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + var result *EmployeeWrap + + for rows.Next() { + var employeeModel model.Employees + err := rows.Scan(&employeeModel) + require.NoError(t, err) + + newEmployeeWrap := &EmployeeWrap{ + Employees: employeeModel, + } + + employeeMap[employeeModel.EmployeeID] = newEmployeeWrap + if result == nil { // top manager(always first row in the result) + result = newEmployeeWrap + continue + } + + if employee, ok := employeeMap[*employeeModel.ReportsTo]; ok { + employee.Subordinates = append(employee.Subordinates, newEmployeeWrap) + } + } + + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) + + testutils.AssertJSON(t, *result, ` +{ + "EmployeeID": 2, + "LastName": "Fuller", + "FirstName": "Andrew", + "Title": "Vice President, Sales", + "TitleOfCourtesy": "Dr.", + "BirthDate": "1952-02-19T00:00:00Z", + "HireDate": "1992-08-14T00:00:00Z", + "Address": "908 W. Capital Way", + "City": "Tacoma", + "Region": "WA", + "PostalCode": "98401", + "Country": "USA", + "HomePhone": "(206) 555-9482", + "Extension": "3457", + "Photo": "", + "Notes": "Andrew received his BTS commercial in 1974 and a Ph.D. in international marketing from the University of Dallas in 1981. He is fluent in French and Italian and reads German. He joined the company as a sales representative, was promoted to sales manager in January 1992 and to vice president of sales in March 1993. Andrew is a member of the Sales Management Roundtable, the Seattle Chamber of Commerce, and the Pacific Rim Importers Association.", + "ReportsTo": null, + "PhotoPath": "http://accweb/emmployees/fuller.bmp", + "Subordinates": [ + { + "EmployeeID": 1, + "LastName": "Davolio", + "FirstName": "Nancy", + "Title": "Sales Representative", + "TitleOfCourtesy": "Ms.", + "BirthDate": "1948-12-08T00:00:00Z", + "HireDate": "1992-05-01T00:00:00Z", + "Address": "507 - 20th Ave. E.\\nApt. 2A", + "City": "Seattle", + "Region": "WA", + "PostalCode": "98122", + "Country": "USA", + "HomePhone": "(206) 555-9857", + "Extension": "5467", + "Photo": "", + "Notes": "Education includes a BA in psychology from Colorado State University in 1970. She also completed The Art of the Cold Call. Nancy is a member of Toastmasters International.", + "ReportsTo": 2, + "PhotoPath": "http://accweb/emmployees/davolio.bmp", + "Subordinates": null + }, + { + "EmployeeID": 3, + "LastName": "Leverling", + "FirstName": "Janet", + "Title": "Sales Representative", + "TitleOfCourtesy": "Ms.", + "BirthDate": "1963-08-30T00:00:00Z", + "HireDate": "1992-04-01T00:00:00Z", + "Address": "722 Moss Bay Blvd.", + "City": "Kirkland", + "Region": "WA", + "PostalCode": "98033", + "Country": "USA", + "HomePhone": "(206) 555-3412", + "Extension": "3355", + "Photo": "", + "Notes": "Janet has a BS degree in chemistry from Boston College (1984). She has also completed a certificate program in food retailing management. Janet was hired as a sales associate in 1991 and promoted to sales representative in February 1992.", + "ReportsTo": 2, + "PhotoPath": "http://accweb/emmployees/leverling.bmp", + "Subordinates": null + }, + { + "EmployeeID": 4, + "LastName": "Peacock", + "FirstName": "Margaret", + "Title": "Sales Representative", + "TitleOfCourtesy": "Mrs.", + "BirthDate": "1937-09-19T00:00:00Z", + "HireDate": "1993-05-03T00:00:00Z", + "Address": "4110 Old Redmond Rd.", + "City": "Redmond", + "Region": "WA", + "PostalCode": "98052", + "Country": "USA", + "HomePhone": "(206) 555-8122", + "Extension": "5176", + "Photo": "", + "Notes": "Margaret holds a BA in English literature from Concordia College (1958) and an MA from the American Institute of Culinary Arts (1966). She was assigned to the London office temporarily from July through November 1992.", + "ReportsTo": 2, + "PhotoPath": "http://accweb/emmployees/peacock.bmp", + "Subordinates": null + }, + { + "EmployeeID": 5, + "LastName": "Buchanan", + "FirstName": "Steven", + "Title": "Sales Manager", + "TitleOfCourtesy": "Mr.", + "BirthDate": "1955-03-04T00:00:00Z", + "HireDate": "1993-10-17T00:00:00Z", + "Address": "14 Garrett Hill", + "City": "London", + "Region": null, + "PostalCode": "SW1 8JR", + "Country": "UK", + "HomePhone": "(71) 555-4848", + "Extension": "3453", + "Photo": "", + "Notes": "Steven Buchanan graduated from St. Andrews University, Scotland, with a BSC degree in 1976. Upon joining the company as a sales representative in 1992, he spent 6 months in an orientation program at the Seattle office and then returned to his permanent post in London. He was promoted to sales manager in March 1993. Mr. Buchanan has completed the courses Successful Telemarketing and International Sales Management. He is fluent in French.", + "ReportsTo": 2, + "PhotoPath": "http://accweb/emmployees/buchanan.bmp", + "Subordinates": [ + { + "EmployeeID": 6, + "LastName": "Suyama", + "FirstName": "Michael", + "Title": "Sales Representative", + "TitleOfCourtesy": "Mr.", + "BirthDate": "1963-07-02T00:00:00Z", + "HireDate": "1993-10-17T00:00:00Z", + "Address": "Coventry House\\nMiner Rd.", + "City": "London", + "Region": null, + "PostalCode": "EC2 7JR", + "Country": "UK", + "HomePhone": "(71) 555-7773", + "Extension": "428", + "Photo": "", + "Notes": "Michael is a graduate of Sussex University (MA, economics, 1983) and the University of California at Los Angeles (MBA, marketing, 1986). He has also taken the courses Multi-Cultural Selling and Time Management for the Sales Professional. He is fluent in Japanese and can read and write French, Portuguese, and Spanish.", + "ReportsTo": 5, + "PhotoPath": "http://accweb/emmployees/davolio.bmp", + "Subordinates": null + }, + { + "EmployeeID": 7, + "LastName": "King", + "FirstName": "Robert", + "Title": "Sales Representative", + "TitleOfCourtesy": "Mr.", + "BirthDate": "1960-05-29T00:00:00Z", + "HireDate": "1994-01-02T00:00:00Z", + "Address": "Edgeham Hollow\\nWinchester Way", + "City": "London", + "Region": null, + "PostalCode": "RG1 9SP", + "Country": "UK", + "HomePhone": "(71) 555-5598", + "Extension": "465", + "Photo": "", + "Notes": "Robert King served in the Peace Corps and traveled extensively before completing his degree in English at the University of Michigan in 1992, the year he joined the company. After completing a course entitled Selling in Europe, he was transferred to the London office in March 1993.", + "ReportsTo": 5, + "PhotoPath": "http://accweb/emmployees/davolio.bmp", + "Subordinates": null + }, + { + "EmployeeID": 9, + "LastName": "Dodsworth", + "FirstName": "Anne", + "Title": "Sales Representative", + "TitleOfCourtesy": "Ms.", + "BirthDate": "1966-01-27T00:00:00Z", + "HireDate": "1994-11-15T00:00:00Z", + "Address": "7 Houndstooth Rd.", + "City": "London", + "Region": null, + "PostalCode": "WG2 7LT", + "Country": "UK", + "HomePhone": "(71) 555-4444", + "Extension": "452", + "Photo": "", + "Notes": "Anne has a BA degree in English from St. Lawrence College. She is fluent in French and German.", + "ReportsTo": 5, + "PhotoPath": "http://accweb/emmployees/davolio.bmp", + "Subordinates": null + } + ] + }, + { + "EmployeeID": 8, + "LastName": "Callahan", + "FirstName": "Laura", + "Title": "Inside Sales Coordinator", + "TitleOfCourtesy": "Ms.", + "BirthDate": "1958-01-09T00:00:00Z", + "HireDate": "1994-03-05T00:00:00Z", + "Address": "4726 - 11th Ave. N.E.", + "City": "Seattle", + "Region": "WA", + "PostalCode": "98105", + "Country": "USA", + "HomePhone": "(206) 555-1189", + "Extension": "2344", + "Photo": "", + "Notes": "Laura received a BA in psychology from the University of Washington. She has also completed a course in business French. She reads and writes French.", + "ReportsTo": 2, + "PhotoPath": "http://accweb/emmployees/davolio.bmp", + "Subordinates": null + } + ] +} +`) +} + +var suppliersWithFax = CTE("suppliers_fax").AS( + SELECT( + Suppliers.SupplierID, + Suppliers.ContactName, + Suppliers.Country, + ).FROM( + Suppliers, + ).WHERE(Suppliers.Fax.IS_NOT_NULL()), +) + +func SuppliersNotFromUSorAUS(suppliersCTE CommonTableExpression) CommonTableExpression { + return CTE("not_from_us_or_aus").AS( + SELECT( + suppliersCTE.AllColumns(), + ).FROM( + suppliersCTE, + ).WHERE( + Suppliers.Country.From(suppliersCTE).NOT_IN(String("US"), String("Australia")), + ), + ) +} + +func TestCTEReuse(t *testing.T) { + suppliersFilteredByCountry := SuppliersNotFromUSorAUS(suppliersWithFax) + supplierContactName := Suppliers.ContactName.From(suppliersFilteredByCountry) + + stmt := WITH( + suppliersWithFax, + suppliersFilteredByCountry, + )( + SELECT( + suppliersFilteredByCountry.AllColumns(), + ).FROM( + suppliersFilteredByCountry, + ).WHERE( + supplierContactName.NOT_EQ(String("John")), + ), + ) + + // fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH suppliers_fax AS ( + SELECT suppliers.supplier_id AS "suppliers.supplier_id", + suppliers.contact_name AS "suppliers.contact_name", + suppliers.country AS "suppliers.country" + FROM northwind.suppliers + WHERE suppliers.fax IS NOT NULL +),not_from_us_or_aus AS ( + SELECT suppliers_fax."suppliers.supplier_id" AS "suppliers.supplier_id", + suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name", + suppliers_fax."suppliers.country" AS "suppliers.country" + FROM suppliers_fax + WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia') +) +SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id", + not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name", + not_from_us_or_aus."suppliers.country" AS "suppliers.country" +FROM not_from_us_or_aus +WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'; +`) + + var dest []model.Suppliers + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Len(t, dest, 11) +} + +func TestWitStatement_CTE_NotMaterialized(t *testing.T) { + orders1 := CTE("orders1") + orders1ID := Orders.OrderID.From(orders1) + orders2 := orders1.ALIAS("orders2") + orders2ID := Orders.OrderID.From(orders2) + + stmt := WITH( + orders1.AS_NOT_MATERIALIZED( + SELECT( + Orders.OrderID, + Orders.EmployeeID, + Orders.ShipCity, + ).FROM( + Orders, + ), + ), + )( + SELECT( + orders1.AllColumns().As("orders1.*"), + orders2.AllColumns().As("orders2.*"), + ).FROM( + orders1. + INNER_JOIN(orders2, orders1ID.EQ(orders2ID)), + ).WHERE( + orders1ID.LT(Int(10320)), + ), + ) + + // fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH orders1 AS NOT MATERIALIZED ( + SELECT orders.order_id AS "orders.order_id", + orders.employee_id AS "orders.employee_id", + orders.ship_city AS "orders.ship_city" + FROM northwind.orders +) +SELECT orders1."orders.order_id" AS "orders1.order_id", + orders1."orders.employee_id" AS "orders1.employee_id", + orders1."orders.ship_city" AS "orders1.ship_city", + orders2."orders.order_id" AS "orders2.order_id", + orders2."orders.employee_id" AS "orders2.employee_id", + orders2."orders.ship_city" AS "orders2.ship_city" +FROM orders1 + INNER JOIN orders1 AS orders2 ON (orders1."orders.order_id" = orders2."orders.order_id") +WHERE orders1."orders.order_id" < $1; +`) + + var dest []struct { + Orders1 model.Orders `alias:"orders1.*"` + Orders2 model.Orders `alias:"orders2.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 72) + fmt.Println(len(dest)) } diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 523d959b..1d90bf22 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -347,10 +347,13 @@ func TestFloatOperators(t *testing.T) { AllTypes.Numeric.IS_NOT_DISTINCT_FROM(AllTypes.Numeric).AS("not_distinct1"), AllTypes.Decimal.IS_NOT_DISTINCT_FROM(Float(12)).AS("not_distinct2"), AllTypes.Real.IS_NOT_DISTINCT_FROM(Float(12.12)).AS("not_distinct3"), + AllTypes.Numeric.LT(Float(124)).AS("lt1"), AllTypes.Numeric.LT(Float(34.56)).AS("lt2"), AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), + AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"), + AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"), AllTypes.Decimal.ADD(AllTypes.Decimal).AS("add1"), AllTypes.Decimal.ADD(Float(11.22)).AS("add2"), @@ -395,6 +398,8 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.numeric < ?) AS "lt2", (all_types.numeric > ?) AS "gt1", (all_types.numeric > ?) AS "gt2", + (all_types.numeric BETWEEN ? AND all_types.decimal) AS "between", + (all_types.numeric NOT BETWEEN (all_types.decimal * ?) AND ?) AS "not_between", (all_types.decimal + all_types.decimal) AS "add1", (all_types.decimal + ?) AS "add2", (all_types.decimal - all_types.decimal_ptr) AS "sub1", @@ -441,40 +446,32 @@ func TestIntegerOperators(t *testing.T) { AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), AllTypes.BigInt.EQ(Int(12)).AS("eq2"), - AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), - AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), - AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), AllTypes.BigInt.LT(Int(65)).AS("lt2"), - AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), - AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), AllTypes.BigInt.GT(Int(65)).AS("gt2"), - AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"), + AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(Int(11)).AS("add2"), - AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), AllTypes.BigInt.SUB(Int(11)).AS("sub2"), - AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), AllTypes.BigInt.MUL(Int(11)).AS("mul2"), - AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), AllTypes.BigInt.DIV(Int(11)).AS("div2"), - AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), AllTypes.BigInt.MOD(Int(11)).AS("mod2"), @@ -483,19 +480,15 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), - AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"), AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"), - AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), - BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), - AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"), @@ -522,7 +515,8 @@ func TestIntegerOperators(t *testing.T) { require.Equal(t, *dest[0].BitXor2, int64(5)) require.Equal(t, *dest[0].BitShiftLeft1, int64(1792)) require.Equal(t, *dest[0].BitShiftRight2, int64(7)) - + require.Equal(t, *dest[0].Between, false) + require.Equal(t, *dest[0].NotBetween, true) } func TestStringOperators(t *testing.T) { @@ -540,6 +534,8 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.LT(String("Text")), AllTypes.Text.LT_EQ(AllTypes.VarCharPtr), AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.BETWEEN(String("min"), String("max")), + AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.CONCAT(String("text2")), AllTypes.Text.CONCAT(Int(11)), AllTypes.Text.LIKE(String("abc")), @@ -717,27 +713,23 @@ func TestDateExpressions(t *testing.T) { AllTypes.Date.EQ(AllTypes.Date), AllTypes.Date.EQ(Date(2019, 6, 6)), - AllTypes.DatePtr.NOT_EQ(AllTypes.Date), AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)), - AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date).AS("distinct1"), AllTypes.Date.IS_DISTINCT_FROM(Date(2008, 7, 4)).AS("distinct2"), - AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date), AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)), AllTypes.Date.LT(AllTypes.Date), AllTypes.Date.LT(Date(2019, 4, 6)), - AllTypes.Date.LT_EQ(AllTypes.Date), AllTypes.Date.LT_EQ(Date(2019, 5, 5)), - AllTypes.Date.GT(AllTypes.Date), AllTypes.Date.GT(Date(2019, 1, 4)), - AllTypes.Date.GT_EQ(AllTypes.Date), AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + AllTypes.Date.BETWEEN(Date(2000, 2, 2), AllTypes.DatePtr), + AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, Date(2000, 2, 2)), //AllTypes.Date.ADD(INTERVAL2(2, HOUR)), //AllTypes.Date.ADD(INTERVAL2(1, DAY, 7, MONTH)), @@ -790,12 +782,12 @@ func TestTimeExpressions(t *testing.T) { AllTypes.TimePtr.NOT_EQ(AllTypes.Time), AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)), - AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)), - AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time), AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)), + AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), AllTypes.TimePtr), + AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, TIME(time.Now())), AllTypes.Time.LT(AllTypes.Time), AllTypes.Time.LT(Time(17, 46, 6)), @@ -822,6 +814,8 @@ func TestTimeExpressions(t *testing.T) { CURRENT_TIME(), ) + //fmt.Println(query.DebugSql()) + var dest struct { Time1 string Time2 time.Time @@ -855,27 +849,23 @@ func TestDateTimeExpressions(t *testing.T) { AllTypes.DateTime.EQ(AllTypes.DateTime), AllTypes.DateTime.EQ(dateTime), - AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime), AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)), - AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime), AllTypes.DateTime.IS_DISTINCT_FROM(dateTime), - AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime), AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime), AllTypes.DateTime.LT(AllTypes.DateTime), AllTypes.DateTime.LT(dateTime), - AllTypes.DateTime.LT_EQ(AllTypes.DateTime), AllTypes.DateTime.LT_EQ(dateTime), - AllTypes.DateTime.GT(AllTypes.DateTime), AllTypes.DateTime.GT(dateTime), - AllTypes.DateTime.GT_EQ(AllTypes.DateTime), AllTypes.DateTime.GT_EQ(dateTime), + AllTypes.DateTime.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), + AllTypes.DateTime.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr), //AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), //AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go index ac7ab5d3..b8280fdc 100644 --- a/tests/sqlite/generator_test.go +++ b/tests/sqlite/generator_test.go @@ -72,6 +72,39 @@ func TestCmdGenerator(t *testing.T) { require.NoError(t, err) } +func TestCmdGeneratorIgnoreTablesViewsEnums(t *testing.T) { + cmd := exec.Command("jet", + "-source=SQLite", + "-dsn=file://"+testDatabaseFilePath, + "-ignore-tables=actor,Address,CATEGORY , city ,film,rental,store", + "-ignore-views=customer_list, film_list,STAFF_LIst", + "-path="+genDestDir) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err := cmd.Run() + require.NoError(t, err) + + tableSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/table") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "country.go", + "customer.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "staff.go") + + viewSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/view") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "sales_by_film_category.go", + "sales_by_store.go") + + modelFiles, err := ioutil.ReadDir(genDestDir + "/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "country.go", + "customer.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "staff.go", "sales_by_film_category.go", "sales_by_store.go") +} + func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/table") diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 710f7ad5..4eb274eb 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -5,12 +5,14 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/stretchr/testify/require" "math/rand" "os" "os/exec" + "runtime" "strings" "testing" "time" @@ -63,11 +65,36 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo sqlite.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + sqlite.SetQueryLogger(func(ctx context.Context, info sqlite.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) +} + +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) } func requireLogged(t *testing.T, statement sqlite.Statement) { diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 95527f1e..657fb3a8 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -39,6 +39,7 @@ WHERE actor.actor_id = ?; testutils.AssertDeepEqual(t, actor, actor2) requireLogged(t, query) + requireQueryLogged(t, query, 1) } var actor2 = model.Actor{ @@ -63,7 +64,7 @@ ORDER BY actor.actor_id; `) dest := []model.Actor{} - err := query.Query(db, &dest) + err := query.QueryContext(context.Background(), db, &dest) require.NoError(t, err) @@ -73,6 +74,7 @@ ORDER BY actor.actor_id; //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json") requireLogged(t, query) + requireQueryLogged(t, query, 200) } func TestSelectGroupByHaving(t *testing.T) { @@ -143,6 +145,67 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; requireLogged(t, query) } +func TestAggregateFunctionDistinct(t *testing.T) { + stmt := SELECT( + Payment.CustomerID, + + COUNT(DISTINCT(Payment.Amount)).AS("distinct.count"), + SUM(DISTINCT(Payment.Amount)).AS("distinct.sum"), + AVG(DISTINCT(Payment.Amount)).AS("distinct.avg"), + MIN(DISTINCT(Payment.PaymentDate)).AS("distinct.first_payment_date"), + MAX(DISTINCT(Payment.PaymentDate)).AS("distinct.last_payment_date"), + ).FROM( + Payment, + ).WHERE( + Payment.CustomerID.EQ(Int(1)), + ).GROUP_BY( + Payment.CustomerID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT payment.customer_id AS "payment.customer_id", + COUNT(DISTINCT payment.amount) AS "distinct.count", + SUM(DISTINCT payment.amount) AS "distinct.sum", + AVG(DISTINCT payment.amount) AS "distinct.avg", + MIN(DISTINCT payment.payment_date) AS "distinct.first_payment_date", + MAX(DISTINCT payment.payment_date) AS "distinct.last_payment_date" +FROM payment +WHERE payment.customer_id = 1 +GROUP BY payment.customer_id; +`) + + type Distinct struct { + model.Payment + + Count int64 + Sum float64 + Avg float64 + FirstPaymentDate time.Time + LastPaymentDate time.Time + } + + var dest Distinct + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "PaymentID": 0, + "CustomerID": 1, + "StaffID": 0, + "RentalID": null, + "Amount": 0, + "PaymentDate": "0001-01-01T00:00:00Z", + "LastUpdate": "0001-01-01T00:00:00Z", + "Count": 8, + "Sum": 38.92000000000001, + "Avg": 4.865000000000001, + "FirstPaymentDate": "2005-05-25T11:30:37Z", + "LastPaymentDate": "2005-08-22T20:03:46Z" +} +`) +} + func TestSubQuery(t *testing.T) { rRatingFilms := diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 61135a8f..110c6590 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -2,6 +2,8 @@ package sqlite import ( "context" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" "testing" "time" @@ -288,3 +290,77 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) { _, err = updateStmt.ExecContext(ctx, tx) require.Error(t, err, "context deadline exceeded") } + +func TestUpdateFrom(t *testing.T) { + tx := beginDBTx(t) + defer tx.Rollback() + + stmt := table.Rental.UPDATE(). + SET( + table.Rental.RentalDate.SET(DateTime(2020, 2, 2, 0, 0, 0)), + ). + FROM( + table.Staff. + INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)), + ). + WHERE( + table.Staff.StaffID.EQ(table.Rental.StaffID). + AND(table.Staff.StaffID.EQ(Int(2))). + AND(table.Rental.RentalID.LT(Int(10))), + ). + RETURNING( + table.Rental.AllColumns.Except(table.Rental.LastUpdate), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE rental +SET rental_date = DATETIME('2020-02-02 00:00:00') +FROM staff + INNER JOIN store ON (store.store_id = staff.staff_id) +WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = 2)) AND (rental.rental_id < 10) +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id"; +`) + + var dest []model2.Rental + + err := stmt.Query(tx, &dest) + + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest, ` +[ + { + "RentalID": 4, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2452, + "CustomerID": 333, + "ReturnDate": "2005-06-03T01:43:41Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 7, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 3995, + "CustomerID": 269, + "ReturnDate": "2005-05-29T20:34:53Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 8, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2346, + "CustomerID": 239, + "ReturnDate": "2005-05-27T23:33:46Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +] +`) +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go index f2b623ab..92cd331e 100644 --- a/tests/sqlite/with_test.go +++ b/tests/sqlite/with_test.go @@ -232,3 +232,117 @@ FROM payment; err := stmt.Query(db, &dest) require.NoError(t, err) } + +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS ( + + SELECT ?, + ?, + ? + + UNION ALL + + SELECT fibonacci1.n1 + ?, + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci1.''fibN1'' + fibonacci1.''nextFibN1'' + FROM fibonacci1 + WHERE fibonacci1.n1 < ? +),fibonacci2 AS ( + + SELECT ? AS "n2", + ? AS "fibN2", + ? AS "nextFibN2" + + UNION ALL + + SELECT fibonacci2.n2 + ?, + fibonacci2.''nextFibN2'' AS "nextFibN2", + fibonacci2.''fibN2'' + fibonacci2.''nextFibN2'' + FROM fibonacci2 + WHERE fibonacci2.n2 < ? +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1.''fibN1'' AS "fibN1", + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2.''fibN2'' AS "fibN2", + fibonacci2.''nextFibN2'' AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = ?; +`, "''", "`")) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) +} diff --git a/tests/testdata b/tests/testdata index 946bc1e5..895bf576 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 946bc1e5d3e162154eade8b79ff915e4c4986efd +Subproject commit 895bf5760d055c717df77c3b872af276f34d06f1