Skip to content

Commit

Permalink
Support alias on schema (#64)
Browse files Browse the repository at this point in the history
* Support alias on a schema
  • Loading branch information
koxudaxi committed Sep 9, 2019
1 parent d238f50 commit 09a68e3
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 49 deletions.
7 changes: 6 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
<idea-plugin url="https://github.com/koxudaxi/pydantic-pycharm-plugin">
<id>com.koxudaxi.pydantic</id>
<name>Pydantic</name>
<version>0.0.17</version>
<version>0.0.18</version>
<vendor email="koaxudai@gmail.com">Koudai Aono @koxudaxi</vendor>
<change-notes><![CDATA[
<h2>version 0.0.18</h2>
<p>Features</p>
<ul>
<li>Support alias on Schema [#64] </li>
</ul>
<h2>version 0.0.17</h2>
<p>BugFixes</p>
<ul>
Expand Down
34 changes: 34 additions & 0 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package com.koxudaxi.pydantic

import com.intellij.psi.ResolveResult
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.QualifiedName
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionImpl
import com.jetbrains.python.psi.impl.PyTargetExpressionImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext

Expand Down Expand Up @@ -66,3 +70,33 @@ internal fun getClassVariables(pyClass: PyClass, context: TypeEvalContext): Sequ
.asSequence()
.filterNot { PyTypingTypeProvider.isClassVar(it, context) }
}

internal fun getAliasedFieldName(field: PyTargetExpression, context: TypeEvalContext): String? {
val fieldName = field.name
val assignedValue = field.findAssignedValue() ?: return fieldName
val callee = (assignedValue as? PyCallExpressionImpl)?.callee ?: return fieldName
val referenceExpression = callee.reference?.element as? PyReferenceExpression ?: return fieldName


val resolveResults = getResolveElements(referenceExpression, context)
return PyUtil.filterTopPriorityResults(resolveResults)
.mapNotNull { PsiTreeUtil.getContextOfType(it, PyClass::class.java) }
.filter { isPydanticField(it, context) }
.mapNotNull {
when (val alias = assignedValue.getKeywordArgument("alias")) {
is StringLiteralExpression -> alias.stringValue
is PyReferenceExpression -> ((alias.reference.resolve() as? PyTargetExpressionImpl)
?.findAssignedValue() as? StringLiteralExpression)?.stringValue
//TODO Support dynamic assigned Value. eg: Schema(..., alias=get_alias_name(field_name))
else -> null
}
}
.firstOrNull() ?: fieldName
}


internal fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context)
return referenceExpression.getReference(resolveContext).multiResolve(false)

}
54 changes: 28 additions & 26 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.codeInsight.lookup.LookupElementBuilder
import com.intellij.icons.AllIcons
import com.intellij.patterns.PlatformPatterns.psiElement
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.PsiTreeUtil.getParentOfType
import com.intellij.util.ProcessingContext
import com.jetbrains.python.PyTokenTypes
import com.jetbrains.python.codeInsight.completion.getTypeEvalContext
import com.jetbrains.python.documentation.PythonDocumentationProvider.*
import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHint
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassType
Expand Down Expand Up @@ -38,7 +37,7 @@ class PydanticCompletionContributor : CompletionContributor() {

abstract val icon: Icon

abstract fun getLookupNameFromFieldName(fieldName: String): String
abstract fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext): String

val typeProvider: PydanticTypeProvider = PydanticTypeProvider()

Expand All @@ -51,7 +50,7 @@ class PydanticCompletionContributor : CompletionContributor() {
val defaultValue = parameter?.defaultValue?.let {
if (parameter.defaultValue is PyNoneLiteralExpression && !isBaseSetting(pyClass, typeEvalContext)) {
"=None"
} else{
} else {
"=${parameter.defaultValueText}"
}
} ?: ""
Expand All @@ -69,11 +68,11 @@ class PydanticCompletionContributor : CompletionContributor() {
}
}

protected fun getPyClassByPyReferenceExpression(pyReferenceExpression: PyReferenceExpression, typeEvalContext: TypeEvalContext, parameters: CompletionParameters? = null, result: CompletionResultSet? = null): PyClass? {
protected fun getPyClassByPyReferenceExpression(pyReferenceExpression: PyReferenceExpression, typeEvalContext: TypeEvalContext, parameters: CompletionParameters?, result: CompletionResultSet?): PyClass? {
val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(typeEvalContext)
return pyReferenceExpression.multiFollowAssignmentsChain(resolveContext).mapNotNull {
when (val resolveElement = it.element) {
is PyClass -> {
is PyClass -> {
if (parameters != null && result != null) {
removeAllFieldElement(parameters, result, resolveElement, typeEvalContext, excludeFields)
null
Expand All @@ -83,16 +82,17 @@ class PydanticCompletionContributor : CompletionContributor() {
}
is PyCallExpression -> getPyClassByPyCallExpression(resolveElement)
is PyNamedParameter -> {
if ((parameters != null && result != null) && resolveElement.isSelf){
getParentOfType(resolveElement, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.takeIf { it.containingClass is PyClass }
?.let {
removeAllFieldElement(parameters, result, it.containingClass!!, typeEvalContext, excludeFields)
return null
}
}
getPyClassFromPyNamedParameter(resolveElement, typeEvalContext)}
if ((parameters != null && result != null) && resolveElement.isSelf) {
getParentOfType(resolveElement, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.takeIf { it.containingClass is PyClass }
?.let {
removeAllFieldElement(parameters, result, it.containingClass!!, typeEvalContext, excludeFields)
return null
}
}
getPyClassFromPyNamedParameter(resolveElement, typeEvalContext)
}
else -> null
}
}.firstOrNull()
Expand All @@ -105,7 +105,7 @@ class PydanticCompletionContributor : CompletionContributor() {
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.forEach {
val elementName = getLookupNameFromFieldName(it.name!!)
val elementName = getLookupNameFromFieldName(it, typeEvalContext)
if (excludes == null || !excludes.contains(elementName)) {
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
Expand Down Expand Up @@ -138,9 +138,10 @@ class PydanticCompletionContributor : CompletionContributor() {
}
result.addAllElements(newElements.values)
}

protected fun removeAllFieldElement(parameters: CompletionParameters, result: CompletionResultSet,
pyClass: PyClass, typeEvalContext: TypeEvalContext,
excludes: HashSet<String>? = null) {
pyClass: PyClass, typeEvalContext: TypeEvalContext,
excludes: HashSet<String>) {

if (!isPydanticModel(pyClass)) return

Expand All @@ -155,7 +156,7 @@ class PydanticCompletionContributor : CompletionContributor() {

result.runRemainingContributors(parameters)
{ completionResult ->
if (completionResult.lookupElement.psiElement?.getIcon(0) == AllIcons.Nodes.Field ) {
if (completionResult.lookupElement.psiElement?.getIcon(0) == AllIcons.Nodes.Field) {
completionResult.lookupElement.lookupString
.takeIf { name -> !fieldElements.contains(name) && (excludes == null || !excludes.contains(name)) }
?.let { result.passResult(completionResult) }
Expand All @@ -167,8 +168,8 @@ class PydanticCompletionContributor : CompletionContributor() {
}

private object KeywordArgumentCompletionProvider : PydanticCompletionProvider() {
override fun getLookupNameFromFieldName(fieldName: String): String {
return "${fieldName}="
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext): String {
return "${getAliasedFieldName(field, context)}="
}

override val icon: Icon = AllIcons.Nodes.Parameter
Expand All @@ -178,7 +179,7 @@ class PydanticCompletionContributor : CompletionContributor() {
val typeEvalContext = parameters.getTypeEvalContext()

val pyClass = when (val pyCallableElement = pyArgumentList.parent!!) {
is PyReferenceExpression -> getPyClassByPyReferenceExpression(pyCallableElement, typeEvalContext)
is PyReferenceExpression -> getPyClassByPyReferenceExpression(pyCallableElement, typeEvalContext, null, null)
?: return
is PyCallExpression -> getPyClassByPyCallExpression(pyCallableElement) ?: return
else -> return
Expand All @@ -196,16 +197,17 @@ class PydanticCompletionContributor : CompletionContributor() {
}

private object FieldCompletionProvider : PydanticCompletionProvider() {
override fun getLookupNameFromFieldName(fieldName: String): String {
return fieldName
override fun getLookupNameFromFieldName(field: PyTargetExpression, context: TypeEvalContext): String {
return field.name!!
}

override val icon: Icon = AllIcons.Nodes.Field

override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val typeEvalContext = parameters.getTypeEvalContext()
val pyClass = when (val instance = parameters.position.parent.firstChild) {
is PyReferenceExpression -> getPyClassByPyReferenceExpression(instance, typeEvalContext, parameters, result) ?: return
is PyReferenceExpression -> getPyClassByPyReferenceExpression(instance, typeEvalContext, parameters, result)
?: return
is PyCallExpression -> getPyClassByPyCallExpression(instance) ?: return
else -> return
}
Expand Down
34 changes: 14 additions & 20 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package com.koxudaxi.pydantic

import com.intellij.lang.ASTNode
import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.intellij.psi.ResolveResult
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.containers.isNullOrEmpty
import com.jetbrains.python.PyElementTypes.NONE_LITERAL_EXPRESSION
import com.jetbrains.python.PyNames
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionImpl
Expand Down Expand Up @@ -136,7 +133,8 @@ class PydanticTypeProvider : PyTypeProviderBase() {
getTypeForParameter(field, context)
}

return PyCallableParameterImpl.nonPsi(field.name, typeForParameter, defaultValue)

return PyCallableParameterImpl.nonPsi(getAliasedFieldName(field, context), typeForParameter, defaultValue)
}

private fun getTypeForParameter(field: PyTargetExpression,
Expand Down Expand Up @@ -169,12 +167,6 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}

private fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context)
return referenceExpression.getReference(resolveContext).multiResolve(false)

}

private fun getDefaultValueByAssignedValue(field: PyTargetExpression,
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext): PyExpression? {
Expand All @@ -189,18 +181,20 @@ class PydanticTypeProvider : PyTypeProviderBase() {

val resolveResults = getResolveElements(referenceExpression, context)
PyUtil.filterTopPriorityResults(resolveResults)
.forEach { it ->
val pyClass = PsiTreeUtil.getContextOfType(it, PyClass::class.java)
if (pyClass != null && isPydanticField(pyClass, context)) {
val defaultValue = assignedValue.getKeywordArgument("default")
?: assignedValue.getArgument(0, PyExpression::class.java)
return when {
defaultValue == null -> null
defaultValue.text == "..." -> null
else -> defaultValue
.mapNotNull { PsiTreeUtil.getContextOfType(it, PyClass::class.java) }
.any { isPydanticField(it, context) }.let {
return when {
it -> {
val defaultValue = assignedValue.getKeywordArgument("default")
?: assignedValue.getArgument(0, PyExpression::class.java)
when {
defaultValue == null -> null
defaultValue.text == "..." -> null
else -> defaultValue
}
}
else -> assignedValue
}
}
return assignedValue
}
}
11 changes: 10 additions & 1 deletion testData/completion/fieldSchema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from builtins import *
from pydantic import BaseModel, Schema

def get_alias():
return 'alias_c_id'
b_id = 'alias_b_id'
class A(BaseModel):
abc: str = Schema(...)
cde = Schema(str('abc'))
efg = Schema(default=str('abc'))
hij = Schema(default=...)

a_id: str = Schema(..., alias='alias_a_id')
b_id: str = Schema(..., alias=b_id)
c_id: str = Schema(..., alias=get_alias())
d_id: str = Schema(..., alias=)
e_id: str = Schema(..., alias=broken)
f_id: str = Schema(..., alias=123)
g_id: str = get_alias()
class B(A):
hij: str

Expand Down
22 changes: 22 additions & 0 deletions testData/completion/keywordArgumentSchema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from builtins import *
from pydantic import BaseModel, Schema

def get_alias():
return 'alias_c_id'
b_id: str = 'alias_b_id'
class A(BaseModel):
abc: str = Schema(...)
cde = Schema(str('abc'))
efg = Schema(default=str('abc'))
hij = Schema(default=...)
a_id: str = Schema(..., alias='alias_a_id')
b_id: str = Schema(..., alias=b_id)
c_id: str = Schema(..., alias=get_alias())
d_id: str = Schema(..., alias=)
e_id: str = Schema(..., alias=broken)
f_id: str = Schema(..., alias=123)
g_id: str = get_alias()
class B(A):
hij: str

A(<caret>)
27 changes: 26 additions & 1 deletion testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ open class PydanticCompletionTest : PydanticTestCase() {
it!!.psiElement is PyTargetExpression
}.mapNotNull {

Pair(it!!.lookupString, LookupElementPresentation.renderElement(it).typeText)
Pair(it!!.lookupString, LookupElementPresentation.renderElement(it).typeText ?: "null")
}
assertEquals(fieldNames, actual)
}
Expand Down Expand Up @@ -331,9 +331,16 @@ open class PydanticCompletionTest : PydanticTestCase() {
fun testFieldSchema() {
doFieldTest(
listOf(
Pair("a_id","str A"),
Pair("abc", "str A"),
Pair("b_id","str A"),
Pair("c_id","str A"),
Pair("cde", "str=str('abc') A"),
Pair("d_id", "str A"),
Pair("e_id", "str A"),
Pair("efg", "str=str('abc') A"),
Pair("f_id", "str A"),
Pair("g_id", "str=get_alias() A"),
Pair("hij", "Any A"),
Pair("___slots__", "BaseModel")
)
Expand Down Expand Up @@ -375,4 +382,22 @@ open class PydanticCompletionTest : PydanticTestCase() {
)
)
}
fun testKeywordArgumentSchema() {
doFieldTest(
listOf(
Pair("abc=", "str A"),
Pair("alias_a_id=", "str A"),
Pair("alias_b_id=", "str A"),
Pair("c_id=", "str A"),
Pair("cde=", "str=str('abc') A"),
Pair("d_id=", "str A"),
Pair("e_id=", "str A"),
Pair("efg=", "str=str('abc') A"),
Pair("f_id=", "str A"),
Pair("g_id=", "str=get_alias() A"),
Pair("hij=", "Any A"),
Pair("b_id", "null")
)
)
}
}

0 comments on commit 09a68e3

Please sign in to comment.