Skip to content

Commit

Permalink
rubysrc2cpg: refactored pseudo-variables' AST (#2814)
Browse files Browse the repository at this point in the history
* Refactored astForPseudoVariableIdentifier

* AST Unit-tests for pseudoVariableIdentifiers

* scalafmt

* Removed TODOs

* scalafmt
  • Loading branch information
xavierpinho authored Jun 5, 2023
1 parent a23fa59 commit 98b032a
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,13 @@ variableIdentifier
;

pseudoVariableIdentifier
: NIL
| TRUE
| FALSE
| SELF
| FILE__
| LINE__
| ENCODING__
: NIL # nilPseudoVariableIdentifier
| TRUE # truePseudoVariableIdentifier
| FALSE # falsePseudoVariableIdentifier
| SELF # selfPseudoVariableIdentifier
| FILE__ # filePseudoVariableIdentifier
| LINE__ # linePseudoVariableIdentifier
| ENCODING__ # encodingPseudoVariableIdentifier
;

scopedConstantReference
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.parser.RubyParser._
import io.joern.rubysrc2cpg.parser.{RubyLexer, RubyParser}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.Ast.storeInDiffGraph
import io.joern.x2cpg.Defines.DynamicCallUnknownFullName
import io.joern.x2cpg.datastructures.Global
Expand All @@ -19,20 +20,8 @@ import scala.jdk.CollectionConverters._

class AstCreator(filename: String, global: Global)
extends AstCreatorBase(filename)
with AstNodeBuilder[ParserRuleContext, AstCreator] {

object Defines {
val Any: String = "ANY"
val Number: String = "number"
val String: String = "string"
val Boolean: String = "boolean"
val Hash: String = "hash"
val Array: String = "array"
val Symbol: String = "symbol"
val ModifierRedo: String = "redo"
val ModifierRetry: String = "retry"
var ModifierNext: String = "next"
}
with AstNodeBuilder[ParserRuleContext, AstCreator]
with AstForPrimitivesCreator {

object MethodFullNames {
val OperatorPrefix = "<operator>."
Expand Down Expand Up @@ -76,12 +65,12 @@ class AstCreator(filename: String, global: Global)
scopeStack.top.varToIdentiferMap.contains(name)
}

private def createIdentiferWithScope(
protected def createIdentifierWithScope(
ctx: ParserRuleContext,
name: String,
code: String,
typeFullName: String,
dynamicTypeHints: Seq[String]
dynamicTypeHints: Seq[String] = Seq()
): NewIdentifier = {
val newNode = identifierNode(ctx, name, code, typeFullName, dynamicTypeHints)
setIdentiferInScope(newNode)
Expand Down Expand Up @@ -155,7 +144,7 @@ class AstCreator(filename: String, global: Global)
val terminalNode = ctx.children.asScala.map(_.asInstanceOf[TerminalNode]).head
val token = terminalNode.getSymbol
val variableName = token.getText
val node = createIdentiferWithScope(ctx, variableName, variableName, Defines.Any, List[String]())
val node = createIdentifierWithScope(ctx, variableName, variableName, Defines.Any, List[String]())
setIdentiferInScope(node)
Seq(Ast(node))
}
Expand Down Expand Up @@ -189,7 +178,7 @@ class AstCreator(filename: String, global: Global)
}
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val yAst = Ast(node)

val callNode = NewCall()
Expand Down Expand Up @@ -741,8 +730,8 @@ class AstCreator(filename: String, global: Global)
val primaryAst = astForPrimaryContext(ctx.primary())
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node = createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val constAst = Ast(node)
val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val constAst = Ast(node)

val callNode = NewCall()
.name(ctx.COLON2().getText)
Expand Down Expand Up @@ -1158,9 +1147,9 @@ class AstCreator(filename: String, global: Global)
.code(text)
.lineNumber(lineStart)
.columnNumber(columnStart)
.typeFullName(Defines.Number)
.dynamicTypeHintFullName(List(Defines.Number))
registerType(Defines.Number)
.typeFullName(Defines.Numeric)
.dynamicTypeHintFullName(List(Defines.Numeric))
registerType(Defines.Numeric)
Seq(Ast(node))
} else if (ctx.literal().SINGLE_QUOTED_STRING_LITERAL() != null) {
val text = ctx.getText
Expand Down Expand Up @@ -1227,7 +1216,7 @@ class AstCreator(filename: String, global: Global)
val varSymbol = localVar.getSymbol()
if (lookupIdentiferInScope(varSymbol.getText)) {
val node =
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
astForCallNode(localVar, code)
Expand All @@ -1237,7 +1226,7 @@ class AstCreator(filename: String, global: Global)
val varSymbol = localVar.getSymbol()
if (lookupIdentiferInScope(varSymbol.getText)) {
val node =
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
astForCallNode(localVar, code)
Expand Down Expand Up @@ -1297,13 +1286,13 @@ class AstCreator(filename: String, global: Global)
val localVar = ctx.LOCAL_VARIABLE_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else if (ctx.CONSTANT_IDENTIFIER() != null) {
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node =
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
} else {
Seq(Ast())
Expand Down Expand Up @@ -1378,7 +1367,7 @@ class AstCreator(filename: String, global: Global)
localVarList
.map(localVar => {
val varSymbol = localVar.getSymbol()
createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
val param = NewMethodParameterIn()
.name(varSymbol.getText)
.code(varSymbol.getText)
Expand Down Expand Up @@ -1559,7 +1548,7 @@ class AstCreator(filename: String, global: Global)
def astForSimpleScopedConstantReferencePrimaryContext(ctx: SimpleScopedConstantReferencePrimaryContext): Seq[Ast] = {
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val node = createIdentiferWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))

val callNode = NewCall()
.name(ctx.COLON2().getText)
Expand Down Expand Up @@ -1780,27 +1769,21 @@ class AstCreator(filename: String, global: Global)
Seq(ast)
}

def astForPseudoVariableIdentifierContext(ctx: PseudoVariableIdentifierContext): Seq[Ast] = {
val node = {
if (ctx.TRUE() != null) { ctx.TRUE() }
else if (ctx.NIL() != null) { ctx.NIL() }
else if (ctx.FALSE() != null) { ctx.FALSE() }
else if (ctx.SELF() != null) { ctx.SELF() }
else if (ctx.FILE__() != null) { ctx.FILE__() }
else if (ctx.LINE__() != null) { ctx.LINE__() }
else if (ctx.ENCODING__() != null) { ctx.ENCODING__() }
else return Seq(Ast())
}

val astNode = createIdentiferWithScope(ctx, ctx.getText, ctx.getText, Defines.Any, List(Defines.Any))
Seq(Ast(astNode))
private def astForPseudoVariableIdentifierContext(ctx: PseudoVariableIdentifierContext): Ast = ctx match {
case ctx: NilPseudoVariableIdentifierContext => astForNilLiteral(ctx)
case ctx: TruePseudoVariableIdentifierContext => astForTrueLiteral(ctx)
case ctx: FalsePseudoVariableIdentifierContext => astForFalseLiteral(ctx)
case ctx: SelfPseudoVariableIdentifierContext => astForSelfPseudoIdentifier(ctx)
case ctx: FilePseudoVariableIdentifierContext => astForFilePseudoIdentifier(ctx)
case ctx: LinePseudoVariableIdentifierContext => astForLinePseudoIdentifier(ctx)
case ctx: EncodingPseudoVariableIdentifierContext => astForEncodingPseudoIdentifier(ctx)
}

def astForVariableRefenceContext(ctx: RubyParser.VariableReferenceContext): Seq[Ast] = {
if (ctx.variableIdentifier() != null) {
astForVariableIdentifierContext(ctx.variableIdentifier())
} else {
astForPseudoVariableIdentifierContext(ctx.pseudoVariableIdentifier())
Seq(astForPseudoVariableIdentifierContext(ctx.pseudoVariableIdentifier()))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.parser.RubyParser
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.Ast

trait AstForPrimitivesCreator { this: AstCreator =>

protected def astForNilLiteral(ctx: RubyParser.NilPseudoVariableIdentifierContext): Ast =
Ast(literalNode(ctx, ctx.getText, Defines.NilClass))

protected def astForTrueLiteral(ctx: RubyParser.TruePseudoVariableIdentifierContext): Ast =
Ast(literalNode(ctx, ctx.getText, Defines.TrueClass))

protected def astForFalseLiteral(ctx: RubyParser.FalsePseudoVariableIdentifierContext): Ast =
Ast(literalNode(ctx, ctx.getText, Defines.FalseClass))

protected def astForSelfPseudoIdentifier(ctx: RubyParser.SelfPseudoVariableIdentifierContext): Ast =
Ast(createIdentifierWithScope(ctx, ctx.getText, ctx.getText, Defines.Object))

protected def astForFilePseudoIdentifier(ctx: RubyParser.FilePseudoVariableIdentifierContext): Ast =
Ast(createIdentifierWithScope(ctx, ctx.getText, ctx.getText, Defines.String))

protected def astForLinePseudoIdentifier(ctx: RubyParser.LinePseudoVariableIdentifierContext): Ast =
Ast(createIdentifierWithScope(ctx, ctx.getText, ctx.getText, Defines.Integer))

protected def astForEncodingPseudoIdentifier(ctx: RubyParser.EncodingPseudoVariableIdentifierContext): Ast =
Ast(createIdentifierWithScope(ctx, ctx.getText, ctx.getText, Defines.Encoding))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.joern.rubysrc2cpg.passes

object Defines {
val Any: String = "ANY"
val Object: String = "Object"

val NilClass: String = "NilClass"
val TrueClass: String = "TrueClass"
val FalseClass: String = "FalseClass"

val Numeric: String = "Numeric"
val Integer: String = "Integer"
val Float: String = "Float"

val String: String = "String"
val Symbol: String = "Symbol"

val Array: String = "Array"
val Hash: String = "Hash"

val Encoding: String = "Encoding"

// TODO: The following shall be moved out eventually.
val ModifierRedo: String = "redo"
val ModifierRetry: String = "retry"
var ModifierNext: String = "next"
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.rubysrc2cpg.passes.ast

import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.semanticcpg.language._

Expand All @@ -20,6 +21,95 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture {
arg.lineNumber shouldBe Some(1)
arg.columnNumber shouldBe Some(5)
}
}

"have correct structure for an unsigned, decimal integer literal" ignore {
val cpg = code("123")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.Integer
literal.code shouldBe "123"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(1)
}

"have correct structure for a +integer, decimal literal" ignore {
val cpg = code("+1")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.Integer
literal.code shouldBe "+1"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(1)
}

"have correct structure for a -integer, decimal literal" ignore {
val cpg = code("-1")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.Integer
literal.code shouldBe "-1"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(1)
}

"have correct structure for `nil` literal" in {
val cpg = code("puts nil")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.NilClass
literal.code shouldBe "nil"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(5)
}

"have correct structure for `true` literal" in {
val cpg = code("puts true")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.TrueClass
literal.code shouldBe "true"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(5)
}

"have correct structure for `false` literal" in {
val cpg = code("puts false")
val List(literal) = cpg.literal.l
literal.typeFullName shouldBe Defines.FalseClass
literal.code shouldBe "false"
literal.lineNumber shouldBe Some(1)
literal.columnNumber shouldBe Some(5)
}

"have correct structure for `self` identifier" in {
val cpg = code("puts self")
val List(self) = cpg.identifier.l
self.typeFullName shouldBe Defines.Object
self.code shouldBe "self"
self.lineNumber shouldBe Some(1)
self.columnNumber shouldBe Some(5)
}

"have correct structure for `__FILE__` identifier" in {
val cpg = code("puts __FILE__")
val List(file) = cpg.identifier.l
file.typeFullName shouldBe Defines.String
file.code shouldBe "__FILE__"
file.lineNumber shouldBe Some(1)
file.columnNumber shouldBe Some(5)
}

"have correct structure for `__LINE__` identifier" in {
val cpg = code("puts __LINE__")
val List(line) = cpg.identifier.l
line.typeFullName shouldBe Defines.Integer
line.code shouldBe "__LINE__"
line.lineNumber shouldBe Some(1)
line.columnNumber shouldBe Some(5)
}

"have correct structure for `__ENCODING__` identifier" in {
val cpg = code("puts __ENCODING__")
val List(encoding) = cpg.identifier.l
encoding.typeFullName shouldBe Defines.Encoding
encoding.code shouldBe "__ENCODING__"
encoding.lineNumber shouldBe Some(1)
encoding.columnNumber shouldBe Some(5)
}
}
}

0 comments on commit 98b032a

Please sign in to comment.