Skip to content

Commit

Permalink
rubysrc2cpg: fixing call-related .code, .line and .column properties (#…
Browse files Browse the repository at this point in the history
…2813)

* Switching RUBY's node builder to take ParserRuleContexts instead of TerminalNodes

* Fixing call-related .code, .line and .column properties
  • Loading branch information
xavierpinho committed Jun 5, 2023
1 parent 7767514 commit c4aae96
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@ package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.parser.RubyParser._
import io.joern.rubysrc2cpg.parser.{RubyLexer, RubyParser}
import io.joern.x2cpg.Ast.storeInDiffGraph
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.Defines.DynamicCallUnknownFullName
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder}
import io.shiftleft.codepropertygraph.generated.{
ControlStructureTypes,
DispatchTypes,
ModifierTypes,
NodeTypes,
Operators
}
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.codepropertygraph.generated._
import org.antlr.v4.runtime.tree.TerminalNode
import org.antlr.v4.runtime.{CharStreams, CommonTokenStream, Token}
import org.antlr.v4.runtime.{CharStreams, CommonTokenStream, ParserRuleContext, Token}
import org.slf4j.LoggerFactory
import overflowdb.BatchedUpdate

Expand All @@ -25,7 +19,7 @@ import scala.jdk.CollectionConverters._

class AstCreator(filename: String, global: Global)
extends AstCreatorBase(filename)
with AstNodeBuilder[TerminalNode, AstCreator] {
with AstNodeBuilder[ParserRuleContext, AstCreator] {

object Defines {
val Any: String = "ANY"
Expand Down Expand Up @@ -83,13 +77,13 @@ class AstCreator(filename: String, global: Global)
}

private def createIdentiferWithScope(
node: TerminalNode,
ctx: ParserRuleContext,
name: String,
code: String,
typeFullName: String,
dynamicTypeHints: Seq[String]
): NewIdentifier = {
val newNode = identifierNode(node, name, code, typeFullName, dynamicTypeHints)
val newNode = identifierNode(ctx, name, code, typeFullName, dynamicTypeHints)
setIdentiferInScope(newNode)
newNode
}
Expand Down Expand Up @@ -146,10 +140,10 @@ class AstCreator(filename: String, global: Global)
diffGraph
}

protected def line(node: TerminalNode): Option[Integer] = Option(node.getSymbol.getLine)
protected def column(node: TerminalNode): Option[Integer] = Option(node.getSymbol.getCharPositionInLine)
protected def lineEnd(node: TerminalNode): Option[Integer] = None
protected def columnEnd(node: TerminalNode): Option[Integer] = None
protected def line(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStart.getLine)
protected def column(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStart.getCharPositionInLine)
protected def lineEnd(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStop.getLine)
protected def columnEnd(ctx: ParserRuleContext): Option[Integer] = Option(ctx.getStop.getCharPositionInLine)

private def registerType(typ: String): String = {
if (typ != Defines.Any) {
Expand All @@ -161,7 +155,7 @@ class AstCreator(filename: String, global: Global)
val terminalNode = ctx.children.asScala.map(_.asInstanceOf[TerminalNode]).head
val token = terminalNode.getSymbol
val variableName = token.getText
val node = createIdentiferWithScope(terminalNode, variableName, variableName, Defines.Any, List[String]())
val node = createIdentiferWithScope(ctx, variableName, variableName, Defines.Any, List[String]())
setIdentiferInScope(node)
Seq(Ast(node))
}
Expand Down Expand Up @@ -195,7 +189,7 @@ class AstCreator(filename: String, global: Global)
}
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val yAst = Ast(node)

val callNode = NewCall()
Expand Down Expand Up @@ -747,8 +741,8 @@ class AstCreator(filename: String, global: Global)
val primaryAst = astForPrimaryContext(ctx.primary())
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node = createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val constAst = Ast(node)
val node = createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val constAst = Ast(node)

val callNode = NewCall()
.name(ctx.COLON2().getText)
Expand Down Expand Up @@ -1078,13 +1072,13 @@ class AstCreator(filename: String, global: Global)
}

def astForInvocationWithBlockOnlyPrimaryContext(ctx: InvocationWithBlockOnlyPrimaryContext): Seq[Ast] = {
val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier())
val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier(), ctx.getText)
val blockAst = astForBlockContext(ctx.block())
blockAst ++ methodIdAst
}

def astForInvocationWithParenthesesPrimaryContext(ctx: InvocationWithParenthesesPrimaryContext): Seq[Ast] = {
val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier())
val methodIdAst = astForMethodIdentifierContext(ctx.methodIdentifier(), ctx.getText)
val parenAst = astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses())
val callNode = methodIdAst.head.nodes.filter(_.isInstanceOf[NewCall]).head.asInstanceOf[NewCall]
callNode.name(getActualMethodName(callNode.name))
Expand Down Expand Up @@ -1156,10 +1150,14 @@ class AstCreator(filename: String, global: Global)
}

def astForLiteralPrimaryContext(ctx: LiteralPrimaryContext): Seq[Ast] = {
val lineStart = line(ctx.literal())
val columnStart = column(ctx.literal())
if (ctx.literal().numericLiteral() != null) {
val text = ctx.getText
val node = NewLiteral()
.code(text)
.lineNumber(lineStart)
.columnNumber(columnStart)
.typeFullName(Defines.Number)
.dynamicTypeHintFullName(List(Defines.Number))
registerType(Defines.Number)
Expand All @@ -1168,13 +1166,17 @@ class AstCreator(filename: String, global: Global)
val text = ctx.getText
val node = NewLiteral()
.code(text)
.lineNumber(lineStart)
.columnNumber(columnStart)
.typeFullName(Defines.String)
.dynamicTypeHintFullName(List(Defines.String))
Seq(Ast(node))
} else if (ctx.literal().DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE() != null) {
val text = ctx.literal().DOUBLE_QUOTED_STRING_CHARACTER_SEQUENCE().getText
val node = NewLiteral()
.code(text)
.lineNumber(lineStart)
.columnNumber(columnStart)
.typeFullName(Defines.String)
.dynamicTypeHintFullName(List(Defines.String))
registerType(Defines.String)
Expand All @@ -1190,7 +1192,7 @@ class AstCreator(filename: String, global: Global)
astForDefinedMethodNameContext(ctx.definedMethodName())
}

def astForCallNode(localIdentifier: TerminalNode): Seq[Ast] = {
def astForCallNode(localIdentifier: TerminalNode, code: String): Seq[Ast] = {
val column = localIdentifier.getSymbol().getCharPositionInLine()
val line = localIdentifier.getSymbol().getLine()
val name = getActualMethodName(localIdentifier.getText)
Expand All @@ -1201,44 +1203,44 @@ class AstCreator(filename: String, global: Global)
.signature(localIdentifier.getText())
.typeFullName(DynamicCallUnknownFullName)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.code(localIdentifier.getText())
.code(code)
.lineNumber(line)
.columnNumber(column)
Seq(callAst(callNode))
}

def astForMethodOnlyIdentifier(ctx: MethodOnlyIdentifierContext): Seq[Ast] = {
if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null) {
astForCallNode(ctx.LOCAL_VARIABLE_IDENTIFIER())
astForCallNode(ctx.LOCAL_VARIABLE_IDENTIFIER(), ctx.getText)
} else if (ctx.CONSTANT_IDENTIFIER() != null) {
astForCallNode(ctx.CONSTANT_IDENTIFIER())
astForCallNode(ctx.CONSTANT_IDENTIFIER(), ctx.getText)
} else {
Seq(Ast())
}
}

def astForMethodIdentifierContext(ctx: MethodIdentifierContext): Seq[Ast] = {
def astForMethodIdentifierContext(ctx: MethodIdentifierContext, code: String): Seq[Ast] = {
if (ctx.methodOnlyIdentifier() != null) {
astForMethodOnlyIdentifier(ctx.methodOnlyIdentifier())
} else if (ctx.LOCAL_VARIABLE_IDENTIFIER() != null) {
val localVar = ctx.LOCAL_VARIABLE_IDENTIFIER()
val varSymbol = localVar.getSymbol()
if (lookupIdentiferInScope(varSymbol.getText)) {
val node =
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
astForCallNode(localVar)
astForCallNode(localVar, code)
}
} else if (ctx.CONSTANT_IDENTIFIER() != null) {
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
if (lookupIdentiferInScope(varSymbol.getText)) {
val node =
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
astForCallNode(localVar)
astForCallNode(localVar, code)
}
} else {
Seq(Ast())
Expand All @@ -1264,7 +1266,7 @@ class AstCreator(filename: String, global: Global)

def astForMethodNameContext(ctx: MethodNameContext): Seq[Ast] = {
if (ctx.methodIdentifier() != null) {
astForMethodIdentifierContext(ctx.methodIdentifier())
astForMethodIdentifierContext(ctx.methodIdentifier(), ctx.getText)
} else if (ctx.operatorMethodName() != null) {
astForOperatorMethodNameContext(ctx.operatorMethodName())
} else if (ctx.keyword() != null) {
Expand Down Expand Up @@ -1295,13 +1297,13 @@ class AstCreator(filename: String, global: Global)
val localVar = ctx.LOCAL_VARIABLE_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else if (ctx.CONSTANT_IDENTIFIER() != null) {
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
Seq(Ast())
Expand Down Expand Up @@ -1376,7 +1378,7 @@ class AstCreator(filename: String, global: Global)
localVarList
.map(localVar => {
val varSymbol = localVar.getSymbol()
createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
val param = NewMethodParameterIn()
.name(varSymbol.getText)
.code(varSymbol.getText)
Expand Down Expand Up @@ -1557,7 +1559,7 @@ class AstCreator(filename: String, global: Global)
def astForSimpleScopedConstantReferencePrimaryContext(ctx: SimpleScopedConstantReferencePrimaryContext): Seq[Ast] = {
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node = createIdentiferWithScope(localVar, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val node = createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))

val callNode = NewCall()
.name(ctx.COLON2().getText)
Expand Down Expand Up @@ -1590,7 +1592,7 @@ class AstCreator(filename: String, global: Global)
case ctx: RubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext =>
val argsAsts = astForArgumentsWithoutParenthesesContext(ctx.argumentsWithoutParentheses())
val doBlockAsts = astForDoBlockContext(ctx.doBlock())
val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier())
val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), ctx.getText)
methodIdAsts ++ argsAsts ++ doBlockAsts
case ctx: RubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext =>
val argsAsts = astForArgumentsWithoutParenthesesContext(ctx.argumentsWithoutParentheses())
Expand Down Expand Up @@ -1790,7 +1792,7 @@ class AstCreator(filename: String, global: Global)
else return Seq(Ast())
}

val astNode = createIdentiferWithScope(node, ctx.getText, ctx.getText, Defines.Any, List(Defines.Any))
val astNode = createIdentiferWithScope(ctx, ctx.getText, ctx.getText, Defines.Any, List(Defines.Any))
Seq(Ast(astNode))
}

Expand Down Expand Up @@ -1906,7 +1908,7 @@ class AstCreator(filename: String, global: Global)
} else if (ctx.YIELD() != null) {
astForArgumentsWithoutParenthesesContext(ctx.argumentsWithoutParentheses())
} else if (ctx.methodIdentifier() != null) {
val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier())
val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), ctx.getText)
methodNameAsIdentiferQ.enqueue(methodIdentifierAsts.head)
val argsAsts = astForArgumentsWithoutParenthesesContext(ctx.argumentsWithoutParentheses())

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.joern.rubysrc2cpg.passes.ast

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.semanticcpg.language._

class SimpleAstCreationPassTest extends RubyCode2CpgFixture {

"AST generation for simple fragments" should {

"have correct structure for a single command call" in {
val cpg = code("""puts 123""")

val List(commandCall) = cpg.call.l
val List(arg) = commandCall.argument.isLiteral.l

commandCall.code shouldBe "puts 123"
commandCall.lineNumber shouldBe Some(1)

arg.code shouldBe "123"
arg.lineNumber shouldBe Some(1)
arg.columnNumber shouldBe Some(5)
}
}

}

0 comments on commit c4aae96

Please sign in to comment.