Skip to content

Commit

Permalink
Add CaseInsensitiveSort flag
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Mar 22, 2024
1 parent 503ea69 commit 0655ce0
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 16 deletions.
5 changes: 4 additions & 1 deletion README.md
Expand Up @@ -66,7 +66,10 @@ settings := &filter.Settings[*model.User]{
// If not nil and not empty, and if the request is not providing any
// sort, the request will be sorted according to the `*Sort` defined in this slice.
// If `DisableSort` is enabled, this has no effect.
DefaultSort: []*Sort{{Field: "name", Order: SortDescending}}
DefaultSort: []*Sort{{Field: "name", Order: SortDescending}},

// If true, the sort will wrap the value in `LOWER()` if it's a string, resulting in `ORDER BY LOWER(column)`.
CaseInsensitiveSort: true,

FieldsSearch: []string{"a", "b"}, // Optional, the fields used for the search feature
SearchOperator: filter.Operators["$eq"], // Optional, operator used for the search feature, defaults to "$cont"
Expand Down
6 changes: 5 additions & 1 deletion settings.go
Expand Up @@ -96,6 +96,10 @@ type Settings[T any] struct {
DisableJoin bool
// DisableSearch ignore the "search" query if true.
DisableSearch bool

// CaseInsensitiveSort if true, the sort will wrap the value in `LOWER()` if it's a string,
// resulting in `ORDER BY LOWER(column)`.
CaseInsensitiveSort bool
}

// Blacklist definition of blacklisted relations and fields.
Expand Down Expand Up @@ -248,7 +252,7 @@ func (s *Settings[T]) scopeSort(db *gorm.DB, request *Request, schema *schema.Sc

if !s.DisableSort {
for _, sort := range sorts {
if scope := sort.Scope(s.Blacklist, schema); scope != nil {
if scope := sort.Scope(s.Blacklist, schema, s.CaseInsensitiveSort); scope != nil {
db = db.Scopes(scope)
}
}
Expand Down
55 changes: 55 additions & 0 deletions settings_test.go
Expand Up @@ -16,6 +16,7 @@ import (
)

var fifteen = 15
var ten = 10

type TestScopeRelation struct {
A string
Expand Down Expand Up @@ -2080,3 +2081,57 @@ func TestNewRequest(t *testing.T) {
}

}

func TestScopeWithCaseInsensitiveSort(t *testing.T) {
request := &Request{
Sort: typeutil.NewUndefined([]*Sort{{Field: "name", Order: SortAscending}}),
}
db := openDryRunDB(t)

results := []*TestScopeModel{}
settings := &Settings[*TestScopeModel]{
CaseInsensitiveSort: true,
}
paginator, err := settings.Scope(db, request, &results)
assert.NotNil(t, paginator)
assert.NoError(t, err)

expected := map[string]clause.Clause{
"ORDER BY": {
Name: "ORDER BY",
Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{
{
Column: clause.Column{
Raw: true,
Name: "LOWER(`test_scope_models`.`name`)",
},
},
},
},
},
"FROM": {
Name: "FROM",
Expression: clause.From{},
},
"SELECT": {
Name: "SELECT",
Expression: clause.Select{
Columns: []clause.Column{
{Raw: true, Name: "`test_scope_models`.`name`"},
{Raw: true, Name: "`test_scope_models`.`email`"},
{Raw: true, Name: "(UPPER(`test_scope_models`.name)) `computed`"},
{Raw: true, Name: "`test_scope_models`.`id`"},
{Raw: true, Name: "`test_scope_models`.`relation_id`"},
},
},
},
"LIMIT": {
Expression: clause.Limit{
Limit: &ten,
Offset: 0,
},
},
}
assert.Equal(t, expected, paginator.DB.Statement.Clauses)
}
10 changes: 8 additions & 2 deletions sort.go
Expand Up @@ -27,7 +27,8 @@ const (
)

// Scope returns the GORM scope to use in order to apply sorting.
func (s *Sort) Scope(blacklist Blacklist, schema *schema.Schema) func(*gorm.DB) *gorm.DB {
// If caseInsensitive is true, the column is wrapped in a `LOWER()` function.
func (s *Sort) Scope(blacklist Blacklist, schema *schema.Schema, caseInsensitive bool) func(*gorm.DB) *gorm.DB {
field, sch, joinName := getField(s.Field, schema, &blacklist)
if field == nil {
return nil
Expand All @@ -51,9 +52,14 @@ func (s *Sort) Scope(blacklist Blacklist, schema *schema.Schema) func(*gorm.DB)
Raw: true,
Name: fmt.Sprintf("(%s)", strings.ReplaceAll(computed, clause.CurrentTable, tx.Statement.Quote(table))),
}
} else if caseInsensitive && getDataType(field) == DataTypeText {
column = clause.Column{
Raw: true,
Name: fmt.Sprintf("LOWER(%s.%s)", tx.Statement.Quote(table), tx.Statement.Quote(field.DBName)),
}
} else {
column = clause.Column{
Table: tableFromJoinName(sch.Table, joinName),
Table: table,
Name: field.DBName,
}
}
Expand Down

0 comments on commit 0655ce0

Please sign in to comment.