Skip to content

Commit

Permalink
Fix query using table alias (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fs02 committed Apr 22, 2024
1 parent 177dd60 commit 266b692
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 21 deletions.
35 changes: 25 additions & 10 deletions builder/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,35 +131,40 @@ func (b Buffer) escape(table, value string) string {
return escapedValue.(string)
}

var escaped_table string
table, alias := extractAlias(table)
var escapedTable string
if table != "" {
if i := strings.Index(strings.ToLower(table), " as "); i > -1 {
return b.escape(table[:i], "") + " AS " + b.Quoter.ID(table[i+4:])
if table != alias {
if value == "" {
return b.escape(table, "") + " AS " + b.Quoter.ID(alias)
} else {
escapedTable = b.Quoter.ID(alias)
}
}
if b.AllowTableSchema && strings.IndexByte(table, '.') >= 0 {
parts := strings.Split(table, ".")
for i, part := range parts {
part = strings.TrimSpace(part)
parts[i] = b.Quoter.ID(part)
}
escaped_table = strings.Join(parts, ".")
escapedTable = strings.Join(parts, ".")
} else {
escaped_table = b.Quoter.ID(strings.ReplaceAll(table, ".", "_"))
escapedTable = b.Quoter.ID(strings.ReplaceAll(table, ".", "_"))
}
}

if value == "" {
escapedValue = escaped_table
escapedValue = escapedTable
} else if value == "*" {
escapedValue = escaped_table + ".*"
escapedValue = escapedTable + ".*"
} else if len(value) > 0 && value[0] == UnescapeCharacter {
escapedValue = value[1:]
} else if _, err := strconv.Atoi(value); err == nil {
escapedValue = value
} else if i := strings.Index(strings.ToLower(value), " as "); i > -1 {
escapedValue = b.escape(table, value[:i]) + " AS " + b.Quoter.ID(value[i+4:])
escapedValue = b.escape(alias, value[:i]) + " AS " + b.Quoter.ID(value[i+4:])
} else if start, end := strings.IndexRune(value, '('), strings.IndexRune(value, ')'); start >= 0 && end >= 0 && end > start {
escapedValue = value[:start+1] + b.escape(table, value[start+1:end]) + value[end:]
escapedValue = value[:start+1] + b.escape(alias, value[start+1:end]) + value[end:]
} else {
parts := strings.Split(value, ".")
for i, part := range parts {
Expand All @@ -171,7 +176,7 @@ func (b Buffer) escape(table, value string) string {
}
result := strings.Join(parts, ".")
if len(parts) == 1 && table != "" {
result = escaped_table + "." + result
result = escapedTable + "." + result
}
escapedValue = result
}
Expand Down Expand Up @@ -228,3 +233,13 @@ func (bf BufferFactory) Create() Buffer {
BoolFalseValue: bf.BoolFalseValue,
}
}

// extract alias in the form of table as alias
// if no alias, table will be returned as alias
func extractAlias(input string) (string, string) {
if i := strings.Index(strings.ToLower(input), " as "); i > -1 {
return input[:i], input[i+4:]
}

return input, input
}
18 changes: 7 additions & 11 deletions builder/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,16 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {

for _, join := range joins {
var (
from = join.From
to = join.To
_, sAlias = extractAlias(table)
jTable, jAlias = extractAlias(join.Table)
from = join.From
to = join.To
)

jtable := join.Table
// If join table has alias use that for filter conditions
if i := strings.Index(strings.ToLower(jtable), " as "); i > -1 {
jtable = jtable[i+4:]
}

// TODO: move this to core functionality, and infer join condition using assoc data.
if join.Arguments == nil && (join.From == "" || join.To == "") {
from = table + "." + strings.TrimSuffix(join.Table, "s") + "_id"
to = jtable + ".id"
from = sAlias + "." + strings.TrimSuffix(jTable, "s") + "_id"
to = jAlias + ".id"
}

buffer.WriteByte(' ')
Expand All @@ -133,7 +129,7 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {
buffer.WriteEscape(to)
if !join.Filter.None() {
buffer.WriteString(" AND ")
q.Filter.Write(buffer, jtable, join.Filter, q)
q.Filter.Write(buffer, join.Table, join.Filter, q)
}
}

Expand Down
30 changes: 30 additions & 0 deletions builder/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ func TestQuery_Build(t *testing.T) {
result: "SELECT `users`.* FROM `users` FOR UPDATE;",
query: rel.From("users").Lock("FOR UPDATE"),
},
{
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c`;",
query: rel.Select("c.id", "c.name").From("contacts as c"),
},
{
result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;",
query: rel.Select("MAX(id)").From("contacts as c"),
},
{
result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;",
query: rel.Select("MAX(c.id)").From("contacts as c"),
},
{
result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;",
query: rel.Select("MAX(id) as max_id").From("contacts as c"),
},
{
result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;",
query: rel.Select("MAX(c.id) as max_id").From("contacts as c"),
},
{
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `c`.`user_id`=`u`.`id` WHERE `u`.`active`=?;",
args: []any{true},
query: rel.Select("c.id", "c.name").From("contacts as c").Join("users as u").Where(rel.Eq("u.active", true)),
},
{
result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `u`.`id`=`c`.`user_id` WHERE `u`.`active`=?;",
args: []any{true},
query: rel.Select("c.id", "c.name").From("contacts as c").JoinOn("users as u", "u.id", "c.user_id").Where(rel.Eq("u.active", true)),
},
}

for _, test := range tests {
Expand Down

0 comments on commit 266b692

Please sign in to comment.