From 8ac6d0b02773d9f626c043886e3e52b94e33fd5b Mon Sep 17 00:00:00 2001 From: Rahul Godbole <119306799+rahul-privado@users.noreply.github.com> Date: Wed, 7 Jun 2023 18:54:01 +0530 Subject: [PATCH] rubysrc2Cpg: Operator and method corrections (#2825) * Operator name and UT correction * UT for index access * Unit test * Field corrections for method node * Test correction --- .../rubysrc2cpg/astcreation/AstCreator.scala | 26 +++-- .../ast/SimpleAstCreationPassTest.scala | 99 ++++++++++++++----- .../querying/IdentifierTests.scala | 2 +- 3 files changed, 97 insertions(+), 30 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index 0729a7144f8..053452c77ec 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -1028,9 +1028,9 @@ class AstCreator(filename: String, global: Global) val lhsExpressionAst = astForPrimaryContext(ctx.primary()) val rhsExpressionAst = astForIndexingArgumentsContext(ctx.indexingArguments()) val callNode = NewCall() - .name(ctx.LBRACK().getText + ctx.RBRACK().getText) + .name(Operators.indexAccess) .code(ctx.getText) - .methodFullName(MethodFullNames.OperatorPrefix + ctx.LBRACK().getText + ctx.RBRACK().getText) + .methodFullName(Operators.indexAccess) .signature("") .dispatchType(DispatchTypes.STATIC_DISPATCH) .typeFullName(Defines.Any) @@ -1250,13 +1250,19 @@ class AstCreator(filename: String, global: Global) def astForOperatorMethodNameContext(ctx: OperatorMethodNameContext): Seq[Ast] = { + /* + * This is for operator overloading for the class + */ val terminalNode = ctx.children.asScala.head .asInstanceOf[TerminalNode] + val name = ctx.getText + val methodFullName = s"$filename:$name" + val callNode = NewCall() - .name(ctx.getText) + .name(name) .code(ctx.getText) - .methodFullName(MethodFullNames.OperatorPrefix + ctx.getText) + .methodFullName(methodFullName) .signature("") .dispatchType(DispatchTypes.STATIC_DISPATCH) .typeFullName(Defines.Any) @@ -1280,7 +1286,7 @@ class AstCreator(filename: String, global: Global) val callNode = NewCall() .name(terminalNode.getText) .code(ctx.getText) - .methodFullName(MethodFullNames.OperatorPrefix + terminalNode.getText) + .methodFullName(terminalNode.getText) .signature("") .dispatchType(DispatchTypes.STATIC_DISPATCH) .typeFullName(Defines.Any) @@ -1404,15 +1410,21 @@ class AstCreator(filename: String, global: Global) val astBody = astForBodyStatementContext(ctx.bodyStatement()) popScope() - // TODO why is there a `callNode` here? + /* + * The method astForMethodNamePartContext() returns a call node in the AST. + * This is because it has been called from several places some of which need a call node. + * We will use fields from the call node to construct the method node. Post that, + * we will discard the call node since it is of no further use to us + */ val classPath = classStack.toList.mkString(".") + "." val methodNode = NewMethod() - .code(callNode.code) + .code(ctx.getText) .name(callNode.name) .fullName(s"$filename:${callNode.name}") .columnNumber(callNode.columnNumber) .lineNumber(callNode.lineNumber) + .lineNumberEnd(ctx.END().getSymbol.getLine) .filename(filename) callNode.methodFullName(classPath + callNode.name) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala index efb3573dfb2..837fb266cc2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala @@ -257,7 +257,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { literal.columnNumber shouldBe Some(0) } - "have correct code for a single left had side call" in { + "have correct structure for a single left had side call" in { val cpg = code("array[n] = 10") val List(callNode) = cpg.call.name(Operators.indexAccess).l callNode.code shouldBe "array[n]" @@ -265,7 +265,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(5) } - "have correct code for a binary expression" in { + "have correct structure for a binary expression" in { val cpg = code("x+y") val List(callNode) = cpg.call.name(Operators.addition).l callNode.code shouldBe "x+y" @@ -273,7 +273,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(1) } - "have correct code for a not expression" in { + "have correct structure for a not expression" in { val cpg = code("not y") val List(callNode) = cpg.call.name(Operators.not).l callNode.code shouldBe "not y" @@ -281,7 +281,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(0) } - "have correct code for a power expression" in { + "have correct structure for a power expression" in { val cpg = code("x**y") val List(callNode) = cpg.call.name(Operators.exponentiation).l callNode.code shouldBe "x**y" @@ -289,7 +289,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(1) } - "have correct code for a inclusive range expression" in { + "have correct structure for a inclusive range expression" in { val cpg = code("1..10") val List(callNode) = cpg.call.name(Operators.range).l callNode.code shouldBe "1..10" @@ -297,7 +297,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(1) } - "have correct code for a non-inclusive range expression" in { + "have correct structure for a non-inclusive range expression" in { val cpg = code("1...10") val List(callNode) = cpg.call.name(Operators.range).l callNode.code shouldBe "1...10" @@ -305,7 +305,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(1) } - "have correct code for a relational expression" in { + "have correct structure for a relational expression" in { val cpg = code("x> y") val List(callNode) = cpg.call.name(Operators.logicalShiftRight).l callNode.code shouldBe "x >> y" @@ -417,7 +417,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(2) } - "have correct code for a shift left expression" in { + "have correct structure for a shift left expression" in { val cpg = code("x << y") val List(callNode) = cpg.call.name(Operators.shiftLeft).l callNode.code shouldBe "x << y" @@ -425,12 +425,67 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { callNode.columnNumber shouldBe Some(2) } - "have correct code for a compare expression" in { + "have correct structure for a compare expression" in { val cpg = code("x <=> y") val List(callNode) = cpg.call.name(Operators.compare).l callNode.code shouldBe "x <=> y" callNode.lineNumber shouldBe Some(1) callNode.columnNumber shouldBe Some(2) } + + "have correct structure for a indexing expression" in { + val cpg = code("def some_method(index)\n some_map[index]\nend") + val List(callNode) = cpg.call.name(Operators.indexAccess).l + callNode.code shouldBe "some_map[index]" + callNode.lineNumber shouldBe Some(2) + callNode.columnNumber shouldBe Some(9) + } + + "have correct structure for overloaded index operator method" in { + val cpg = code(""" + |class MyClass + |def [](key) + | @member_hash[key] + |end + |end + |""".stripMargin) + + val List(methodNode) = cpg.method.name("\\[]").l + methodNode.code shouldBe "def [](key)\n @member_hash[key]\nend" + methodNode.lineNumber shouldBe Some(3) + methodNode.lineNumberEnd shouldBe Some(5) + methodNode.columnNumber shouldBe Some(4) + } + + "have correct structure for overloaded equality operator method" in { + val cpg = code(""" + |class MyClass + |def ==(other) + | @my_member==other + |end + |end + |""".stripMargin) + + val List(methodNode) = cpg.method.name("==").l + methodNode.code shouldBe "def ==(other)\n @my_member==other\nend" + methodNode.lineNumber shouldBe Some(3) + methodNode.lineNumberEnd shouldBe Some(5) + methodNode.columnNumber shouldBe Some(4) + } + + "have correct structure for class method" in { + val cpg = code(""" + |class MyClass + |def some_method(param) + |end + |end + |""".stripMargin) + + val List(methodNode) = cpg.method.name("some_method").l + methodNode.code shouldBe "def some_method(param)\nend" + methodNode.lineNumber shouldBe Some(3) + methodNode.lineNumberEnd shouldBe Some(4) + methodNode.columnNumber shouldBe Some(4) + } } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala index f9378de553c..a9af291270a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala @@ -300,7 +300,7 @@ class IdentifierTests extends RubyCode2CpgFixture { |""".stripMargin) "recognise all identifier and call nodes" in { - cpg.method.name("\\[]").size shouldBe 2 + cpg.method.name("\\[]").size shouldBe 1 cpg.method.name("\\[]=").size shouldBe 1 cpg.call.name(Operators.assignment).size shouldBe 3 cpg.method.name("initialize").size shouldBe 1