diff --git a/ast/builtins.go b/ast/builtins.go index 17128961f0..661d65d3d1 100644 --- a/ast/builtins.go +++ b/ast/builtins.go @@ -25,7 +25,7 @@ var DefaultBuiltins = [...]*Builtin{ Count, Sum, Max, ToNumber, RegexMatch, - Concat, FormatInt, + Concat, FormatInt, IndexOf, Substring, Lower, Upper, Contains, StartsWith, EndsWith, } // BuiltinMap provides a convenient mapping of built-in names to @@ -196,6 +196,53 @@ var FormatInt = &Builtin{ TargetPos: []int{2}, } +// IndexOf returns the index of a substring contained inside a string +var IndexOf = &Builtin{ + Name: Var("indexof"), + NumArgs: 3, + TargetPos: []int{2}, +} + +// Substring returns the portion of a string for a given start index and a length. +// If the length is less than zero, then substring returns the remainder of the string. +var Substring = &Builtin{ + Name: Var("substring"), + NumArgs: 4, + TargetPos: []int{3}, +} + +// Contains returns true if the search string is included in the base string +var Contains = &Builtin{ + Name: Var("contains"), + NumArgs: 2, +} + +// StartsWith returns true if the search string begins with the base string +var StartsWith = &Builtin{ + Name: Var("startswith"), + NumArgs: 2, +} + +// EndsWith returns true if the search string begins with the base string +var EndsWith = &Builtin{ + Name: Var("endswith"), + NumArgs: 2, +} + +// Lower returns the input string but with all characters in lower-case +var Lower = &Builtin{ + Name: Var("lower"), + NumArgs: 2, + TargetPos: []int{1}, +} + +// Upper returns the input string but with all characters in upper-case +var Upper = &Builtin{ + Name: Var("upper"), + NumArgs: 2, + TargetPos: []int{1}, +} + // Builtin represents a built-in function supported by OPA. Every // built-in function is uniquely identified by a name. type Builtin struct { diff --git a/site/documentation/how-do-i-write-policies/index.md b/site/documentation/how-do-i-write-policies/index.md index b981f1cbcd..a51a955c31 100644 --- a/site/documentation/how-do-i-write-policies/index.md +++ b/site/documentation/how-do-i-write-policies/index.md @@ -914,18 +914,49 @@ false ### Strings ```ruby -> format_int(15.5, 16, x) -+-----+ -| x | -+-----+ -| "f" | -+-----+ > concat("/", ["", "foo", "bar", "baz"], x) +----------------+ | x | +----------------+ | "/foo/bar/baz" | +----------------+ +> contains("abcdef", "cde") +true +> endswith("abcdef", "def") +true +> format_int(15.5, 16, x) ++-----+ +| x | ++-----+ +| "f" | ++-----+ +> indexof("abcdefg", "cde", x) ++---+ +| x | ++---+ +| 2 | ++---+ +> lower("AbCdEf", x) ++----------+ +| x | ++----------+ +| "abcdef" | ++----------+ +> startswith("abcdef", "abcd") +true +> substring("abcdef", 2, 3, x) ++-------+ +| x | ++-------+ +| "cde" | ++-------+ +> upper("AbCdEf", x) ++----------+ +| x | ++----------+ +| "ABCDEF" | ++----------+ + ``` ## Examples diff --git a/site/documentation/references/language/index.md b/site/documentation/references/language/index.md index 3be5dba5da..78ebd298dc 100644 --- a/site/documentation/references/language/index.md +++ b/site/documentation/references/language/index.md @@ -27,44 +27,51 @@ complex types. | Built-in | Inputs | Description | | ------- |--------|-------------| -| ``x != y`` | 2 | x is not equal to y | -| ``x < y`` | 2 | x is less than y | -| ``x <= y`` | 2 | x is less than or equal to y | -| ``x > y`` | 2 | x is greater than y | -| ``x >= y`` | 2 | x is greater than or equal to y | +| ``x != y`` | 2 | ``x`` is not equal to ``y`` | +| ``x < y`` | 2 | ``x`` is less than ``y`` | +| ``x <= y`` | 2 | ``x`` is less than or equal to ``y`` | +| ``x > y`` | 2 | ``x`` is greater than ``y`` | +| ``x >= y`` | 2 | ``x`` is greater than or equal to ``y`` | ### Numbers | Built-in | Inputs | Description | | ------- |--------|-------------| -| ``plus(x, y, output)`` | 2 | x + y = output | -| ``minus(x, y, output)`` | 2 | x - y = output | -| ``mul(x, y, output)`` | 2 | x * y = output | -| ``div(x, y, output)`` | 2 | x / y = output | -| ``round(x, output)`` | 1 | output is x rounded to the nearest integer | -| ``abs(x, output)`` | 1 | output is the absolute value of x | +| ``plus(x, y, output)`` | 2 | ``x`` + ``y`` = ``output`` | +| ``minus(x, y, output)`` | 2 | ``x`` - ``y`` = ``output`` | +| ``mul(x, y, output)`` | 2 | ``x`` * ``y`` = ``output`` | +| ``div(x, y, output)`` | 2 | ``x`` / ``y`` = ``output`` | +| ``round(x, output)`` | 1 | ``output`` is ``x`` rounded to the nearest integer | +| ``abs(x, output)`` | 1 | ``output`` is the absolute value of ``x`` | ### Aggregates | Built-in | Inputs | Description | | ------- |--------|-------------| -| ``count(collection, output)`` | 1 | output is the length of the object, array, or set | -| ``sum(array_or_set, output)`` | 1 | output is the sum of the numbers in array or set | -| ``max(array_or_set, output)`` | 1 | output is the maximum value in the array or set | +| ``count(collection, output)`` | 1 | ``output`` is the length of the object, array, or set ``collection`` | +| ``sum(array_or_set, output)`` | 1 | ``output`` is the sum of the numbers in ``array_or_set`` | +| ``max(array_or_set, output)`` | 1 | ``output`` is the maximum value in ``array_or_set`` | ### Types | Built-in | Inputs | Description | | ------- |--------|-------------| -| ``to_number(x, output)`` | 1 | output is x converted to a number | +| ``to_number(x, output)`` | 1 | ``output`` is ``x`` converted to a number | ### Strings | Built-in | Inputs | Description | | ------- |--------|-------------| -| ``format_int(number, base, output)`` | 2 | output is string representation of number in given base | -| ``concat(join, array_or_set, output)`` | 2 | output is the result of concatenating the elements of array or set with the join string | -|``re_match(pattern, value)`` | 2 | true if the value matches the pattern | +| ``concat(join, array_or_set, output)`` | 2 | ``output`` is the result of concatenating the elements of ``array_or_set`` with the string ``join`` | +| ``contains(string, search)`` | 2 | true if ``string`` contains ``search`` | +| ``endswith(string, search)`` | 2 | true if ``string`` ends with ``search`` | +| ``format_int(number, base, output)`` | 2 | ``output`` is string representation of ``number`` in the given ``base`` | +| ``indexof(string, search, output)`` | 2 | ``output`` is the index inside ``string`` where ``search`` first occurs, or -1 if ``search`` does not exist | +| ``lower(string, output)`` | 1 | ``output`` is ``string`` after converting to lower case | +| ``re_match(pattern, value)`` | 2 | true if the value matches the pattern | +| ``startswith(string, search)`` | 2 | true if ``string`` begins with ``search`` | +| ``substring(string, start, length, output)`` | 2 | ``output`` is the portion of ``string`` from index ``start`` and having a length of ``length``. If ``length`` is less than zero, ``length`` is the remainder of the ``string``. | +| ``upper(string, output)`` | 1 | ``output`` is ``string`` after converting to upper case | ## Grammar diff --git a/topdown/builtins.go b/topdown/builtins.go index 55b5d2238b..c47736f6e7 100644 --- a/topdown/builtins.go +++ b/topdown/builtins.go @@ -44,6 +44,13 @@ var defaultBuiltinFuncs = map[ast.Var]BuiltinFunc{ ast.RegexMatch.Name: evalRegexMatch, ast.FormatInt.Name: evalFormatInt, ast.Concat.Name: evalConcat, + ast.IndexOf.Name: evalIndexOf, + ast.Substring.Name: evalSubstring, + ast.Contains.Name: evalContains, + ast.StartsWith.Name: evalStartsWith, + ast.EndsWith.Name: evalEndsWith, + ast.Upper.Name: evalUpper, + ast.Lower.Name: evalLower, } func init() { diff --git a/topdown/strings.go b/topdown/strings.go index e3fad7bd06..964365ffef 100644 --- a/topdown/strings.go +++ b/topdown/strings.go @@ -54,3 +54,140 @@ func evalConcat(ctx *Context, expr *ast.Expr, iter Iterator) error { ctx.Unbind(undo) return err } + +func evalIndexOf(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + base, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: base value must be a string", ast.IndexOf.Name) + } + + search, err := ValueToString(ops[2].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: search value must be a string", ast.IndexOf.Name) + } + + index := ast.Number(strings.Index(base, search)) + + undo, err := evalEqUnify(ctx, index, ops[3].Value, nil, iter) + ctx.Unbind(undo) + return err +} + +func evalSubstring(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + base, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: base value must be a string", ast.Substring.Name) + } + + startIndex, err := ValueToInt(ops[2].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: start index must be a number", ast.Substring.Name) + } + + l, err := ValueToInt(ops[3].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: length must be a number", ast.Substring.Name) + } + + var s ast.String + if l < 0 { + s = ast.String(base[startIndex:]) + } else { + s = ast.String(base[startIndex : startIndex+l]) + } + + undo, err := evalEqUnify(ctx, s, ops[4].Value, nil, iter) + ctx.Unbind(undo) + return err +} + +func evalContains(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + base, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: base value must be a string", ast.Contains.Name) + } + + search, err := ValueToString(ops[2].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: search must be a string", ast.Contains.Name) + } + + if strings.Contains(base, search) { + return iter(ctx) + } + return nil +} + +func evalStartsWith(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + base, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: base value must be a string", ast.StartsWith.Name) + } + + search, err := ValueToString(ops[2].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: search must be a string", ast.StartsWith.Name) + } + + if strings.HasPrefix(base, search) { + return iter(ctx) + } + return nil +} + +func evalEndsWith(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + base, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: base value must be a string", ast.EndsWith.Name) + } + + search, err := ValueToString(ops[2].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: search must be a string", ast.EndsWith.Name) + } + + if strings.HasSuffix(base, search) { + return iter(ctx) + } + return nil +} + +func evalLower(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + orig, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: original value must be a string", ast.Lower.Name) + } + + s := ast.String(strings.ToLower(orig)) + + undo, err := evalEqUnify(ctx, s, ops[2].Value, nil, iter) + ctx.Unbind(undo) + return err +} + +func evalUpper(ctx *Context, expr *ast.Expr, iter Iterator) error { + ops := expr.Terms.([]*ast.Term) + + orig, err := ValueToString(ops[1].Value, ctx) + if err != nil { + return errors.Wrapf(err, "%v: original value must be a string", ast.Upper.Name) + } + + s := ast.String(strings.ToUpper(orig)) + + undo, err := evalEqUnify(ctx, s, ops[2].Value, nil, iter) + ctx.Unbind(undo) + return err +} diff --git a/topdown/topdown.go b/topdown/topdown.go index 2798fbf790..914ec0ff65 100644 --- a/topdown/topdown.go +++ b/topdown/topdown.go @@ -6,6 +6,7 @@ package topdown import ( "fmt" + "math" "sync" "github.com/open-policy-agent/opa/ast" @@ -697,6 +698,19 @@ func ValueToFloat64(v ast.Value, ctx *Context) (float64, error) { return f, nil } +// ValueToInt returns the underlying Go value associated with an AST value. +// If the value is a reference, the reference is fetched from storage. +func ValueToInt(v ast.Value, ctx *Context) (int64, error) { + x, err := ValueToFloat64(v, ctx) + if err != nil { + return 0, err + } + if x != math.Floor(x) { + return 0, fmt.Errorf("illegal argument: %v", v) + } + return int64(x), nil +} + // ValueToString returns the underlying Go value associated with an AST value. // If the value is a reference, the reference is fetched from storage. func ValueToString(v ast.Value, ctx *Context) (string, error) { diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 1464a2f5fb..f374f67923 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -1040,6 +1040,30 @@ func TestTopDownStrings(t *testing.T) { {"concat: non-string err", []string{`p = x :- concat("/", ["", "foo", "bar", 0, "baz"], x)`}, fmt.Errorf("concat: input value must be array of strings: illegal argument: 0")}, {"concat: ref dest", []string{`p :- concat("", ["f", "o", "o"], c[0].x[2])`}, "true"}, {"concat: ref dest (2)", []string{`p :- not concat("", ["b", "a", "r"], c[0].x[2])`}, "true"}, + {"indexof", []string{`p = x :- indexof("abcdefgh", "cde", x)`}, "2"}, + {"indexof: not found", []string{`p = x :- indexof("abcdefgh", "xyz", x)`}, "-1"}, + {"indexof: error", []string{`p = x :- indexof("abcdefgh", 1, x)`}, fmt.Errorf("indexof: search value must be a string: illegal argument: 1")}, + {"substring", []string{`p = x :- substring("abcdefgh", 2, 3, x)`}, `"cde"`}, + {"substring: remainder", []string{`p = x :- substring("abcdefgh", 2, -1, x)`}, `"cdefgh"`}, + {"substring: error 1", []string{`p = x :- substring(17, "xyz", 3, x)`}, fmt.Errorf("substring: base value must be a string: illegal argument: 17")}, + {"substring: error 2", []string{`p = x :- substring("abcdefgh", "xyz", 3, x)`}, fmt.Errorf(`substring: start index must be a number: illegal argument: "xyz"`)}, + {"substring: error 3", []string{`p = x :- substring("abcdefgh", 2, "xyz", x)`}, fmt.Errorf(`substring: length must be a number: illegal argument: "xyz"`)}, + {"contains", []string{`p :- contains("abcdefgh", "defg")`}, "true"}, + {"contains: undefined", []string{`p :- contains("abcdefgh", "ac")`}, ""}, + {"contains: error 1", []string{`p :- contains(17, "ac")`}, fmt.Errorf(`contains: base value must be a string: illegal argument: 17`)}, + {"contains: error 2", []string{`p :- contains("abcdefgh", 17)`}, fmt.Errorf(`contains: search must be a string: illegal argument: 17`)}, + {"startswith", []string{`p :- startswith("abcdefgh", "abcd")`}, "true"}, + {"startswith: undefined", []string{`p :- startswith("abcdefgh", "bcd")`}, ""}, + {"startswith: error 1", []string{`p :- startswith(17, "bcd")`}, fmt.Errorf(`startswith: base value must be a string: illegal argument: 17`)}, + {"startswith: error 2", []string{`p :- startswith("abcdefgh", 17)`}, fmt.Errorf(`startswith: search must be a string: illegal argument: 17`)}, + {"endswith", []string{`p :- endswith("abcdefgh", "fgh")`}, "true"}, + {"endswith: undefined", []string{`p :- endswith("abcdefgh", "fg")`}, ""}, + {"endswith: error 1", []string{`p :- endswith(17, "bcd")`}, fmt.Errorf(`endswith: base value must be a string: illegal argument: 17`)}, + {"endswith: error 2", []string{`p :- endswith("abcdefgh", 17)`}, fmt.Errorf(`endswith: search must be a string: illegal argument: 17`)}, + {"lower", []string{`p = x :- lower("AbCdEf", x)`}, `"abcdef"`}, + {"lower error", []string{`p = x :- lower(true, x)`}, fmt.Errorf("lower: original value must be a string: illegal argument: true")}, + {"upper", []string{`p = x :- upper("AbCdEf", x)`}, `"ABCDEF"`}, + {"upper error", []string{`p = x :- upper(true, x)`}, fmt.Errorf("upper: original value must be a string: illegal argument: true")}, } data := loadSmallTestData()