Skip to content

Commit

Permalink
rubysrc2Cpg: Operator name correction (#2819)
Browse files Browse the repository at this point in the history
* Operator name correction

* Operator corrections

* Bugfixes

* Moved tests to the test group

* Used Operators.assignment

* Ruby specific operators

* More unit tests

* More unit tests

* Test case for compare (spaceship operator)
  • Loading branch information
rahul-privado committed Jun 6, 2023
1 parent 760168d commit 46fa845
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,45 @@ class AstCreator(filename: String, global: Global)
diffGraph
}

object RubyOperators {
val none = "<operator>.none"
val patternMatch = "<operator>.patternMatch"
val notPatternMatch = "<operator>.notPatternMatch"
}
private def getOperatorName(token: Token): String = token.getType match {
case AMP => Operators.logicalAnd
case AMP2 => Operators.and
case ASSIGNMENT_OPERATOR => Operators.assignment
case BAR => Operators.logicalOr
case BAR2 => Operators.or
case CARET => Operators.logicalOr
case DOT2 => Operators.range
case DOT3 => Operators.range
case EMARK => Operators.not
case EMARKEQ => Operators.notEquals
case EMARKTILDE => RubyOperators.notPatternMatch
case EQ => Operators.assignment
case EQ2 => Operators.equals
case EQ3 => Operators.is
case EQTILDE => RubyOperators.patternMatch
case GT => Operators.greaterThan
case GT2 => Operators.logicalShiftRight
case GTEQ => Operators.greaterEqualsThan
case LT => Operators.lessThan
case LT2 => Operators.shiftLeft
case LTEQ => Operators.lessEqualsThan
case LTEQGT => Operators.compare
case MINUS => Operators.subtraction
case PERCENT => Operators.modulo
case PLUS => Operators.addition
case SLASH => Operators.division
case STAR => Operators.multiplication
case TILDE => Operators.not
case NOT => Operators.not
case STAR2 => Operators.exponentiation
case _ => RubyOperators.none
}

protected def line(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStart.getLine)
protected def column(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStart.getCharPositionInLine)
protected def lineEnd(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStop.getLine)
Expand Down Expand Up @@ -240,12 +279,13 @@ class AstCreator(filename: String, global: Global)
}

def astForSingleAssignmentExpressionContext(ctx: SingleAssignmentExpressionContext): Seq[Ast] = {
val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide())
val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide())
val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide())
val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide())
val operatorName = getOperatorName(ctx.op)
val callNode = NewCall()
.name(ctx.op.getText)
.name(operatorName)
.code(ctx.op.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down Expand Up @@ -1419,12 +1459,13 @@ class AstCreator(filename: String, global: Global)
}

def astForMultipleAssignmentExpressionContext(ctx: MultipleAssignmentExpressionContext): Seq[Ast] = {
val lhsAsts = astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide())
val rhsAsts = astForMultipleRightHandSideContext(ctx.multipleRightHandSide())
val lhsAsts = astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide())
val rhsAsts = astForMultipleRightHandSideContext(ctx.multipleRightHandSide())
val operatorName = getOperatorName(ctx.EQ().getSymbol)
val callNode = NewCall()
.name(ctx.EQ().getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.EQ().getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1438,11 +1479,12 @@ class AstCreator(filename: String, global: Global)
}

def astForNotExpressionOrCommandContext(ctx: NotExpressionOrCommandContext): Seq[Ast] = {
val expAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand())
val expAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand())
val operatorName = getOperatorName(ctx.NOT().getSymbol)
val callNode = NewCall()
.name(ctx.NOT().getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.NOT().getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1460,12 +1502,13 @@ class AstCreator(filename: String, global: Global)
}

def astForOrAndExpressionOrCommandContext(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = {
val lhsAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand().get(0))
val rhsAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand().get(1))
val lhsAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand().get(0))
val rhsAsts = astForExpressionOrCommandContext(ctx.expressionOrCommand().get(1))
val operatorName = getOperatorName(ctx.op)
val callNode = NewCall()
.name(ctx.op.getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1478,10 +1521,11 @@ class AstCreator(filename: String, global: Global)
val expressions = ctx.expression()
val baseExpressionAsts = astForExpressionContext(expressions.get(0))
val exponentExpressionAsts = astForExpressionContext(expressions.get(1))
val operatorName = getOperatorName(ctx.STAR2().getSymbol)
val callNode = NewCall()
.name(ctx.STAR2().getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.STAR2().getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down Expand Up @@ -1510,10 +1554,11 @@ class AstCreator(filename: String, global: Global)
): Seq[Ast] = {
val lhsExpressionAsts = astForExpressionContext(lhs)
val rhsExpressionAsts = astForExpressionContext(rhs)
val operatorName = getOperatorName(operatorToken)
val callNode = NewCall()
.name(operatorToken.getText)
.name(operatorName)
.code(code)
.methodFullName(MethodFullNames.OperatorPrefix + operatorToken.getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down Expand Up @@ -1659,10 +1704,11 @@ class AstCreator(filename: String, global: Global)
* This is incorrectly identified as a unary expression since the parser identifies the LHS as methodIdentifier
* PLUS is to be interpreted as a binary operator
*/
val operatorName = getOperatorName(ctx.op)
val callNode = NewCall()
.name(ctx.op.getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1671,10 +1717,19 @@ class AstCreator(filename: String, global: Global)
val lhsAst = methodNameAsIdentiferQ.dequeue()
Seq(callAst(callNode, Seq(lhsAst) ++ expressionAst))
} else {
val operatorName =
if (
ctx.op.getType == TILDE ||
ctx.op.getType == EMARK
) {
getOperatorName(ctx.op)
} else {
Operators.plus
}
val callNode = NewCall()
.name(ctx.op.getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.op.getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1691,10 +1746,11 @@ class AstCreator(filename: String, global: Global)
* This is incorrectly identified as a unary expression since the parser identifies the LHS as methodIdentifier
* PLUS is to be interpreted as a binary operator
*/
val operatorName = Operators.subtraction
val callNode = NewCall()
.name(ctx.MINUS().getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.MINUS().getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand All @@ -1703,10 +1759,11 @@ class AstCreator(filename: String, global: Global)
val lhsAst = methodNameAsIdentiferQ.dequeue()
Seq(callAst(callNode, Seq(lhsAst) ++ expressionAst))
} else {
val operatorName = Operators.minus
val callNode = NewCall()
.name(ctx.MINUS().getText)
.name(operatorName)
.code(ctx.getText)
.methodFullName(MethodFullNames.OperatorPrefix + ctx.MINUS().getText)
.methodFullName(operatorName)
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
Expand Down
Loading

0 comments on commit 46fa845

Please sign in to comment.