diff --git a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt index be8516087be5..0d9ba438a276 100644 --- a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt +++ b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt @@ -3002,14 +3002,7 @@ OLD: KE1 tw.writeExprsKotlinType(id, type.kotlinResult.id) binOp(id, dr, callable, enclosingStmt) } - isFunction(target, "kotlin", "Boolean", "not") -> { - val id = tw.getFreshIdLabel() - val type = useType(c.type) - tw.writeExprs_lognotexpr(id, type.javaResult.id, parent, idx) - tw.writeExprsKotlinType(id, type.kotlinResult.id) - unaryopDisp(id) - } - isNumericFunction(target, "inv", "unaryMinus", "unaryPlus") -> { + isNumericFunction(target, "inv") -> { val type = useType(c.type) val id: Label = when (val targetName = target.name.asString()) { @@ -3018,16 +3011,6 @@ OLD: KE1 tw.writeExprs_bitnotexpr(id, type.javaResult.id, parent, idx) id } - "unaryMinus" -> { - val id = tw.getFreshIdLabel() - tw.writeExprs_minusexpr(id, type.javaResult.id, parent, idx) - id - } - "unaryPlus" -> { - val id = tw.getFreshIdLabel() - tw.writeExprs_plusexpr(id, type.javaResult.id, parent, idx) - id - } else -> { logger.errorElement("Unhandled unary target name: $targetName", c) return diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt index dd62df86d98c..3cc081261a7e 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt @@ -3,6 +3,7 @@ package com.github.codeql import com.github.codeql.KotlinFileExtractor.StmtExprParent import org.jetbrains.kotlin.KtNodeTypes import org.jetbrains.kotlin.analysis.api.KaSession +import org.jetbrains.kotlin.analysis.api.resolution.KaCompoundAccessCall import org.jetbrains.kotlin.analysis.api.resolution.KaSimpleFunctionCall import org.jetbrains.kotlin.analysis.api.resolution.KaSuccessCallInfo import org.jetbrains.kotlin.analysis.api.resolution.symbol @@ -10,6 +11,7 @@ import org.jetbrains.kotlin.analysis.api.symbols.KaFunctionSymbol import org.jetbrains.kotlin.analysis.api.types.KaType import org.jetbrains.kotlin.analysis.api.types.KaTypeNullability import org.jetbrains.kotlin.analysis.api.types.symbol +import org.jetbrains.kotlin.lexer.KtToken import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.name.CallableId import org.jetbrains.kotlin.name.ClassId @@ -241,13 +243,12 @@ private fun KaFunctionSymbol.hasMatchingNames( nullability == null } -private fun KaFunctionSymbol.hasName( +private fun KaFunctionSymbol?.hasName( packageName: String, className: String?, functionName: String ): Boolean { - - return this.hasMatchingNames( + return this != null && this.hasMatchingNames( CallableId( FqName(packageName), if (className == null) null else FqName(className), @@ -257,13 +258,51 @@ private fun KaFunctionSymbol.hasName( } private fun KaFunctionSymbol?.isNumericWithName(functionName: String): Boolean { - return this != null && - (this.hasName("kotlin", "Int", functionName) || - this.hasName("kotlin", "Byte", functionName) || - this.hasName("kotlin", "Short", functionName) || - this.hasName("kotlin", "Long", functionName) || - this.hasName("kotlin", "Float", functionName) || - this.hasName("kotlin", "Double", functionName)) + return this.hasName("kotlin", "Int", functionName) || + this.hasName("kotlin", "Byte", functionName) || + this.hasName("kotlin", "Short", functionName) || + this.hasName("kotlin", "Long", functionName) || + this.hasName("kotlin", "Float", functionName) || + this.hasName("kotlin", "Double", functionName) +} + +context(KaSession) +private fun KotlinFileExtractor.extractPrefixUnaryExpression( + expression: KtPrefixExpression, + callable: Label, + parent: StmtExprParent +) { + val op = expression.operationToken as? KtToken + val target = ((expression.resolveToCall() as? KaSuccessCallInfo)?.call as? KaSimpleFunctionCall)?.symbol + + if (op == KtTokens.PLUS && target.isNumericWithName("unaryPlus")) { + extractUnaryExpression(expression, callable, parent, tw::writeExprs_plusexpr) + } else if (op == KtTokens.MINUS && target.isNumericWithName("unaryMinus")) { + extractUnaryExpression(expression, callable, parent, tw::writeExprs_minusexpr) + } else if (op == KtTokens.EXCL && target.hasName("kotlin", "Boolean", "not")) { + extractUnaryExpression(expression, callable, parent, tw::writeExprs_lognotexpr) + } else { + TODO("Extract as method call") + } +} + +context(KaSession) +private fun KotlinFileExtractor.extractPostfixUnaryExpression( + expression: KtPostfixExpression, + callable: Label, + parent: StmtExprParent +) { + val op = expression.operationToken as? KtToken + val target = + ((expression.resolveToCall() as? KaSuccessCallInfo)?.call as? KaCompoundAccessCall)?.compoundAccess?.operationPartiallyAppliedSymbol?.symbol + + if (op == KtTokens.PLUSPLUS && target.isNumericWithName("inc")) { + extractUnaryExpression(expression, callable, parent, tw::writeExprs_postincexpr) + } else if (op == KtTokens.MINUSMINUS && target.isNumericWithName("dec")) { + extractUnaryExpression(expression, callable, parent, tw::writeExprs_postdecexpr) + } else { + TODO("Extract as method call") + } } context(KaSession) @@ -378,6 +417,29 @@ private fun KotlinFileExtractor.extractBinaryExpression( extractExpressionExpr(expression.right!!, callable, id, 1, exprParent.enclosingStmt) } +context(KaSession) +private fun KotlinFileExtractor.extractUnaryExpression( + expression: KtUnaryExpression, + callable: Label, + parent: StmtExprParent, + extractExpression: ( + id: Label, + typeid: Label, + parent: Label, + idx: Int + ) -> Unit +) { + val id = tw.getFreshIdLabel() + val type = useType(expression.expressionType) + val exprParent = parent.expr(expression, callable) + extractExpression(id, type.javaResult.id, exprParent.parent, exprParent.idx) + tw.writeExprsKotlinType(id, type.kotlinResult.id) + + extractExprContext(id, tw.getLocation(expression), callable, exprParent.enclosingStmt) + extractExpressionExpr(expression.baseExpression!!, callable, id, 0, exprParent.enclosingStmt) +} + + context(KaSession) private fun KotlinFileExtractor.extractExpression( e: KtExpression, @@ -399,6 +461,14 @@ private fun KotlinFileExtractor.extractExpression( extractExpression(e.selectorExpression!!, callable, parent) } + is KtPrefixExpression -> { + extractPrefixUnaryExpression(e, callable, parent) + } + + is KtPostfixExpression -> { + extractPostfixUnaryExpression(e, callable, parent) + } + is KtBinaryExpression -> { extractBinaryExpression(e, callable, parent) }