From e2dae058666d54a1e2722d3dc3ba576da6ec5a03 Mon Sep 17 00:00:00 2001 From: Jack Conradson Date: Thu, 17 Nov 2016 13:39:08 -0800 Subject: [PATCH] Fix reserved variable availability in lambdas in Painless --- .../org/elasticsearch/painless/Locals.java | 77 ++++++++++------- .../painless/node/SFunction.java | 7 +- .../elasticsearch/painless/node/SSource.java | 2 +- .../elasticsearch/painless/LambdaTests.java | 83 ++++++++++++------- .../test/plan_a/20_scriptfield.yaml | 44 ++++++++++ 5 files changed, 147 insertions(+), 66 deletions(-) diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Locals.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Locals.java index b02ea085904ac..5588c943dd2a0 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Locals.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Locals.java @@ -36,7 +36,7 @@ * Tracks user defined methods and variables across compilation phases. */ public final class Locals { - + /** Reserved word: params map parameter */ public static final String PARAMS = "params"; /** Reserved word: Lucene scorer parameter */ @@ -53,25 +53,35 @@ public final class Locals { public static final String THIS = "#this"; /** Reserved word: unused */ public static final String DOC = "doc"; - - /** Map of always reserved keywords */ - public static final Set KEYWORDS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( - THIS,PARAMS,SCORER,DOC,VALUE,SCORE,CTX,LOOP + + /** Map of always reserved keywords for the main scope */ + public static final Set MAIN_KEYWORDS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + THIS,PARAMS,SCORER,DOC,VALUE,SCORE,CTX,LOOP + ))); + + /** Map of always reserved keywords for a function scope */ + public static final Set FUNCTION_KEYWORDS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + THIS,LOOP ))); - + + /** Map of always reserved keywords for a lambda scope */ + public static final Set LAMBDA_KEYWORDS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + THIS,LOOP + ))); + /** Creates a new local variable scope (e.g. loop) inside the current scope */ public static Locals newLocalScope(Locals currentScope) { return new Locals(currentScope); } - - /** + + /** * Creates a new lambda scope inside the current scope *

* This is just like {@link #newFunctionScope}, except the captured parameters are made read-only. */ - public static Locals newLambdaScope(Locals programScope, Type returnType, List parameters, + public static Locals newLambdaScope(Locals programScope, Type returnType, List parameters, int captureCount, int maxLoopCounter) { - Locals locals = new Locals(programScope, returnType); + Locals locals = new Locals(programScope, returnType, LAMBDA_KEYWORDS); for (int i = 0; i < parameters.size(); i++) { Parameter parameter = parameters.get(i); // TODO: allow non-captures to be r/w: @@ -87,10 +97,10 @@ public static Locals newLambdaScope(Locals programScope, Type returnType, List

parameters, int maxLoopCounter) { - Locals locals = new Locals(programScope, returnType); + Locals locals = new Locals(programScope, returnType, FUNCTION_KEYWORDS); for (Parameter parameter : parameters) { locals.addVariable(parameter.location, parameter.type, parameter.name, false); } @@ -100,10 +110,10 @@ public static Locals newFunctionScope(Locals programScope, Type returnType, List } return locals; } - + /** Creates a new main method scope */ public static Locals newMainMethodScope(Locals programScope, boolean usesScore, boolean usesCtx, int maxLoopCounter) { - Locals locals = new Locals(programScope, Definition.OBJECT_TYPE); + Locals locals = new Locals(programScope, Definition.OBJECT_TYPE, MAIN_KEYWORDS); // This reference. Internal use only. locals.defineVariable(null, Definition.getType("Object"), THIS, true); @@ -137,16 +147,16 @@ public static Locals newMainMethodScope(Locals programScope, boolean usesScore, } return locals; } - + /** Creates a new program scope: the list of methods. It is the parent for all methods */ public static Locals newProgramScope(Collection methods) { - Locals locals = new Locals(null, null); + Locals locals = new Locals(null, null, null); for (Method method : methods) { locals.addMethod(method); } return locals; } - + /** Checks if a variable exists or not, in this scope or any parents. */ public boolean hasVariable(String name) { Variable variable = lookupVariable(null, name); @@ -158,7 +168,7 @@ public boolean hasVariable(String name) { } return false; } - + /** Accesses a variable. This will throw IAE if the variable does not exist */ public Variable getVariable(Location location, String name) { Variable variable = lookupVariable(location, name); @@ -170,7 +180,7 @@ public Variable getVariable(Location location, String name) { } throw location.createError(new IllegalArgumentException("Variable [" + name + "] is not defined.")); } - + /** Looks up a method. Returns null if the method does not exist. */ public Method getMethod(MethodKey key) { Method method = lookupMethod(key); @@ -182,23 +192,23 @@ public Method getMethod(MethodKey key) { } return null; } - + /** Creates a new variable. Throws IAE if the variable has already been defined (even in a parent) or reserved. */ public Variable addVariable(Location location, Type type, String name, boolean readonly) { if (hasVariable(name)) { throw location.createError(new IllegalArgumentException("Variable [" + name + "] is already defined.")); } - if (KEYWORDS.contains(name)) { + if (keywords.contains(name)) { throw location.createError(new IllegalArgumentException("Variable [" + name + "] is reserved.")); } return defineVariable(location, type, name, readonly); } - + /** Return type of this scope (e.g. int, if inside a function that returns int) */ public Type getReturnType() { return returnType; } - + /** Returns the top-level program scope. */ public Locals getProgramScope() { Locals locals = this; @@ -207,13 +217,15 @@ public Locals getProgramScope() { } return locals; } - + ///// private impl // parent scope private final Locals parent; // return type of this scope private final Type returnType; + // keywords for this scope + private final Set keywords; // next slot number to assign private int nextSlotNumber; // variable name -> variable @@ -225,15 +237,16 @@ public Locals getProgramScope() { * Create a new Locals */ private Locals(Locals parent) { - this(parent, parent.getReturnType()); + this(parent, parent.returnType, parent.keywords); } - + /** * Create a new Locals with specified return type */ - private Locals(Locals parent, Type returnType) { + private Locals(Locals parent, Type returnType, Set keywords) { this.parent = parent; this.returnType = returnType; + this.keywords = keywords; if (parent == null) { this.nextSlotNumber = 0; } else { @@ -262,7 +275,7 @@ private Method lookupMethod(MethodKey key) { return methods.get(key); } - + /** Defines a variable at this scope internally. */ private Variable defineVariable(Location location, Type type, String name, boolean readonly) { if (variables == null) { @@ -273,7 +286,7 @@ private Variable defineVariable(Location location, Type type, String name, boole nextSlotNumber += type.type.getSize(); return variable; } - + private void addMethod(Method method) { if (methods == null) { methods = new HashMap<>(); @@ -293,7 +306,7 @@ public static final class Variable { public final Type type; public final boolean readonly; private final int slot; - + public Variable(Location location, String name, Type type, int slot, boolean readonly) { this.location = location; this.name = name; @@ -301,12 +314,12 @@ public Variable(Location location, String name, Type type, int slot, boolean rea this.slot = slot; this.readonly = readonly; } - + public int getSlot() { return slot; } } - + public static final class Parameter { public final Location location; public final String name; diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java index 44afe828ef24c..22c7c6d96b021 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java @@ -53,9 +53,6 @@ */ public final class SFunction extends AStatement { public static final class FunctionReserved implements Reserved { - public static final String THIS = "#this"; - public static final String LOOP = "#loop"; - private int maxLoopCounter = 0; public void markReserved(String name) { @@ -63,7 +60,7 @@ public void markReserved(String name) { } public boolean isReserved(String name) { - return name.equals(THIS) || name.equals(LOOP); + return Locals.FUNCTION_KEYWORDS.contains(name); } @Override @@ -173,7 +170,7 @@ void analyze(Locals locals) { } if (reserved.getMaxLoopCounter() > 0) { - loop = locals.getVariable(null, FunctionReserved.LOOP); + loop = locals.getVariable(null, Locals.LOOP); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SSource.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SSource.java index a4cf1cc8eee5f..9d4a74d3cb3ef 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SSource.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SSource.java @@ -89,7 +89,7 @@ public void markReserved(String name) { @Override public boolean isReserved(String name) { - return Locals.KEYWORDS.contains(name); + return Locals.MAIN_KEYWORDS.contains(name); } public boolean usesScore() { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java index 6bb800eb92c96..4958942b2a267 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java @@ -19,61 +19,64 @@ package org.elasticsearch.painless; +import java.util.HashMap; +import java.util.Map; + public class LambdaTests extends ScriptTestCase { public void testNoArgLambda() { assertEquals(1, exec("Optional.empty().orElseGet(() -> 1);")); } - + public void testNoArgLambdaDef() { assertEquals(1, exec("def x = Optional.empty(); x.orElseGet(() -> 1);")); } - + public void testLambdaWithArgs() { - assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); " + assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); " + "l.sort((a, b) -> a.length() - b.length()); return l.get(0)")); } - + public void testLambdaWithTypedArgs() { - assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); " + assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); " + "l.sort((String a, String b) -> a.length() - b.length()); return l.get(0)")); } - + public void testPrimitiveLambdas() { assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(x -> x + 1).sum();")); } - + public void testPrimitiveLambdasWithTypedArgs() { assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(int x -> x + 1).sum();")); } - + public void testPrimitiveLambdasDef() { assertEquals(4, exec("def l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(x -> x + 1).sum();")); } - + public void testPrimitiveLambdasWithTypedArgsDef() { assertEquals(4, exec("def l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(int x -> x + 1).sum();")); } - + public void testPrimitiveLambdasConvertible() { assertEquals(2, exec("List l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(byte x -> x).sum();")); } - + public void testPrimitiveArgs() { assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(x -> x + 1)")); } - + public void testPrimitiveArgsTyped() { assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(int x -> x + 1)")); } - + public void testPrimitiveArgsTypedOddly() { assertEquals(2L, exec("long applyOne(IntFunction arg) { arg.apply(1) } applyOne(long x -> x + 1)")); } @@ -85,7 +88,7 @@ public void testMultipleStatements() { public void testUnneededCurlyStatements() { assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(x -> { x + 1 })")); } - + /** interface ignores return value */ public void testVoidReturn() { assertEquals(2, exec("List list = new ArrayList(); " @@ -94,7 +97,7 @@ public void testVoidReturn() { + "list.forEach(x -> list2.add(x));" + "return list[0]")); } - + /** interface ignores return value */ public void testVoidReturnDef() { assertEquals(2, exec("def list = new ArrayList(); " @@ -121,15 +124,15 @@ public void testLambdaInLoop() { "}" + "return sum;")); } - + public void testCapture() { assertEquals(5, exec("int x = 5; return Optional.empty().orElseGet(() -> x);")); } - + public void testTwoCaptures() { assertEquals("1test", exec("int x = 1; String y = 'test'; return Optional.empty().orElseGet(() -> x + y);")); } - + public void testCapturesAreReadOnly() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("List l = new ArrayList(); l.add(1); l.add(1); " @@ -137,13 +140,13 @@ public void testCapturesAreReadOnly() { }); assertTrue(expected.getMessage().contains("is read-only")); } - + @AwaitsFix(bugUrl = "def type tracking") public void testOnlyCapturesAreReadOnly() { assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); " + "return l.stream().mapToInt(x -> { x += 1; return x }).sum();")); } - + /** Lambda parameters shouldn't be able to mask a variable already in scope */ public void testNoParamMasking() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { @@ -156,31 +159,31 @@ public void testNoParamMasking() { public void testCaptureDef() { assertEquals(5, exec("int x = 5; def y = Optional.empty(); y.orElseGet(() -> x);")); } - + public void testNestedCapture() { assertEquals(1, exec("boolean x = false; int y = 1;" + "return Optional.empty().orElseGet(() -> x ? 5 : Optional.empty().orElseGet(() -> y));")); } - + public void testNestedCaptureParams() { assertEquals(2, exec("int foo(Function f) { return f.apply(1) }" + "return foo(x -> foo(y -> x + 1))")); } - + public void testWrongArity() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("Optional.empty().orElseGet(x -> x);"); }); assertTrue(expected.getMessage().contains("Incorrect number of parameters")); } - + public void testWrongArityDef() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("def y = Optional.empty(); return y.orElseGet(x -> x);"); }); assertTrue(expected.getMessage(), expected.getMessage().contains("Incorrect number of parameters")); } - + public void testWrongArityNotEnough() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("List l = new ArrayList(); l.add(1); l.add(1); " @@ -188,7 +191,7 @@ public void testWrongArityNotEnough() { }); assertTrue(expected.getMessage().contains("Incorrect number of parameters")); } - + public void testWrongArityNotEnoughDef() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("def l = new ArrayList(); l.add(1); l.add(1); " @@ -196,12 +199,36 @@ public void testWrongArityNotEnoughDef() { }); assertTrue(expected.getMessage().contains("Incorrect number of parameters")); } - + public void testLambdaInFunction() { assertEquals(5, exec("def foo() { Optional.empty().orElseGet(() -> 5) } return foo();")); } - + public void testLambdaCaptureFunctionParam() { assertEquals(5, exec("def foo(int x) { Optional.empty().orElseGet(() -> x) } return foo(5);")); } + + public void testReservedCapture() { + String compare = "boolean compare(Supplier s, def v) {s.get() == v}"; + assertEquals(true, exec(compare + "compare(() -> new ArrayList(), new ArrayList())")); + assertEquals(true, exec(compare + "compare(() -> { new ArrayList() }, new ArrayList())")); + + Map params = new HashMap<>(); + params.put("key", "value"); + params.put("number", 2); + + assertEquals(true, exec(compare + "compare(() -> { return params['key'] }, 'value')", params, true)); + assertEquals(false, exec(compare + "compare(() -> { return params['nokey'] }, 'value')", params, true)); + assertEquals(true, exec(compare + "compare(() -> { return params['nokey'] }, null)", params, true)); + assertEquals(true, exec(compare + "compare(() -> { return params['number'] }, 2)", params, true)); + assertEquals(false, exec(compare + "compare(() -> { return params['number'] }, 'value')", params, true)); + assertEquals(false, exec(compare + "compare(() -> { if (params['number'] == 2) { return params['number'] }" + + "else { return params['key'] } }, 'value')", params, true)); + assertEquals(true, exec(compare + "compare(() -> { if (params['number'] == 2) { return params['number'] }" + + "else { return params['key'] } }, 2)", params, true)); + assertEquals(true, exec(compare + "compare(() -> { if (params['number'] == 1) { return params['number'] }" + + "else { return params['key'] } }, 'value')", params, true)); + assertEquals(false, exec(compare + "compare(() -> { if (params['number'] == 1) { return params['number'] }" + + "else { return params['key'] } }, 2)", params, true)); + } } diff --git a/modules/lang-painless/src/test/resources/rest-api-spec/test/plan_a/20_scriptfield.yaml b/modules/lang-painless/src/test/resources/rest-api-spec/test/plan_a/20_scriptfield.yaml index b92012959d18f..bb3bcf4341edb 100644 --- a/modules/lang-painless/src/test/resources/rest-api-spec/test/plan_a/20_scriptfield.yaml +++ b/modules/lang-painless/src/test/resources/rest-api-spec/test/plan_a/20_scriptfield.yaml @@ -34,3 +34,47 @@ setup: x: "bbb" - match: { hits.hits.0.fields.bar.0: "aaabbb"} +--- +"Scripted Field Doing Compare": + - do: + search: + body: + script_fields: + bar: + script: + inline: "boolean compare(Supplier s, def v) {return s.get() == v;} + compare(() -> { return doc['foo'].value }, params.x);" + params: + x: "aaa" + + - match: { hits.hits.0.fields.bar.0: true} + - do: + search: + body: + script_fields: + bar: + script: + inline: "boolean compare(Supplier s, def v) {return s.get() == v;} + compare(() -> { return doc['foo'].value }, params.x);" + params: + x: "bbb" + + - match: { hits.hits.0.fields.bar.0: false} +--- +"Scripted Field with script error": + - do: + catch: request + search: + body: + script_fields: + bar: + script: + inline: "while (true) {}" + + - match: { error.root_cause.0.type: "script_exception" } + - match: { error.root_cause.0.reason: "compile error" } + - match: { error.caused_by.type: "script_exception" } + - match: { error.caused_by.reason: "compile error" } + - match: { error.caused_by.caused_by.type: "illegal_argument_exception" } + - match: { error.caused_by.caused_by.reason: "While loop has no escape." } +