Skip to content

Commit

Permalink
Merge pull request #43 from koxudaxi/support_dataclass
Browse files Browse the repository at this point in the history
Support pydantic.dataclasses.dataclass
  • Loading branch information
koxudaxi committed Aug 18, 2019
2 parents 09b2ee7 + 4fb73f0 commit eae6516
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 157 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
* Refactor support for renaming fields for subclasses of `BaseModel`
* (If the field name is refactored from the model definition or `__init__` call keyword arguments, PyCharm will present a dialog offering the choice to automatically rename the keyword where it occurs in a model initialization call.
* Search related-fields by class attributes and keyword arguments of `__init__` with `Ctrl+B` and `Cmd+B`

#### pydantic.dataclasses.dataclass
Support same features as `pydantic.BaseModel`

## How to install:
### MarketPlace
Expand Down
3 changes: 2 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<h2>version 0.0.14</h2>
<p>Features</p>
<ul>
<li>Support pydantic.dataclasses.dataclass [#43] </li>
<li>Search related-fields by class attributes and keyword arguments of __init__. with Ctrl+B and Cmd+B [#42] </li>
</ul>
<h2>version 0.0.13</h2>
Expand Down Expand Up @@ -39,7 +40,7 @@
</li>
<li>pydantic.dataclasses.dataclass
<ul>
<li>The plugin has not supported dataclass yet.</li>
<li>Support same features as `pydantic.BaseModel`</li>
</ul>
</li>
</ul>
Expand Down
43 changes: 43 additions & 0 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.koxudaxi.pydantic

import com.intellij.psi.util.QualifiedName
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext


fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument): PyClass? {
val pyCallExpression = pyKeywordArgument.parent?.parent as? PyCallExpression ?: return null
return pyCallExpression.callee?.reference?.resolve() as? PyClass ?: return null
}

fun isPydanticModel(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
return isSubClassOfPydanticBaseModel(pyClass, context) || isPydanticDataclass(pyClass)
}

fun isPydanticBaseModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == "pydantic.main.BaseModel"
}

fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean {
return pyClass.isSubclass("pydantic.main.BaseModel", context)
}

fun isPydanticDataclass(pyClass: PyClass): Boolean {
val decorators = pyClass.decoratorList?.decorators ?: return false
for (decorator in decorators) {
val callee = (decorator.callee as? PyReferenceExpression) ?: continue

for (decoratorQualifiedName in PyResolveUtil.resolveImportedElementQNameLocally(callee)) {
if (decoratorQualifiedName == QualifiedName.fromDottedString("pydantic.dataclasses.dataclass")) return true
}
}
return false
}

fun isPydanticField(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
return pyClass.isSubclass("pydantic.schema.Schema", context) || pyClass.isSubclass("pydantic.field.Field", context)
}
11 changes: 0 additions & 11 deletions src/com/koxudaxi/pydantic/PydanticBaseModel.kt

This file was deleted.

11 changes: 5 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
when (element) {
is PyTargetExpression -> {
val pyClass = element.containingClass ?: return false
if (pyClass.isSubclass("pydantic.main.BaseModel", null)) return true
if (isPydanticModel(pyClass)) return true
}
is PyKeywordArgument -> {
val pyClass = getPyClassByPyKeywordArgument(element) ?: return false
if (pyClass.isSubclass("pydantic.main.BaseModel", null)) return true
if (isPydanticModel(pyClass)) return true
}
}
return false
Expand Down Expand Up @@ -68,15 +68,15 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
addClassAttributes(pyClass, elementName)
addKeywordArguments(pyClass, elementName)
pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel") {
if (ancestorClass.isSubclass("pydantic.main.BaseModel", null) &&
if (!isPydanticBaseModel(ancestorClass)) {
if (isPydanticModel(ancestorClass) &&
!added.contains(ancestorClass)) {
addAllElement(ancestorClass, elementName, added)
}
}
}
PyClassInheritorsSearch.search(pyClass, true).forEach { inheritorsPyClass ->
if (inheritorsPyClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(inheritorsPyClass)) {
if (!isPydanticBaseModel(inheritorsPyClass) && !added.contains(inheritorsPyClass)) {
addAllElement(inheritorsPyClass, elementName, added)
}
}
Expand All @@ -93,7 +93,6 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
callee?.arguments?.forEach { argument ->
if (argument is PyKeywordArgument && argument.name == elementName) {
myElements.add(argument)

}
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.search.PyClassInheritorsSearch

private fun searchField(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>): Boolean {
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return false
if (!isPydanticModel(pyClass)) return false
val pyTargetExpression = pyClass.findClassAttribute(elementName, false, null) ?: return false
consumer.process(pyTargetExpression.reference)
return true
}

private fun searchKeywordArgument(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>) {
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return
if (!isPydanticModel(pyClass)) return
ReferencesSearch.search(pyClass as PsiElement).forEach { psiReference ->
val callee = PsiTreeUtil.getParentOfType(psiReference.element, PyCallExpression::class.java)
callee?.arguments?.forEach { argument ->
Expand All @@ -37,8 +37,8 @@ private fun searchDirectReferenceField(pyClass: PyClass, elementName: String, co
if (searchField(pyClass, elementName, consumer)) return true

pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel") {
if (ancestorClass.isSubclass("pydantic.main.BaseModel", null)) {
if (!isPydanticBaseModel(ancestorClass)) {
if (isPydanticModel(ancestorClass)) {
if (searchDirectReferenceField(ancestorClass, elementName, consumer)) {
return true
}
Expand All @@ -54,12 +54,12 @@ private fun searchAllElementReference(pyClass: PyClass?, elementName: String, ad
searchField(pyClass, elementName, consumer)
searchKeywordArgument(pyClass, elementName, consumer)
pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(ancestorClass)){
if (isPydanticBaseModel(ancestorClass) && !added.contains(ancestorClass)){
searchField(pyClass, elementName, consumer)
}
}
PyClassInheritorsSearch.search(pyClass, true).forEach { inheritorsPyClass ->
if (inheritorsPyClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(inheritorsPyClass)) {
if (!isPydanticBaseModel(inheritorsPyClass) && !added.contains(inheritorsPyClass)) {
searchAllElementReference(inheritorsPyClass, elementName, added, consumer)
}
}
Expand All @@ -72,13 +72,13 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
is PyKeywordArgument -> run<RuntimeException> {
val elementName = element.name ?: return@run
val pyClass = getPyClassByPyKeywordArgument(element) ?: return@run
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return@run
if (!isPydanticModel(pyClass)) return@run
searchDirectReferenceField(pyClass, elementName, consumer)
}
is PyTargetExpression -> run<RuntimeException> {
val elementName = element.name ?: return@run
val pyClass = element.containingClass ?: return@run
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return@run
if (!isPydanticModel(pyClass)) return@run
searchAllElementReference(pyClass, elementName, mutableSetOf(), consumer)
}
}
Expand Down
69 changes: 0 additions & 69 deletions src/com/koxudaxi/pydantic/PydanticFieldStub.kt

This file was deleted.

19 changes: 0 additions & 19 deletions src/com/koxudaxi/pydantic/PydanticFieldStubType.kt

This file was deleted.

5 changes: 1 addition & 4 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ import com.jetbrains.python.inspections.PyInspectionVisitor
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.impl.PyClassImpl
import com.jetbrains.python.psi.impl.PyReferenceExpressionImpl
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.impl.references.PyReferenceImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassTypeImpl

class PydanticInspection : PyInspection() {

Expand All @@ -29,7 +26,7 @@ class PydanticInspection : PyInspection() {

if (node != null) {
val pyClass: PyClass = (node.callee?.reference as? PyReferenceImpl)?.resolve() as? PyClass ?: return
if (!pyClass.isSubclass("pydantic.main.BaseModel", myTypeEvalContext)) return
if (!isPydanticModel(pyClass, myTypeEvalContext)) return
if ((node.callee as PyReferenceExpressionImpl).isQualified) return
for (argument in node.arguments) {
if (argument is PyKeywordArgument) {
Expand Down
11 changes: 0 additions & 11 deletions src/com/koxudaxi/pydantic/PydanticStub.kt

This file was deleted.

47 changes: 20 additions & 27 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}

val current = currentType.pyClass
if (!current.isSubclass("pydantic.main.BaseModel", context)) return null
if (!isPydanticModel(current, context)) return null

current
.classAttributes
Expand All @@ -107,14 +107,12 @@ class PydanticTypeProvider : PyTypeProviderBase() {
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext,
pyClass: PyClass): PyCallableParameter? {
val stub = field.stub
val fieldStub = if (stub == null) PydanticFieldStubImpl.create(field) else stub.getCustomStub(PydanticFieldStub::class.java)
if (fieldStub != null && !fieldStub.initValue()) return null
if (fieldStub == null && field.annotationValue == null && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

if (field.annotationValue == null && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

val defaultValue = when {
pyClass.isSubclass("pydantic.env_settings.BaseSettings", context) -> ellipsis
else -> getDefaultValueForParameter(field, fieldStub, ellipsis, context)
else -> getDefaultValueForParameter(field, ellipsis, context)
}

return PyCallableParameterImpl.nonPsi(field.name,
Expand All @@ -130,36 +128,31 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}

private fun getDefaultValueForParameter(field: PyTargetExpression,
fieldStub: PydanticFieldStub?,
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext): PyExpression? {
if (fieldStub == null) {
val value = field.findAssignedValue()
when {
value == null -> {
val annotation = (field.annotation?.value as? PySubscriptionExpressionImpl) ?: return null

when {
annotation.qualifier?.text == "Optional" -> return ellipsis
annotation.qualifier?.text == "Union" -> for (child in annotation.children) {
if (child is PyTupleExpression) {
for (type in child.children) {
if (type is PyNoneLiteralExpression) {
return ellipsis
}
val value = field.findAssignedValue()
when {
value == null -> {
val annotation = (field.annotation?.value as? PySubscriptionExpressionImpl) ?: return null

when {
annotation.qualifier?.text == "Optional" -> return ellipsis
annotation.qualifier?.text == "Union" -> for (child in annotation.children) {
if (child is PyTupleExpression) {
for (type in child.children) {
if (type is PyNoneLiteralExpression) {
return ellipsis
}
}
}
}
return value
}
field.hasAssignedValue() -> return getDefaultValueByAssignedValue(field, ellipsis, context)
else -> return null
return value
}
} else if (fieldStub.hasDefault() || fieldStub.hasDefaultFactory()) {
return ellipsis
field.hasAssignedValue() -> return getDefaultValueByAssignedValue(field, ellipsis, context)
else -> return null
}
return null
}

private fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
Expand All @@ -184,7 +177,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.asSequence()
.forEach { it ->
val pyClass = PsiTreeUtil.getContextOfType(it, PyClass::class.java)
if (pyClass != null && pyClass.isSubclass("pydantic.schema.Schema", context)) {
if (pyClass != null && isPydanticField(pyClass, context)) {
val defaultValue = assignedValue.getKeywordArgument("default")
?: assignedValue.getArgument(0, PyExpression::class.java)
when {
Expand Down

0 comments on commit eae6516

Please sign in to comment.