Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize resolving pydantic class #658

Merged
merged 13 commits into from
Mar 1, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased
- Fix wrong inspections when a model has a __call__ method [[#655](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/655)]
- Reduce unnecessary resolve in type providers [[#656](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/656)]
- Optimize resolving pydantic class [[#658](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/658)]

## 0.3.17 - 2022-12-16
- Support Union operator [[#602](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/602)]
Expand Down
73 changes: 43 additions & 30 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ val CUSTOM_BASE_MODEL_Q_NAMES = listOf(
val CUSTOM_MODEL_FIELD_Q_NAMES = listOf(
SQL_MODEL_FIELD_Q_NAME
)

val DATA_CLASS_Q_NAMES = listOf(DATA_CLASS_Q_NAME, DATA_CLASS_SHORT_Q_NAME)

val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME)
Expand Down Expand Up @@ -135,27 +138,14 @@ const val CUSTOM_ROOT_FIELD = "__root__"

fun PyTypedElement.getType(context: TypeEvalContext): PyType? = context.getType(this)

fun getPyClassByPyCallExpression(
pyCallExpression: PyCallExpression,

fun getPydanticModelByPyKeywordArgument(
pyKeywordArgument: PyKeywordArgument,
includeDataclass: Boolean,
context: TypeEvalContext,
): PyClass? {
val callee = pyCallExpression.callee ?: return null
val pyType = when (val type = callee.getType(context)) {
is PyClass -> return type
is PyClassType -> type
else -> (callee.reference?.resolve() as? PyTypedElement)?.getType(context) ?: return null
}
return pyType.pyClassTypes.firstOrNull {
isPydanticModel(it.pyClass,
includeDataclass,
context)
}?.pyClass
}

fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression, true, context)
return getPydanticPyClass(pyCallExpression, context, includeDataclass)
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
Expand Down Expand Up @@ -228,6 +218,7 @@ internal val PyClass.isConfigClass: Boolean get() = name == "Config"

internal val PyFunction.isConStr: Boolean get() = qualifiedName == CON_STR_Q_NAME

internal val PyFunction.isPydanticDataclass: Boolean get() = qualifiedName in DATA_CLASS_Q_NAMES
internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression): Boolean {
val pyKeywordArgument = stringLiteralExpression.parent as? PyKeywordArgument ?: return false
if (pyKeywordArgument.keyword != "regex") return false
Expand Down Expand Up @@ -270,14 +261,14 @@ private fun getAliasedFieldName(

fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List<PsiElement> {
return RecursionManager.doPreventingRecursion(
Pair.create<PsiElement, TypeEvalContext>(
Pair.create(
referenceExpression,
context
), false
) {
PyUtil.multiResolveTopPriority(
referenceExpression,
PyResolveContext.defaultContext(context)
val resolveContext = PyResolveContext.defaultContext(context)
PyUtil.filterTopPriorityResults(
referenceExpression.getReference(resolveContext).multiResolve(false)
)
} ?: emptyList()
}
Expand Down Expand Up @@ -494,6 +485,9 @@ fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? {
return pyPsiElement?.parent?.parent as? PyClass
}

fun getPydanticModelByAttribute(pyPsiElement: PsiElement?, includeDataclass: Boolean, context: TypeEvalContext): PyClass? =
getPyClassByAttribute(pyPsiElement)?.takeIf { isPydanticModel(it, includeDataclass, context) }

fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: TypeEvalContext): PyClassTypeImpl? {
var psiElement = getPsiElementByQualifiedName(QualifiedName.fromDottedString(qualifiedName), project, context)
if (psiElement == null) {
Expand All @@ -504,11 +498,13 @@ fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: Type
return PyClassTypeImpl.createTypeByQName(psiElement, qualifiedName, false)
}

fun getPydanticPyClass(pyCallExpression: PyCallExpression, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? {
val pyClass = getPyClassByPyCallExpression(pyCallExpression, includeDataclass, context) ?: return null
if (!isPydanticModel(pyClass, includeDataclass, context)) return null
return pyClass
}
fun getPydanticPyClass(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? =
getPydanticPyClassType(pyTypedElement, context, includeDataclass)?.pyClass

fun getPydanticPyClassType(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClassType? =
context.getType(pyTypedElement)?.pyClassTypes?.firstOrNull {
isPydanticModel(it.pyClass, includeDataclass, context)
}

fun getAncestorPydanticModels(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): List<PyClass> {
return pyClass.getAncestorClasses(context).filter { isPydanticModel(it, includeDataclass, context) }
Expand All @@ -535,15 +531,27 @@ fun addKeywordArgument(pyCallExpression: PyCallExpression, pyKeywordArgument: Py
}
}

val PyExpression.isKeywordArgument: Boolean get() =
this is PyKeywordArgument || (this as? PyStarArgument)?.isKeyword == true

fun getPydanticUnFilledArguments(
pydanticType: PyCallableType,
pyCallExpression: PyCallExpression,
context: TypeEvalContext,
isDataClass: Boolean
): List<PyCallableParameter> {
val currentArguments =
pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgument)?.isKeyword == true }
.mapNotNull { it.name }.toSet()
return pydanticType.getParameters(context)?.filterNot { currentArguments.contains(it.name) } ?: emptyList()
val parameters = pydanticType.getParameters(context)?.let { allParameters ->
if (isDataClass) {
pyCallExpression.arguments
.filterNot { it.isKeywordArgument }
.let { allParameters.drop(it.size) }
} else {
allParameters
}
} ?: listOf()

val currentArguments = pyCallExpression.arguments.filter { it.isKeywordArgument }.mapNotNull { it.name }.toSet()
return parameters.filterNot { currentArguments.contains(it.name) }
}

val PyCallableParameter.required: Boolean
Expand Down Expand Up @@ -659,3 +667,8 @@ fun getPydanticModelInit(pyClass: PyClass, context: TypeEvalContext): PyFunction

fun PyCallExpression.isDefinitionCallExpression(context: TypeEvalContext): Boolean =
this.callee?.reference?.resolve()?.let { it as? PyClass }?.getType(context)?.isDefinition == true

fun PyCallExpression.getPyCallableType(context: TypeEvalContext): PyCallableType? =
this.callee?.getType(context) as? PyCallableType
fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEvalContext): PyClass? =
this.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass?.takeIf { isPydanticModel(it,includeDataclass, context) }
12 changes: 6 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticAnnotator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ import com.intellij.openapi.util.TextRange
import com.intellij.util.containers.nullize
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyStarArgument
import com.jetbrains.python.psi.types.PyCallableType
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.validation.PyAnnotator


class PydanticAnnotator : PyAnnotator() {
private val pydanticTypeProvider = PydanticTypeProvider()
override fun visitPyCallExpression(node: PyCallExpression) {
super.visitPyCallExpression(node)
annotatePydanticModelCallableExpression(node)
}

private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) {
val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile)
if (!pyCallExpression.isDefinitionCallExpression(context)) return

val pyClass = getPydanticPyClass(pyCallExpression, context) ?: return
val pyClassType = pyCallExpression.getPyCallableType(context) ?: return
val pyClass = pyClassType.getPydanticModel(true, context) ?: return
if (!isPydanticModel(pyClass, true, context)) return
if (getPydanticModelInit(pyClass, context) != null) return
val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, pyCallExpression) ?: return
if (!pyCallExpression.isDefinitionCallExpression(context)) return

val unFilledArguments =
getPydanticUnFilledArguments(pydanticType, pyCallExpression, context).nullize()
getPydanticUnFilledArguments(pyClassType, pyCallExpression, context, pyClass.isPydanticDataclass).nullize()
?: return
holder.newSilentAnnotation(HighlightSeverity.INFORMATION).withFix(PydanticInsertArgumentsQuickFix(false))
.create()
Expand Down
13 changes: 4 additions & 9 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,9 @@ class PydanticCompletionContributor : CompletionContributor() {
val typeEvalContext = parameters.getTypeEvalContext()
val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return

val pyType = typeEvalContext.getType(pyTypedElement) ?: return

val pyClassType =
pyType.pyClassTypes.firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) }
?: return
val pyClassType = getPydanticPyClassType(pyTypedElement, typeEvalContext, true) ?: return

val pyClass = pyClassType.pyClass
val config = getConfig(pyClass, typeEvalContext, true)
if (pyClassType.isDefinition) { // class
Expand Down Expand Up @@ -377,9 +375,8 @@ class PydanticCompletionContributor : CompletionContributor() {
) {
val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!configClass.isConfigClass) return
val pydanticModel = getPyClassByAttribute(configClass) ?: return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
if (getPydanticModelByAttribute(configClass,true, parameters.getTypeEvalContext()) == null) return


val definedSet = configClass.classAttributes
Expand All @@ -404,9 +401,7 @@ class PydanticCompletionContributor : CompletionContributor() {
context: ProcessingContext,
result: CompletionResultSet,
) {
val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, parameters.getTypeEvalContext()) ?: return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
Expand Down
49 changes: 6 additions & 43 deletions src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionImpl
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.types.*

/**
Expand All @@ -20,18 +18,13 @@ import com.jetbrains.python.psi.types.*
*/
class PydanticDataclassTypeProvider : PyTypeProviderBase() {
private val pyDataclassTypeProvider = PyDataclassTypeProvider()
private val pydanticTypeProvider = PydanticTypeProvider()

override fun getReferenceType(
referenceTarget: PsiElement,
context: TypeEvalContext,
anchor: PsiElement?
): Ref<PyType>? {
return when {
referenceTarget is PyClass && referenceTarget.isPydanticDataclass ->
getPydanticDataclassType(referenceTarget, context, anchor as? PyCallExpression, true)
else ->null
}?.let { Ref.create(it) }
override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? {
if (callable is PyFunction && callable.isPydanticDataclass) {
// Drop fake dataclass return type
return PyCallableTypeImpl(callable.getParameters(context), null)
}
return super.getCallableType(callable, context)
}

internal fun getDataclassCallableType(
Expand All @@ -45,34 +38,4 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {
callSite ?: PyCallExpressionImpl(referenceTarget.node)
)?.get() as? PyCallableType
}

private fun getPydanticDataclassType(
referenceTarget: PsiElement,
context: TypeEvalContext,
callSite: PyCallExpression?,
definition: Boolean,
): PyType? {
val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null

val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null
if (!dataclassType.pyClass.isPydanticDataclass) return null
val ellipsis = PyElementGenerator.getInstance(referenceTarget.project).createEllipsis()
val injectedPyCallableType = PyCallableTypeImpl(
dataclassCallableType.getParameters(context)?.map {
when {
it.defaultValueText == "..." && it.defaultValue is PyNoneLiteralExpression ->
pydanticTypeProvider.injectDefaultValue(dataclassType.pyClass, it, ellipsis, null, context)
?: it

else -> it
}
}, dataclassType
)
val injectedDataclassType = (injectedPyCallableType).getReturnType(context) as? PyClassType ?: return null
return when {
callSite is PyCallExpression && definition -> injectedPyCallableType
definition -> injectedDataclassType.toClass()
else -> injectedDataclassType
}
}
}
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
}
is PyKeywordArgument -> {
val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false
if (isPydanticModel(pyClass, true, context)) return true
return getPydanticModelByPyKeywordArgument(element, true,context) is PyClass
//
}
}
return false
Expand Down Expand Up @@ -64,7 +64,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
is PyKeywordArgument ->
element.name?.let { name ->
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
getPydanticModelByPyKeywordArgument(element, true,context)
?.let { pyClass ->
addAllElement(pyClass, name, added, context)
}
Expand Down
3 changes: 1 addition & 2 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
element.name
?.let { elementName ->
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) }
getPydanticModelByPyKeywordArgument(element, true,context)
?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) }
}
}
Expand Down
12 changes: 4 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.intellij.psi.PsiFile
import com.intellij.util.IncorrectOperationException
import com.intellij.util.containers.nullize
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.types.PyCallableParameter
import com.jetbrains.python.psi.types.TypeEvalContext

class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : LocalQuickFix, IntentionAction,
Expand Down Expand Up @@ -48,14 +47,11 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local
if (originalElement !is PyCallExpression) return null
if (file !is PyFile) return null
val newEl = originalElement.copy() as PyCallExpression
val pyClass = getPydanticPyClass(originalElement, context, true) ?: return null
val pydanticType = if (pyClass.isPydanticDataclass) {
pydanticDataclassTypeProvider.getDataclassCallableType(pyClass, context, originalElement)
} else {
pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, originalElement) ?: return null
} ?: return null
val pyCallableType = originalElement.getPyCallableType(context) ?: return null
val pyClass = pyCallableType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return null
if (!isPydanticModel(pyClass, true, context)) return null
val unFilledArguments =
getPydanticUnFilledArguments(pydanticType, originalElement, context).let {
getPydanticUnFilledArguments(pyCallableType, originalElement, context, pyClass.isPydanticDataclass).let {
when {
onlyRequired -> it.filter { arguments -> arguments.required }
else -> it
Expand Down
Loading