diff --git a/openrewrite/src/javascript/parser.ts b/openrewrite/src/javascript/parser.ts index 60800798..f13575d7 100644 --- a/openrewrite/src/javascript/parser.ts +++ b/openrewrite/src/javascript/parser.ts @@ -23,7 +23,7 @@ import { randomId, SourceFile } from "../core"; -import {binarySearch, compareTextSpans, getNextSibling, getPreviousSibling, TextSpan, hasFlowAnnotation, checkSyntaxErrors, isValidSurrogateRange} from "./parserUtils"; +import {binarySearch, compareTextSpans, getNextSibling, getPreviousSibling, TextSpan, hasFlowAnnotation, checkSyntaxErrors, isValidSurrogateRange, isStatement} from "./parserUtils"; import {JavaScriptTypeMapping} from "./typeMapping"; import path from "node:path"; import {ExpressionStatement, TypeTreeExpression} from "."; @@ -2626,10 +2626,7 @@ export class JavaScriptParserVisitor { visitExpressionStatement(node: ts.ExpressionStatement): J.Statement { const expression = this.visit(node.expression) as J.Expression; - if (expression instanceof J.MethodInvocation || expression instanceof J.NewClass || expression instanceof J.Unknown || - expression instanceof J.AssignmentOperation || expression instanceof J.Ternary || expression instanceof J.Empty || - expression instanceof JS.ExpressionStatement || expression instanceof J.Assignment || expression instanceof J.FieldAccess) { - // FIXME this is a hack we currently require because our `Expression` and `Statement` interfaces don't have any type guards + if (isStatement(expression)) { return expression as J.Statement; } return new JS.ExpressionStatement( @@ -2719,7 +2716,7 @@ export class JavaScriptParserVisitor { Markers.EMPTY, [node.initializer ? (ts.isVariableDeclarationList(node.initializer) ? this.rightPadded(this.visit(node.initializer), Space.EMPTY) : - this.rightPadded(new ExpressionStatement(randomId(), this.visit(node.initializer)), this.suffix(node.initializer))) : + this.rightPadded(ts.isStatement(node.initializer) ? this.visit(node.initializer) : new ExpressionStatement(randomId(), this.visit(node.initializer)), this.suffix(node.initializer))) : this.rightPadded(this.newJEmpty(), this.suffix(this.findChildNode(node, ts.SyntaxKind.OpenParenToken)!))], // to handle for (/*_*/; ; ); node.condition ? this.rightPadded(this.visit(node.condition), this.suffix(node.condition)) : this.rightPadded(this.newJEmpty(), this.suffix(this.findChildNode(node, ts.SyntaxKind.SemicolonToken)!)), // to handle for ( ;/*_*/; ); diff --git a/openrewrite/src/javascript/parserUtils.ts b/openrewrite/src/javascript/parserUtils.ts index 74ef20fb..9e41edb6 100644 --- a/openrewrite/src/javascript/parserUtils.ts +++ b/openrewrite/src/javascript/parserUtils.ts @@ -1,4 +1,158 @@ import * as ts from "typescript"; +import * as J from '../java' +import * as JS from "./tree"; + +const is_statements = [ + J.Assert, + J.Assignment, + J.AssignmentOperation, + J.Block, + J.Break, + J.Case, + J.ClassDeclaration, + J.Continue, + J.DoWhileLoop, + J.Empty, + J.EnumValueSet, + J.Erroneous, + J.FieldAccess, + J.ForEachLoop, + J.ForLoop, + J.If, + J.Import, + J.Label, + J.Lambda, + J.MethodDeclaration, + J.MethodInvocation, + J.NewClass, + J.Package, + J.Return, + J.Switch, + J.Synchronized, + J.Ternary, + J.Throw, + J.Try, + J.Unary, + J.Unknown, + J.VariableDeclarations, + J.WhileLoop, + J.Yield, + JS.ArrowFunction, + JS.BindingElement, + JS.Delete, + JS.Export, + JS.ExportAssignment, + JS.ExportDeclaration, + JS.FunctionDeclaration, + JS.JSForInLoop, + JS.JSForOfLoop, + JS.ImportAttribute, + JS.IndexSignatureDeclaration, + JS.JsAssignmentOperation, + JS.JsImport, + JS.JSMethodDeclaration, + JS.JSTry, + JS.JSVariableDeclarations, + JS.MappedType.KeysRemapping, + JS.MappedType.MappedTypeParameter, + JS.NamespaceDeclaration, + JS.PropertyAssignment, + JS.ScopedVariableDeclarations, + JS.TaggedTemplateExpression, + JS.TemplateExpression, + JS.TrailingTokenStatement, + JS.TypeDeclaration, + JS.Unary, + JS.Void, + JS.WithStatement, + JS.ExpressionStatement, + JS.StatementExpression +] + +const is_expressions = [ + J.AnnotatedType, + J.Annotation, + J.ArrayAccess, + J.ArrayType, + J.Assignment, + J.AssignmentOperation, + J.Binary, + J.ControlParentheses, + J.Empty, + J.Erroneous, + J.FieldAccess, + J.Identifier, + J.InstanceOf, + J.IntersectionType, + J.Lambda, + J.Literal, + J.MethodInvocation, + J.MemberReference, + J.NewArray, + J.NewClass, + J.NullableType, + J.ParameterizedType, + J.Parentheses, + J.ParenthesizedTypeTree, + J.Primitive, + J.SwitchExpression, + J.Ternary, + J.TypeCast, + J.Unary, + J.Unknown, + J.Wildcard, + JS.Alias, + JS.ArrayBindingPattern, + JS.ArrowFunction, + JS.Await, + JS.BindingElement, + JS.ConditionalType, + JS.DefaultType, + JS.Delete, + JS.ExportSpecifier, + JS.ExpressionWithTypeArguments, + JS.FunctionDeclaration, + JS.FunctionType, + JS.ImportType, + JS.IndexedAccessType, + JS.IndexedAccessType.IndexType, + JS.InferType, + JS.Intersection, + JS.JsAssignmentOperation, + JS.JsBinary, + JS.JsImportSpecifier, + JS.LiteralType, + JS.MappedType, + JS.NamedExports, + JS.NamedImports, + JS.ObjectBindingDeclarations, + JS.SatisfiesExpression, + JS.TaggedTemplateExpression, + JS.TemplateExpression, + JS.TrailingTokenStatement, + JS.Tuple, + JS.TypeInfo, + JS.TypeLiteral, + JS.TypeOf, + JS.TypeOperator, + JS.TypePredicate, + JS.TypeQuery, + JS.TypeTreeExpression, + JS.Unary, + JS.Union, + JS.Void, + JS.Yield, + JS.ExpressionStatement, + JS.StatementExpression +] + +export function isStatement(statement: J.J): statement is J.Statement { + return is_statements.some((cls: any) => statement instanceof cls); +} + +export function isExpression(expression: J.J): expression is J.Expression { + return is_expressions.some((cls: any) => expression instanceof cls); +} export function getNextSibling(node: ts.Node): ts.Node | null { const parent = node.parent; diff --git a/openrewrite/test/javascript/parser/function.test.ts b/openrewrite/test/javascript/parser/function.test.ts index c7c29032..766731b0 100644 --- a/openrewrite/test/javascript/parser/function.test.ts +++ b/openrewrite/test/javascript/parser/function.test.ts @@ -442,4 +442,15 @@ describe('function mapping', () => { `) ); }); + + test('function invocation', () => { + rewriteRun( + //language=typescript + typeScript(` + !function(e, t) { + console.log("This is an IIFE", e, t); + }("Hello", "World"); + `) + ); + }); }); diff --git a/openrewrite/test/javascript/parser/void.test.ts b/openrewrite/test/javascript/parser/void.test.ts index 1c29697c..023be5d0 100644 --- a/openrewrite/test/javascript/parser/void.test.ts +++ b/openrewrite/test/javascript/parser/void.test.ts @@ -11,8 +11,8 @@ describe('void operator mapping', () => { //language=typescript typeScript('void 1', cu => { const statement = cu.statements[0] as JS.ExpressionStatement; - expect(statement.expression).toBeInstanceOf(JS.Void); - const type = (statement.expression as JS.Void).type as JavaType.Primitive; + expect(statement).toBeInstanceOf(JS.Void); + const type = statement.type as JavaType.Primitive; expect(type.kind).toBe(JavaType.PrimitiveKind.Void); }) ); diff --git a/rewrite-javascript/src/main/java/org/openrewrite/javascript/JavaScriptVisitor.java b/rewrite-javascript/src/main/java/org/openrewrite/javascript/JavaScriptVisitor.java index 0df79b64..96c67fef 100644 --- a/rewrite-javascript/src/main/java/org/openrewrite/javascript/JavaScriptVisitor.java +++ b/rewrite-javascript/src/main/java/org/openrewrite/javascript/JavaScriptVisitor.java @@ -214,9 +214,6 @@ public J visitExpressionStatement(JS.ExpressionStatement statement, P p) { es = (JS.ExpressionStatement) temp; } J expression = visit(es.getExpression(), p); - if (expression instanceof Statement) { - return expression; - } es = es.withExpression((Expression) Objects.requireNonNull(expression)); return es; } @@ -528,9 +525,6 @@ public J visitStatementExpression(JS.StatementExpression expression, P p) { se = (JS.StatementExpression) temp; } J statement = visit(se.getStatement(), p); - if (statement instanceof Expression) { - return statement; - } se = se.withStatement((Statement) Objects.requireNonNull(statement)); return se; }