/
docstring_inference.py
77 lines (61 loc) · 2.58 KB
/
docstring_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from astroid import MANAGER, UseInferenceDefault, inference_tip, YES, InferenceError, nodes
from astroid.builder import AstroidBuilder
import xml.etree.ElementTree as etree
from docutils.core import publish_doctree
from grammar import parse_node
def register(linter):
pass
def infer_rtype(node, context=None):
if context is not None:
context_copy = context.clone()
else:
context_copy = None
for infer in node.func.infer(context_copy):
if infer is YES:
raise UseInferenceDefault()
docstring = infer.doc
if docstring is None:
break
doctree = etree.fromstring(publish_doctree(docstring).asdom().toxml())
field_lists = doctree.findall(".//field_list")
fields = [f for field_list in field_lists
for f in field_list.findall('field')]
if fields:
for field in fields:
field_names = field.findall("field_name")
field_bodies = field.findall("field_body")
if not field_names:
break
if not field_bodies:
break
field_name = field_names[0].text
paragraphs = field_bodies[0].findall("paragraph")
if not paragraphs:
break
field_body = paragraphs[0].text
if field_name.startswith("rtype"):
return parse_node(node, context, field_body)
# found nothing
raise UseInferenceDefault()
def infer_arg(node, context=None):
if not isinstance(node.parent, nodes.Arguments):
raise UseInferenceDefault()
if not isinstance(node.parent.parent, nodes.Function):
raise UseInferenceDefault()
func = node.parent.parent
docstring = func.doc
if docstring is None:
raise UseInferenceDefault()
doctree = etree.fromstring(publish_doctree(docstring).asdom().toxml())
field_lists = doctree.findall(".//field_list")
fields = [f for field_list in field_lists
for f in field_list.findall('field')]
if fields:
for field in fields:
field_name = field.findall("field_name")[0].text
field_body = field.findall("field_body")[0].findall("paragraph")[0].text
if field_name == "type %s" % node.name:
return parse_node(node, context, field_body)
raise UseInferenceDefault()
MANAGER.register_transform(nodes.CallFunc, inference_tip(infer_rtype))
MANAGER.register_transform(nodes.AssName, inference_tip(infer_arg))