Skip to content

Commit

Permalink
[ruby] Type & Singleton Split (#4655)
Browse files Browse the repository at this point in the history
1. Created classes and modules in pairs, where one is `Foo` (the "regular" class) and the other is `Foo<class>` (the singleton class). `module` singletons get the `final` keyword.
2. Define members for `@@` fields under `Foo<class>`
3. Define members with `dynamicTypeHintFullName`s set to the corresponding `self` methods and under `Foo<class>`. `Foo<class>` should have empty bindings to these `self` methods. An empty binding is one with `name==""` and `signature==""`
4. During object instantiations, e.g. `Foo.new`, add `Foo<class>` to the receiver's dynamic type hints, and `Foo` as the return type of `Call(new)`. The call linker may add an edge directly from `Foo.new` to `Foo.initialize`.

cc @AndreiDreyer 

Resolves #4652
  • Loading branch information
DavidBakerEffendi committed Jun 12, 2024
1 parent 75fb7a9 commit ed9671b
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,15 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
callAst(assignment, Seq(lhs, rhs))
}

protected def memberForMethod(method: NewMethod): NewMember = {
NewMember().name(method.name).code(method.name).dynamicTypeHintFullName(Seq(method.fullName))
protected def memberForMethod(
method: NewMethod,
astParentType: Option[String] = None,
astParentFullName: Option[String] = None
): NewMember = {
val member = NewMember().name(method.name).code(method.name).dynamicTypeHintFullName(Seq(method.fullName))
astParentType.foreach(member.astParentType(_))
astParentFullName.foreach(member.astParentFullName(_))
member
}

protected val UnaryOperatorNames: Map[String, String] = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

protected def astForObjectInstantiation(node: RubyNode & ObjectInstantiation): Ast = {
val className = node.target.text
val methodName = XDefines.ConstructorMethodName
val callName = "new"
val methodName = Defines.Initialize
/*
We short-cut the call edge from `new` call to `initialize` method, however we keep the modelling of the receiver
as referring to the singleton class.
*/
val (receiverTypeFullName, fullName) = scope.tryResolveTypeReference(className) match {
case Some(typeMetaData) => typeMetaData.name -> s"${typeMetaData.name}:$methodName"
case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName
case Some(typeMetaData) => s"${typeMetaData.name}<class>" -> s"${typeMetaData.name}:$methodName"
case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName
}
/*
Similarly to some other frontends, we lower the constructor into two operations, e.g.,
Expand All @@ -221,7 +226,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpGen.fresh))
def tmpIdentifier = {
val tmpAst = astForSimpleIdentifier(tmp)
tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(receiverTypeFullName) }
tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(receiverTypeFullName.stripSuffix("<class>")) }
tmpAst
}

Expand All @@ -248,7 +253,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
x.arguments.map(astForMethodCallArgument) :+ methodRef
}

val constructorCall = callNode(node, code(node), methodName, fullName, DispatchTypes.DYNAMIC_DISPATCH)
val constructorCall = callNode(node, code(node), callName, fullName, DispatchTypes.DYNAMIC_DISPATCH)
val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier))
scope.popScope()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
protected def astForMethodDeclaration(node: MethodDeclaration, isClosure: Boolean = false): Seq[Ast] = {

// Special case constructor methods
val isInTypeDecl = scope.surroundingAstLabel.contains(NodeTypes.TYPE_DECL)
val isConstructor = node.methodName == "initialize" && isInTypeDecl
val methodName = if isConstructor then XDefines.ConstructorMethodName else node.methodName
val isInTypeDecl = scope.surroundingAstLabel.contains(NodeTypes.TYPE_DECL)
val isConstructor =
(node.methodName == Defines.Initialize || node.methodName == Defines.InitializeClass) && isInTypeDecl
val isSingletonConstructor = node.methodName == Defines.InitializeClass && isInTypeDecl
val methodName = if isSingletonConstructor then Defines.Initialize else node.methodName
// TODO: body could be a try
val fullName = computeMethodFullName(methodName)
val method = methodNode(
Expand All @@ -52,7 +54,9 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
signature = None,
fileName = relativeFileName,
astParentType = scope.surroundingAstLabel,
astParentFullName = scope.surroundingScopeFullName
astParentFullName = scope.surroundingScopeFullName.map { tn =>
if isSingletonConstructor then s"$tn<class>" else tn
}
)

if (isConstructor) scope.pushNewScope(ConstructorScope(fullName))
Expand Down Expand Up @@ -80,7 +84,9 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
relativeFileName,
code(node),
astParentType = scope.surroundingAstLabel.getOrElse("<empty>"),
astParentFullName = scope.surroundingScopeFullName.getOrElse("<empty>")
astParentFullName = scope.surroundingScopeFullName
.map { tn => if isSingletonConstructor then s"$tn<class>" else tn }
.getOrElse("<empty>")
),
typeRefNode(node, methodName, fullName),
methodRefNode(node, methodName, fullName, methodReturn.typeFullName)
Expand All @@ -94,13 +100,14 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
val baseStmtBlockAst = astForMethodBody(node.body, optionalStatementList)
transformAsClosureBody(refs, baseStmtBlockAst)
} else {
if (methodName != XDefines.ConstructorMethodName && node.methodName != XDefines.StaticInitMethodName) {
if (methodName != Defines.Initialize && methodName != Defines.InitializeClass) {
astForMethodBody(node.body, optionalStatementList)
} else {
astForConstructorMethodBody(node.body, optionalStatementList)
}
}

// For yield statements where there isn't an explicit proc parameter
val anonProcParam = scope.anonProcParam.map { param =>
val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param"))
val nextIndex =
Expand All @@ -118,18 +125,42 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

val prefixMemberAst =
if isClosure || scope.isSurroundedByProgramScope then Ast() // program scope members are set elsewhere
else Ast(memberForMethod(method))
else {
// Singleton constructors that initialize @@ fields should have their members linked under the singleton class
val methodMember = scope.surroundingTypeFullName.map {
case x if isSingletonConstructor => s"$x<class>"
case x => x
} match {
case Some(astParentTfn) => memberForMethod(method, Option(NodeTypes.TYPE_DECL), Option(astParentTfn))
case None => memberForMethod(method)
}
if (isSingletonConstructor) {
diffGraph.addNode(methodMember)
Ast()
} else {
Ast(memberForMethod(method))
}
}
val prefixRefAssignAst = if isClosure then Ast() else createMethodRefPointer(method)
// For closures, we also want the method/type refs for upstream use
val suffixAsts = if isClosure then refs else refs.filter(_.root.exists(_.isInstanceOf[NewTypeDecl]))
val methodAsts = prefixMemberAst :: prefixRefAssignAst ::
methodAst(
val methodAst_ = {
val mAst = methodAst(
method,
parameterAsts ++ anonProcParam,
stmtBlockAst,
methodReturn,
modifiers.map(newModifierNode).toSeq
) :: suffixAsts
)
// AstLinker will link the singleton as the parent
if isSingletonConstructor then {
Ast.storeInDiffGraph(mAst, diffGraph)
Ast()
} else {
mAst
}
}
val methodAsts = prefixMemberAst :: prefixRefAssignAst :: methodAst_ :: suffixAsts
methodAsts.filterNot(_.root.isEmpty)
}

Expand Down Expand Up @@ -177,7 +208,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

/** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports.
*/
private def createMethodTypeBindings(method: NewMethod, refs: List[Ast]): Unit = {
protected def createMethodTypeBindings(method: NewMethod, refs: List[Ast]): Unit = {
refs.flatMap(_.root).collectFirst { case typeRef: NewTypeDecl =>
val bindingNode = newBindingNode("", "", method.fullName)
diffGraph.addEdge(typeRef, bindingNode, EdgeTypes.BINDS)
Expand Down Expand Up @@ -294,8 +325,9 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
}

// This will link the type decl to the surrounding context via base overlays
val typeDeclAst = astForClassDeclaration(node).last
val Seq(_, typeDeclAst, singletonAsts) = astForClassDeclaration(node).take(3)
Ast.storeInDiffGraph(typeDeclAst, diffGraph)
Ast.storeInDiffGraph(singletonAsts, diffGraph)

typeDeclAst.nodes
.collectFirst { case typeDecl: NewTypeDecl =>
Expand All @@ -310,7 +342,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
}

protected def astForSingletonMethodDeclaration(node: SingletonMethodDeclaration): Seq[Ast] = {
node.target match
node.target match {
case targetNode: SingletonMethodIdentifier =>
val fullName = computeMethodFullName(node.methodName)

Expand All @@ -321,7 +353,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
val baseType = node.target.span.text
scope.surroundingTypeFullName.map(_.split("[.]").last) match {
case Some(typ) if typ == baseType =>
(scope.surroundingAstLabel, scope.surroundingTypeFullName, baseType, false)
(scope.surroundingAstLabel, scope.surroundingScopeFullName, baseType, false)
case Some(typ) =>
scope.tryResolveTypeReference(baseType) match {
case Some(typ) =>
Expand Down Expand Up @@ -366,6 +398,10 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

scope.popScope()

// The member for these types refers to the singleton class
val member = memberForMethod(method, Option(NodeTypes.TYPE_DECL), astParentFullName.map(x => s"$x<class>"))
diffGraph.addNode(member)

val _methodAst =
methodAst(
method,
Expand All @@ -384,6 +420,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
s"Target node type for singleton method declarations are not supported yet: ${targetNode.text} (${targetNode.getClass.getSimpleName}), skipping"
)
astForUnknown(node) :: Nil
}
}

private def createMethodRefPointer(method: NewMethod): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.*
import io.joern.rubysrc2cpg.datastructures.{BlockScope, MethodScope, ModuleScope, TypeScope}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.utils.NodeBuilders.newModifierNode
import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.nodes.{
NewCall,
NewFieldIdentifier,
NewIdentifier,
NewMethod,
NewTypeDecl,
NewTypeRef
}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies, Operators}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies, ModifierTypes, Operators}

import scala.collection.immutable.List

Expand Down Expand Up @@ -65,39 +67,77 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:
inherits = inheritsFrom,
alias = None
)
/*
In Ruby, there are semantic differences between the ordinary class and singleton class (think "meta" class in
Python). Similar to how Java allows both static and dynamic methods/fields/etc. within the same type declaration,
Ruby allows `self` methods and @@ fields to be defined alongside ordinary methods and @ fields. However, both
classes are more dynamic and have separate behaviours in Ruby and we model it as such.
node match {
case _: ModuleDeclaration => scope.pushNewScope(ModuleScope(classFullName))
case _: TypeDeclaration => scope.pushNewScope(TypeScope(classFullName, List.empty))
To signify the singleton type, we add the <class> tag.
*/
val singletonTypeDecl = typeDecl.copy
.name(s"$className<class>")
.fullName(s"$classFullName<class>")
.inheritsFromTypeFullName(inheritsFrom.map(x => s"$x<class>"))

val (typeDeclModifiers, singletonModifiers) = node match {
case _: ModuleDeclaration =>
scope.pushNewScope(ModuleScope(classFullName))
(
ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply,
ModifierTypes.VIRTUAL :: ModifierTypes.FINAL :: Nil map newModifierNode map Ast.apply
)
case _: TypeDeclaration =>
scope.pushNewScope(TypeScope(classFullName, List.empty))
(
ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply,
ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply
)
}

val classBody =
node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList)

val classBodyAsts = classBody.statements.flatMap(astsForStatement) match {
val classBodyAsts = classBody.statements.flatMap {
case n: SingletonMethodDeclaration =>
val singletonMethodAst = astsForStatement(n)
// Create binding from singleton methods to singleton type decls
singletonMethodAst.flatMap(_.root).collectFirst { case n: NewMethod =>
createMethodTypeBindings(n, Ast(singletonTypeDecl) :: Nil)
}
// Method declaration remains in the normal type decl body
singletonMethodAst
case n => astsForStatement(n)
} match {
case bodyAsts if scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST =>
val bodyStart = classBody.span.spanStart()
val initBody = StatementList(List())(bodyStart)
val methodDecl = astForMethodDeclaration(
MethodDeclaration(XDefines.ConstructorMethodName, List(), initBody)(bodyStart)
)
val bodyStart = classBody.span.spanStart()
val initBody = StatementList(List())(bodyStart)
val methodDecl = astForMethodDeclaration(MethodDeclaration(Defines.Initialize, List(), initBody)(bodyStart))
methodDecl ++ bodyAsts
case bodyAsts => bodyAsts
}

val fieldMemberNodes = node match {
val (fieldTypeMemberNodes, fieldSingletonMemberNodes) = node match {
case classDecl: ClassDeclaration =>
classDecl.fields.map { x =>
val name = code(x)
Ast(memberNode(x, name, name, Defines.Any))
}
case _ => Seq.empty
classDecl.fields
.map { x =>
val name = code(x)
x.isInstanceOf[InstanceFieldIdentifier] -> Ast(memberNode(x, name, name, Defines.Any))
}
.partition(_._1)
case _ => Seq.empty -> Seq.empty
}

scope.popScope()
val prefixAst = createTypeRefPointer(typeDecl)
val typeDeclAsts = prefixAst :: Ast(typeDecl).withChildren(fieldMemberNodes).withChildren(classBodyAsts) :: Nil
typeDeclAsts.filterNot(_.root.isEmpty)
val prefixAst = createTypeRefPointer(typeDecl)
val typeDeclAst = Ast(typeDecl)
.withChildren(typeDeclModifiers)
.withChildren(fieldTypeMemberNodes.map(_._2))
.withChildren(classBodyAsts)
val singletonTypeDeclAst =
Ast(singletonTypeDecl).withChildren(singletonModifiers).withChildren(fieldSingletonMemberNodes.map(_._2))

prefixAst :: typeDeclAst :: singletonTypeDeclAst :: Nil filterNot (_.root.isEmpty)
}

private def createTypeRefPointer(typeDecl: NewTypeDecl): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.joern.x2cpg.Defines as XDefines
import io.joern.x2cpg.datastructures.{FieldLike, MethodLike, ProgramSummary, StubbedType, TypeLike}
import io.joern.x2cpg.typestub.{TypeStubMetaData, TypeStubUtil}
import org.slf4j.LoggerFactory
import io.joern.rubysrc2cpg.passes.Defines
import upickle.default.*

import java.io.{ByteArrayInputStream, InputStream}
Expand Down Expand Up @@ -165,7 +166,7 @@ case class RubyType(name: String, methods: List[RubyMethod], fields: List[RubyFi
}

def hasConstructor: Boolean = {
methods.exists(_.name == XDefines.ConstructorMethodName)
methods.exists(_.name == Defines.Initialize)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,16 +887,13 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {

val (instanceFieldsInMethodDecls, classFieldsInMethodDecls) = partitionRubyFields(fieldsInMethodDecls)

val initializeMethod = methodDecls.collectFirst { x =>
x.methodName match
case "initialize" => x
}
val initializeMethod = methodDecls.collectFirst { case x if x.methodName == Defines.Initialize => x }

val initStmtListStatements = genSingleAssignmentStmtList(instanceFields, instanceFieldsInMethodDecls)
val clinitStmtList = genSingleAssignmentStmtList(classFields, classFieldsInMethodDecls)

val clinitMethod =
MethodDeclaration(XDefines.StaticInitMethodName, List.empty, StatementList(clinitStmtList)(stmtList.span))(
MethodDeclaration(Defines.InitializeClass, List.empty, StatementList(clinitStmtList)(stmtList.span))(
stmtList.span
)

Expand All @@ -914,7 +911,7 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}
case None =>
val newInitMethod =
MethodDeclaration("initialize", List.empty, StatementList(initStmtListStatements)(stmtList.span))(
MethodDeclaration(Defines.Initialize, List.empty, StatementList(initStmtListStatements)(stmtList.span))(
stmtList.span
)
val initializers = newInitMethod :: clinitMethod :: Nil
Expand Down
Loading

0 comments on commit ed9671b

Please sign in to comment.