Skip to content

Commit

Permalink
[ruby] MemberAccess/MemberCall Handling (#4676)
Browse files Browse the repository at this point in the history
* `initialize` methods are under all types/modules with fields that need to be initialized, where `InstanceFields` under classes are considered `ClassFields` and places under the singleton type. e.g, `class Foo;SomeMember=1;end` is considered under the singleton.
* `MemberAccess` is now not always treated as a call. As per Ruby, if the first letter of the member is capitalized, then it is a field access, otherwise a function call.
* Calls (and qualified calls) have their base/targets recursively checked for if they need to be prepended with `self.` (i.e., not a local variable that has been declared in scope).
* Added initial `<body>` method to catch all arbitrary statements under a type declaration, requires follow-up work
* Fixed local variable scoping, where variables introduced in control structure blocks get fixed to the parent method and not the block itself.
  • Loading branch information
DavidBakerEffendi authored Jun 18, 2024
1 parent dd53acd commit 5a4c499
Show file tree
Hide file tree
Showing 27 changed files with 575 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ class AstCreator(
scope.popScope()
val bodyAst = blockAst(block, statementAsts)
scope.popScope()
methodAst(methodNode_, Seq.empty, bodyAst, methodReturn, newModifierNode(ModifierTypes.MODULE) :: Nil)
methodAst(
methodNode_,
Seq.empty,
bodyAst,
methodReturn,
newModifierNode(ModifierTypes.MODULE) :: newModifierNode(ModifierTypes.VIRTUAL) :: Nil
)
}
.getOrElse(Ast())
}
Expand Down Expand Up @@ -131,7 +137,9 @@ class AstCreator(
diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS)
diffGraph.addEdge(bindingNode, method, EdgeTypes.REF)

Ast(typeDeclNode_).withChild(Ast(newModifierNode(ModifierTypes.MODULE))).withChildren(members)
Ast(typeDeclNode_)
.withChildren(Ast(newModifierNode(ModifierTypes.MODULE)) :: Ast(newModifierNode(ModifierTypes.VIRTUAL)) :: Nil)
.withChildren(members)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.shiftleft.codepropertygraph.generated.{
ControlStructureTypes,
DispatchTypes,
EdgeTypes,
NodeTypes,
Operators,
PropertyNames
}
Expand All @@ -32,6 +33,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: IndexAccess => astForIndexAccess(node)
case node: SingleAssignment => astForSingleAssignment(node)
case node: AttributeAssignment => astForAttributeAssignment(node)
case node: TypeIdentifier => astForTypeIdentifier(node)
case node: RubyIdentifier => astForSimpleIdentifier(node)
case node: SimpleCall => astForSimpleCall(node)
case node: RequireCall => astForRequireCall(node)
Expand Down Expand Up @@ -131,9 +133,20 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
callAst(call, Seq(lhsAst, rhsAst))
}

// Member accesses are lowered as calls, i.e. `x.y` is the call of `y` of `x` without any arguments.
// Member accesses are checked in RubyNodeCreator, i.e. `x.y` is the call of `y` of `x` without any arguments.
// where x.Y is considered a constant access as Y is capitalized.
protected def astForMemberAccess(node: MemberAccess): Ast = {
astForMemberCall(MemberCall(node.target, node.op, node.memberName, List.empty)(node.span))
node.target match {
case x: SimpleIdentifier =>
val newTarget = scope.getSurroundingType(x.text).map(_.fullName) match {
case Some(surroundingType) =>
val typeName = surroundingType.split('.').last
TypeIdentifier(s"$surroundingType<class>")(x.span.spanStart(typeName))
case None => x
}
astForFieldAccess(node.copy(target = newTarget)(node.span))
case _ => astForFieldAccess(node)
}
}

/** Attempts to extract a type from the base of a member call.
Expand All @@ -154,27 +167,65 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}
}

private def astForTypeIdentifier(node: TypeIdentifier): Ast = {
Ast(typeRefNode(node, code(node), node.typeFullName))
}

protected def astForMemberCall(node: MemberCall): Ast = {
// Use the scope type recovery to attempt to obtain a receiver type for the call
// TODO: Type recovery should potentially resolve this
val receiver = astForExpression(node.target)
val (receiverFullName, methodFullName) = receiver.root match {
case Some(x: NewMethodRef) => x.methodFullName -> x.methodFullName
case _ =>
typeFromCallTarget(node.target)
.map(x => x -> s"$x:${node.methodName}")
.getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName)
}
val argumentAsts = node.arguments.map(astForMethodCallArgument)

receiver.root.collect { case x: NewCall => x.typeFullName(methodFullName) }
val dispatchType =
if receiverFullName.startsWith(s"<${GlobalTypes.builtinPrefix}") then DispatchTypes.STATIC_DISPATCH
else DispatchTypes.DYNAMIC_DISPATCH
def createMemberCall(n: MemberCall): Ast = {
val baseAst = astForExpression(n.target) // this wil be something like self.Foo
val receiverAst = astForExpression(MemberAccess(n.target, ".", n.methodName)(n.span))
val builtinType = n.target match {
case MemberAccess(_: SelfIdentifier, _, memberName) if isBundledClass(memberName) =>
Option(prefixAsBundledType(memberName))
case x: TypeIdentifier if x.isBuiltin => Option(x.typeFullName)
case _ => None
}
val (receiverFullName, methodFullName) = receiverAst.nodes
.collectFirst {
case _ if builtinType.isDefined => builtinType.get -> s"${builtinType.get}:${n.methodName}"
case x: NewMethodRef => x.methodFullName -> x.methodFullName
case _ =>
(n.target match {
case ma: MemberAccess => scope.tryResolveTypeReference(ma.memberName).map(_.name)
case _ => typeFromCallTarget(n.target)
}).map(x => x -> s"$x:${n.methodName}")
.getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName)
}
.getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName)
val argumentAsts = n.arguments.map(astForMethodCallArgument)
val dispatchType =
if builtinType.isDefined then DispatchTypes.STATIC_DISPATCH
else DispatchTypes.DYNAMIC_DISPATCH

val call = callNode(n, code(n), n.methodName, methodFullName, dispatchType)
callAst(call, argumentAsts, base = Option(baseAst), receiver = Option(receiverAst))
}

val fieldAccessCall = callNode(node, code(node), node.methodName, methodFullName, dispatchType)
def determineMemberAccessBase(target: RubyNode): RubyNode = target match {
case MemberAccess(SelfIdentifier(), _, _) => target
case x: SimpleIdentifier =>
scope.getSurroundingType(x.text).map(_.fullName) match {
case Some(surroundingType) =>
val typeName = surroundingType.split('.').last
TypeIdentifier(s"$surroundingType<class>")(x.span.spanStart(typeName))
case None if scope.lookupVariable(x.text).isDefined => x
case None => MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span)
}
case x @ MemberAccess(ma, op, memberName) => x.copy(target = determineMemberAccessBase(ma))(x.span)
case _ => target
}

callAst(fieldAccessCall, argumentAsts, Option(receiver))
node.target match {
case x: SimpleIdentifier if isBundledClass(x.text) =>
createMemberCall(node.copy(target = TypeIdentifier(prefixAsBundledType(x.text))(x.span))(node.span))
case x: SimpleIdentifier =>
createMemberCall(node.copy(target = determineMemberAccessBase(x))(node.span))
case memAccess: MemberAccess =>
createMemberCall(node.copy(target = determineMemberAccessBase(memAccess))(node.span))
case x => createMemberCall(node)
}
}

protected def astForIndexAccess(node: IndexAccess): Ast = {
Expand Down Expand Up @@ -223,10 +274,15 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val block = blockNode(node)
scope.pushNewScope(BlockScope(block))

val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpGen.fresh))
val tmpName = tmpGen.fresh
val tmpTypeHint = receiverTypeFullName.stripSuffix("<class>")
val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpName))
val tmpLocal = NewLocal().name(tmpName).code(tmpName).dynamicTypeHintFullName(Seq(tmpTypeHint))
scope.addToScope(tmpName, tmpLocal)

def tmpIdentifier = {
val tmpAst = astForSimpleIdentifier(tmp)
tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(receiverTypeFullName.stripSuffix("<class>")) }
tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(tmpTypeHint) }
tmpAst
}

Expand Down Expand Up @@ -255,10 +311,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

val constructorCall = callNode(node, code(node), callName, fullName, DispatchTypes.DYNAMIC_DISPATCH)
val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier))
val retIdentifierAst = tmpIdentifier
scope.popScope()

// Assemble statements
blockAst(block, tmpAssignment :: constructorCallAst :: tmpIdentifier :: Nil)
blockAst(block, Ast(tmpLocal) :: tmpAssignment :: constructorCallAst :: retIdentifierAst :: Nil)
}

protected def astForSingleAssignment(node: SingleAssignment): Ast = {
Expand Down Expand Up @@ -375,7 +432,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

protected def astForSimpleIdentifier(node: RubyNode & RubyIdentifier): Ast = {
val name = code(node)

if (isBundledClass(name)) {
val typeFullName = prefixAsBundledType(name)
Ast(typeRefNode(node, typeFullName, typeFullName))
Expand All @@ -384,7 +440,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case Some(_) => handleVariableOccurrence(node)
case None if scope.tryResolveMethodInvocation(node.text).isDefined =>
astForSimpleCall(SimpleCall(node, List())(node.span))
case None => handleVariableOccurrence(node)
case None =>
astForMemberAccess(
MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", node.text)(node.span)
)
}
}
}
Expand Down Expand Up @@ -487,6 +546,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

val block = blockNode(node)
scope.pushNewScope(BlockScope(block))
val tmpLocal = NewLocal().name(tmp).code(tmp)
scope.addToScope(tmp, tmpLocal)

val argumentAsts = node.elements.flatMap(elem =>
elem match
Expand All @@ -507,9 +568,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val assignment =
callNode(node, code(node), Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH)
val tmpAssignment = callAst(assignment, tmpAst() :: Ast(hashInitCall) :: Nil)
val tmpRetAst = tmpAst(node.elements.lastOption)

scope.popScope()
blockAst(block, tmpAssignment +: argumentAsts :+ tmpAst(node.elements.lastOption))
blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst)
}

protected def astForAssociationHash(node: Association, tmp: String): List[Ast] = {
Expand Down Expand Up @@ -652,7 +714,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

private def astForSelfIdentifier(node: SelfIdentifier): Ast = {
val thisIdentifier = identifierNode(node, "this", code(node), scope.surroundingTypeFullName.getOrElse(Defines.Any))
val thisIdentifier =
identifierNode(node, Defines.Self, code(node), scope.surroundingTypeFullName.getOrElse(Defines.Any))
Ast(thisIdentifier)
}

Expand All @@ -672,7 +735,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
.getOrElse(XDefines.DynamicCallUnknownFullName)
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH)

callAst(call, argumentAsts, Some(receiverAst))
}

Expand Down Expand Up @@ -701,20 +763,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
else DispatchTypes.DYNAMIC_DISPATCH

val call = callNode(node, code(node), methodName, methodFullName, dispatchType)
val receiverAst = {
val fi = Ast(fieldIdentifierNode(node, call.name, call.name))
val self = Ast(identifierNode(node, Defines.Self, Defines.Self, receiverType))
val baseAccess = callNode(
node,
s"${Defines.Self}.${call.name}",
Operators.fieldAccess,
Operators.fieldAccess,
DispatchTypes.STATIC_DISPATCH,
None,
Option(Defines.Any)
)
callAst(baseAccess, Seq(self, fi))
}
val receiverAst = astForExpression(
MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", call.name)(node.span)
)
val baseAst = Ast(identifierNode(node, Defines.Self, Defines.Self, receiverType))
callAst(call, argumentAst, Option(baseAst), Option(receiverAst))
}
Expand Down Expand Up @@ -747,19 +798,35 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
Ast.storeInDiffGraph(typeDecl, diffGraph)

methodRef
case selfMethod: SingletonMethodDeclaration =>
// Last element is the method declaration, the prefix methods would be `foo = def foo (...)` pointers in other
// contexts, but this would be empty as a method call argument
val Seq(_, methodDeclAst) = astForSingletonMethodDeclaration(selfMethod)
scope.surroundingTypeFullName.foreach { tfn =>
methodDeclAst.root.collect { case m: NewMethod =>
m.astParentType(NodeTypes.TYPE_DECL).astParentFullName(s"$tfn<class>")
}
}
Ast.storeInDiffGraph(methodDeclAst, diffGraph)
scope.surroundingScopeFullName
.map(s => Ast(methodRefNode(node, selfMethod.span.text, s"$s:${selfMethod.methodName}", Defines.Any)))
.getOrElse(Ast())
case _ => astForExpression(node)
}

private def astForKeywordArgument(assoc: Association): Ast = {
val value = astForExpression(assoc.value)
astForExpression(assoc.key).root match
case Some(keyNode: NewIdentifier) =>
assoc.key match
case keyIdentifier: SimpleIdentifier =>
value.root.collectFirst { case x: ExpressionNew =>
x.argumentName_=(Option(keyNode.name))
x.argumentName_=(Option(keyIdentifier.text))
x.argumentIndex_=(-1)
}
value
case _ => astForExpression(assoc)
case _: StaticLiteral => astForExpression(assoc)
case x =>
logger.warn(s"Not explicitly handled argument association key of type ${x.getClass.getSimpleName}")
astForExpression(assoc)
}

protected def astForFieldAccess(node: MemberAccess): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
}
)

val isSurroundedByProgramScope = scope.isSurroundedByProgramScope
if (isConstructor) scope.pushNewScope(ConstructorScope(fullName))
else scope.pushNewScope(MethodScope(fullName, procParamGen.fresh))

Expand Down Expand Up @@ -124,7 +125,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
createMethodTypeBindings(method, refs)

val prefixMemberAst =
if isClosure || scope.isSurroundedByProgramScope then Ast() // program scope members are set elsewhere
if isClosure || isSurroundedByProgramScope then Ast() // program scope members are set elsewhere
else {
// Singleton constructors that initialize @@ fields should have their members linked under the singleton class
val methodMember = scope.surroundingTypeFullName.map {
Expand Down Expand Up @@ -407,7 +408,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
method,
(thisParameterAst +: parameterAsts) ++ anonProcParam,
stmtBlockAst,
methodReturnNode(node, Defines.Any)
methodReturnNode(node, Defines.Any),
newModifierNode(ModifierTypes.VIRTUAL) :: Nil
)
if (addEdge) {
Ast.storeInDiffGraph(_methodAst, diffGraph)
Expand Down
Loading

0 comments on commit 5a4c499

Please sign in to comment.