Skip to content

Commit

Permalink
Support computed columns in manual joins
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Nov 3, 2022
1 parent c317f32 commit 1402f6e
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 38 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,23 @@ filter.Operators["$cont"] = &filter.Operator{
}
```

### Manual joins

Manual joins are supported and won't clash with joins that are automatically generated by the library. That means that if needed, you can write something like described in the following piece of code.

```go
func Index(response *goyave.Response, request *goyave.Request) {
var users []*model.User

db := database.GetConnection().Joins("Relation")

paginator, tx := filter.Scope(db, request, &users)
if response.HandleDatabaseError(tx) {
response.JSON(http.StatusOK, paginator)
}
}
```

## License

`goyave.dev/filter` is MIT Licensed. Copyright (c) 2021 Jérémy LAMBERT (SystemGlitch)
14 changes: 2 additions & 12 deletions filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,19 +444,9 @@ func TestFilterScopeWithAlreadyExistingJoin(t *testing.T) {
},
},
},
"SELECT": {
Name: "SELECT",
Expression: clause.Select{
Columns: []clause.Column{
// Base model fields are not selected because in this test we only execute the filter scope, not the select scope
{Raw: true, Name: "`Relation`.`name` `Relation__name`"},
{Raw: true, Name: "`Relation`.`id` `Relation__id`"},
{Raw: true, Name: "`Relation`.`parent_id` `Relation__parent_id`"},
},
},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
assert.Equal(t, expected["FROM"], db.Statement.Clauses["FROM"])
assert.Equal(t, expected["WHERE"], db.Statement.Clauses["WHERE"])
assert.Empty(t, db.Statement.Joins)
}

Expand Down
65 changes: 39 additions & 26 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,33 +202,8 @@ func joinExists(stmt *gorm.Statement, join clause.Join) bool {
// This is used to avoid duplicate joins that produce ambiguous column names and to
// support computed columns.
func findStatementJoin(stmt *gorm.Statement, relation *schema.Relationship, join *clause.Join) bool {
for i, j := range stmt.Joins {
for _, j := range stmt.Joins {
if j.Name == join.Table.Alias {
columnStmt := gorm.Statement{
Table: join.Table.Alias,
Schema: relation.FieldSchema,
Selects: j.Selects,
Omits: j.Omits,
}
if len(columnStmt.Selects) == 0 {
columnStmt.Selects = []string{"*"}
}

selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
j.Selects = nil
j.Omits = []string{"*"}
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
field := relation.FieldSchema.FieldsByDBName[s]
computed := field.StructField.Tag.Get("computed")
if computed != "" {
stmt.Selects = append(stmt.Selects, fmt.Sprintf("(%s) %s", strings.ReplaceAll(computed, clause.CurrentTable, quoteString(stmt, join.Table.Alias)), quoteString(stmt, join.Table.Alias+"__"+s)))
continue
}
stmt.Selects = append(stmt.Selects, fmt.Sprintf("%s.%s %s", quoteString(stmt, join.Table.Alias), quoteString(stmt, s), quoteString(stmt, join.Table.Alias+"__"+s)))
}
}
stmt.Joins[i] = j
return true
}
}
Expand All @@ -241,3 +216,41 @@ func quoteString(stmt *gorm.Statement, str string) string {
stmt.DB.Dialector.QuoteTo(writer, str)
return writer.String()
}

// processJoinsComputedColumns processes joins' Selects and Omit by adding them to the statement selects.
// Removes this information from the join afterwards to avoid Gorm reprocessing it.
// This is used to support computed columns with manual joins.
func processJoinsComputedColumns(stmt *gorm.Statement, sch *schema.Schema) {
for i, j := range stmt.Joins {
rel, ok := sch.Relationships.Relations[j.Name]
if !ok {
continue
}

columnStmt := gorm.Statement{
Table: j.Name,
Schema: rel.FieldSchema,
Selects: j.Selects,
Omits: j.Omits,
}
if len(columnStmt.Selects) == 0 {
columnStmt.Selects = []string{"*"}
}

selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
j.Selects = nil
j.Omits = []string{"*"}
for _, s := range rel.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
field := rel.FieldSchema.FieldsByDBName[s]
computed := field.StructField.Tag.Get("computed")
if computed != "" {
stmt.Selects = append(stmt.Selects, fmt.Sprintf("(%s) %s", strings.ReplaceAll(computed, clause.CurrentTable, quoteString(stmt, j.Name)), quoteString(stmt, j.Name+"__"+s)))
continue
}
stmt.Selects = append(stmt.Selects, fmt.Sprintf("%s.%s %s", quoteString(stmt, j.Name), quoteString(stmt, s), quoteString(stmt, j.Name+"__"+s)))
}
}
stmt.Joins[i] = j
}
}
7 changes: 7 additions & 0 deletions settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ func (s *Settings) scopeCommon(db *gorm.DB, request *goyave.Request, dest interf
}
}

db.Scopes(func(tx *gorm.DB) *gorm.DB {
// Convert all joins' selects to support computed columns
// for manual Joins.
processJoinsComputedColumns(tx.Statement, schema)
return tx
})

return db, schema, hasJoins
}

Expand Down
75 changes: 75 additions & 0 deletions settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1823,3 +1823,78 @@ func TestSettingsSelectWithExistingJoinAndComputedOmit(t *testing.T) {
assert.Equal(t, expected, db.Statement.Clauses)
assert.Empty(t, db.Statement.Joins)
}

func TestSettingsSelectWithExistingJoinAndComputedWithoutFiltering(t *testing.T) {
request := &goyave.Request{
Data: map[string]interface{}{
"per_page": 15,
},
Lang: "en-US",
}
db := openDryRunDB(t)

// Gorm will automatically select all the fields of the relation.
// We expect the computed fields to work properly.
db = db.Joins("Relation", db.Session(&gorm.Session{NewDB: true}).Where("Relation.id > ?", 0))

results := []*TestScopeModelWithComputed{}
paginator, db := Scope(db, request, results)

assert.NotNil(t, paginator)

expected := map[string]clause.Clause{
"FROM": {
Name: "FROM",
Expression: clause.From{
Joins: []clause.Join{
{
Type: clause.LeftJoin,
Table: clause.Table{
Name: "test_scope_relation_with_computeds",
Alias: "Relation",
},
ON: clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{
Table: clause.CurrentTable,
Name: "relation_id",
},
Value: clause.Column{
Table: "Relation",
Name: "id",
},
},
clause.Expr{SQL: "Relation.id > ?", Vars: []interface{}{0}},
},
},
},
},
},
},
"LIMIT": {
Expression: clause.Limit{
Limit: &fifteen,
Offset: 0,
},
},
"SELECT": {
Name: "SELECT",
Expression: clause.Select{
Columns: []clause.Column{
{Raw: true, Name: "`Relation`.`a` `Relation__a`"},
{Raw: true, Name: "`Relation`.`b` `Relation__b`"},
{Raw: true, Name: "(UPPER(`Relation`.b)) `Relation__c`"},
{Raw: true, Name: "`Relation`.`id` `Relation__id`"},
{Raw: true, Name: "`test_scope_model_with_computeds`.`name`"},
{Raw: true, Name: "`test_scope_model_with_computeds`.`email`"},
{Raw: true, Name: "(UPPER(`test_scope_model_with_computeds`.name)) `computed`"},
{Raw: true, Name: "`test_scope_model_with_computeds`.`id`"},
{Raw: true, Name: "`test_scope_model_with_computeds`.`relation_id`"},
},
},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
assert.Empty(t, db.Statement.Joins)
}

0 comments on commit 1402f6e

Please sign in to comment.