Skip to content

Commit

Permalink
Fix wrong an error for a duplicate in config.
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed May 13, 2021
1 parent 738fb31 commit a63a182
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 18 deletions.
2 changes: 1 addition & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
<projectService
serviceImplementation="com.koxudaxi.pydantic.PydanticConfigService"/>
<projectService
serviceImplementation="com.koxudaxi.pydantic.PydanticVersionService"/>
serviceImplementation="com.koxudaxi.pydantic.PydanticCacheService"/>

<projectConfigurable groupId="tools" instance="com.koxudaxi.pydantic.PydanticConfigurable"/>
<postStartupActivity implementation="com.koxudaxi.pydantic.PydanticInitializer" order="last"/>
Expand Down
7 changes: 4 additions & 3 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
}
}

fun validateConfig(pyClass: PyClass): List<PsiElement>? {
fun validateConfig(pyClass: PyClass, context: TypeEvalContext): List<PsiElement>? {
val configClass = pyClass.nestedClasses.firstOrNull { it.isConfigClass } ?: return null

val allowedConfigKwargs = PydanticCacheService.getAllowedConfigKwargs(pyClass.project, context) ?: return null
val configKwargs = pyClass.superClassExpressions.filterIsInstance<PyKeywordArgument>()
.filter { allowedConfigKwargs.contains(it.name) }
.takeIf { it.isNotEmpty() } ?: return null

val results: MutableList<PsiElement> = configKwargs.toMutableList()
Expand All @@ -357,7 +358,7 @@ fun getConfig(
pydanticVersion: KotlinVersion? = null,
): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
val version = pydanticVersion ?: PydanticVersionService.getVersion(pyClass.project, context)
val version = pydanticVersion ?: PydanticCacheService.getVersion(pyClass.project, context)
pyClass.getAncestorClasses(context)
.reversed()
.filter { isPydanticModel(it, false, context) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.impl.PyStringLiteralExpressionImpl
import com.jetbrains.python.psi.types.TypeEvalContext

class PydanticVersionService {
class PydanticCacheService {
private var version: KotlinVersion? = null
private var allowedConfigKwargs: Set<String>? = null

private fun getAllowedConfigKwargs(project: Project, context: TypeEvalContext): Set<String>? {
val baseConfig = getPydanticBaseConfig(project, context) ?: return null
return baseConfig.classAttributes
.mapNotNull { it.name }
.filterNot { it.startsWith("__") && it.endsWith("__") }
.toSet()
}
private fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression
?: return null
Expand All @@ -34,21 +42,31 @@ class PydanticVersionService {
return getVersion(project, context).apply { version = this }
}

private fun getOrAllowedConfigKwargs(project: Project, context: TypeEvalContext): Set<String>? {
if (allowedConfigKwargs != null) return allowedConfigKwargs
return getAllowedConfigKwargs(project, context).apply { allowedConfigKwargs = this }
}

private fun clear() {
version = null
allowedConfigKwargs = null
}

companion object {
fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
return getInstance(project).getOrPutVersion(project, context)
}

fun getAllowedConfigKwargs(project: Project, context: TypeEvalContext): Set<String>? {
return getInstance(project).getOrAllowedConfigKwargs(project, context)
}

fun clear(project: Project) {
return getInstance(project).clear()
}

private fun getInstance(project: Project): PydanticVersionService {
return ServiceManager.getService(project, PydanticVersionService::class.java)
private fun getInstance(project: Project): PydanticCacheService {
return ServiceManager.getService(project, PydanticCacheService::class.java)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class PydanticCompletionContributor : CompletionContributor() {
isDataclass: Boolean,
genericTypeMap: Map<PyGenericType, PyType>?,
) {
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, typeEvalContext)
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
Expand Down
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class PydanticInspection : PyInspection() {
}

private fun inspectConfig(pyClass: PyClass) {
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, myTypeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
if (pydanticVersion?.isAtLeast(1, 8) != true) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
validateConfig(pyClass)?.forEach {
validateConfig(pyClass, myTypeEvalContext)?.forEach {
registerProblem(it,
"Specifying config in two places is ambiguous, use either Config attribute or class kwargs",
ProblemHighlightType.GENERIC_ERROR)
Expand All @@ -134,7 +134,7 @@ class PydanticInspection : PyInspection() {
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
val attributeName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.name ?: return
val config = getConfig(pyClass, myTypeEvalContext, true)
val version = PydanticVersionService.getVersion(pyClass.project, myTypeEvalContext)
val version = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
if (config["allow_mutation"] == false || (version?.isAtLeast(1, 8) == true && config["frozen"] == true)) {
registerProblem(node,
"Property \"${attributeName}\" defined in \"${pyClass.name}\" is read-only",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class PydanticPackageManagerListener : PyPackageManager.Listener {
private fun clearVersion(sdk: Sdk) {
ProjectManager.getInstance().openProjects
.filter { it.sdks.contains(sdk) }
.forEach { PydanticVersionService.clear(it) }
.forEach { PydanticCacheService.clear(it) }
}

override fun packagesRefreshed(sdk: Sdk) {
Expand Down
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref<PyType>? {
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()

val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion)
?: pyClass.getAncestorClasses(context)
.filter { isPydanticModel(it, false, context) }
Expand Down Expand Up @@ -300,7 +300,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
): PydanticDynamicModelClassType? {
val project = pyFunction.project
val typed = getInstance(project).currentInitTyped
val pydanticVersion = PydanticVersionService.getVersion(pyFunction.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyFunction.project, context)
val collected = linkedMapOf<String, PydanticDynamicModel.Attribute>()
val newVersion = pydanticVersion == null || pydanticVersion.isAtLeast(1, 5)
val modelNameParameterName = if (newVersion) "__model_name" else "model_name"
Expand Down Expand Up @@ -478,7 +478,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}
val genericTypeMap = getGenericTypeMap(pyClass, context, pyCallExpression)
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
val config = getConfig(pyClass, context, true)
for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) {
if (currentType !is PyClassType) continue
Expand Down
6 changes: 6 additions & 0 deletions testData/inspectionv18/configDuplicate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import abc

from pydantic import BaseModel


Expand All @@ -11,3 +13,7 @@ class F(BaseModel, allow_mutation=True):
class G(BaseModel):
class Config:
allow_mutation=True

class H(BaseModel, metaclass=abc.ABCMeta, <error descr="Specifying config in two places is ambiguous, use either Config attribute or class kwargs">allow_mutation=True</error>):
class <error descr="Specifying config in two places is ambiguous, use either Config attribute or class kwargs">Config</error>:
allow_mutation= True
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ open class PydanticPackageManagerListenerTest : PydanticTestCase() {
val context = TypeEvalContext.userInitiated(project, null)
val sdk = PythonSdkUtil.findPythonSdk(myFixture!!.module)!!

val pydanticVersion = PydanticVersionService.getVersion(project, context)
val pydanticVersion = PydanticCacheService.getVersion(project, context)
assertEquals(KotlinVersion(1, 0, 1), pydanticVersion)

BackgroundTaskUtil.syncPublisher(project, PyPackageManager.PACKAGE_MANAGER_TOPIC).packagesRefreshed(sdk)
invokeLater {
val privateVersionField = PydanticVersionService::class.java.getDeclaredField("version")
val privateVersionField = PydanticCacheService::class.java.getDeclaredField("version")
privateVersionField.trySetAccessible()
val pydanticVersionService = ServiceManager.getService(project, PydanticVersionService::class.java)
val pydanticVersionService = ServiceManager.getService(project, PydanticCacheService::class.java)
val actual = privateVersionField.get(pydanticVersionService)
assertNull(actual)
}
Expand Down

0 comments on commit a63a182

Please sign in to comment.