diff --git a/CHANGELOG.md b/CHANGELOG.md index e7ebee91..a5eef50a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased - Fix wrong inspections when a model has a __call__ method [[#655](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/655)] - Reduce unnecessary resolve in type providers [[#656](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/656)] +- Optimize resolving pydantic class [[#658](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/658)] ## 0.3.17 - 2022-12-16 - Support Union operator [[#602](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/602)] diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index f450af00..9038b6d3 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -68,6 +68,9 @@ val CUSTOM_BASE_MODEL_Q_NAMES = listOf( val CUSTOM_MODEL_FIELD_Q_NAMES = listOf( SQL_MODEL_FIELD_Q_NAME ) + +val DATA_CLASS_Q_NAMES = listOf(DATA_CLASS_Q_NAME, DATA_CLASS_SHORT_Q_NAME) + val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME) val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME) @@ -135,27 +138,14 @@ const val CUSTOM_ROOT_FIELD = "__root__" fun PyTypedElement.getType(context: TypeEvalContext): PyType? = context.getType(this) -fun getPyClassByPyCallExpression( - pyCallExpression: PyCallExpression, + +fun getPydanticModelByPyKeywordArgument( + pyKeywordArgument: PyKeywordArgument, includeDataclass: Boolean, context: TypeEvalContext, ): PyClass? { - val callee = pyCallExpression.callee ?: return null - val pyType = when (val type = callee.getType(context)) { - is PyClass -> return type - is PyClassType -> type - else -> (callee.reference?.resolve() as? PyTypedElement)?.getType(context) ?: return null - } - return pyType.pyClassTypes.firstOrNull { - isPydanticModel(it.pyClass, - includeDataclass, - context) - }?.pyClass -} - -fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? { val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null - return getPyClassByPyCallExpression(pyCallExpression, true, context) + return getPydanticPyClass(pyCallExpression, context, includeDataclass) } fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean { @@ -228,6 +218,7 @@ internal val PyClass.isConfigClass: Boolean get() = name == "Config" internal val PyFunction.isConStr: Boolean get() = qualifiedName == CON_STR_Q_NAME +internal val PyFunction.isPydanticDataclass: Boolean get() = qualifiedName in DATA_CLASS_Q_NAMES internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression): Boolean { val pyKeywordArgument = stringLiteralExpression.parent as? PyKeywordArgument ?: return false if (pyKeywordArgument.keyword != "regex") return false @@ -270,14 +261,14 @@ private fun getAliasedFieldName( fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List { return RecursionManager.doPreventingRecursion( - Pair.create( + Pair.create( referenceExpression, context ), false ) { - PyUtil.multiResolveTopPriority( - referenceExpression, - PyResolveContext.defaultContext(context) + val resolveContext = PyResolveContext.defaultContext(context) + PyUtil.filterTopPriorityResults( + referenceExpression.getReference(resolveContext).multiResolve(false) ) } ?: emptyList() } @@ -494,6 +485,9 @@ fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? { return pyPsiElement?.parent?.parent as? PyClass } +fun getPydanticModelByAttribute(pyPsiElement: PsiElement?, includeDataclass: Boolean, context: TypeEvalContext): PyClass? = + getPyClassByAttribute(pyPsiElement)?.takeIf { isPydanticModel(it, includeDataclass, context) } + fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: TypeEvalContext): PyClassTypeImpl? { var psiElement = getPsiElementByQualifiedName(QualifiedName.fromDottedString(qualifiedName), project, context) if (psiElement == null) { @@ -504,11 +498,13 @@ fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: Type return PyClassTypeImpl.createTypeByQName(psiElement, qualifiedName, false) } -fun getPydanticPyClass(pyCallExpression: PyCallExpression, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? { - val pyClass = getPyClassByPyCallExpression(pyCallExpression, includeDataclass, context) ?: return null - if (!isPydanticModel(pyClass, includeDataclass, context)) return null - return pyClass -} +fun getPydanticPyClass(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? = + getPydanticPyClassType(pyTypedElement, context, includeDataclass)?.pyClass + +fun getPydanticPyClassType(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClassType? = + context.getType(pyTypedElement)?.pyClassTypes?.firstOrNull { + isPydanticModel(it.pyClass, includeDataclass, context) + } fun getAncestorPydanticModels(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): List { return pyClass.getAncestorClasses(context).filter { isPydanticModel(it, includeDataclass, context) } @@ -535,15 +531,27 @@ fun addKeywordArgument(pyCallExpression: PyCallExpression, pyKeywordArgument: Py } } +val PyExpression.isKeywordArgument: Boolean get() = + this is PyKeywordArgument || (this as? PyStarArgument)?.isKeyword == true + fun getPydanticUnFilledArguments( pydanticType: PyCallableType, pyCallExpression: PyCallExpression, context: TypeEvalContext, + isDataClass: Boolean ): List { - val currentArguments = - pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgument)?.isKeyword == true } - .mapNotNull { it.name }.toSet() - return pydanticType.getParameters(context)?.filterNot { currentArguments.contains(it.name) } ?: emptyList() + val parameters = pydanticType.getParameters(context)?.let { allParameters -> + if (isDataClass) { + pyCallExpression.arguments + .filterNot { it.isKeywordArgument } + .let { allParameters.drop(it.size) } + } else { + allParameters + } + } ?: listOf() + + val currentArguments = pyCallExpression.arguments.filter { it.isKeywordArgument }.mapNotNull { it.name }.toSet() + return parameters.filterNot { currentArguments.contains(it.name) } } val PyCallableParameter.required: Boolean @@ -659,3 +667,8 @@ fun getPydanticModelInit(pyClass: PyClass, context: TypeEvalContext): PyFunction fun PyCallExpression.isDefinitionCallExpression(context: TypeEvalContext): Boolean = this.callee?.reference?.resolve()?.let { it as? PyClass }?.getType(context)?.isDefinition == true + +fun PyCallExpression.getPyCallableType(context: TypeEvalContext): PyCallableType? = + this.callee?.getType(context) as? PyCallableType +fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEvalContext): PyClass? = + this.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass?.takeIf { isPydanticModel(it,includeDataclass, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt index cf80b4c0..11d22421 100644 --- a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt +++ b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.util.TextRange import com.intellij.util.containers.nullize import com.jetbrains.python.psi.PyCallExpression import com.jetbrains.python.psi.PyStarArgument +import com.jetbrains.python.psi.types.PyCallableType import com.jetbrains.python.psi.types.TypeEvalContext import com.jetbrains.python.validation.PyAnnotator class PydanticAnnotator : PyAnnotator() { - private val pydanticTypeProvider = PydanticTypeProvider() override fun visitPyCallExpression(node: PyCallExpression) { super.visitPyCallExpression(node) annotatePydanticModelCallableExpression(node) @@ -19,14 +19,14 @@ class PydanticAnnotator : PyAnnotator() { private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) { val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile) - if (!pyCallExpression.isDefinitionCallExpression(context)) return - - val pyClass = getPydanticPyClass(pyCallExpression, context) ?: return + val pyClassType = pyCallExpression.getPyCallableType(context) ?: return + val pyClass = pyClassType.getPydanticModel(true, context) ?: return + if (!isPydanticModel(pyClass, true, context)) return if (getPydanticModelInit(pyClass, context) != null) return - val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, pyCallExpression) ?: return + if (!pyCallExpression.isDefinitionCallExpression(context)) return val unFilledArguments = - getPydanticUnFilledArguments(pydanticType, pyCallExpression, context).nullize() + getPydanticUnFilledArguments(pyClassType, pyCallExpression, context, pyClass.isPydanticDataclass).nullize() ?: return holder.newSilentAnnotation(HighlightSeverity.INFORMATION).withFix(PydanticInsertArgumentsQuickFix(false)) .create() diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 72d15897..7e860959 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -313,11 +313,9 @@ class PydanticCompletionContributor : CompletionContributor() { val typeEvalContext = parameters.getTypeEvalContext() val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return - val pyType = typeEvalContext.getType(pyTypedElement) ?: return - val pyClassType = - pyType.pyClassTypes.firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) } - ?: return + val pyClassType = getPydanticPyClassType(pyTypedElement, typeEvalContext, true) ?: return + val pyClass = pyClassType.pyClass val config = getConfig(pyClass, typeEvalContext, true) if (pyClassType.isDefinition) { // class @@ -377,9 +375,8 @@ class PydanticCompletionContributor : CompletionContributor() { ) { val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return if (!configClass.isConfigClass) return - val pydanticModel = getPyClassByAttribute(configClass) ?: return val typeEvalContext = parameters.getTypeEvalContext() - if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return + if (getPydanticModelByAttribute(configClass,true, parameters.getTypeEvalContext()) == null) return val definedSet = configClass.classAttributes @@ -404,9 +401,7 @@ class PydanticCompletionContributor : CompletionContributor() { context: ProcessingContext, result: CompletionResultSet, ) { - val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return - val typeEvalContext = parameters.getTypeEvalContext() - if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return + val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, parameters.getTypeEvalContext()) ?: return if (pydanticModel.findNestedClass("Config", false) != null) return val element = PrioritizedLookupElement.withGrouping( LookupElementBuilder diff --git a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt index c9faa09b..fedf8e38 100644 --- a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt @@ -1,11 +1,9 @@ package com.koxudaxi.pydantic -import com.intellij.openapi.util.Ref import com.intellij.psi.PsiElement import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyCallExpressionImpl -import com.jetbrains.python.psi.impl.PyCallExpressionNavigator import com.jetbrains.python.psi.types.* /** @@ -20,18 +18,13 @@ import com.jetbrains.python.psi.types.* */ class PydanticDataclassTypeProvider : PyTypeProviderBase() { private val pyDataclassTypeProvider = PyDataclassTypeProvider() - private val pydanticTypeProvider = PydanticTypeProvider() - override fun getReferenceType( - referenceTarget: PsiElement, - context: TypeEvalContext, - anchor: PsiElement? - ): Ref? { - return when { - referenceTarget is PyClass && referenceTarget.isPydanticDataclass -> - getPydanticDataclassType(referenceTarget, context, anchor as? PyCallExpression, true) - else ->null - }?.let { Ref.create(it) } + override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? { + if (callable is PyFunction && callable.isPydanticDataclass) { + // Drop fake dataclass return type + return PyCallableTypeImpl(callable.getParameters(context), null) + } + return super.getCallableType(callable, context) } internal fun getDataclassCallableType( @@ -45,34 +38,4 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { callSite ?: PyCallExpressionImpl(referenceTarget.node) )?.get() as? PyCallableType } - - private fun getPydanticDataclassType( - referenceTarget: PsiElement, - context: TypeEvalContext, - callSite: PyCallExpression?, - definition: Boolean, - ): PyType? { - val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null - - val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null - if (!dataclassType.pyClass.isPydanticDataclass) return null - val ellipsis = PyElementGenerator.getInstance(referenceTarget.project).createEllipsis() - val injectedPyCallableType = PyCallableTypeImpl( - dataclassCallableType.getParameters(context)?.map { - when { - it.defaultValueText == "..." && it.defaultValue is PyNoneLiteralExpression -> - pydanticTypeProvider.injectDefaultValue(dataclassType.pyClass, it, ellipsis, null, context) - ?: it - - else -> it - } - }, dataclassType - ) - val injectedDataclassType = (injectedPyCallableType).getReturnType(context) as? PyClassType ?: return null - return when { - callSite is PyCallExpression && definition -> injectedPyCallableType - definition -> injectedDataclassType.toClass() - else -> injectedDataclassType - } - } } diff --git a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt index 947d327e..83a50fe6 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt @@ -25,8 +25,8 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { } is PyKeywordArgument -> { val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile) - val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false - if (isPydanticModel(pyClass, true, context)) return true + return getPydanticModelByPyKeywordArgument(element, true,context) is PyClass +// } } return false @@ -64,7 +64,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { is PyKeywordArgument -> element.name?.let { name -> val context = TypeEvalContext.userInitiated(element.project, element.containingFile) - getPyClassByPyKeywordArgument(element, context) + getPydanticModelByPyKeywordArgument(element, true,context) ?.let { pyClass -> addAllElement(pyClass, name, added, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt index 1cca22ae..770590a7 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt @@ -22,8 +22,7 @@ class PydanticFieldSearchExecutor : QueryExecutorBase val context = TypeEvalContext.userInitiated(element.project, element.containingFile) - getPyClassByPyKeywordArgument(element, context) - ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } + getPydanticModelByPyKeywordArgument(element, true,context) ?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) } } } diff --git a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt index ae173d42..95c1270a 100644 --- a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt +++ b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt @@ -12,7 +12,6 @@ import com.intellij.psi.PsiFile import com.intellij.util.IncorrectOperationException import com.intellij.util.containers.nullize import com.jetbrains.python.psi.* -import com.jetbrains.python.psi.types.PyCallableParameter import com.jetbrains.python.psi.types.TypeEvalContext class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : LocalQuickFix, IntentionAction, @@ -48,14 +47,11 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local if (originalElement !is PyCallExpression) return null if (file !is PyFile) return null val newEl = originalElement.copy() as PyCallExpression - val pyClass = getPydanticPyClass(originalElement, context, true) ?: return null - val pydanticType = if (pyClass.isPydanticDataclass) { - pydanticDataclassTypeProvider.getDataclassCallableType(pyClass, context, originalElement) - } else { - pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, originalElement) ?: return null - } ?: return null + val pyCallableType = originalElement.getPyCallableType(context) ?: return null + val pyClass = pyCallableType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return null + if (!isPydanticModel(pyClass, true, context)) return null val unFilledArguments = - getPydanticUnFilledArguments(pydanticType, originalElement, context).let { + getPydanticUnFilledArguments(pyCallableType, originalElement, context, pyClass.isPydanticDataclass).let { when { onlyRequired -> it.filter { arguments -> arguments.required } else -> it diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt index 71956a80..4d75cdfb 100644 --- a/src/com/koxudaxi/pydantic/PydanticInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt @@ -35,8 +35,8 @@ class PydanticInspection : PyInspection() { override fun visitPyFunction(node: PyFunction) { super.visitPyFunction(node) - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext) || !node.isValidatorMethod) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return + if (!node.isValidatorMethod) return val paramList = node.parameterList val params = paramList.parameters val firstParam = params.firstOrNull() @@ -188,23 +188,9 @@ class PydanticInspection : PyInspection() { val resolveContext = PyResolveContext.defaultContext(myTypeEvalContext) val pyCallable = pyCallExpression.multiResolveCalleeFunction(resolveContext).firstOrNull() ?: return if (pyCallable.asMethod()?.qualifiedName != "pydantic.main.BaseModel.from_orm") return - val type = - (pyCallExpression.node?.firstChildNode?.firstChildNode?.psi as? PyTypedElement)?.getType( - myTypeEvalContext - ) - ?: return - val pyClass = when (type) { - is PyClass -> type - is PyClassType -> type.pyClassTypes.firstOrNull { - isPydanticModel( - it.pyClass, - false, myTypeEvalContext - ) - }?.pyClass + val typedElement = pyCallExpression.node?.firstChildNode?.firstChildNode?.psi as? PyTypedElement ?: return + val pyClass = getPydanticPyClass(typedElement, myTypeEvalContext, false) ?: return - else -> null - } ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return val config = getConfig(pyClass, myTypeEvalContext, true) if (config["orm_mode"] != true) { registerProblem( @@ -229,11 +215,11 @@ class PydanticInspection : PyInspection() { } private fun inspectReadOnlyProperty(node: PyAssignmentStatement) { - val pyType = - (node.leftHandSideExpression?.firstChild as? PyTypedElement)?.getType(myTypeEvalContext) ?: return - if ((pyType as? PyClassTypeImpl)?.isDefinition == true) return - val pyClass = pyType.pyClassTypes.firstOrNull()?.pyClass ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return + val pyTypedElement = + node.leftHandSideExpression?.firstChild as? PyTypedElement ?: return + val pyClassType = getPydanticPyClassType(pyTypedElement, myTypeEvalContext, false) ?: return + if (pyClassType.isDefinition) return + val pyClass = pyClassType.pyClass val attributeName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.name ?: return val config = getConfig(pyClass, myTypeEvalContext, true) val version = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext) @@ -247,8 +233,7 @@ class PydanticInspection : PyInspection() { } private fun inspectWarnUntypedFields(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return if (node.annotation != null) return if ((node.leftHandSideExpression as? PyTargetExpressionImpl)?.text?.isValidFieldName != true) return registerProblem( @@ -259,8 +244,8 @@ class PydanticInspection : PyInspection() { } private fun inspectCustomRootField(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return + val pyClass = getPydanticModelByAttribute(node, false, myTypeEvalContext) ?: return + val fieldName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.text ?: return if (fieldName.startsWith('_')) return val rootModel = pyClass.findClassAttribute("__root__", true, myTypeEvalContext)?.containingClass ?: return @@ -302,8 +287,8 @@ class PydanticInspection : PyInspection() { } private fun inspectAnnotatedAssignedField(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return + val fieldName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.text ?: return val assignedValue = node.assignedValue diff --git a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt index 81f285bc..a38c6291 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt @@ -37,9 +37,8 @@ class PydanticTypeCheckerInspection : PyTypeCheckerInspection() { private val pydanticConfigService = PydanticConfigService.getInstance(holder!!.project) override fun visitPyCallExpression(node: PyCallExpression) { - val pyClass = getPyClassByPyCallExpression(node, true, myTypeEvalContext) - getPyClassByPyCallExpression(node, true, myTypeEvalContext) - if (pyClass is PyClass && isPydanticModel(pyClass, true, myTypeEvalContext)) { + val pyClass = getPydanticPyClass(node, myTypeEvalContext, true) + if (pyClass is PyClass) { checkCallSiteForPydantic(node) return } diff --git a/testData/inspection/acceptsOnlyKeywordArguments.py b/testData/inspection/acceptsOnlyKeywordArguments.py index a284f1ba..3642feb7 100644 --- a/testData/inspection/acceptsOnlyKeywordArguments.py +++ b/testData/inspection/acceptsOnlyKeywordArguments.py @@ -22,3 +22,17 @@ def __call__(self, *args, **kwargs): c = C(a='abc') c('a') + +@dataclass +class D(): + a: str + b: str + + +D('a') + + +class E(BaseModel): + pass + +E() diff --git a/testData/typeinspectionv18/dataclass.py b/testData/typeinspectionv18/dataclass.py index 994f6648..4dbfe83d 100644 --- a/testData/typeinspectionv18/dataclass.py +++ b/testData/typeinspectionv18/dataclass.py @@ -31,23 +31,23 @@ class ChildDataclass(MyDataclass): ChildDataclass(a=2, b='orange', c=4, d='cherry') -a: MyDataclass = MyDataclass() +a: MyDataclass = MyDataclass() b: Type[MyDataclass] = MyDataclass c: MyDataclass = MyDataclass -d: Type[MyDataclass] = MyDataclass() +d: Type[MyDataclass] = MyDataclass() -aa: Union[str, MyDataclass] = MyDataclass() +aa: Union[str, MyDataclass] = MyDataclass() bb: Union[str, Type[MyDataclass]] = MyDataclass cc: Union[str, MyDataclass] = MyDataclass -dd: Union[str, Type[MyDataclass]] = MyDataclass() +dd: Union[str, Type[MyDataclass]] = MyDataclass() -aaa: ChildDataclass = ChildDataclass() +aaa: ChildDataclass = ChildDataclass() bbb: Type[ChildDataclass] = ChildDataclass ccc: ChildDataclass = ChildDataclass -ddd: Type[ChildDataclass] = ChildDataclass() +ddd: Type[ChildDataclass] = ChildDataclass() e: str = MyDataclass(a='apple', b=1).a @@ -79,7 +79,7 @@ class ChildDataclass(MyDataclass): mm: str = ii.d def my_fn_1() -> MyDataclass: - return MyDataclass() + return MyDataclass() def my_fn_2() -> Type[MyDataclass]: return MyDataclass @@ -88,10 +88,10 @@ def my_fn_3() -> MyDataclass: return MyDataclass def my_fn_4() -> Type[MyDataclass]: - return MyDataclass() + return MyDataclass() def my_fn_5() -> Union[str, MyDataclass]: - return MyDataclass() + return MyDataclass() def my_fn_6() -> Type[str, MyDataclass]: return MyDataclass @@ -100,10 +100,10 @@ def my_fn_7() -> Union[str, MyDataclass]: return MyDataclass def my_fn_8() -> Union[str, Type[MyDataclass]]: - return MyDataclass() + return MyDataclass() def my_fn_9() -> ChildDataclass: - return ChildDataclass() + return ChildDataclass() def my_fn_10() -> Type[ChildDataclass]: return ChildDataclass @@ -112,10 +112,10 @@ def my_fn_11() -> ChildDataclass: return ChildDataclass def my_fn_12() -> Type[ChildDataclass]: - return ChildDataclass() + return ChildDataclass() def my_fn_13() -> Union[str, ChildDataclass]: - return ChildDataclass() + return ChildDataclass() def my_fn_14() -> Type[str, ChildDataclass]: return ChildDataclass @@ -124,4 +124,4 @@ def my_fn_7() -> Union[str, ChildDataclass]: return ChildDataclass def my_fn_8() -> Union[str, Type[ChildDataclass]]: - return ChildDataclass() + return ChildDataclass() diff --git a/testData/typeinspectionv18/sqlModel.py b/testData/typeinspectionv18/sqlModel.py index b4147894..46b3b0b2 100644 --- a/testData/typeinspectionv18/sqlModel.py +++ b/testData/typeinspectionv18/sqlModel.py @@ -13,6 +13,6 @@ class Hero(SQLModel, table=True): hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador") hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48) -hero_4 = Hero(secret_name="test", ) +hero_4 = Hero(secret_name="test") hero_5 = Hero(name=123, secret_name=456, age="abc") \ No newline at end of file