diff --git a/src/com/google/javascript/jscomp/Promises.java b/src/com/google/javascript/jscomp/Promises.java
index ed379878d4d..69ad2bfd54d 100644
--- a/src/com/google/javascript/jscomp/Promises.java
+++ b/src/com/google/javascript/jscomp/Promises.java
@@ -16,6 +16,7 @@
package com.google.javascript.jscomp;
+
import com.google.javascript.rhino.jstype.JSType;
import com.google.javascript.rhino.jstype.JSTypeNative;
import com.google.javascript.rhino.jstype.JSTypeRegistry;
@@ -96,4 +97,36 @@ static final JSType getResolvedType(JSTypeRegistry registry, JSType type) {
return type;
}
+
+ /**
+ * Synthesizes a type representing the legal types of a return expression within async code
+ * (i.e.`Promise` callbacks, async functions).
+ *
+ *
The return type will generally be a union but may not be, for example:
+ *
+ *
+ * - `!Promise` => `number|!IThenable`
+ *
- `number` => `number|!IThenable`
+ *
- `?` => `?`
+ *
- `*` => `*`
+ *
- `number|!Foo` => `number|!Foo|!IThenable`
+ *
- `!Foo|!IThenable` => `Foo|!IThenable`
+ *
- `!Promise>` => `!Foo|!IThenable`
+ *
+ *
+ * Note that this method may create an incorrect (but not really dangerous) type when supplied
+ * with types that are nonsensical in an async context, for example:
+ *
+ *
+ * - `number|!IThenable` => `number|string|!IThenable`
+ *
- `?IThenable` => `null|!Foo|!Thenable`
+ *
+ */
+ static final JSType createAsyncReturnableType(JSTypeRegistry registry, JSType maybeThenable) {
+ JSType parameterType = getResolvedType(registry, maybeThenable);
+ return registry.createUnionType(
+ parameterType,
+ registry.createTemplatizedType(
+ registry.getNativeObjectType(JSTypeNative.I_THENABLE_TYPE), parameterType));
+ }
}
diff --git a/src/com/google/javascript/jscomp/TypeCheck.java b/src/com/google/javascript/jscomp/TypeCheck.java
index 9b8882e7820..b60c0b31a95 100644
--- a/src/com/google/javascript/jscomp/TypeCheck.java
+++ b/src/com/google/javascript/jscomp/TypeCheck.java
@@ -2135,7 +2135,7 @@ private void visitFunction(NodeTraversal t, Node n) {
} else if (n.isAsyncFunction()) {
// An async function must return a Promise or supertype of Promise
JSType returnType = functionType.getReturnType();
- validator.expectValidAsyncReturnType(t, n, returnType.restrictByNotNullOrUndefined());
+ validator.expectValidAsyncReturnType(t, n, returnType);
}
}
@@ -2481,14 +2481,11 @@ private void visitImplicitReturnExpression(NodeTraversal t, Node exprNode) {
expectedReturnType = getNativeType(VOID_TYPE);
} else if (enclosingFunction.isAsyncFunction()) {
// Unwrap the async function's declared return type.
- expectedReturnType = Promises.getTemplateTypeOfThenable(typeRegistry, expectedReturnType);
+ expectedReturnType = Promises.createAsyncReturnableType(typeRegistry, expectedReturnType);
}
// Fetch the returned value's type
JSType actualReturnType = getJSType(exprNode);
- if (enclosingFunction.isAsyncFunction()) {
- actualReturnType = Promises.getResolvedType(typeRegistry, actualReturnType);
- }
validator.expectCanAssignTo(t, exprNode, actualReturnType, expectedReturnType,
"inconsistent return type");
@@ -2526,9 +2523,10 @@ private void visitReturn(NodeTraversal t, Node n) {
// e.g. if returnType is "Generator", make it just "string".
returnType = getTemplateTypeOfGenerator(returnType);
} else if (enclosingFunction.isAsyncFunction()) {
- // Unwrap the template variable from a async function's declared return type.
- // e.g. if returnType is "!Promise" or "!IThenable", make it just "string".
- returnType = Promises.getTemplateTypeOfThenable(typeRegistry, returnType);
+ // e.g. `!Promise` => `string|!IThenable`
+ // We transform the expected return type rather than the actual return type so that the
+ // extual return type is always reported to the user. This was felt to be clearer.
+ returnType = Promises.createAsyncReturnableType(typeRegistry, returnType);
} else if (returnType.isVoidType() && functionType.isConstructor()) {
// Allow constructors to use empty returns for flow control.
if (!n.hasChildren()) {
@@ -2547,11 +2545,6 @@ private void visitReturn(NodeTraversal t, Node n) {
valueNode = n;
} else {
actualReturnType = getJSType(valueNode);
- if (enclosingFunction.isAsyncFunction()) {
- // We want to treat `return Promise.resolve(1);` as if it were `return 1;` inside an async
- // function.
- actualReturnType = Promises.getResolvedType(typeRegistry, actualReturnType);
- }
}
// verifying
diff --git a/test/com/google/javascript/jscomp/TypeCheckNoTranspileTest.java b/test/com/google/javascript/jscomp/TypeCheckNoTranspileTest.java
index f57f417f240..85a7947fc67 100644
--- a/test/com/google/javascript/jscomp/TypeCheckNoTranspileTest.java
+++ b/test/com/google/javascript/jscomp/TypeCheckNoTranspileTest.java
@@ -213,12 +213,12 @@ public void testAsyncArrowWithCorrectBlocklessReturn() {
public void testAsyncArrowWithIncorrectBlocklessReturn() {
testTypes(
lines(
- "function takesPromiseProvider(/** function(): ?Promise */ getPromise) {}",
+ "function takesPromiseProvider(/** function(): !Promise */ getPromise) {}",
"takesPromiseProvider(async () => 1);"),
lines(
"inconsistent return type", // preserve newline
"found : number",
- "required: string"));
+ "required: (IThenable|string)"));
}
@Test
@@ -4093,22 +4093,20 @@ public void testAsyncReturnsPromise2() {
lines(
"inconsistent return type", // preserve newline
"found : number",
- "required: string"));
+ "required: (IThenable|string)"));
}
@Test
- public void testAsyncCanReturnNullablePromise() {
- // TODO(lharker): don't allow async functions to return null.
+ public void testAsyncFunctionCannotReturnNullablePromise() {
testTypesWithCommonExterns(
lines(
"/** @return {?Promise} */",
"async function getAString() {",
- " return 1;",
+ " return '';",
"}"),
lines(
- "inconsistent return type", // preserve newline
- "found : number",
- "required: string"));
+ "The return type of an async function must be a non-union supertype of Promise",
+ "found: (Promise|null)"));
}
@Test
@@ -4170,7 +4168,7 @@ public void testAsyncCanReturnIThenable1() {
lines(
"inconsistent return type", //
"found : number",
- "required: string"));
+ "required: (IThenable|string)"));
}
@Test
@@ -4185,8 +4183,8 @@ public void testAsyncReturnStatementIsResolved() {
"}"),
lines(
"inconsistent return type", // preserve newline
- "found : number",
- "required: string"));
+ "found : IThenable",
+ "required: (IThenable|string)"));
}
@Test