Skip to content

Commit

Permalink
rubysrc2Cpg: Call node code field set (#2812)
Browse files Browse the repository at this point in the history
* Code set for some call nodes

* Code for binary expressions

* Formatting fix

* Used code

* UT for single LHS

* Binary expression code UT

* More unit tests

* More unit tests

* Name correction
  • Loading branch information
rahul-privado committed Jun 6, 2023
1 parent 472e594 commit 421c19a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class AstCreator(filename: String, global: Global)
val argsAsts = astForArgumentsContext(ctx.arguments())
val callNode = NewCall()
.name(Operators.indexAccess)
.code(Operators.indexAccess)
.code(ctx.getText)
.methodFullName(Operators.indexAccess)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand Down Expand Up @@ -531,7 +531,7 @@ class AstCreator(filename: String, global: Global)
}

def astForAdditiveExpressionContext(ctx: AdditiveExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForIndexingArgumentsContext(ctx: IndexingArgumentsContext): Seq[Ast] = ctx match {
Expand Down Expand Up @@ -593,15 +593,15 @@ class AstCreator(filename: String, global: Global)
}

def astForBitwiseAndExpressionContext(ctx: BitwiseAndExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForBitwiseOrExpressionContext(ctx: BitwiseOrExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForBitwiseShiftExpressionContext(ctx: BitwiseShiftExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForWhenArgumentContext(ctx: WhenArgumentContext): Seq[Ast] = {
Expand Down Expand Up @@ -840,7 +840,7 @@ class AstCreator(filename: String, global: Global)
}

def astForEqualityExpressionContext(ctx: EqualityExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForGroupedLeftHandSideContext(ctx: GroupedLeftHandSideContext): Seq[Ast] = {
Expand Down Expand Up @@ -1186,6 +1186,7 @@ class AstCreator(filename: String, global: Global)
val line = localIdentifier.getSymbol().getLine()
val name = getActualMethodName(localIdentifier.getText)
val methodFullName = s"$filename:$name"

val callNode = NewCall()
.name(name)
.methodFullName(methodFullName)
Expand All @@ -1195,6 +1196,7 @@ class AstCreator(filename: String, global: Global)
.code(code)
.lineNumber(line)
.columnNumber(column)
.code(code)
Seq(callAst(callNode))
}

Expand Down Expand Up @@ -1462,14 +1464,14 @@ class AstCreator(filename: String, global: Global)
}

def astForMultiplicativeExpressionContext(ctx: MultiplicativeExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForNotExpressionOrCommandContext(ctx: NotExpressionOrCommandContext): Seq[Ast] = {
val expAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand())
val callNode = NewCall()
.name(ctx.NOT().getText)
.code(ctx.NOT().getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.NOT().getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand All @@ -1480,11 +1482,11 @@ class AstCreator(filename: String, global: Global)
}

def astForOperatorAndExpressionContext(ctx: OperatorAndExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForOperatorOrExpressionContext(ctx: OperatorOrExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForOrAndExpressionOrCommandContext(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = {
Expand All @@ -1508,7 +1510,7 @@ class AstCreator(filename: String, global: Global)
val exponentExpressionAsts = astForExpressionContext(expressions.get(1))
val callNode = NewCall()
.name(ctx.STAR2().getText)
.code(ctx.STAR2().getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.STAR2().getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand All @@ -1520,22 +1522,27 @@ class AstCreator(filename: String, global: Global)

def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] = {
if (ctx.expression().size() == 2) {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
} else {
Seq(Ast())
}
}

def astForRelationalExpressionContext(ctx: RelationalExpressionContext): Seq[Ast] = {
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op)
astForBinaryExpression(ctx.expression(0), ctx.expression(1), ctx.op, ctx.getText)
}

def astForBinaryExpression(lhs: ExpressionContext, rhs: ExpressionContext, operatorToken: Token): Seq[Ast] = {
def astForBinaryExpression(
lhs: ExpressionContext,
rhs: ExpressionContext,
operatorToken: Token,
code: String
): Seq[Ast] = {
val lhsExpressionAsts = astForExpressionContext(lhs)
val rhsExpressionAsts = astForExpressionContext(rhs)
val callNode = NewCall()
.name(operatorToken.getText)
.code(operatorToken.getText)
.code(code)
.methodFullName(MethodFullNames.OperatorPrefix + operatorToken.getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand Down Expand Up @@ -1684,7 +1691,7 @@ class AstCreator(filename: String, global: Global)
*/
val callNode = NewCall()
.name(ctx.op.getText)
.code(ctx.op.getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand All @@ -1696,7 +1703,7 @@ class AstCreator(filename: String, global: Global)
} else {
val callNode = NewCall()
.name(ctx.op.getText)
.code(ctx.op.getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand All @@ -1716,7 +1723,7 @@ class AstCreator(filename: String, global: Global)
*/
val callNode = NewCall()
.name(ctx.MINUS().getText)
.code(ctx.MINUS().getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.MINUS().getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand All @@ -1728,7 +1735,7 @@ class AstCreator(filename: String, global: Global)
} else {
val callNode = NewCall()
.name(ctx.MINUS().getText)
.code(ctx.MINUS().getText)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.MINUS().getText)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,103 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture {
encoding.columnNumber shouldBe Some(5)
}
}

"Code field for simple fragments" should {

"have correct code for a single left had side call" in {
val cpg = code("array[n] = 10")
val List(callNode) = cpg.call.name("<operator>.indexAccess").l
callNode.code shouldBe "array[n]"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(5)
}

"have correct code for a binary expression" in {
val cpg = code("x+y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "x+y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a not expression" in {
val cpg = code("not y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "not y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a power expression" in {
val cpg = code("x**y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "x**y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a inclusive range expression" in {
val cpg = code("1..10")
val List(callNode) = cpg.call.l
callNode.code shouldBe "1..10"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a non-inclusive range expression" in {
val cpg = code("1...10")
val List(callNode) = cpg.call.l
callNode.code shouldBe "1...10"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a relational expression" in {
val cpg = code("x<y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "x<y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(1)
}

"have correct code for a unary exclamation expression" in {
val cpg = code("!y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "!y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary tilde expression" in {
val cpg = code("~y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "~y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary plus expression" in {
val cpg = code("+y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "+y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a unary minus expression" in {
val cpg = code("-y")
val List(callNode) = cpg.call.l
callNode.code shouldBe "-y"
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}

"have correct code for a call node" in {
val cpg = code("puts \"something\"")
val List(callNode) = cpg.call.l
callNode.code shouldBe "puts \"something\""
callNode.lineNumber shouldBe Some(1)
callNode.columnNumber shouldBe Some(0)
}
}
}

0 comments on commit 421c19a

Please sign in to comment.