Skip to content

Commit

Permalink
rubysrc2Cpg: Operator and method corrections (#2825)
Browse files Browse the repository at this point in the history
* Operator name and UT correction

* UT for index access

* Unit test

* Field corrections for method node

* Test correction
  • Loading branch information
rahul-privado committed Jun 7, 2023
1 parent c81a77d commit 8ac6d0b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1028,9 +1028,9 @@ class AstCreator(filename: String, global: Global)
val lhsExpressionAst = astForPrimaryContext(ctx.primary())
val rhsExpressionAst = astForIndexingArgumentsContext(ctx.indexingArguments())
val callNode = NewCall()
.name(ctx.LBRACK().getText + ctx.RBRACK().getText)
.name(Operators.indexAccess)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.LBRACK().getText + ctx.RBRACK().getText)
.methodFullName(Operators.indexAccess)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down Expand Up @@ -1250,13 +1250,19 @@ class AstCreator(filename: String, global: Global)

def astForOperatorMethodNameContext(ctx: OperatorMethodNameContext): Seq[Ast] = {

/*
* This is for operator overloading for the class
*/
val terminalNode = ctx.children.asScala.head
.asInstanceOf[TerminalNode]

val name = ctx.getText
val methodFullName = s"$filename:$name"

val callNode = NewCall()
.name(ctx.getText)
.name(name)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.getText)
.methodFullName(methodFullName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1280,7 +1286,7 @@ class AstCreator(filename: String, global: Global)
val callNode = NewCall()
.name(terminalNode.getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + terminalNode.getText)
.methodFullName(terminalNode.getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down Expand Up @@ -1404,15 +1410,21 @@ class AstCreator(filename: String, global: Global)
val astBody = astForBodyStatementContext(ctx.bodyStatement())
popScope()

// TODO why is there a `callNode` here?
/*
* The method astForMethodNamePartContext() returns a call node in the AST.
* This is because it has been called from several places some of which need a call node.
* We will use fields from the call node to construct the method node. Post that,
* we will discard the call node since it is of no further use to us
*/

val classPath = classStack.toList.mkString(".") + "."
val methodNode = NewMethod()
.code(callNode.code)
.code(ctx.getText)
.name(callNode.name)
.fullName(s"$filename:${callNode.name}")
.columnNumber(callNode.columnNumber)
.lineNumber(callNode.lineNumber)
.lineNumberEnd(ctx.END().getSymbol.getLine)
.filename(filename)
callNode.methodFullName(classPath + callNode.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,180 +257,235 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture {
literal.columnNumber shouldBe Some(0)
}

"have correct code for a single left had side call" in {
"have correct structure for a single left had side call" in {
val cpg = code("array[n] = 10")
val List(callNode) = cpg.call.name(Operators.indexAccess).l
callNode.code shouldBe "array[n]"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(5)
}

"have correct code for a binary expression" in {
"have correct structure for a binary expression" in {
val cpg = code("x+y")
val List(callNode) = cpg.call.name(Operators.addition).l
callNode.code shouldBe "x+y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a not expression" in {
"have correct structure for a not expression" in {
val cpg = code("not y")
val List(callNode) = cpg.call.name(Operators.not).l
callNode.code shouldBe "not y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a power expression" in {
"have correct structure for a power expression" in {
val cpg = code("x**y")
val List(callNode) = cpg.call.name(Operators.exponentiation).l
callNode.code shouldBe "x**y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a inclusive range expression" in {
"have correct structure for a inclusive range expression" in {
val cpg = code("1..10")
val List(callNode) = cpg.call.name(Operators.range).l
callNode.code shouldBe "1..10"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a non-inclusive range expression" in {
"have correct structure for a non-inclusive range expression" in {
val cpg = code("1...10")
val List(callNode) = cpg.call.name(Operators.range).l
callNode.code shouldBe "1...10"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a relational expression" in {
"have correct structure for a relational expression" in {
val cpg = code("x<y")
val List(callNode) = cpg.call.name(Operators.lessThan).l
callNode.code shouldBe "x<y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a unary exclamation expression" in {
"have correct structure for a unary exclamation expression" in {
val cpg = code("!y")
val List(callNode) = cpg.call.name(Operators.not).l
callNode.code shouldBe "!y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary tilde expression" in {
"have correct structure for a unary tilde expression" in {
val cpg = code("~y")
val List(callNode) = cpg.call.name(Operators.not).l
callNode.code shouldBe "~y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary plus expression" in {
"have correct structure for a unary plus expression" in {
val cpg = code("+y")
val List(callNode) = cpg.call.name(Operators.plus).l
callNode.code shouldBe "+y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary minus expression" in {
"have correct structure for a unary minus expression" in {
val cpg = code("-y")
val List(callNode) = cpg.call.name(Operators.minus).l
callNode.code shouldBe "-y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a call node" in {
"have correct structure for a call node" in {
val cpg = code("puts \"something\"")
val List(callNode) = cpg.call.l
callNode.code shouldBe "puts \"something\""
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a logical and expression" in {
"have correct structure for a logical and expression" in {
val cpg = code("x & y")
val List(callNode) = cpg.call.name(Operators.logicalAnd).l
callNode.code shouldBe "x & y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a logical or with bar expression" in {
"have correct structure for a logical or with bar expression" in {
val cpg = code("x | y")
val List(callNode) = cpg.call.name(Operators.logicalOr).l
callNode.code shouldBe "x | y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a logical or with carat expression" in {
"have correct structure for a logical or with carat expression" in {
val cpg = code("x ^ y")
val List(callNode) = cpg.call.name(Operators.logicalOr).l
callNode.code shouldBe "x ^ y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a assignment expression" in {
"have correct structure for a assignment expression" in {
val cpg = code("x = y")
val List(callNode) = cpg.call.name(Operators.assignment).l
callNode.code shouldBe "="
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a equals expression" in {
"have correct structure for a equals expression" in {
val cpg = code("x == y")
val List(callNode) = cpg.call.name(Operators.equals).l
callNode.code shouldBe "x == y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a division expression" in {
"have correct structure for a division expression" in {
val cpg = code("x / y")
val List(callNode) = cpg.call.name(Operators.division).l
callNode.code shouldBe "x / y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a modulo expression" in {
"have correct structure for a modulo expression" in {
val cpg = code("x % y")
val List(callNode) = cpg.call.name(Operators.modulo).l
callNode.code shouldBe "x % y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a shift right expression" in {
"have correct structure for a shift right expression" in {
val cpg = code("x >> y")
val List(callNode) = cpg.call.name(Operators.logicalShiftRight).l
callNode.code shouldBe "x >> y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a shift left expression" in {
"have correct structure for a shift left expression" in {
val cpg = code("x << y")
val List(callNode) = cpg.call.name(Operators.shiftLeft).l
callNode.code shouldBe "x << y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct code for a compare expression" in {
"have correct structure for a compare expression" in {
val cpg = code("x <=> y")
val List(callNode) = cpg.call.name(Operators.compare).l
callNode.code shouldBe "x <=> y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(2)
}

"have correct structure for a indexing expression" in {
val cpg = code("def some_method(index)\n some_map[index]\nend")
val List(callNode) = cpg.call.name(Operators.indexAccess).l
callNode.code shouldBe "some_map[index]"
callNode.lineNumber shouldBe Some(2)
callNode.columnNumber shouldBe Some(9)
}

"have correct structure for overloaded index operator method" in {
val cpg = code("""
|class MyClass
|def [](key)
| @member_hash[key]
|end
|end
|""".stripMargin)

val List(methodNode) = cpg.method.name("\\[]").l
methodNode.code shouldBe "def [](key)\n @member_hash[key]\nend"
methodNode.lineNumber shouldBe Some(3)
methodNode.lineNumberEnd shouldBe Some(5)
methodNode.columnNumber shouldBe Some(4)
}

"have correct structure for overloaded equality operator method" in {
val cpg = code("""
|class MyClass
|def ==(other)
| @my_member==other
|end
|end
|""".stripMargin)

val List(methodNode) = cpg.method.name("==").l
methodNode.code shouldBe "def ==(other)\n @my_member==other\nend"
methodNode.lineNumber shouldBe Some(3)
methodNode.lineNumberEnd shouldBe Some(5)
methodNode.columnNumber shouldBe Some(4)
}

"have correct structure for class method" in {
val cpg = code("""
|class MyClass
|def some_method(param)
|end
|end
|""".stripMargin)

val List(methodNode) = cpg.method.name("some_method").l
methodNode.code shouldBe "def some_method(param)\nend"
methodNode.lineNumber shouldBe Some(3)
methodNode.lineNumberEnd shouldBe Some(4)
methodNode.columnNumber shouldBe Some(4)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class IdentifierTests extends RubyCode2CpgFixture {
|""".stripMargin)

"recognise all identifier and call nodes" in {
cpg.method.name("\\[]").size shouldBe 2
cpg.method.name("\\[]").size shouldBe 1
cpg.method.name("\\[]=").size shouldBe 1
cpg.call.name(Operators.assignment).size shouldBe 3
cpg.method.name("initialize").size shouldBe 1
Expand Down

0 comments on commit 8ac6d0b

Please sign in to comment.