From 266b6920dfac54ae53e713e4643e2f0a980cae90 Mon Sep 17 00:00:00 2001 From: Surya Asriadie Date: Mon, 22 Apr 2024 10:53:24 +0900 Subject: [PATCH] Fix query using table alias (#73) --- builder/buffer.go | 35 +++++++++++++++++++++++++---------- builder/query.go | 18 +++++++----------- builder/query_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 21 deletions(-) diff --git a/builder/buffer.go b/builder/buffer.go index e807a62..1f76e1a 100644 --- a/builder/buffer.go +++ b/builder/buffer.go @@ -131,10 +131,15 @@ func (b Buffer) escape(table, value string) string { return escapedValue.(string) } - var escaped_table string + table, alias := extractAlias(table) + var escapedTable string if table != "" { - if i := strings.Index(strings.ToLower(table), " as "); i > -1 { - return b.escape(table[:i], "") + " AS " + b.Quoter.ID(table[i+4:]) + if table != alias { + if value == "" { + return b.escape(table, "") + " AS " + b.Quoter.ID(alias) + } else { + escapedTable = b.Quoter.ID(alias) + } } if b.AllowTableSchema && strings.IndexByte(table, '.') >= 0 { parts := strings.Split(table, ".") @@ -142,24 +147,24 @@ func (b Buffer) escape(table, value string) string { part = strings.TrimSpace(part) parts[i] = b.Quoter.ID(part) } - escaped_table = strings.Join(parts, ".") + escapedTable = strings.Join(parts, ".") } else { - escaped_table = b.Quoter.ID(strings.ReplaceAll(table, ".", "_")) + escapedTable = b.Quoter.ID(strings.ReplaceAll(table, ".", "_")) } } if value == "" { - escapedValue = escaped_table + escapedValue = escapedTable } else if value == "*" { - escapedValue = escaped_table + ".*" + escapedValue = escapedTable + ".*" } else if len(value) > 0 && value[0] == UnescapeCharacter { escapedValue = value[1:] } else if _, err := strconv.Atoi(value); err == nil { escapedValue = value } else if i := strings.Index(strings.ToLower(value), " as "); i > -1 { - escapedValue = b.escape(table, value[:i]) + " AS " + b.Quoter.ID(value[i+4:]) + escapedValue = b.escape(alias, value[:i]) + " AS " + b.Quoter.ID(value[i+4:]) } else if start, end := strings.IndexRune(value, '('), strings.IndexRune(value, ')'); start >= 0 && end >= 0 && end > start { - escapedValue = value[:start+1] + b.escape(table, value[start+1:end]) + value[end:] + escapedValue = value[:start+1] + b.escape(alias, value[start+1:end]) + value[end:] } else { parts := strings.Split(value, ".") for i, part := range parts { @@ -171,7 +176,7 @@ func (b Buffer) escape(table, value string) string { } result := strings.Join(parts, ".") if len(parts) == 1 && table != "" { - result = escaped_table + "." + result + result = escapedTable + "." + result } escapedValue = result } @@ -228,3 +233,13 @@ func (bf BufferFactory) Create() Buffer { BoolFalseValue: bf.BoolFalseValue, } } + +// extract alias in the form of table as alias +// if no alias, table will be returned as alias +func extractAlias(input string) (string, string) { + if i := strings.Index(strings.ToLower(input), " as "); i > -1 { + return input[:i], input[i+4:] + } + + return input, input +} diff --git a/builder/query.go b/builder/query.go index a8136c5..001c3f3 100644 --- a/builder/query.go +++ b/builder/query.go @@ -105,20 +105,16 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) { for _, join := range joins { var ( - from = join.From - to = join.To + _, sAlias = extractAlias(table) + jTable, jAlias = extractAlias(join.Table) + from = join.From + to = join.To ) - jtable := join.Table - // If join table has alias use that for filter conditions - if i := strings.Index(strings.ToLower(jtable), " as "); i > -1 { - jtable = jtable[i+4:] - } - // TODO: move this to core functionality, and infer join condition using assoc data. if join.Arguments == nil && (join.From == "" || join.To == "") { - from = table + "." + strings.TrimSuffix(join.Table, "s") + "_id" - to = jtable + ".id" + from = sAlias + "." + strings.TrimSuffix(jTable, "s") + "_id" + to = jAlias + ".id" } buffer.WriteByte(' ') @@ -133,7 +129,7 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) { buffer.WriteEscape(to) if !join.Filter.None() { buffer.WriteString(" AND ") - q.Filter.Write(buffer, jtable, join.Filter, q) + q.Filter.Write(buffer, join.Table, join.Filter, q) } } diff --git a/builder/query_test.go b/builder/query_test.go index 3c33b18..23d1627 100644 --- a/builder/query_test.go +++ b/builder/query_test.go @@ -105,6 +105,36 @@ func TestQuery_Build(t *testing.T) { result: "SELECT `users`.* FROM `users` FOR UPDATE;", query: rel.From("users").Lock("FOR UPDATE"), }, + { + result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c`;", + query: rel.Select("c.id", "c.name").From("contacts as c"), + }, + { + result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;", + query: rel.Select("MAX(id)").From("contacts as c"), + }, + { + result: "SELECT MAX(`c`.`id`) FROM `contacts` AS `c`;", + query: rel.Select("MAX(c.id)").From("contacts as c"), + }, + { + result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;", + query: rel.Select("MAX(id) as max_id").From("contacts as c"), + }, + { + result: "SELECT MAX(`c`.`id`) AS `max_id` FROM `contacts` AS `c`;", + query: rel.Select("MAX(c.id) as max_id").From("contacts as c"), + }, + { + result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `c`.`user_id`=`u`.`id` WHERE `u`.`active`=?;", + args: []any{true}, + query: rel.Select("c.id", "c.name").From("contacts as c").Join("users as u").Where(rel.Eq("u.active", true)), + }, + { + result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c` JOIN `users` AS `u` ON `u`.`id`=`c`.`user_id` WHERE `u`.`active`=?;", + args: []any{true}, + query: rel.Select("c.id", "c.name").From("contacts as c").JoinOn("users as u", "u.id", "c.user_id").Where(rel.Eq("u.active", true)), + }, } for _, test := range tests {