Skip to content

Commit

Permalink
Link args and annassign with type hints (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-hilden committed Jun 8, 2022
1 parent cde0071 commit 21bd4f6
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 49 deletions.
1 change: 1 addition & 0 deletions docs/src/release_notes.rst
Expand Up @@ -15,6 +15,7 @@ Unreleased
:confval:`codeautolink_warn_on_failed_resolve` for debugging (:issue:`106`)
- Define extension environment version for Sphinx (:issue:`107`)
- Merge environments only when the extension is active (:issue:`107`)
- Link arguments and annotated assignment with type hints (:issue:`108`)

0.10.0 (2022-01-25)
-------------------
Expand Down
104 changes: 62 additions & 42 deletions src/sphinx_codeautolink/parse.py
Expand Up @@ -21,27 +21,6 @@ def parse_names(source: str, doctree_node) -> List['Name']:
return sum([split_access(a) for a in visitor.accessed], [])


@dataclass
class PendingAccess:
"""Pending name access."""

components: List[ast.AST]


@dataclass
class PendingAssign:
"""
Pending assign target.
`targets` represent the assignment targets.
If a single PendingAccess is found, it should be used to store the value
on the right hand side of the assignment. If multiple values are found,
they should overwrite any names in the current scope and not assign values.
"""

targets: Union[Optional[PendingAccess], List[Optional[PendingAccess]]]


@dataclass
class Component:
"""Name access component."""
Expand All @@ -61,6 +40,8 @@ def from_ast(cls, node):
elif isinstance(node, ast.Attribute):
name = node.attr
context = node.ctx.__class__.__name__.lower()
elif isinstance(node, ast.arg):
name = node.arg
elif isinstance(node, ast.Call):
name = NameBreak.call
else:
Expand All @@ -69,6 +50,27 @@ def from_ast(cls, node):
return cls(name, node.lineno, end_lineno, context)


@dataclass
class PendingAccess:
"""Pending name access."""

components: List[Component]


@dataclass
class PendingAssign:
"""
Pending assign target.
`targets` represent the assignment targets.
If a single PendingAccess is found, it should be used to store the value
on the right hand side of the assignment. If multiple values are found,
they should overwrite any names in the current scope and not assign values.
"""

targets: Union[Optional[PendingAccess], List[Optional[PendingAccess]]]


class NameBreak(str, Enum):
"""Elements that break name access chains."""

Expand Down Expand Up @@ -269,7 +271,7 @@ def _assign(self, local_name: str, components: List[Component]):
self.pseudo_scopes_stack[-1][local_name] = components

def _access(self, access: PendingAccess) -> Optional[Access]:
components = [Component.from_ast(n) for n in access.components]
components = access.components
prior = self.pseudo_scopes_stack[-1].get(components[0].name, None)

if prior is None:
Expand Down Expand Up @@ -306,7 +308,7 @@ def _resolve_assignment(self, assignment: Assignment):
continue

if len(target.components) == 1:
comp = Component.from_ast(target.components[0])
comp = target.components[0]
self._overwrite(comp.name)
if access is not None:
self._assign(comp.name, access.full_components)
Expand Down Expand Up @@ -408,24 +410,24 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
self.visit_Import(node, prefix=node.module + '.')

@track_parents
def visit_Name(self, node):
def visit_Name(self, node: ast.Name):
"""Visit a Name node."""
return PendingAccess([node])
return PendingAccess([Component.from_ast(node)])

@track_parents
def visit_Attribute(self, node):
def visit_Attribute(self, node: ast.Attribute):
"""Visit an Attribute node."""
inner: Optional[PendingAccess] = self.visit(node.value)
if inner is not None:
inner.components.append(node)
inner.components.append(Component.from_ast(node))
return inner

@track_parents
def visit_Call(self, node: ast.Call):
"""Visit a Call node."""
inner: Optional[PendingAccess] = self.visit(node.func)
if inner is not None:
inner.components.append(node)
inner.components.append(Component.from_ast(node))
with self.reset_parents():
for arg in node.args + node.keywords:
self.visit(arg)
Expand Down Expand Up @@ -462,14 +464,22 @@ def visit_Assign(self, node: ast.Assign):
@track_parents
def visit_AnnAssign(self, node: ast.AnnAssign):
"""Visit an AnnAssign node."""
if node.value is not None:
value = self.visit(node.value)
target = self.visit(node.target)

with self.reset_parents():
self.visit(node.annotation)
value = self.visit(node.value) if node.value is not None else None
annot = self.visit(node.annotation)
if annot is not None:
if value is not None:
self._access(value)

annot.components.append(Component(
NameBreak.call,
node.annotation.lineno,
node.annotation.end_lineno,
'load',
))
value = annot

if node.value is not None:
target = self.visit(node.target)
if value is not None:
return Assignment([PendingAssign(target)], value)

def visit_AugAssign(self, node: ast.AugAssign):
Expand Down Expand Up @@ -528,31 +538,41 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
self._overwrite(node.name)
for dec in node.decorator_list:
self.visit(dec)
if node.returns is not None:
self.visit(node.returns)
for d in node.args.defaults + node.args.kw_defaults:
if d is None:
continue
self.visit(d)
args = self._get_args(node.args)
args += [node.args.vararg, node.args.kwarg]
for arg in args:
if arg is None or arg.annotation is None:
continue
self.visit(arg.annotation)

inner = self.__class__(self.doctree_node)
inner.pseudo_scopes_stack[0] = self.pseudo_scopes_stack[0].copy()
inner.outer_scopes_stack = list(self.outer_scopes_stack)
inner.outer_scopes_stack.append(self.pseudo_scopes_stack[0])

for arg in args:
if arg is None:
continue
inner._overwrite(arg.arg)
inner.visit(arg)
if node.returns is not None:
self.visit(node.returns)
for n in node.body:
inner.visit(n)
self.accessed.extend(inner.accessed)

@track_parents
def visit_arg(self, arg: ast.arg):
"""Handle function argument and its annotation."""
target = PendingAccess([Component.from_ast(arg)])
if arg.annotation is not None:
value = self.visit(arg.annotation)
if value is not None:
comp = Component(NameBreak.call, arg.lineno, arg.end_lineno, 'load')
value.components.append(comp)
else:
value = None
return Assignment([PendingAssign(target)], value)

def visit_Lambda(self, node: ast.Lambda):
"""Swap node order and separate inner scope."""
for d in node.args.defaults + node.args.kw_defaults:
Expand Down
18 changes: 18 additions & 0 deletions tests/parse/__init__.py
Expand Up @@ -61,6 +61,24 @@ def test_simple_import_then_access(self):
refs = [('lib', 'lib'), ('lib', 'lib')]
return s, refs

@refs_equal
def test_inside_list_literal(self):
s = 'import lib\n[lib]'
refs = [('lib', 'lib'), ('lib', 'lib')]
return s, refs

@refs_equal
def test_inside_subscript(self):
s = 'import lib\n0[lib]'
refs = [('lib', 'lib'), ('lib', 'lib')]
return s, refs

@refs_equal
def test_outside_subscript(self):
s = 'import lib\nlib[0]'
refs = [('lib', 'lib'), ('lib', 'lib')]
return s, refs

@refs_equal
def test_simple_import_then_attrib(self):
s = 'import lib\nlib.attr'
Expand Down
2 changes: 2 additions & 0 deletions tests/parse/_util.py
Expand Up @@ -8,6 +8,8 @@ def wrapper(self):
source, expected = func(self)
names = parse_names(source, doctree_node=None)
names = sorted(names, key=lambda name: name.lineno)
print('All names:')
[print(n) for n in names]
for n, e in zip(names, expected):
s = '.'.join(c for c in n.import_components)
assert s == e[0], f'Wrong import! Expected\n{e}\ngot\n{n}'
Expand Down
27 changes: 21 additions & 6 deletions tests/parse/assign.py
Expand Up @@ -78,27 +78,42 @@ def test_augassign_uses_and_assigns_imported(self):
return s, refs

@refs_equal
def test_annassign_uses_imported(self):
def test_annassign_overwrites_imported(self):
s = 'import a\na: b = 1\na'
refs = [('a', 'a')]
return s, refs

@refs_equal
def test_annassign_uses_and_assigns_imported(self):
s = 'import a\na: b = a\na'
refs = [('a', 'a'), ('a', 'a'), ('a', 'a')]
s = 'import a\nb: 1 = a\nb.c'
refs = [('a', 'a'), ('a', 'a'), ('a.c', 'b.c')]
return s, refs

@refs_equal
def test_annassign_uses_and_annotates_imported(self):
s = 'import a\nb: a = 1\nb.c'
refs = [('a', 'a'), ('a', 'a'), ('a.().c', 'b.c')]
return s, refs

@refs_equal
def test_annassign_prioritises_annotation(self):
s = 'import a, b\nc: a = b\nc.d'
# note that AnnAssign is executed from value -> annot -> target
refs = [('a', 'a'), ('b', 'b'), ('b', 'b'), ('a', 'a'), ('a.().d', 'c.d')]
return s, refs

@refs_equal
def test_annassign_why_would_anyone_do_this(self):
s = 'import a\na: a = a\na'
refs = [('a', 'a'), ('a', 'a'), ('a', 'a'), ('a', 'a')]
s = 'import a\na: a = a\na.b'
refs = [('a', 'a'), ('a', 'a'), ('a', 'a'), ('a.().b', 'a.b')]
return s, refs

@refs_equal
def test_annassign_without_value_overrides_annotation_but_not_linked(self):
# note that this is different from runtime behavior
# which does not overwrite the variable value
s = 'import a\na: b\na'
refs = [('a', 'a'), ('a', 'a')]
refs = [('a', 'a')]
return s, refs

@pytest.mark.skipif(
Expand Down
29 changes: 28 additions & 1 deletion tests/parse/scope.py
Expand Up @@ -42,11 +42,38 @@ def test_func_assigns_then_used_outside(self):
return s, refs

@refs_equal
def test_func_annotations_then_assigns(self):
def test_func_annotates_then_uses(self):
s = 'import a\ndef f(arg: a):\n arg.b'
refs = [('a', 'a'), ('a', 'a'), ('a.().b', 'arg.b')]
return s, refs

@refs_equal
def test_func_annotates_then_assigns(self):
s = 'import a\ndef f(arg: a) -> a:\n a = 1'
refs = [('a', 'a'), ('a', 'a'), ('a', 'a')]
return s, refs

@refs_equal
def test_func_annotates_as_generic_then_uses(self):
s = 'import a\ndef f(arg: a[0]):\n arg.b'
refs = [('a', 'a'), ('a', 'a')]
return s, refs

@refs_equal
def test_func_annotates_inside_generic_then_uses(self):
s = 'import a\ndef f(arg: b[a]):\n arg.b'
refs = [('a', 'a'), ('a', 'a')]
return s, refs

@pytest.mark.skipif(
sys.version_info < (3, 10), reason='Union syntax introduced in 3.10.'
)
@refs_equal
def test_func_annotates_union_then_uses(self):
s = 'import a\ndef f(arg: a | 1):\n arg.b'
refs = [('a', 'a'), ('a', 'a')]
return s, refs

@refs_equal
def test_func_kw_default_uses(self):
s = 'import a\ndef f(*_, c, b=a):\n pass'
Expand Down

0 comments on commit 21bd4f6

Please sign in to comment.