Skip to content
This repository has been archived by the owner on Dec 2, 2023. It is now read-only.

Commit

Permalink
Rename insert_grad_of, fixes #16
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbw committed Nov 7, 2017
1 parent 8b96e52 commit 01f861a
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name='tangent',
version='0.1.2',
version='0.1.3',
description=('Automatic differentiation using source code transformation '
'for Python'),
long_description=readme,
Expand Down
10 changes: 5 additions & 5 deletions tangent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tangent.utils import balanced_eq
from tangent.utils import copy
from tangent.utils import grad_dot
from tangent.utils import grad_of
from tangent.utils import insert_grad_of
from tangent.utils import init_grad
from tangent.utils import pop
from tangent.utils import pop_stack
Expand All @@ -45,25 +45,25 @@


class RemoveWith(gast.NodeTransformer):
"""A transformer that removes `with grad_of` statements."""
"""A transformer that removes `with insert_grad_of` statements."""

def visit_With(self, node):
if ast_.is_grad_of_statement(node):
if ast_.is_insert_grad_of_statement(node):
return None
else:
return node


def tangent(f):
"""A decorator which removes the `with grad_of` statement.
"""A decorator which removes the `with insert_grad_of` statement.
This allows the function to be called as usual.
Args:
f: A function
Returns:
A function with any `with grad_of` context managers removed.
A function with any `with insert_grad_of` context managers removed.
"""
node = annotate.resolve_calls(f)
RemoveWith().visit(node)
Expand Down
10 changes: 5 additions & 5 deletions tangent/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,20 @@ def append_args(node, node_list):
return ArgAppend(node_list).visit(node)


def is_grad_of_statement(node):
"""Check whether a context manager calls `grad_of`.
def is_insert_grad_of_statement(node):
"""Check whether a context manager calls `insert_grad_of`.
Args:
node: The context manager node.
Returns:
Whether or not this node contains `grad_of` calls.
Whether or not this node contains `insert_grad_of` calls.
Raises:
ValueError: If the `grad_of` calls are mixed with other calls.
ValueError: If the `insert_grad_of` calls are mixed with other calls.
"""
tangent_calls = [anno.getanno(item.context_expr, 'func', None)
is utils.grad_of for item in node.items]
is utils.insert_grad_of for item in node.items]
if all(tangent_calls):
return True
elif any(tangent_calls):
Expand Down
2 changes: 1 addition & 1 deletion tangent/fence.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def validate(node, source):
"""Call this function to validate an AST."""
# TODO: leaving strict checking off to support grad_of
# TODO: leaving strict checking off to support insert_grad_of
lf = LanguageFence(source, strict=False)
lf.visit(node)
return node
Expand Down
2 changes: 1 addition & 1 deletion tangent/grad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def grad(func,
`preserve_result` is True, the function will also return the original
result of `func`.
"""
# If the function had the with grad_of statements removed, retrieve them
# If the function had the with insert_grad_of statements removed, retrieve them
func = getattr(func, 'tangent', func)

# Take the gradient
Expand Down
4 changes: 2 additions & 2 deletions tangent/reverse_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ def visit_While(self, node):
return primal, adjoint

def visit_With(self, node):
"""Deal with the special with grad_of(x) statement."""
if ast_.is_grad_of_statement(node):
"""Deal with the special with insert_grad_of(x) statement."""
if ast_.is_insert_grad_of_statement(node):
primal = []
adjoint = node.body
if isinstance(adjoint[0], gast.With):
Expand Down
6 changes: 3 additions & 3 deletions tangent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,10 @@ def push_stack(stack, substack, op_id):
stack.append(substack)


def grad_of(var):
def insert_grad_of(var):
"""The context manager that allows insertion of arbitrary adjoint code.
This function can be used as a context manager e.g. `with grad_of(x) as dx`
This function can be used as a context manager e.g. `with insert_grad_of(x) as dx`
to write code that will be inserted in the adjoint while having access to the
gradients of certain variables.
Expand All @@ -540,7 +540,7 @@ def grad_of(var):
decorator and the code is actually run.
"""
raise ValueError('use the tangent decorator for functions containing '
'the `with grad_of` statement')
'the `with insert_grad_of` statement')


def grad_dot(dy, x1, x2):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
'cart2polar',
'iterpower_with_nested_def',
'fn_multiple_return',
'grad_of',
'insert_grad_of',
'_trace_mul',
'_nontrace_mul',
'active_subscript', # TODO: fix then remove from blacklist
Expand Down
4 changes: 2 additions & 2 deletions tests/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

import tangent
from tangent import grad_of
from tangent import insert_grad_of

import tensorflow as tf

Expand Down Expand Up @@ -640,7 +640,7 @@ def cart2polar(a, b):

def inlining_contextmanager(a):
b = a * a
with grad_of(a) as g:
with insert_grad_of(a) as g:
g = g * 0.9
c = b * a
return c
Expand Down

0 comments on commit 01f861a

Please sign in to comment.