Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

javasrc2cpg: Handling Generic Types + Type Arguments #2655

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.joern.javasrc2cpg.passes

import com.github.javaparser.ast.`type`.TypeParameter
import com.github.javaparser.ast.`type`.{ClassOrInterfaceType, Type, TypeParameter}
import com.github.javaparser.ast.{CompilationUnit, Node, NodeList, PackageDeclaration}
import com.github.javaparser.ast.body.{
AnnotationDeclaration,
Expand Down Expand Up @@ -145,8 +145,11 @@ import io.shiftleft.codepropertygraph.generated.nodes.{
NewNamespaceBlock,
NewNode,
NewReturn,
NewType,
NewTypeArgument,
NewTypeDecl,
NewTypeRef
NewTypeRef,
NewUnknown
}
import io.joern.x2cpg.{Ast, AstCreatorBase, Defines}
import io.joern.x2cpg.datastructures.Global
Expand Down Expand Up @@ -698,7 +701,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa

val modifiers = List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC))

methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers)
methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, Ast(returnNode), modifiers)
}

private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = {
Expand Down Expand Up @@ -820,7 +823,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
constructorNode,
thisAst :: parameterAsts,
bodyAst,
methodReturn,
Ast(methodReturn),
modifiers,
annotationAsts
)
Expand Down Expand Up @@ -1003,7 +1006,15 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
val expectedReturnType = Try(symbolSolver.toResolvedType(methodDeclaration.getType, classOf[ResolvedType])).toOption
val returnTypeFullName = expectedReturnType
.flatMap(typeInfoCalc.fullName)
.orElse(scopeStack.lookupVariableType(methodDeclaration.getTypeAsString, wildcardFallback = true))
.orElse(
scopeStack.lookupVariableType(methodDeclaration.getTypeAsString.takeWhile(_ != '<'), wildcardFallback = true)
)
.orElse(Option(s"${Defines.UnresolvedNamespace}.${methodDeclaration.getTypeAsString}"))
DavidBakerEffendi marked this conversation as resolved.
Show resolved Hide resolved
val typeNode = methodDeclaration.getType match {
case x: ClassOrInterfaceType if x.getTypeArguments.isPresent =>
astForGenericType(x)
case _ => Ast() // This will be created by some TypePass
}

scopeStack.pushNewScope(MethodScope(ExpectedType(returnTypeFullName, expectedReturnType)))

Expand Down Expand Up @@ -1036,14 +1047,18 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
line(methodDeclaration.getType),
column(methodDeclaration.getType)
)
val methodReturnAst = typeNode.root match {
case Some(t) => Ast(methodReturn).withEvalTypeEdge(methodReturn, t)
case None => Ast(methodReturn)
}

val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq

val modifiers = modifiersForMethod(methodDeclaration)

scopeStack.popScope()

methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturn, modifiers, annotationAsts)
methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturnAst, modifiers, annotationAsts)
}

private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration): NewMethodReturn = {
Expand Down Expand Up @@ -2080,6 +2095,45 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
}
}

private def typeToTypeArgument(x: Type): Ast = {
val typeWithoutGeneric = x.asString().takeWhile(_ != '<')
val typeFullName = typeInfoCalc
.fullName(x)
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric))
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true))
.getOrElse(typeWithoutGeneric)
x match {
case t: ClassOrInterfaceType if t.getTypeArguments.isPresent =>
Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x)))
.withChildren(astForTypeArgument(t.getTypeArguments.get().asScala.toList))
case _ =>
Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x)))
}
}

private def astForTypeArgument(xs: List[Type]): Seq[Ast] = xs match {
case head :: next => typeToTypeArgument(head) +: astForTypeArgument(next)
case Nil => Seq.empty
}

private def astForGenericType(x: ClassOrInterfaceType): Ast = {
val typeArguments =
if (x.getTypeArguments.isPresent)
astForTypeArgument(x.getTypeArguments.get().asScala.toList)
else Seq.empty
val typeWithoutGeneric = x.asString().takeWhile(_ != '<')
val typeFullName = typeInfoCalc
.fullName(x)
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric))
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true))
.getOrElse(typeWithoutGeneric)
val typeName = typeFullName match {
case t if t.contains(".") && !t.endsWith(".") => t.substring(t.lastIndexOf('.') + 1)
case t => t
}
Ast(NewType().name(typeName).fullName(typeFullName)).withChildren(typeArguments)
}

private def assignmentsForVarDecl(
variables: Iterable[VariableDeclarator],
lineNumber: Option[Integer],
Expand All @@ -2104,22 +2158,28 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
// Need the actual resolvedType here for when the RHS is a lambda expression.
val resolvedExpectedType = Try(symbolSolver.toResolvedType(variable.getType, classOf[ResolvedType])).toOption
val initializerAsts = astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType))

val typeName = typeFullName
.map(TypeNodePass.fullToShortName)
.getOrElse(s"${Defines.UnresolvedNamespace}.${variable.getTypeAsString}")
val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}"
val code = s"${variable.getTypeAsString} $name = ${initializerAsts.rootCodeOrEmpty}"

val callNode = newOperatorCallNode(Operators.assignment, code, typeFullName, lineNumber, columnNumber)

val typeNode = variable.getType match {
case x: ClassOrInterfaceType if x.getTypeArguments.isPresent =>
astForGenericType(x)
case _ => Ast() // This will be created by the TypeUsagePass
}

val targetAst = scopeStack.lookupVariable(name) match {
case Some(nodeTypeInfo) if nodeTypeInfo.isField && !nodeTypeInfo.isStatic =>
val thisType = scopeStack.getEnclosingTypeDecl.map(_.fullName)
fieldAccessAst(NameConstants.This, thisType, name, typeFullName, line(variable), column(variable))

case maybeCorrespNode =>
val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any))
Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList)
val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any))
val identifierAst = Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList)
typeNode.root match {
case Some(t) => identifierAst.withEvalTypeEdge(identifier, t)
case None => identifierAst
}
}

// Since all partial constructors will be dealt with here, don't pass them up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying
import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal}
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal, TypeArgument}
import io.shiftleft.semanticcpg.language._

import java.io.File
Expand Down Expand Up @@ -210,6 +210,10 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
|package net.javaguides.hibernate;
|
|import java.util.List;
|import java.util.Map;
|import java.lang.Integer;
|import java.lang.Long;
|import java.lang.String;
|
|import org.hibernate.Session;
|import org.hibernate.Transaction;
Expand All @@ -235,8 +239,11 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
| transaction.rollback();
| }
| }
|
| }
|
| public List<Map<String, Integer>> foo() {
| return new List<>();
| }
|}
|""".stripMargin,
Seq("net", "javaguides", "hibernate", "NamedQueryExample.java").mkString(File.separator)
Expand All @@ -254,6 +261,32 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
transaction.typeFullName shouldBe "org.hibernate.Transaction"
transaction.dynamicTypeHintFullName.contains("null")
}

"present type arguments to generic types if known" in {
// List
// | Long
val Some(totalStudents) = cpg.identifier.nameExact("totalStudents").headOption
val List(list) = totalStudents.evalTypeOut.l
list.name shouldBe "List"
list.fullName shouldBe "java.util.List"
val List(long) = list.astOut.l
long.code shouldBe "java.lang.Long"
}

"present (nested) type arguments to method returns" in {
// List
// | Map
// | String | Integer
val Some(fooReturn) = cpg.method("foo").methodReturn.headOption
val List(list) = fooReturn.evalTypeOut.l
list.name shouldBe "List"
list.fullName shouldBe "java.util.List"
val List(map) = list.astOut.collectAll[TypeArgument].l
map.code shouldBe "java.util.Map"
val List(string, integer) = map._astOut.collectAll[TypeArgument].l
string.code shouldBe "java.lang.String"
integer.code shouldBe "java.lang.Integer"
}
}

}
Expand Down
25 changes: 21 additions & 4 deletions joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ object Ast {
ast.bindsEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS)
}

ast.evalTypeEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.EVAL_TYPE)
}
}

/** For all `order` fields that are unset, derive the `order` field automatically by determining the position of the
Expand Down Expand Up @@ -86,7 +90,8 @@ case class Ast(
refEdges: collection.Seq[AstEdge] = Vector.empty,
bindsEdges: collection.Seq[AstEdge] = Vector.empty,
receiverEdges: collection.Seq[AstEdge] = Vector.empty,
argEdges: collection.Seq[AstEdge] = Vector.empty
argEdges: collection.Seq[AstEdge] = Vector.empty,
evalTypeEdges: collection.Seq[AstEdge] = Vector.empty
) {

def root: Option[NewNode] = nodes.headOption
Expand All @@ -107,7 +112,8 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges
bindsEdges = bindsEdges ++ other.bindsEdges,
evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges
)
}

Expand All @@ -119,7 +125,8 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges
bindsEdges = bindsEdges ++ other.bindsEdges,
evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges
)
}

Expand Down Expand Up @@ -154,6 +161,10 @@ case class Ast(
this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst)))
}

def withEvalTypeEdge(src: NewNode, dst: NewNode): Ast = {
this.copy(evalTypeEdges = evalTypeEdges ++ List(AstEdge(src, dst)))
}

def withArgEdge(src: NewNode, dst: NewNode): Ast = {
this.copy(argEdges = argEdges ++ List(AstEdge(src, dst)))
}
Expand Down Expand Up @@ -200,6 +211,10 @@ case class Ast(
this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _)))
}

def withEvalTypeEdges(src: NewNode, dsts: List[NewNode]): Ast = {
this.copy(evalTypeEdges = evalTypeEdges ++ dsts.map(AstEdge(src, _)))
}

/** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex`
* fields of the new root node are set to `order`.
*/
Expand Down Expand Up @@ -229,14 +244,16 @@ case class Ast(
val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newEvalTypeEdges = evalTypeEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))

Ast(newNode)
.copy(
argEdges = newArgEdges,
conditionEdges = newConditionEdges,
refEdges = newRefEdges,
bindsEdges = newBindsEdges,
receiverEdges = newReceiverEdges
receiverEdges = newReceiverEdges,
evalTypeEdges = newEvalTypeEdges
)
.withChildren(newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ abstract class AstCreatorBase(filename: String) {
methodReturn: NewMethodReturn,
modifiers: Seq[NewModifier] = Nil
): Ast =
methodAstWithAnnotations(method, parameters, body, methodReturn, modifiers, annotations = Nil)
methodAstWithAnnotations(method, parameters, body, Ast(methodReturn), modifiers, annotations = Nil)

/** Creates an AST that represents an entire method, including its content and with support for both method and
* parameter annotations.
Expand All @@ -62,7 +62,7 @@ abstract class AstCreatorBase(filename: String) {
method: NewMethod,
parameters: Seq[Ast],
body: Ast,
methodReturn: NewMethodReturn,
methodReturn: Ast,
modifiers: Seq[NewModifier] = Nil,
annotations: Seq[Ast] = Nil
): Ast =
Expand All @@ -71,7 +71,7 @@ abstract class AstCreatorBase(filename: String) {
.withChild(body)
.withChildren(modifiers.map(Ast(_)))
.withChildren(annotations)
.withChild(Ast(methodReturn))
.withChild(methodReturn)

/** Creates an AST that represents a method stub, containing information about the method, its parameters, and the
* return type.
Expand Down