Skip to content

Commit

Permalink
fix: select with alias (#122)
Browse files Browse the repository at this point in the history
* fix: select with alias "AS"

* chore: add escape unit test

* chore: add cache load test

* chore: move check alias after cache check

* refactor: follow if pattern to reduce return

* fix: update behaviour

* docs: remove unused comments
  • Loading branch information
h4ckm03d committed Oct 2, 2020
1 parent 39a644b commit 3fb088c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
16 changes: 13 additions & 3 deletions adapter/sql/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ func TestBuilder_Find_ordinal(t *testing.T) {
nil,
query.Select("id", "name"),
},
{
"SELECT \"id\" AS \"user_id\",\"name\" FROM \"users\";",
nil,
query.Select("id as user_id", "name"),
},
{
"SELECT \"id\" AS \"user_id\",\"name\" FROM \"users\";",
nil,
query.Select("id AS user_id", "name"),
},
{
"SELECT * FROM \"users\" JOIN \"transactions\" ON \"transactions\".\"id\"=\"users\".\"transaction_id\";",
nil,
Expand Down Expand Up @@ -781,15 +791,15 @@ func TestBuilder_Select(t *testing.T) {
fields: []string{"id", "name"},
},
{
result: "SELECT COUNT(*) AS count",
result: "SELECT COUNT(*) AS `count`",
fields: []string{"COUNT(*) AS count"},
},
{
result: "SELECT COUNT(`transactions`.*) AS count",
result: "SELECT COUNT(`transactions`.*) AS `count`",
fields: []string{"COUNT(transactions.*) AS count"},
},
{
result: "SELECT SUM(`transactions`.`total`) AS total",
result: "SELECT SUM(`transactions`.`total`) AS `total`",
fields: []string{"SUM(transactions.total) AS total"},
},
}
Expand Down
2 changes: 2 additions & 0 deletions adapter/sql/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ func Escape(config Config, field string) string {

if len(field) > 0 && field[0] == UnescapeCharacter {
escapedField = field[1:]
} else if i := strings.Index(strings.ToLower(field), " as "); i > -1 {
escapedField = Escape(config, field[:i]) + " AS " + Escape(config, field[i+4:])
} else if start, end := strings.IndexRune(field, '('), strings.IndexRune(field, ')'); start >= 0 && end >= 0 && end > start {
escapedField = field[:start+1] + Escape(config, field[start+1:end]) + field[end:]
} else if strings.HasSuffix(field, "*") {
Expand Down
42 changes: 42 additions & 0 deletions adapter/sql/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,45 @@ func TestToInt64(t *testing.T) {
assert.Equal(t, int64(1), toInt64(uint16(1)))
assert.Equal(t, int64(1), toInt64(uint8(1)))
}

func TestEscape(t *testing.T) {
config := Config{
Placeholder: "?",
EscapeChar: "`",
}

tests := []struct {
field string
result string
}{
{
field: "count(*) as count",
result: "count(*) AS `count`",
},
{
field: "user.address as home_address",
result: "`user`.`address` AS `home_address`",
},
{
field: "^FIELD(`gender`, \"male\") AS order",
result: "FIELD(`gender`, \"male\") AS order",
},
{
field: "*",
result: "*",
},
{
field: "user.*",
result: "`user`.*",
},
}
for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
result = Escape(config, test.field)
)

assert.Equal(t, test.result, result)
})
}
}

0 comments on commit 3fb088c

Please sign in to comment.