diff --git a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt index 8234f9e1a87a..9668d74dab40 100644 --- a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt +++ b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt @@ -3515,18 +3515,6 @@ OLD: KE1 isNullable: Boolean? = false ) = isFunction(target, pkgName, className, { it == className }, fName, isNullable) - private fun isNumericFunction(target: IrFunction, fName: String): Boolean { - return isFunction(target, "kotlin", "Int", fName) || - isFunction(target, "kotlin", "Byte", fName) || - isFunction(target, "kotlin", "Short", fName) || - isFunction(target, "kotlin", "Long", fName) || - isFunction(target, "kotlin", "Float", fName) || - isFunction(target, "kotlin", "Double", fName) - } - - private fun isNumericFunction(target: IrFunction, vararg fNames: String) = - fNames.any { isNumericFunction(target, it) } - private fun isArrayType(typeName: String) = when (typeName) { "Array" -> true @@ -3747,11 +3735,6 @@ OLD: KE1 val type = useType(c.type) val id: Label = when (val targetName = target.name.asString()) { - "plus" -> { - val id = tw.getFreshIdLabel() - tw.writeExprs_addexpr(id, type.javaResult.id, parent, idx) - id - } "minus" -> { val id = tw.getFreshIdLabel() tw.writeExprs_subexpr(id, type.javaResult.id, parent, idx) diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt index 4592059b1776..97a18aafb58a 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt @@ -3,6 +3,10 @@ package com.github.codeql import com.github.codeql.KotlinFileExtractor.StmtExprParent import org.jetbrains.kotlin.KtNodeTypes import org.jetbrains.kotlin.analysis.api.KaSession +import org.jetbrains.kotlin.analysis.api.resolution.KaSimpleFunctionCall +import org.jetbrains.kotlin.analysis.api.resolution.KaSuccessCallInfo +import org.jetbrains.kotlin.analysis.api.resolution.symbol +import org.jetbrains.kotlin.analysis.api.symbols.KaFunctionSymbol import org.jetbrains.kotlin.analysis.api.types.KaType import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.parsing.parseNumericLiteral @@ -211,6 +215,78 @@ OLD: KE1 } */ +private fun isFunction( + symbol: KaFunctionSymbol, + packageName: String, + className: String, + functionName: String +): Boolean { + + return symbol.callableId?.packageName?.asString() == packageName && + symbol.callableId?.className?.asString() == className && + symbol.callableId?.callableName?.asString() == functionName +} + +private fun isNumericFunction(target: KaFunctionSymbol, fName: String): Boolean { + return isFunction(target, "kotlin", "Int", fName) || + isFunction(target, "kotlin", "Byte", fName) || + isFunction(target, "kotlin", "Short", fName) || + isFunction(target, "kotlin", "Long", fName) || + isFunction(target, "kotlin", "Float", fName) || + isFunction(target, "kotlin", "Double", fName) +} + +/** + * Extracts a binary expression as either a binary expression or a function call. + * + * Overloaded operators are extracted as function calls. + * + * ``` + * data class Counter(val dayIndex: Int) { + * operator fun plus(increment: Int): Counter { + * return Counter(dayIndex + increment) + * } + * } + * ``` + * + * `Counter(1) + 3` is extracted as `Counter(1).plus(3)`. + * + */ +context(KaSession) +private fun KotlinFileExtractor.extractBinaryExpression( + expression: KtBinaryExpression, + callable: Label, + parent: StmtExprParent +) { + val op = expression.operationToken + val target = ((expression.resolveToCall() as? KaSuccessCallInfo)?.call as? KaSimpleFunctionCall)?.symbol + + when (op) { + KtTokens.PLUS -> { + if (target == null) { + TODO() + } + + if (isNumericFunction(target, "plus")) { + val id = tw.getFreshIdLabel() + val type = useType(expression.expressionType) + val exprParent = parent.expr(expression, callable) + tw.writeExprs_addexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx) + tw.writeExprsKotlinType(id, type.kotlinResult.id) + + extractExprContext(id, tw.getLocation(expression), callable, exprParent.enclosingStmt) + extractExpressionExpr(expression.left!!, callable, id, 0, exprParent.enclosingStmt) + extractExpressionExpr(expression.right!!, callable, id, 1, exprParent.enclosingStmt) + } else { + TODO() + } + } + + else -> TODO() + } + +} + context(KaSession) private fun KotlinFileExtractor.extractExpression( e: KtExpression, @@ -225,6 +301,10 @@ private fun KotlinFileExtractor.extractExpression( extractExpression(e.baseExpression!!, callable, parent) } + is KtBinaryExpression -> { + extractBinaryExpression(e, callable, parent) + } + is KtIsExpression -> { val locId = tw.getLocation(e)