Skip to content

Commit

Permalink
feat: parse scim query filters
Browse files Browse the repository at this point in the history
- pass query to filter builder
- use switches to map supported fields
  • Loading branch information
BruceMacD committed Oct 13, 2022
1 parent 06c11c9 commit 12a12f6
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 74 deletions.
5 changes: 2 additions & 3 deletions internal/server/data/provideruser.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,11 @@ func ListProviderUsers(tx ReadTxn, providerID uid.ID, p *SCIMParameters) ([]mode
query.B("INNER JOIN providers ON provider_users.provider_id = providers.id AND providers.organization_id = ?", tx.OrganizationID())
query.B("WHERE provider_id = ?", providerID)
if p != nil && p.Filter != nil {
filter, err := filterSQL(*p.Filter)
query.B("AND (")
err := filterSQL(*p.Filter, query)
if err != nil {
return nil, fmt.Errorf("apply filter: %w", err)
}
query.B("AND (")
query.B(filter)
query.B(")")
}
query.B("ORDER BY email ASC")
Expand Down
91 changes: 50 additions & 41 deletions internal/server/data/scim.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,75 @@ import (
"fmt"
"strings"

"github.com/infrahq/infra/internal/server/data/querybuilder"
"github.com/scim2/filter-parser/v2"
)

// supportedColumns maps SCIM input filters to provider user database columns
var supportedColumns = map[string]string{
"id": "identity_id",
"userName": "email",
"email": "email",
"name.givenName": "givenName",
"name.familyName": "familyName",
"active": "active",
}

func filterSQL(e filter.Expression) (string, error) {
func filterSQL(e filter.Expression, query *querybuilder.Query) error {
switch v := e.(type) {
case *filter.LogicalExpression:
l, err := filterSQL(v.Left)
err := filterSQL(v.Left, query)
if err != nil {
return "", fmt.Errorf("left: %w", err)
return fmt.Errorf("left: %w", err)
}
r, err := filterSQL(v.Right)
query.B(strings.ToUpper(string(v.Operator)))
err = filterSQL(v.Right, query)
if err != nil {
return "", fmt.Errorf("right: %w", err)
return fmt.Errorf("right: %w", err)
}
return fmt.Sprintf("%s %s %s", l, strings.ToUpper(string(v.Operator)), r), nil
return nil
case *filter.AttributeExpression:
comparison, err := sqlComparator(v.Operator, v.CompareValue)
err := sqlColumn(v.AttributePath, query)
if err != nil {
return "", fmt.Errorf("attribute comparator: %w", err)
return fmt.Errorf("attribute path: %w", err)
}
column, err := sqlColumn(v.AttributePath)
err = sqlComparator(v.Operator, v.CompareValue, query)
if err != nil {
return "", fmt.Errorf("attribute path: %w", err)
return fmt.Errorf("attribute comparator: %w", err)
}
return fmt.Sprintf("%s %s", column, comparison), nil
return nil
}
return "", fmt.Errorf("unable to parse filter, unrecognized format")
return fmt.Errorf("unable to parse filter, unrecognized format")
}

func sqlColumn(a filter.AttributePath) (string, error) {
if supportedColumns[a.String()] == "" {
return "", fmt.Errorf("unsupported filter attribute: %q", a)
// sqlColumns maps SCIM input filters to provider user database columns
func sqlColumn(a filter.AttributePath, query *querybuilder.Query) error {
switch a.String() {
case "id":
query.B("identity_id")
case "userName":
query.B("email")
case "email":
query.B("email")
case "name.givenName":
query.B("givenName")
case "name.familyName":
query.B("familyName")
case "active":
query.B("active")
default:
return fmt.Errorf("unsupported filter attribute: %q", a)
}
return supportedColumns[a.String()], nil
return nil
}

func sqlComparator(c filter.CompareOperator, compare any) (string, error) {
switch {
case c == filter.PR:
return "IS NOT NULL", nil
case c == filter.EQ:
return fmt.Sprintf("= '%s'", compare), nil
case c == filter.NE:
return fmt.Sprintf("!= '%s'", compare), nil
case c == filter.SW:
return fmt.Sprintf("LIKE '%s%%'", compare), nil
case c == filter.CO:
return fmt.Sprintf("LIKE '%%%s%%'", compare), nil
case c == filter.EW:
return fmt.Sprintf("LIKE '%%%s'", compare), nil
func sqlComparator(c filter.CompareOperator, compare any, query *querybuilder.Query) error {
switch c {
case filter.PR:
query.B("IS NOT NULL")
case filter.EQ:
query.B("= ?", compare)
case filter.NE:
query.B("!= ?", compare)
case filter.SW:
query.B("LIKE ?", compare.(string)+"%")
case filter.CO:
query.B("LIKE ?", "%"+compare.(string)+"%")
case filter.EW:
query.B("LIKE ?", "%"+compare.(string))
default:
return fmt.Errorf("upsupported comparator: %q", c)
}
return "", fmt.Errorf("upsupported comparator: %q", c)

return nil
}
74 changes: 44 additions & 30 deletions internal/server/data/scim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,67 +3,80 @@ package data
import (
"testing"

"github.com/infrahq/infra/internal/server/data/querybuilder"
"github.com/scim2/filter-parser/v2"
"gotest.tools/v3/assert"
)

func TestFilterParser(t *testing.T) {
type testCase struct {
name string
expression string
expected string
name string
expression string
expectedQuery string
expectedArgs []any
}

testCases := []testCase{
{
name: "equality",
expression: "id eq \"a1234\"",
expected: "identity_id = 'a1234'",
name: "equality",
expression: "id eq \"a1234\"",
expectedQuery: " identity_id = ? ",
expectedArgs: []any{"a1234"},
},
{
name: "present",
expression: "userName pr",
expected: "email IS NOT NULL",
name: "present",
expression: "userName pr",
expectedQuery: " email IS NOT NULL ",
},
{
name: "not equal",
expression: "email ne \"hello@example.com\"",
expected: "email != 'hello@example.com'",
name: "not equal",
expression: "email ne \"hello@example.com\"",
expectedQuery: " email != ? ",
expectedArgs: []any{"hello@example.com"},
},
{
name: "starts with",
expression: "name.givenName sw \"S\"",
expected: "givenName LIKE 'S%'",
name: "starts with",
expression: "name.givenName sw \"S\"",
expectedQuery: " givenName LIKE ? ",
expectedArgs: []any{"S%"},
},
{
name: "contains",
expression: "name.familyName co \"S\"",
expected: "familyName LIKE '%S%'",
name: "contains",
expression: "name.familyName co \"S\"",
expectedQuery: " familyName LIKE ? ",
expectedArgs: []any{"%S%"},
},
{
name: "ends with",
expression: "userName ew \"S\"",
expected: "email LIKE '%S'",
name: "ends with",
expression: "userName ew \"S\"",
expectedQuery: " email LIKE ? ",
expectedArgs: []any{"%S"},
},
{
name: "logical and",
expression: "(email eq \"M\") and (email eq \"W\")",
expected: "email = 'M' AND email = 'W'",
name: "logical and",
expression: "(email eq \"M\") and (email eq \"W\")",
expectedQuery: " email = ? AND email = ? ",
expectedArgs: []any{"M", "W"},
},
{
name: "logical or",
expression: "(email eq \"M\") or (email eq \"W\")",
expected: "email = 'M' OR email = 'W'",
name: "logical or",
expression: "(email eq \"M\") or (email eq \"W\")",
expectedQuery: " email = ? OR email = ? ",
expectedArgs: []any{"M", "W"},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
exp, err := filter.ParseFilter([]byte(tc.expression))
assert.NilError(t, err)
result, err := filterSQL(exp)
query := querybuilder.New("")
err = filterSQL(exp, query)
assert.NilError(t, err)
assert.Equal(t, result, tc.expected)
assert.Equal(t, query.String(), tc.expectedQuery)
if tc.expectedArgs != nil {
assert.DeepEqual(t, query.Args, tc.expectedArgs)
}
})
}
}
Expand Down Expand Up @@ -92,7 +105,8 @@ func TestFilterParserError(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
exp, err := filter.ParseFilter([]byte(tc.expression))
assert.NilError(t, err)
_, err = filterSQL(exp)
query := querybuilder.New("")
err = filterSQL(exp, query)
assert.ErrorContains(t, err, tc.expectedErrMsg)
})
}
Expand Down

0 comments on commit 12a12f6

Please sign in to comment.