Skip to content

Commit

Permalink
Add type-safety in filters
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Mar 30, 2023
1 parent 1402f6e commit 3396308
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 64 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: [1.17, 1.18, 1.19]
go: ["1.17", "1.18", "1.19", "1.20"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
Expand All @@ -23,7 +23,7 @@ jobs:
- name: Run tests
run: |
go test -v -race -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... ./...
- if: ${{ matrix.go == 1.19 }}
- if: ${{ matrix.go == 1.20 }}
uses: shogo82148/actions-goveralls@v1
with:
path-to-profile: coverage.txt
Expand All @@ -36,5 +36,5 @@ jobs:
- name: Run lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.50
version: v1.52
args: --timeout 5m
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ type MyModelWithStatus struct{

- Inputs are escaped to prevent SQL injections.
- Fields are pre-processed and clients cannot request fields that don't exist. This prevents database errors. If a non-existing field is required, it is simply ignored. The same goes for sorts and joins. It is not possible to request a relation that doesn't exist.
- Type-safety: in the same field pre-processing, the broad type of the field is checked against the database type (based on the model definition). This prevents database errors if the input cannot be converted to the column's type.
- Foreign keys are always selected in joins to ensure associations can be assigned to parent model.
- **Be careful** with bidirectional relations (for example an article is written by a user, and a user can have many articles). If you enabled both your models to preload these relations, the client can request them with an infinite depth (`Articles.User.Articles.User...`). To prevent this, it is advised to use **the relation blacklist** or **IsFinal** on the deepest requestable models. See the settings section for more details.

Expand All @@ -251,6 +252,27 @@ type MyModelWithStatus struct{
- Don't use `gorm.Model` and add the necessary fields manually. You get better control over json struct tags this way.
- Use pointers for nullable relations and nullable fields that implement `sql.Scanner` (such as `null.Time`).

### Filter type

For non-primitive types (such as `*null.Time`), you should always use the `filter_type` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion.

Available broad types are:
- `text`
- `bool`
- `int`
- `uint`
- `float`
- `time`

**Example**
```go
type MyModel struct{
ID uint
// ...
StartDate null.Time `filter_type:"time"`
}
```

### Static conditions

If you want to add static conditions (not automatically defined by the library), it is advised to group them like so:
Expand Down Expand Up @@ -279,10 +301,25 @@ import (
// ...

filter.Operators["$cont"] = &filter.Operator{
Function: func(tx *gorm.DB, filter *filter.Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB {
if dataType != schema.String {
return tx
}
query := column + " LIKE ?"
value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%"
return filter.Where(tx, query, value)
value := "%" + sqlutil.EscapeLike(f.Args[0]) + "%"
return f.Where(tx, query, value)
},
RequiredArguments: 1,
}

filter.Operators["$eq"] = &filter.Operator{
Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB {
arg, ok := filter.ConvertToSafeType(f.Args[0], dataType)
if !ok {
return tx
}
query := fmt.Sprintf("%s = ?", column, op)
return f.Where(tx, query, arg)
},
RequiredArguments: 1,
}
Expand Down
7 changes: 6 additions & 1 deletion filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) *
} else {
fieldExpr = table + "." + tx.Statement.Quote(field.DBName)
}
return f.Operator.Function(tx, f, fieldExpr, field.DataType)

dataType := getDataType(field)
if dataType == DataTypeUnsupported {
return tx
}
return f.Operator.Function(tx, f, fieldExpr, dataType)
}

return joinScope, conditionScope
Expand Down
4 changes: 2 additions & 2 deletions filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestFilterScope(t *testing.T) {
schema := &schema.Schema{
DBNames: []string{"name"},
FieldsByDBName: map[string]*schema.Field{
"name": {Name: "Name", DBName: "name"},
"name": {Name: "Name", DBName: "name", DataType: schema.String},
},
Table: "test_scope_models",
}
Expand Down Expand Up @@ -346,7 +346,7 @@ func TestFilterScopeWithJoinDontDuplicate(t *testing.T) {
Expression: clause.Where{
Exprs: []clause.Expression{
clause.Expr{SQL: "`Relation`.`name` = ?", Vars: []interface{}{"val1"}},
clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{"0"}},
clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{uint64(0)}},
},
},
},
Expand Down
4 changes: 2 additions & 2 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func join(tx *gorm.DB, joinName string, sch *schema.Schema) *gorm.DB {
Table: clause.Table{Name: sch.Table, Alias: relation.Name},
ON: clause.Where{Exprs: exprs},
}
if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, relation, &j) {
if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, &j) {
joins = append(joins, j)
}
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func joinExists(stmt *gorm.Statement, join clause.Join) bool {
// Removes this information from the join afterwards to avoid Gorm reprocessing it.
// This is used to avoid duplicate joins that produce ambiguous column names and to
// support computed columns.
func findStatementJoin(stmt *gorm.Statement, relation *schema.Relationship, join *clause.Join) bool {
func findStatementJoin(stmt *gorm.Statement, join *clause.Join) bool {
for _, j := range stmt.Joins {
if j.Name == join.Table.Alias {
return true
Expand Down
65 changes: 45 additions & 20 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ import (
"fmt"

"gorm.io/gorm"
"gorm.io/gorm/schema"
"goyave.dev/goyave/v4/util/sqlutil"
)

// Operator used by filters to build the SQL query.
// The operator function modifies the GORM statement (most of the time by adding
// a WHERE condition) then returns the modified statement.
//
// Operators may need arguments (e.g. "$eq", equals needs a value to compare the field to);
// RequiredArguments define the minimum number of arguments a client must send in order to
// use this operator in a filter. RequiredArguments is checked during Filter parsing.
//
// Operators may return the given tx without change if they don't support the given dataType.
type Operator struct {
Function func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB
Function func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB
RequiredArguments uint8
}

Expand All @@ -30,31 +31,43 @@ var (
"$gte": {Function: basicComparison(">="), RequiredArguments: 1},
"$lte": {Function: basicComparison("<="), RequiredArguments: 1},
"$starts": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeText {
return tx
}
query := column + " LIKE ?"
value := sqlutil.EscapeLike(filter.Args[0]) + "%"
return filter.Where(tx, query, value)
},
RequiredArguments: 1,
},
"$ends": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeText {
return tx
}
query := column + " LIKE ?"
value := "%" + sqlutil.EscapeLike(filter.Args[0])
return filter.Where(tx, query, value)
},
RequiredArguments: 1,
},
"$cont": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeText {
return tx
}
query := column + " LIKE ?"
value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%"
return filter.Where(tx, query, value)
},
RequiredArguments: 1,
},
"$excl": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeText {
return tx
}
query := column + " NOT LIKE ?"
value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%"
return filter.Where(tx, query, value)
Expand All @@ -64,55 +77,67 @@ var (
"$in": {Function: multiComparison("IN"), RequiredArguments: 1},
"$notin": {Function: multiComparison("NOT IN"), RequiredArguments: 1},
"$isnull": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
return filter.Where(tx, column+" IS NULL")
},
RequiredArguments: 0,
},
"$istrue": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
if dataType != schema.Bool {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeBool {
return tx
}
return filter.Where(tx, column+" IS TRUE")
},
RequiredArguments: 0,
},
"$isfalse": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
if dataType != schema.Bool {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
if dataType != DataTypeBool {
return tx
}
return filter.Where(tx, column+" IS FALSE")
},
RequiredArguments: 0,
},
"$notnull": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
return filter.Where(tx, column+" IS NOT NULL")
},
RequiredArguments: 0,
},
"$between": {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
args, ok := ConvertArgsToSafeType(filter.Args[:2], dataType)
if !ok {
return tx
}
query := column + " BETWEEN ? AND ?"
return filter.Where(tx, query, filter.Args[0], filter.Args[1])
return filter.Where(tx, query, args...)
},
RequiredArguments: 2,
},
}
)

func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
return func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
arg, ok := ConvertToSafeType(filter.Args[0], dataType)
if !ok {
return tx
}
query := fmt.Sprintf("%s %s ?", column, op)
return filter.Where(tx, query, filter.Args[0])
return filter.Where(tx, query, arg)
}
}

func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
return func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB {
func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB {
args, ok := ConvertArgsToSafeType(filter.Args, dataType)
if !ok {
return tx
}
query := fmt.Sprintf("%s %s ?", column, op)
return filter.Where(tx, query, filter.Args)
return filter.Where(tx, query, args)
}
}
Loading

0 comments on commit 3396308

Please sign in to comment.