Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

javasrc2cpg: Handling Generic Types + Type Arguments #2655

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@ import com.github.javaparser.ast.Node.Parsedness
import com.github.javaparser.symbolsolver.JavaSymbolSolver
import com.github.javaparser.symbolsolver.resolution.typesolvers.JarTypeSolver
import com.github.javaparser.{JavaParser, ParserConfiguration}
import io.joern.javasrc2cpg.passes.{
AstCreationPass,
ConfigFileCreationPass,
JavaTypeHintCallLinker,
JavaTypeRecoveryPass,
TypeInferencePass
}
import io.joern.javasrc2cpg.passes._
import io.joern.javasrc2cpg.typesolvers.{CachingReflectionTypeSolver, EagerSourceTypeSolver, SimpleCombinedTypeSolver}
import io.joern.javasrc2cpg.util.Delombok.DelombokMode
import io.joern.javasrc2cpg.util.{Delombok, SourceRootFinder}
Expand Down Expand Up @@ -85,7 +79,11 @@ class JavaSrc2Cpg extends X2CpgFrontend[Config] {
val astCreationPass = new AstCreationPass(javaparserAsts.analysisAsts, config, cpg, symbolSolver)
astCreationPass.createAndApply()
new ConfigFileCreationPass(config.inputPath, cpg).createAndApply()
new TypeNodePass(astCreationPass.global.usedTypes.keys().asScala.toList, cpg).createAndApply()
new TypeNodePass(
astCreationPass.global.usedTypes.keys().asScala.toList,
cpg,
nodesWithGenericTypes = astCreationPass.global.nodesWithGenericTypes.asScala.toMap
).createAndApply()
new TypeInferencePass(cpg).createAndApply()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package io.joern.javasrc2cpg.passes

import com.github.javaparser.symbolsolver.JavaSymbolSolver
import io.joern.javasrc2cpg.{Config, JpAstWithMeta}
import io.joern.x2cpg.datastructures.{CodeTree, Global, TreeNode}
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.EdgeTypes
import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewType, NewTypeDecl, NewTypeParameter}
import io.shiftleft.passes.ConcurrentWriterCpgPass
import io.joern.javasrc2cpg.{Config, JpAstWithMeta}
import io.joern.x2cpg.datastructures.Global
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.jdk.CollectionConverters.MapHasAsScala

class AstCreationPass(asts: List[JpAstWithMeta], config: Config, cpg: Cpg, symbolSolver: JavaSymbolSolver)
extends ConcurrentWriterCpgPass[JpAstWithMeta](cpg) {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.joern.javasrc2cpg.passes

import com.github.javaparser.ast.`type`.TypeParameter
import com.github.javaparser.ast.`type`.{ClassOrInterfaceType, Type, TypeParameter}
import com.github.javaparser.ast.{CompilationUnit, Node, NodeList, PackageDeclaration}
import com.github.javaparser.ast.body.{
AnnotationDeclaration,
Expand Down Expand Up @@ -145,12 +145,13 @@ import io.shiftleft.codepropertygraph.generated.nodes.{
NewNamespaceBlock,
NewNode,
NewReturn,
NewType,
NewTypeArgument,
NewTypeDecl,
NewTypeRef
}
import io.joern.x2cpg.{Ast, AstCreatorBase, Defines}
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.passes.frontend.TypeNodePass
import io.joern.x2cpg.datastructures.{Global, JavaTree, TreeNode}
import io.joern.x2cpg.utils.AstPropertiesUtil._
import io.joern.x2cpg.utils.NodeBuilders
import io.joern.x2cpg.AstNodeBuilder
Expand Down Expand Up @@ -1003,7 +1004,10 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
val expectedReturnType = Try(symbolSolver.toResolvedType(methodDeclaration.getType, classOf[ResolvedType])).toOption
val returnTypeFullName = expectedReturnType
.flatMap(typeInfoCalc.fullName)
.orElse(scopeStack.lookupVariableType(methodDeclaration.getTypeAsString, wildcardFallback = true))
.orElse(
scopeStack.lookupVariableType(methodDeclaration.getTypeAsString.takeWhile(_ != '<'), wildcardFallback = true)
)
.orElse(Option(s"${Defines.UnresolvedNamespace}.${methodDeclaration.getTypeAsString}"))
DavidBakerEffendi marked this conversation as resolved.
Show resolved Hide resolved

scopeStack.pushNewScope(MethodScope(ExpectedType(returnTypeFullName, expectedReturnType)))

Expand Down Expand Up @@ -1036,6 +1040,11 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
line(methodDeclaration.getType),
column(methodDeclaration.getType)
)
methodDeclaration.getType match {
case x: ClassOrInterfaceType if x.getTypeArguments.isPresent =>
global.nodesWithGenericTypes.put(methodReturn, astForGenericType(x))
case _ =>
}

val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq

Expand Down Expand Up @@ -2080,6 +2089,41 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
}
}

private def typeToTypeArgument(x: Type): TreeNode = {
val typeWithoutGeneric = x.asString().takeWhile(_ != '<')
val typeFullName = typeInfoCalc
.fullName(x)
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric))
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true))
.getOrElse(s"${Defines.UnresolvedNamespace}.$typeWithoutGeneric")
x match {
case t: ClassOrInterfaceType if t.getTypeArguments.isPresent =>
TreeNode(typeFullName)
.withChildren(astForTypeArgument(t.getTypeArguments.get().asScala.toList))
case _ =>
TreeNode(typeFullName)
}
}

private def astForTypeArgument(xs: List[Type]): List[TreeNode] = xs match {
case head :: next => typeToTypeArgument(head) +: astForTypeArgument(next)
case Nil => List.empty
}

private def astForGenericType(x: ClassOrInterfaceType): JavaTree = {
val typeArguments =
if (x.getTypeArguments.isPresent)
astForTypeArgument(x.getTypeArguments.get().asScala.toList)
else List.empty
val typeWithoutGeneric = x.asString().takeWhile(_ != '<')
val typeFullName = typeInfoCalc
.fullName(x)
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric))
.orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true))
.getOrElse(s"${Defines.UnresolvedNamespace}.$typeWithoutGeneric")
new JavaTree(io.joern.x2cpg.datastructures.TreeNode(typeFullName).withChildren(typeArguments))
}

private def assignmentsForVarDecl(
variables: Iterable[VariableDeclarator],
lineNumber: Option[Integer],
Expand All @@ -2104,11 +2148,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa
// Need the actual resolvedType here for when the RHS is a lambda expression.
val resolvedExpectedType = Try(symbolSolver.toResolvedType(variable.getType, classOf[ResolvedType])).toOption
val initializerAsts = astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType))

val typeName = typeFullName
.map(TypeNodePass.fullToShortName)
.getOrElse(s"${Defines.UnresolvedNamespace}.${variable.getTypeAsString}")
val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}"
val code = s"${variable.getTypeAsString} $name = ${initializerAsts.rootCodeOrEmpty}"

val callNode = newOperatorCallNode(Operators.assignment, code, typeFullName, lineNumber, columnNumber)

Expand All @@ -2119,6 +2159,11 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa

case maybeCorrespNode =>
val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any))
variable.getType match {
case x: ClassOrInterfaceType if x.getTypeArguments.isPresent =>
global.nodesWithGenericTypes.put(identifier, astForGenericType(x))
case _ =>
}
Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ class GenericsTests extends JavaSrcCode2CpgFixture {
|public class Test extends Box<String> {}
|""".stripMargin)

"it should create the correct generic typeDecl name" in {
"it should create the correct generic typeDecls, each with a simple name and one with the arguments" in {
cpg.typeDecl.nameExact("Box").l match {
case decl :: Nil => decl.fullName shouldBe "Box"
case decl1 :: decl2 :: Nil =>
decl1.fullName shouldBe "Box"
decl2.fullName shouldBe "Box<java.lang.Object>"

case res => fail(s"Expected typeDecl Box but got $res")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package io.joern.javasrc2cpg.querying

import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture
import io.joern.x2cpg.Defines
import io.joern.x2cpg.datastructures.TreeNode
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal}
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal, TypeArgument}
import io.shiftleft.semanticcpg.language._

import java.io.File
Expand Down Expand Up @@ -210,6 +211,10 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
|package net.javaguides.hibernate;
|
|import java.util.List;
|import java.util.Map;
|import java.lang.Integer;
|import java.lang.Long;
|import java.lang.String;
|
|import org.hibernate.Session;
|import org.hibernate.Transaction;
Expand All @@ -235,8 +240,11 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
| transaction.rollback();
| }
| }
|
| }
|
| public List<Map<String, Integer>> foo() {
| return new List<>();
| }
|}
|""".stripMargin,
Seq("net", "javaguides", "hibernate", "NamedQueryExample.java").mkString(File.separator)
Expand All @@ -254,6 +262,25 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover
transaction.typeFullName shouldBe "org.hibernate.Transaction"
transaction.dynamicTypeHintFullName.contains("null")
}

"present type arguments to generic types if known" in {
// List
// | Long
val Some(totalStudents) = cpg.identifier.nameExact("totalStudents").headOption
val List(list) = totalStudents.evalTypeOut.l
list.name shouldBe "List"
list.fullName shouldBe "java.util.List<java.lang.Long>"
}

"present (nested) type arguments to method returns" in {
// List
// | Map
// | String | Integer
val Some(fooReturn) = cpg.method("foo").methodReturn.headOption
val List(list) = fooReturn.evalTypeOut.l
list.name shouldBe "List"
list.fullName shouldBe "java.util.List<java.util.Map<java.lang.String, java.lang.Integer>>"
}
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,49 @@
package io.joern.x2cpg.datastructures

import io.shiftleft.codepropertygraph.generated.nodes.NewNode

import java.util.concurrent.ConcurrentHashMap

class Global {

val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap()

val nodesWithGenericTypes: ConcurrentHashMap[NewNode, CodeTree] = new ConcurrentHashMap()

}

case class TreeNode(value: String, children: List[TreeNode] = List.empty) {

def withChildren(children: List[TreeNode]): TreeNode = this.copy(children = this.children ++ children)

override def toString: String = value
}

abstract class CodeTree(val root: TreeNode) {

protected val separator: String
protected val lbracket: String
protected val rbracket: String

// Lazy load the code tree string
private lazy val treeString = _toString(List(root))

override def toString: String = treeString

private def _toString(xs: List[TreeNode]): String = xs match {
case head :: Nil if head.children.nonEmpty =>
head.toString + lbracket + _toString(head.children) + rbracket
case head :: next if head.children.nonEmpty =>
head.toString + lbracket + _toString(head.children) + rbracket + separator + _toString(next)
case head :: Nil => head.toString
case head :: next => head.toString + separator + _toString(next)
case Nil => ""
}

}

final class JavaTree(root: TreeNode) extends CodeTree(root) {
override protected val separator: String = ", "
override protected val lbracket: String = "<"
override protected val rbracket: String = ">"
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package io.joern.x2cpg.passes.frontend

import io.joern.x2cpg.datastructures.{CodeTree, TreeNode}
import io.joern.x2cpg.passes.frontend.TypeNodePass.fullToShortName
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.NewType
import io.shiftleft.passes.{KeyPool, CpgPass}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes}
import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewType, NewTypeDecl, NewTypeParameter}
import io.shiftleft.passes.{CpgPass, KeyPool}

import scala.collection.mutable

/** Creates a `TYPE` node for each type in `usedTypes`
*/
class TypeNodePass(usedTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None)
extends CpgPass(cpg, "types", keyPool) {
class TypeNodePass(
usedTypes: List[String],
cpg: Cpg,
keyPool: Option[KeyPool] = None,
nodesWithGenericTypes: Map[NewNode, CodeTree] = Map.empty
) extends CpgPass(cpg, "types", keyPool) {

override def run(diffGraph: DiffGraphBuilder): Unit = {

Expand All @@ -27,7 +35,54 @@ class TypeNodePass(usedTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] =
.typeDeclFullName(typeName)
diffGraph.addNode(node)
}

generateGenericTypes(diffGraph, nodesWithGenericTypes)
}

private def generateGenericTypes(diffGraph: DiffGraphBuilder, nodesWithGenericTypes: Map[NewNode, CodeTree]): Unit = {

def treeNodeToTypeParameter(x: TreeNode): NewTypeParameter = {
val typeParameter = NewTypeParameter().name(x.value).code(x.value)
diffGraph.addNode(typeParameter)
generateTypeParametersFromChildren(x.children).foreach(tp => diffGraph.addEdge(typeParameter, tp, EdgeTypes.AST))
typeParameter
}

def generateTypeParametersFromChildren(xs: List[TreeNode]): List[NewTypeParameter] = xs match {
case head :: Nil => List(treeNodeToTypeParameter(head))
case head :: next => treeNodeToTypeParameter(head) +: generateTypeParametersFromChildren(next)
case Nil => List.empty
}

def generateTypeNodeFromTree(tree: CodeTree): NewType = {
val fullName = tree.toString
val shortType = tree.root.value match {
case t if t.contains('.') && !t.endsWith(".") => t.substring(t.lastIndexOf('.') + 1)
case t => t
}
val typeNode = NewType().name(shortType).fullName(fullName).typeDeclFullName(fullName)
val typeDecl = NewTypeDecl()
.name(shortType)
.fullName(fullName)
.astParentType(NodeTypes.NAMESPACE_BLOCK)
.astParentFullName("ANY")
diffGraph.addNode(typeNode).addNode(typeDecl).addEdge(typeNode, typeDecl, EdgeTypes.REF)
// TODO: How to do TYPE->TYPE_ARGUMENT or TYPE_DECL->TYPE_PARAMETER?
//
// generateTypeParametersFromChildren(tree.root.children).foreach(ta =>
// diffGraph.addEdge(typeDecl, ta, EdgeTypes.AST)
// )
typeNode
}

val typeToNode = mutable.HashMap.empty[String, NewType]

nodesWithGenericTypes.foreach { case (node, tree) =>
val associatedTypeNode = typeToNode.getOrElseUpdate(tree.toString, generateTypeNodeFromTree(tree))
diffGraph.addEdge(node, associatedTypeNode, EdgeTypes.EVAL_TYPE)
}
}

}

object TypeNodePass {
Expand Down