Skip to content

Commit

Permalink
[AutoDiff] Attribute gardening.
Browse files Browse the repository at this point in the history
Upstream changes from `tensorflow` branch:

- apple#28932:
  deprecate `@differentiable(jvp:vjp)` arguments.

- apple#29038: gardening.

Additional gardening included.
  • Loading branch information
dan-zheng committed Jan 7, 2020
1 parent 3fa5011 commit 683b813
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 214 deletions.
28 changes: 14 additions & 14 deletions include/swift/AST/Attr.h
Expand Up @@ -1650,7 +1650,7 @@ class DifferentiableAttr final

/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameRefWithLoc> JVP;
Expand All @@ -1662,7 +1662,7 @@ class DifferentiableAttr final
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
Expand Down Expand Up @@ -1720,7 +1720,7 @@ class DifferentiableAttr final
ParameterIndices = parameterIndices;
}

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1771,15 +1771,15 @@ class DifferentiableAttr final
///
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
/// specifying the parameters that are differentiated "with respect to", i.e.
/// the differentiation parameters. The differentiation parameters must conform
/// to the `Differentiable` protocol.
/// the differentiability parameters. The differentiability parameters must
/// conform to the `Differentiable` protocol.
///
/// If the `wrt:` clause is unspecified, the differentiation parameters are
/// If the `wrt:` clause is unspecified, the differentiability parameters are
/// inferred to be all parameters that conform to `Differentiable`.
///
/// `@derivative(of:)` attribute type-checking verifies that the type of the
/// derivative function declaration is consistent with the type of the
/// referenced original declaration and the differentiation parameters.
/// referenced original declaration and the differentiability parameters.
///
/// Examples:
/// @derivative(of: sin(_:))
Expand All @@ -1798,9 +1798,9 @@ class DerivativeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The derivative function kind (JVP or VJP), resolved by the type checker.
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
Expand Down Expand Up @@ -1843,7 +1843,7 @@ class DerivativeAttr final
}
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1872,7 +1872,7 @@ class DerivativeAttr final
/// computed property declaration.
///
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
/// parameters that are transposed "with respect to", i.e. the transposed
/// parameters that are transposed "with respect to", i.e. the linearity
/// parameters.
///
/// Examples:
Expand All @@ -1892,9 +1892,9 @@ class TransposeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed linearity parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The transposed parameters' indices, resolved by the type checker.
/// The linearity parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
Expand Down Expand Up @@ -1927,7 +1927,7 @@ class TransposeAttr final
OriginalFunction = decl;
}

/// The parsed transposed parameters, i.e. the list of parameters specified in
/// The parsed linearity parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down
13 changes: 8 additions & 5 deletions include/swift/AST/DiagnosticsParse.def
Expand Up @@ -1542,21 +1542,24 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
"expected a member name as second parameter in '_implements' attribute", ())

// differentiable
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
"expected a %0 function name", (StringRef))
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@differentiable' attribute", (StringRef))
ERROR(attr_differentiable_expected_label,none,
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
"or 'vjp:'", ())
ERROR(differentiable_attribute_expected_rparen,none,
"expected ')' in '@differentiable' attribute", ())
ERROR(unexpected_argument_differentiable,none,
ERROR(attr_differentiable_unexpected_argument,none,
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none,
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
"deprecated; use '@derivative' attribute for derivative registration "
"instead", ())

// differentiation `wrt` parameters clause
ERROR(expected_colon_after_label,PointsToFirstBadToken,
Expand Down
11 changes: 6 additions & 5 deletions include/swift/Parse/Parser.h
Expand Up @@ -1002,12 +1002,13 @@ class Parser {
Optional<DeclNameRefWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable` and `@derivative` attributes.
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
///
/// If `allowNamedParameters` is false, allow only index parameters and
/// 'self'.
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
/// 'self'. Used for `@transpose` attributes.
bool parseDifferentiabilityParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &parameters, StringRef attrName,
bool allowNamedParameters = true);

/// Parse the @derivative attribute.
Expand Down
55 changes: 32 additions & 23 deletions lib/AST/Attr.cpp
Expand Up @@ -366,51 +366,60 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
Printer.printNewline();
}

/// Printing style for a differentiation parameter in a `wrt:` differentiation
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
/// `@transpose` attributes.
enum class DifferentiationParameterPrintingStyle {
/// Print parameter by name.
/// The kind of a parameter in a `wrt:` differentiation parameters clause:
/// either a differentiability parameter or a linearity parameter. Used for
/// printing `@differentiable`, `@derivative`, and `@transpose` attributes.
enum class DifferentiationParameterKind {
/// A differentiability parameter, printed by name.
/// Used for `@differentiable` and `@derivative` attribute.
Name,
/// Print parameter by index.
Differentiability,
/// A linearity parameter, printed by index.
/// Used for `@transpose` attribute.
Index
Linearity
};

/// Returns the differentiation parameters clause string for the given function,
/// parameter indices, parsed parameters, . Use the parameter indices if
/// specified; otherwise, use the parsed parameters.
/// parameter indices, parsed parameters, and differentiation parameter kind.
/// Use the parameter indices if specified; otherwise, use the parsed
/// parameters.
static std::string getDifferentiationParametersClauseString(
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
const AbstractFunctionDecl *function, IndexSubset *parameterIndices,
ArrayRef<ParsedAutoDiffParameter> parsedParams,
DifferentiationParameterPrintingStyle style) {
DifferentiationParameterKind parameterKind) {
assert(function);
bool isInstanceMethod = function->isInstanceMember();
bool isStaticMethod = function->isStatic();
std::string result;
llvm::raw_string_ostream printer(result);

// Use the parameter indices, if specified.
if (paramIndices) {
auto parameters = paramIndices->getBitVector();
if (parameterIndices) {
auto parameters = parameterIndices->getBitVector();
auto parameterCount = parameters.count();
printer << "wrt: ";
if (parameterCount > 1)
printer << '(';
// Check if differentiating wrt `self`. If so, manually print it first.
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
bool isWrtSelf =
(isInstanceMethod ||
(isStaticMethod &&
parameterKind == DifferentiationParameterKind::Linearity)) &&
parameters.test(parameters.size() - 1);
if (isWrtSelf) {
parameters.reset(parameters.size() - 1);
printer << "self";
if (parameters.any())
printer << ", ";
}
// Print remaining differentiation parameters.
interleave(parameters.set_bits(), [&](unsigned index) {
switch (style) {
case DifferentiationParameterPrintingStyle::Name:
switch (parameterKind) {
// Print differentiability parameters by name.
case DifferentiationParameterKind::Differentiability:
printer << function->getParameters()->get(index)->getName().str();
break;
case DifferentiationParameterPrintingStyle::Index:
// Print linearity parameters by index.
case DifferentiationParameterKind::Linearity:
printer << index;
break;
}
Expand Down Expand Up @@ -487,7 +496,7 @@ static void printDifferentiableAttrArguments(
if (!omitWrtClause) {
auto diffParamsString = getDifferentiationParametersClauseString(
original, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
DifferentiationParameterKind::Differentiability);
// Check whether differentiation parameter clause is empty.
// Handles edge case where resolved parameter indices are unset and
// parsed parameters are empty. This case should never trigger for
Expand Down Expand Up @@ -927,7 +936,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
auto *derivative = cast<AbstractFunctionDecl>(D);
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
DifferentiationParameterKind::Differentiability);
if (!diffParamsString.empty())
Printer << ", " << diffParamsString;
Printer << ')';
Expand All @@ -942,7 +951,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
auto *transpose = cast<AbstractFunctionDecl>(D);
auto transParamsString = getDifferentiationParametersClauseString(
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Index);
DifferentiationParameterKind::Linearity);
if (!transParamsString.empty())
Printer << ", " << transParamsString;
Printer << ')';
Expand Down Expand Up @@ -1510,11 +1519,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(

void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause,
bool omitAssociatedFunctions) const {
bool omitDerivativeFunctions) const {
StreamPrinter P(OS);
P << "@" << getAttrName();
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause,
omitAssociatedFunctions);
omitDerivativeFunctions);
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
Expand Down

0 comments on commit 683b813

Please sign in to comment.