From 683b8131d50b2b5926e468bcf0115d756a9ee834 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 7 Jan 2020 12:00:11 -0800 Subject: [PATCH] [AutoDiff] Attribute gardening. Upstream changes from `tensorflow` branch: - https://github.com/apple/swift/pull/28932: deprecate `@differentiable(jvp:vjp)` arguments. - https://github.com/apple/swift/pull/29038: gardening. Additional gardening included. --- include/swift/AST/Attr.h | 28 +-- include/swift/AST/DiagnosticsParse.def | 13 +- include/swift/Parse/Parser.h | 11 +- lib/AST/Attr.cpp | 55 ++--- lib/Parse/ParseDecl.cpp | 100 +++++---- lib/Sema/TypeCheckAttr.cpp | 192 +++++++++--------- lib/Sema/TypeChecker.h | 23 ++- .../Parse/differentiable_attr_parse.swift | 67 ++++-- 8 files changed, 275 insertions(+), 214 deletions(-) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index e4d624c0bd42e..04f34ae1271f5 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1650,7 +1650,7 @@ class DifferentiableAttr final /// Whether this function is linear (optional). bool Linear; - /// The number of parsed parameters specified in 'wrt:'. + /// The number of parsed differentiability parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; /// The JVP function. Optional JVP; @@ -1662,7 +1662,7 @@ class DifferentiableAttr final /// The VJP function (optional), resolved by the type checker if VJP name is /// specified. FuncDecl *VJPFunction = nullptr; - /// The differentiation parameters' indices, resolved by the type checker. + /// The differentiability parameter indices, resolved by the type checker. IndexSubset *ParameterIndices = nullptr; /// The trailing where clause (optional). TrailingWhereClause *WhereClause = nullptr; @@ -1720,7 +1720,7 @@ class DifferentiableAttr final ParameterIndices = parameterIndices; } - /// The parsed differentiation parameters, i.e. the list of parameters + /// The parsed differentiability parameters, i.e. the list of parameters /// specified in 'wrt:'. ArrayRef getParsedParameters() const { return {getTrailingObjects(), NumParsedParameters}; @@ -1771,15 +1771,15 @@ class DifferentiableAttr final /// /// The `@derivative(of:)` attribute also has an optional `wrt:` clause /// specifying the parameters that are differentiated "with respect to", i.e. -/// the differentiation parameters. The differentiation parameters must conform -/// to the `Differentiable` protocol. +/// the differentiability parameters. The differentiability parameters must +/// conform to the `Differentiable` protocol. /// -/// If the `wrt:` clause is unspecified, the differentiation parameters are +/// If the `wrt:` clause is unspecified, the differentiability parameters are /// inferred to be all parameters that conform to `Differentiable`. /// /// `@derivative(of:)` attribute type-checking verifies that the type of the /// derivative function declaration is consistent with the type of the -/// referenced original declaration and the differentiation parameters. +/// referenced original declaration and the differentiability parameters. /// /// Examples: /// @derivative(of: sin(_:)) @@ -1798,9 +1798,9 @@ class DerivativeAttr final DeclNameRefWithLoc OriginalFunctionName; /// The original function declaration, resolved by the type checker. AbstractFunctionDecl *OriginalFunction = nullptr; - /// The number of parsed parameters specified in 'wrt:'. + /// The number of parsed differentiability parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; - /// The differentiation parameters' indices, resolved by the type checker. + /// The differentiability parameter indices, resolved by the type checker. IndexSubset *ParameterIndices = nullptr; /// The derivative function kind (JVP or VJP), resolved by the type checker. Optional Kind = None; @@ -1843,7 +1843,7 @@ class DerivativeAttr final } void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; } - /// The parsed differentiation parameters, i.e. the list of parameters + /// The parsed differentiability parameters, i.e. the list of parameters /// specified in 'wrt:'. ArrayRef getParsedParameters() const { return {getTrailingObjects(), NumParsedParameters}; @@ -1872,7 +1872,7 @@ class DerivativeAttr final /// computed property declaration. /// /// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the -/// parameters that are transposed "with respect to", i.e. the transposed +/// parameters that are transposed "with respect to", i.e. the linearity /// parameters. /// /// Examples: @@ -1892,9 +1892,9 @@ class TransposeAttr final DeclNameRefWithLoc OriginalFunctionName; /// The original function declaration, resolved by the type checker. AbstractFunctionDecl *OriginalFunction = nullptr; - /// The number of parsed parameters specified in 'wrt:'. + /// The number of parsed linearity parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; - /// The transposed parameters' indices, resolved by the type checker. + /// The linearity parameter indices, resolved by the type checker. IndexSubset *ParameterIndices = nullptr; explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, @@ -1927,7 +1927,7 @@ class TransposeAttr final OriginalFunction = decl; } - /// The parsed transposed parameters, i.e. the list of parameters specified in + /// The parsed linearity parameters, i.e. the list of parameters specified in /// 'wrt:'. ArrayRef getParsedParameters() const { return {getTrailingObjects(), NumParsedParameters}; diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index e4888fe373852..fabdc9682b4f0 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1542,21 +1542,24 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken, "expected a member name as second parameter in '_implements' attribute", ()) // differentiable +// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken, "expected a %0 function name", (StringRef)) ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken, "expected a list of parameters to differentiate with respect to", ()) +// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. ERROR(attr_differentiable_use_wrt_not_withrespectto,none, "use 'wrt:' to specify parameters to differentiate with respect to", ()) -ERROR(attr_differentiable_missing_label,PointsToFirstBadToken, - "missing label '%0:' in '@differentiable' attribute", (StringRef)) ERROR(attr_differentiable_expected_label,none, "expected either 'wrt:' or a function specifier label, e.g. 'jvp:', " "or 'vjp:'", ()) -ERROR(differentiable_attribute_expected_rparen,none, - "expected ')' in '@differentiable' attribute", ()) -ERROR(unexpected_argument_differentiable,none, +ERROR(attr_differentiable_unexpected_argument,none, "unexpected argument '%0' in '@differentiable' attribute", (StringRef)) +// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. +WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none, + "'jvp:' and 'vjp:' arguments in '@differentiable' attribute are " + "deprecated; use '@derivative' attribute for derivative registration " + "instead", ()) // differentiation `wrt` parameters clause ERROR(expected_colon_after_label,PointsToFirstBadToken, diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index 21f3014298573..93dfc1cfe3f6c 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -1002,12 +1002,13 @@ class Parser { Optional &vjpSpec, TrailingWhereClause *&whereClause); - /// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in - /// `@differentiable` and `@derivative` attributes. + /// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in + /// `@differentiable`, `@derivative`, and `@transpose` attributes. + /// /// If `allowNamedParameters` is false, allow only index parameters and - /// 'self'. - bool parseDifferentiationParametersClause( - SmallVectorImpl ¶ms, StringRef attrName, + /// 'self'. Used for `@transpose` attributes. + bool parseDifferentiabilityParametersClause( + SmallVectorImpl ¶meters, StringRef attrName, bool allowNamedParameters = true); /// Parse the @derivative attribute. diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index bb6d7c47d8079..46d80f1c8c8d2 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -366,39 +366,46 @@ static void printShortFormAvailable(ArrayRef Attrs, Printer.printNewline(); } -/// Printing style for a differentiation parameter in a `wrt:` differentiation -/// parameters clause. Used for printing `@differentiable`, `@derivative`, and -/// `@transpose` attributes. -enum class DifferentiationParameterPrintingStyle { - /// Print parameter by name. +/// The kind of a parameter in a `wrt:` differentiation parameters clause: +/// either a differentiability parameter or a linearity parameter. Used for +/// printing `@differentiable`, `@derivative`, and `@transpose` attributes. +enum class DifferentiationParameterKind { + /// A differentiability parameter, printed by name. /// Used for `@differentiable` and `@derivative` attribute. - Name, - /// Print parameter by index. + Differentiability, + /// A linearity parameter, printed by index. /// Used for `@transpose` attribute. - Index + Linearity }; /// Returns the differentiation parameters clause string for the given function, -/// parameter indices, parsed parameters, . Use the parameter indices if -/// specified; otherwise, use the parsed parameters. +/// parameter indices, parsed parameters, and differentiation parameter kind. +/// Use the parameter indices if specified; otherwise, use the parsed +/// parameters. static std::string getDifferentiationParametersClauseString( - const AbstractFunctionDecl *function, IndexSubset *paramIndices, + const AbstractFunctionDecl *function, IndexSubset *parameterIndices, ArrayRef parsedParams, - DifferentiationParameterPrintingStyle style) { + DifferentiationParameterKind parameterKind) { assert(function); bool isInstanceMethod = function->isInstanceMember(); + bool isStaticMethod = function->isStatic(); std::string result; llvm::raw_string_ostream printer(result); // Use the parameter indices, if specified. - if (paramIndices) { - auto parameters = paramIndices->getBitVector(); + if (parameterIndices) { + auto parameters = parameterIndices->getBitVector(); auto parameterCount = parameters.count(); printer << "wrt: "; if (parameterCount > 1) printer << '('; // Check if differentiating wrt `self`. If so, manually print it first. - if (isInstanceMethod && parameters.test(parameters.size() - 1)) { + bool isWrtSelf = + (isInstanceMethod || + (isStaticMethod && + parameterKind == DifferentiationParameterKind::Linearity)) && + parameters.test(parameters.size() - 1); + if (isWrtSelf) { parameters.reset(parameters.size() - 1); printer << "self"; if (parameters.any()) @@ -406,11 +413,13 @@ static std::string getDifferentiationParametersClauseString( } // Print remaining differentiation parameters. interleave(parameters.set_bits(), [&](unsigned index) { - switch (style) { - case DifferentiationParameterPrintingStyle::Name: + switch (parameterKind) { + // Print differentiability parameters by name. + case DifferentiationParameterKind::Differentiability: printer << function->getParameters()->get(index)->getName().str(); break; - case DifferentiationParameterPrintingStyle::Index: + // Print linearity parameters by index. + case DifferentiationParameterKind::Linearity: printer << index; break; } @@ -487,7 +496,7 @@ static void printDifferentiableAttrArguments( if (!omitWrtClause) { auto diffParamsString = getDifferentiationParametersClauseString( original, attr->getParameterIndices(), attr->getParsedParameters(), - DifferentiationParameterPrintingStyle::Name); + DifferentiationParameterKind::Differentiability); // Check whether differentiation parameter clause is empty. // Handles edge case where resolved parameter indices are unset and // parsed parameters are empty. This case should never trigger for @@ -927,7 +936,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, auto *derivative = cast(D); auto diffParamsString = getDifferentiationParametersClauseString( derivative, attr->getParameterIndices(), attr->getParsedParameters(), - DifferentiationParameterPrintingStyle::Name); + DifferentiationParameterKind::Differentiability); if (!diffParamsString.empty()) Printer << ", " << diffParamsString; Printer << ')'; @@ -942,7 +951,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, auto *transpose = cast(D); auto transParamsString = getDifferentiationParametersClauseString( transpose, attr->getParameterIndices(), attr->getParsedParameters(), - DifferentiationParameterPrintingStyle::Index); + DifferentiationParameterKind::Linearity); if (!transParamsString.empty()) Printer << ", " << transParamsString; Printer << ')'; @@ -1510,11 +1519,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment( void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause, - bool omitAssociatedFunctions) const { + bool omitDerivativeFunctions) const { StreamPrinter P(OS); P << "@" << getAttrName(); printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause, - omitAssociatedFunctions); + omitDerivativeFunctions); } DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc, diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 5c85c9c749551..599c7c6bab428 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -792,7 +792,7 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) { /// /// \verbatim /// differentiable-attribute-arguments: -/// '(' (differentiation-params-clause ',')? +/// '(' (differentiability-params-clause ',')? /// (differentiable-attr-func-specifier ',')? /// differentiable-attr-func-specifier? /// where-clause? @@ -805,7 +805,7 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { StringRef AttrName = "differentiable"; SourceLoc lParenLoc = loc, rParenLoc = loc; bool linear = false; - SmallVector params; + SmallVector parameters; Optional jvpSpec; Optional vjpSpec; TrailingWhereClause *whereClause = nullptr; @@ -813,8 +813,8 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { // Parse '('. if (consumeIf(tok::l_paren, lParenLoc)) { // Parse @differentiable attribute arguments. - if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec, - whereClause)) + if (parseDifferentiableAttributeArguments(linear, parameters, jvpSpec, + vjpSpec, whereClause)) return makeParserError(); // Parse ')'. if (!consumeIf(tok::r_paren, rParenLoc)) { @@ -824,10 +824,9 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { } } - return ParserResult( - DifferentiableAttr::create(Context, /*implicit*/ false, atLoc, - SourceRange(loc, rParenLoc), linear, - params, jvpSpec, vjpSpec, whereClause)); + return ParserResult(DifferentiableAttr::create( + Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), linear, + parameters, jvpSpec, vjpSpec, whereClause)); } // Attribute parsing error helper. @@ -847,29 +846,29 @@ static bool errorAndSkipUntilConsumeRightParen(Parser &P, StringRef attrName, return true; }; -/// Parse a differentiation parameters 'wrt:' clause, returning true on error. +/// Parse a differentiability parameters 'wrt:' clause, returning true on error. /// If `allowNamedParameters` is false, allow only index parameters and 'self'. /// /// \verbatim -/// differentiation-params-clause: -/// 'wrt' ':' (differentiation-param | differentiation-params) -/// differentiation-params: -/// '(' differentiation-param (',' differentiation-param)* ')' -/// differentiation-param: +/// differentiability-params-clause: +/// 'wrt' ':' (differentiability-param | differentiability-params) +/// differentiability-params: +/// '(' differentiability-param (',' differentiability-param)* ')' +/// differentiability-param: /// 'self' | identifier | [0-9]+ /// \endverbatim -bool Parser::parseDifferentiationParametersClause( - SmallVectorImpl ¶ms, StringRef attrName, +bool Parser::parseDifferentiabilityParametersClause( + SmallVectorImpl ¶meters, StringRef attrName, bool allowNamedParameters) { SyntaxParsingContext DiffParamsClauseContext( - SyntaxContext, SyntaxKind::DifferentiationParamsClause); + SyntaxContext, SyntaxKind::DifferentiationParamsClause); consumeToken(tok::identifier); if (!consumeIf(tok::colon)) { diagnose(Tok, diag::expected_colon_after_label, "wrt"); return errorAndSkipUntilConsumeRightParen(*this, attrName); } - // Function that parses a parameter into `params`. Returns true if error + // Function that parses a parameter into `parameters`. Returns true if error // occurred. auto parseParam = [&](bool parseTrailingComma = true) -> bool { SyntaxParsingContext DiffParamContext( @@ -886,8 +885,8 @@ bool Parser::parseDifferentiationParametersClause( if (parseIdentifier(paramName, paramLoc, diag::diff_params_clause_expected_parameter)) return true; - params.push_back(ParsedAutoDiffParameter::getNamedParameter( - paramLoc, paramName)); + parameters.push_back( + ParsedAutoDiffParameter::getNamedParameter(paramLoc, paramName)); break; } case tok::integer_literal: { @@ -896,13 +895,13 @@ bool Parser::parseDifferentiationParametersClause( paramNum, paramLoc, diag::diff_params_clause_expected_parameter)) return true; - params.push_back(ParsedAutoDiffParameter::getOrderedParameter( - paramLoc, paramNum)); + parameters.push_back( + ParsedAutoDiffParameter::getOrderedParameter(paramLoc, paramNum)); break; } case tok::kw_self: { paramLoc = consumeToken(tok::kw_self); - params.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc)); + parameters.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc)); break; } default: @@ -940,10 +939,9 @@ bool Parser::parseDifferentiationParametersClause( } bool Parser::parseDifferentiableAttributeArguments( - bool &linear, SmallVectorImpl ¶ms, + bool &linear, SmallVectorImpl ¶meters, Optional &jvpSpec, - Optional &vjpSpec, - TrailingWhereClause *&whereClause) { + Optional &vjpSpec, TrailingWhereClause *&whereClause) { StringRef AttrName = "differentiable"; // Parse trailing comma, if it exists, and check for errors. @@ -968,7 +966,7 @@ bool Parser::parseDifferentiableAttributeArguments( SyntaxParsingContext ContentContext( SyntaxContext, SyntaxKind::DifferentiableAttributeArguments); - // Parse optional differentiation parameters. + // Parse optional differentiability parameters. // Parse 'linear' label (optional). linear = false; if (isIdentifier(Tok, "linear")) { @@ -989,9 +987,9 @@ bool Parser::parseDifferentiableAttributeArguments( .fixItReplace(withRespectToRange, "wrt:"); return errorAndSkipUntilConsumeRightParen(*this, AttrName); } - // Parse differentiation parameters' clause. + // Parse the optional 'wrt' differentiability parameters clause. if (isIdentifier(Tok, "wrt")) { - if (parseDifferentiationParametersClause(params, AttrName)) + if (parseDifferentiabilityParametersClause(parameters, AttrName)) return true; // If no trailing comma or 'where' clause, terminate parsing arguments. if (Tok.isNot(tok::comma, tok::kw_where)) @@ -1005,8 +1003,8 @@ bool Parser::parseDifferentiableAttributeArguments( auto parseFuncSpec = [&](StringRef label, DeclNameRefWithLoc &result, bool &terminateParsingArgs) -> bool { // Parse label. - if (parseSpecificIdentifier(label, - diag::attr_differentiable_missing_label, label) || + if (parseSpecificIdentifier(label, diag::attr_missing_label, label, + AttrName) || parseToken(tok::colon, diag::expected_colon_after_label, label)) return true; // Parse the name of the function. @@ -1016,6 +1014,13 @@ bool Parser::parseDifferentiableAttributeArguments( { label }); result.Name = parseDeclNameRef(result.Loc, funcDiag, DeclNameFlag::AllowZeroArgCompoundNames | DeclNameFlag::AllowOperators); + // Emit warning for deprecated `jvp:` and `vjp:` arguments. + // TODO(TF-1001): Remove deprecated `jvp:` and `vjp:` arguments. + if (result.Loc.isValid()) { + diagnose(result.Loc.getStartLoc(), + diag::attr_differentiable_jvp_vjp_deprecated_warning) + .highlight(result.Loc.getSourceRange()); + } // If no trailing comma or 'where' clause, terminate parsing arguments. if (Tok.isNot(tok::comma, tok::kw_where)) terminateParsingArgs = true; @@ -1133,7 +1138,7 @@ static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError, /// /// \verbatim /// derivative-attribute-arguments: -/// '(' 'of' ':' qualified-decl-name (',' differentiation-params-clause)? +/// '(' 'of' ':' qualified-decl-name (',' differentiability-params-clause)? /// ')' /// \endverbatim ParserResult Parser::parseDerivativeAttribute(SourceLoc atLoc, @@ -1142,7 +1147,7 @@ ParserResult Parser::parseDerivativeAttribute(SourceLoc atLoc, SourceLoc lParenLoc = loc, rParenLoc = loc; TypeRepr *baseType = nullptr; DeclNameRefWithLoc original; - SmallVector params; + SmallVector parameters; // Parse trailing comma, if it exists, and check for errors. auto consumeIfTrailingComma = [&]() -> bool { @@ -1182,9 +1187,9 @@ ParserResult Parser::parseDerivativeAttribute(SourceLoc atLoc, } if (consumeIfTrailingComma()) return makeParserError(); - // Parse the optional 'wrt' differentiation parameters clause. + // Parse the optional 'wrt' differentiability parameters clause. if (isIdentifier(Tok, "wrt") && - parseDifferentiationParametersClause(params, AttrName)) + parseDifferentiabilityParametersClause(parameters, AttrName)) return makeParserError(); } // Parse ')'. @@ -1195,14 +1200,20 @@ ParserResult Parser::parseDerivativeAttribute(SourceLoc atLoc, } return ParserResult(DerivativeAttr::create( Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), baseType, - original, params)); + original, parameters)); } /// Parse a `@transpose(of:)` attribute, returning true on error. /// /// \verbatim /// transpose-attribute-arguments: -/// '(' 'of' ':' qualified-decl-name (',' transposed-params-clause)? ')' +/// '(' 'of' ':' qualified-decl-name (',' linearity-params-clause)? ')' +/// linearity-params-clause: +/// 'wrt' ':' (linearity-param | linearity-params) +/// linearity-params: +/// '(' linearity-param (',' linearity-param)* ')' +/// linearity-param: +/// 'self' | [0-9]+ /// \endverbatim ParserResult Parser::parseTransposeAttribute(SourceLoc atLoc, SourceLoc loc) { @@ -1210,7 +1221,7 @@ ParserResult Parser::parseTransposeAttribute(SourceLoc atLoc, SourceLoc lParenLoc = loc, rParenLoc = loc; TypeRepr *baseType = nullptr; DeclNameRefWithLoc original; - SmallVector params; + SmallVector parameters; // Parse trailing comma, if it exists, and check for errors. auto consumeIfTrailingComma = [&]() -> bool { @@ -1251,10 +1262,10 @@ ParserResult Parser::parseTransposeAttribute(SourceLoc atLoc, } if (consumeIfTrailingComma()) return makeParserError(); - // Parse the optional 'wrt' transposed parameters clause. + // Parse the optional 'wrt' linearity parameters clause. if (Tok.is(tok::identifier) && Tok.getText() == "wrt" && - parseDifferentiationParametersClause(params, AttrName, - /*allowNamedParameters*/ false)) + parseDifferentiabilityParametersClause(parameters, AttrName, + /*allowNamedParameters*/ false)) return makeParserError(); } // Parse ')'. @@ -1265,7 +1276,7 @@ ParserResult Parser::parseTransposeAttribute(SourceLoc atLoc, } return ParserResult(TransposeAttr::create( Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), baseType, - original, params)); + original, parameters)); } void Parser::parseObjCSelector(SmallVector &Names, @@ -2612,7 +2623,8 @@ static bool parseDifferentiableAttributeArgument(Parser &P, if (P.Tok.is(tok::l_paren)) { backtrack.cancelBacktrack(); if (emitDiagnostics) - P.diagnose(P.Tok, diag::differentiable_attribute_expected_rparen); + P.diagnose(P.Tok, diag::attr_expected_rparen, "@differentiable", + /*DeclModifier*/ false); return true; } return false; @@ -2628,7 +2640,7 @@ static bool parseDifferentiableAttributeArgument(Parser &P, if (argument.getText() != "linear") { if (emitDiagnostics) - P.diagnose(argument, diag::unexpected_argument_differentiable, + P.diagnose(argument, diag::attr_differentiable_unexpected_argument, argument.getText()); return true; } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 03756c2a8bef1..9a2b2db26aa25 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -272,7 +272,7 @@ void AttributeChecker::visitTransparentAttr(TransparentAttr *attr) { if (!(isa(D) && D->isImplicit())) diagnoseAndRemoveAttr(attr, diag::transparent_in_classes_not_supported); } - + if (auto *VD = dyn_cast(D)) { // Stored properties and variables can't be transparent. if (VD->hasStorage()) @@ -339,7 +339,7 @@ void AttributeChecker::visitMutationAttr(DeclAttribute *attr) { } } } - + // Verify that we don't have a static function. if (FD->isStatic()) diagnoseAndRemoveAttr(attr, diag::static_functions_not_mutating); @@ -560,7 +560,7 @@ isAcceptableOutletType(Type type, bool &isArray, ASTContext &ctx) { if (type->isExistentialType()) return diag::iboutlet_nonobjc_protocol; - + // No other types are permitted. return diag::iboutlet_nonobject_type; } @@ -1430,7 +1430,7 @@ void AttributeChecker::visitSwiftNativeObjCRuntimeBaseAttr( attr->setInvalid(); return; } - + if (theClass->hasSuperclass()) { diagnose(attr->getLocation(), diag::swift_native_objc_runtime_base_not_on_root_class); @@ -1582,7 +1582,7 @@ void AttributeChecker::visitNSCopyingAttr(NSCopyingAttr *attr) { // Check the type. It must be an [unchecked]optional, weak, a normal // class, AnyObject, or classbound protocol. // It must conform to the NSCopying protocol. - + } void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, @@ -1594,7 +1594,7 @@ void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, UIApplicationMainClass, NSApplicationMainClass, }; - + unsigned applicationMainKind; if (isa(attr)) applicationMainKind = UIApplicationMainClass; @@ -1602,9 +1602,9 @@ void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, applicationMainKind = NSApplicationMainClass; else llvm_unreachable("not an ApplicationMain attr"); - + auto *CD = dyn_cast(D); - + // The applicant not being a class should have been diagnosed by the early // checker. if (!CD) return; @@ -1617,7 +1617,7 @@ void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, attr->setInvalid(); return; } - + // @XXApplicationMain classes must conform to the XXApplicationDelegate // protocol. auto *SF = cast(CD->getModuleScopeContext()); @@ -1646,7 +1646,7 @@ void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, if (attr->isInvalid()) return; - + // Register the class as the main class in the module. If there are multiples // they will be diagnosed. if (SF->registerMainClass(CD, attr->getLocation())) @@ -1969,7 +1969,7 @@ void AttributeChecker::visitFixedLayoutAttr(FixedLayoutAttr *attr) { if (isa(D)) { diagnose(attr->getLocation(), diag::fixed_layout_struct) .fixItReplace(attr->getRange(), "@frozen"); - } + } auto *VD = cast(D); @@ -2469,10 +2469,10 @@ void AttributeChecker::visitCustomAttr(CustomAttr *attr) { attr->setInvalid(); return; } - + return; } - + // If the nominal type is a function builder type, verify that D is a // function, storage with an explicit getter, or parameter of function type. if (nominal->getAttrs().hasAttribute()) { @@ -2942,7 +2942,7 @@ static bool conformsToDifferentiable(Type type, DeclContext *DC) { return !tanType.isNull() && !tanType->hasError(); }; -IndexSubset *TypeChecker::inferDifferentiationParameters( +IndexSubset *TypeChecker::inferDifferentiabilityParameters( AbstractFunctionDecl *AFD, GenericEnvironment *derivativeGenEnv) { auto &ctx = AFD->getASTContext(); auto *functionType = AFD->getInterfaceType()->castTo(); @@ -2981,7 +2981,7 @@ IndexSubset *TypeChecker::inferDifferentiationParameters( for (auto ¶m : functionType->getParams()) allParamTypes.push_back(param.getPlainType()); - // Set differentiation parameters. + // Set differentiability parameters. for (unsigned i : range(parameterBits.size())) if (isDifferentiableParam(i)) parameterBits.set(i); @@ -2989,14 +2989,15 @@ IndexSubset *TypeChecker::inferDifferentiationParameters( return IndexSubset::get(ctx, parameterBits); } -// Computes `IndexSubset` from the given parsed differentiation parameters -// (possibly empty) for the given function and derivative generic environment, -// then verifies that the parameter indices are valid. +// Computes the differentiability parameter indices from the given parsed +// differentiability parameters for the given original or derivative +// `AbstractFunctionDecl` and derivative generic environment. On error, emits +// diagnostics and returns `nullptr`. // - If parsed parameters are empty, infer parameter indices. // - Otherwise, build parameter indices from parsed parameters. // The attribute name/location are used in diagnostics. -static IndexSubset *computeDifferentiationParameters( - ArrayRef parsedWrtParams, +static IndexSubset *computeDifferentiabilityParameters( + ArrayRef parsedDiffParams, AbstractFunctionDecl *function, GenericEnvironment *derivativeGenEnv, StringRef attrName, SourceLoc attrLoc) { auto &ctx = function->getASTContext(); @@ -3036,13 +3037,14 @@ static IndexSubset *computeDifferentiationParameters( } } - // If parsed differentiation parameters are empty, infer parameter indices + // If parsed differentiability parameters are empty, infer parameter indices // from the function type. - if (parsedWrtParams.empty()) - return TypeChecker::inferDifferentiationParameters(function, - derivativeGenEnv); + if (parsedDiffParams.empty()) + return TypeChecker::inferDifferentiabilityParameters(function, + derivativeGenEnv); - // Otherwise, build parameter indices from parsed differentiation parameters. + // Otherwise, build parameter indices from parsed differentiability + // parameters. auto numUncurriedParams = functionType->getNumParams(); if (auto *resultFnType = functionType->getResult()->getAs()) { @@ -3050,17 +3052,17 @@ static IndexSubset *computeDifferentiationParameters( } llvm::SmallBitVector parameterBits(numUncurriedParams); int lastIndex = -1; - for (unsigned i : indices(parsedWrtParams)) { - auto paramLoc = parsedWrtParams[i].getLoc(); - switch (parsedWrtParams[i].getKind()) { + for (unsigned i : indices(parsedDiffParams)) { + auto paramLoc = parsedDiffParams[i].getLoc(); + switch (parsedDiffParams[i].getKind()) { case ParsedAutoDiffParameter::Kind::Named: { auto nameIter = llvm::find_if(params.getArray(), [&](ParamDecl *param) { - return param->getName() == parsedWrtParams[i].getName(); + return param->getName() == parsedDiffParams[i].getName(); }); // Parameter name must exist. if (nameIter == params.end()) { diags.diagnose(paramLoc, diag::diff_params_clause_param_name_unknown, - parsedWrtParams[i].getName()); + parsedDiffParams[i].getName()); return nullptr; } // Parameter names must be specified in the original order. @@ -3090,7 +3092,7 @@ static IndexSubset *computeDifferentiationParameters( break; } case ParsedAutoDiffParameter::Kind::Ordered: { - auto index = parsedWrtParams[i].getIndex(); + auto index = parsedDiffParams[i].getIndex(); if (index >= numParams) { diags.diagnose(paramLoc, diag::diff_params_clause_param_index_out_of_range); @@ -3111,50 +3113,53 @@ static IndexSubset *computeDifferentiationParameters( return IndexSubset::get(ctx, parameterBits); } -// Checks if the given `IndexSubset` instance is valid for the given function +// Checks if the given differentiability parameter indices are valid for the +// given original or derivative `AbstractFunctionDecl` and original function // type in the given derivative generic environment and module context. Returns // true on error. -// The parsed differentiation parameters and attribute location are used in +// +// The parsed differentiability parameters and attribute location are used in // diagnostics. -static bool checkDifferentiationParameters( - AbstractFunctionDecl *AFD, IndexSubset *indices, +static bool checkDifferentiabilityParameters( + AbstractFunctionDecl *AFD, IndexSubset *diffParamIndices, AnyFunctionType *functionType, GenericEnvironment *derivativeGenEnv, - ModuleDecl *module, ArrayRef parsedWrtParams, + ModuleDecl *module, ArrayRef parsedDiffParams, SourceLoc attrLoc) { auto &ctx = AFD->getASTContext(); auto &diags = ctx.Diags; - // Diagnose empty parameter indices. This occurs when no `wrt:` clause is - // declared and no differentiation parameters can be inferred. - if (indices->isEmpty()) { + // Diagnose empty differentiability indices. No differentiability parameters + // were resolved or inferred. + if (diffParamIndices->isEmpty()) { diags.diagnose(attrLoc, diag::diff_params_clause_no_inferred_parameters); return true; } - // Check that differentiation parameters have allowed types. - SmallVector wrtParamTypes; - autodiff::getSubsetParameterTypes(indices, functionType, wrtParamTypes); - for (unsigned i : range(wrtParamTypes.size())) { + // Check that differentiability parameters have allowed types. + SmallVector diffParamTypes; + autodiff::getSubsetParameterTypes(diffParamIndices, functionType, + diffParamTypes); + for (unsigned i : range(diffParamTypes.size())) { SourceLoc loc = - parsedWrtParams.empty() ? attrLoc : parsedWrtParams[i].getLoc(); - auto wrtParamType = wrtParamTypes[i]; + parsedDiffParams.empty() ? attrLoc : parsedDiffParams[i].getLoc(); + auto diffParamType = diffParamTypes[i]; // `inout` parameters are not yet supported. - if (wrtParamType->is()) { + if (diffParamType->is()) { diags.diagnose(loc, diag::diff_params_clause_cannot_diff_wrt_inout_parameter, - wrtParamType); + diffParamType); return true; } - if (!wrtParamType->hasTypeParameter()) - wrtParamType = wrtParamType->mapTypeOutOfContext(); + if (!diffParamType->hasTypeParameter()) + diffParamType = diffParamType->mapTypeOutOfContext(); if (derivativeGenEnv) - wrtParamType = derivativeGenEnv->mapTypeIntoContext(wrtParamType); + diffParamType = derivativeGenEnv->mapTypeIntoContext(diffParamType); else - wrtParamType = AFD->mapTypeIntoContext(wrtParamType); + diffParamType = AFD->mapTypeIntoContext(diffParamType); // Parameter must conform to `Differentiable`. - if (!conformsToDifferentiable(wrtParamType, AFD)) { + if (!conformsToDifferentiable(diffParamType, AFD)) { diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable, - wrtParamType); + diffParamType); return true; } } @@ -3330,26 +3335,20 @@ static bool checkFunctionSignature( return checkFunctionSignature(requiredResultFnTy, candidateResultTy); }; -// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with -// the given parameters, result type, and generic signature. If the given -// generic signature is `null`, use the generic signature of `fnType`. +// Returns an `AnyFunctionType` from the given parameters, result type, and +// generic signature. static AnyFunctionType * -makeFunctionType(AnyFunctionType *fnType, - ArrayRef parameters, Type resultType, +makeFunctionType(ArrayRef parameters, Type resultType, GenericSignature genericSignature) { - if (!genericSignature) - if (auto *genericFunctionType = fnType->getAs()) - genericSignature = genericFunctionType->getGenericSignature(); if (genericSignature) - return GenericFunctionType::get(genericSignature, parameters, resultType, - fnType->getExtInfo()); - return FunctionType::get(parameters, resultType, fnType->getExtInfo()); + return GenericFunctionType::get(genericSignature, parameters, resultType); + return FunctionType::get(parameters, resultType); } // Computes the original function type corresponding to the given derivative -// function type. -AnyFunctionType * -getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) { +// function type. Used for `@derivative` attribute type-checking. +static AnyFunctionType * +getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) { // Unwrap curry levels. At most, two parameter lists are necessary, for // curried method types with a `(Self)` parameter list. SmallVector curryLevels; @@ -3367,7 +3366,7 @@ getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) { "Expected derivative result to be a two-element tuple"); auto originalResult = derivativeResult->getElement(0).getType(); auto *originalType = makeFunctionType( - curryLevels.back(), curryLevels.back()->getParams(), originalResult, + curryLevels.back()->getParams(), originalResult, curryLevels.size() == 1 ? derivativeFnTy->getOptGenericSignature() : nullptr); @@ -3378,7 +3377,7 @@ getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) { unsigned i = pair.index(); AnyFunctionType *curryLevel = pair.value(); originalType = - makeFunctionType(curryLevel, curryLevel->getParams(), originalType, + makeFunctionType(curryLevel->getParams(), originalType, i == curryLevelsWithoutLast.size() - 1 ? derivativeFnTy->getOptGenericSignature() : nullptr); @@ -3466,7 +3465,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, // Compute expected original function type and look up original function. auto *originalFnType = - getAutoDiffOriginalFunctionType(derivativeInterfaceType); + getDerivativeOriginalFunctionType(derivativeInterfaceType); // Returns true if the generic parameters in `source` satisfy the generic // requirements in `target`. @@ -3585,39 +3584,40 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, } attr->setOriginalFunction(originalAFD); - // Get the resolved `wrt:` parameter indices. - auto *resolvedWrtParamIndices = attr->getParameterIndices(); + // Get the resolved differentiability parameter indices. + auto *resolvedDiffParamIndices = attr->getParameterIndices(); - // Get the parsed `wrt:` parameter indices, which have not yet been resolved. - // Parsed `wrt:` paraeter indices are defined only for parsed attributes. - auto parsedWrtParams = attr->getParsedParameters(); + // Get the parsed differentiability parameter indices, which have not yet been + // resolved. Parsed differentiability parameter indices are defined only for + // parsed attributes. + auto parsedDiffParams = attr->getParsedParameters(); - // If parameter indices are not resolved, compute them. - if (!resolvedWrtParamIndices) - resolvedWrtParamIndices = computeDifferentiationParameters( - parsedWrtParams, derivative, derivative->getGenericEnvironment(), + // If differentiability parameter indices are not resolved, compute them. + if (!resolvedDiffParamIndices) + resolvedDiffParamIndices = computeDifferentiabilityParameters( + parsedDiffParams, derivative, derivative->getGenericEnvironment(), attr->getAttrName(), attr->getLocation()); - if (!resolvedWrtParamIndices) + if (!resolvedDiffParamIndices) return true; - // Check if the `wrt:` parameter indices are valid. - if (checkDifferentiationParameters( - originalAFD, resolvedWrtParamIndices, originalFnType, + // Check if the differentiability parameter indices are valid. + if (checkDifferentiabilityParameters( + originalAFD, resolvedDiffParamIndices, originalFnType, derivative->getGenericEnvironment(), derivative->getModuleContext(), - parsedWrtParams, attr->getLocation())) + parsedDiffParams, attr->getLocation())) return true; - // Set the resolved `wrt:` parameter indices in the attribute. - attr->setParameterIndices(resolvedWrtParamIndices); + // Set the resolved differentiability parameter indices in the attribute. + attr->setParameterIndices(resolvedDiffParamIndices); - // Gather `wrt:` parameters. - SmallVector wrtParamTypes; - autodiff::getSubsetParameterTypes(resolvedWrtParamIndices, originalFnType, - wrtParamTypes); + // Gather differentiability parameters. + SmallVector diffParamTypes; + autodiff::getSubsetParameterTypes(resolvedDiffParamIndices, originalFnType, + diffParamTypes); - // Get the `TangentVector` associated types of the `wrt:` parameters. - auto wrtParamTanTypes = - map>(wrtParamTypes, [&](Type paramType) { + // Get the differentiability parameters' `TangentVector` associated types. + auto diffParamTanTypes = + map>(diffParamTypes, [&](Type paramType) { if (paramType->hasTypeParameter()) paramType = derivative->mapTypeIntoContext(paramType); auto conf = TypeChecker::conformsToProtocol(paramType, diffableProto, @@ -3650,14 +3650,14 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, Type expectedFuncEltType; if (kind == AutoDiffDerivativeFunctionKind::JVP) { auto diffParams = map>( - wrtParamTanTypes, [&](TupleTypeElt elt) { + diffParamTanTypes, [&](TupleTypeElt elt) { return AnyFunctionType::Param(elt.getType()); }); expectedFuncEltType = FunctionType::get(diffParams, resultTanType); } else { expectedFuncEltType = FunctionType::get({AnyFunctionType::Param(resultTanType)}, - TupleType::get(wrtParamTanTypes, Ctx)); + TupleType::get(diffParamTanTypes, Ctx)); } expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext(); @@ -3695,7 +3695,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, // Reject duplicate `@derivative` attributes. auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple( - originalAFD, resolvedWrtParamIndices, kind)]; + originalAFD, resolvedDiffParamIndices, kind)]; derivativeAttrs.insert(attr); if (derivativeAttrs.size() > 1) { diags.diagnose(attr->getLocation(), diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 13a5aac4f5ba2..f509f31abee51 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1580,18 +1580,19 @@ class TypeChecker final { static DeclTypeCheckingSemantics getDeclTypeCheckingSemantics(ValueDecl *decl); - /// Creates an `IndexSubset` for the given function type, representing - /// all inferred differentiation parameters. Used by `@differentiable` and - /// `@derivative` attribute type-checking. - /// - /// The differentiation parameters are inferred to be: - /// - All parameters of the function type that conform to `Differentiable`. - /// - If the function type's result is a function type (i.e. it is a curried - /// method type), then also all parameters of the function result type that - /// conform to `Differentiable`. + /// Infers the differentiability parameter indices for the given + /// original or derivative `AbstractFunctionDecl`. + /// + /// The differentiability parameters are inferred to be: + /// - All parameters of the function that conform to `Differentiable`. + /// - If the function result type is a function type (i.e. the function has + /// a curried method type), then also all parameters of the function result + /// type that conform to `Differentiable`. + /// + /// Used by `@differentiable` and `@derivative` attribute type-checking. static IndexSubset * - inferDifferentiationParameters(AbstractFunctionDecl *AFD, - GenericEnvironment *derivativeGenEnv); + inferDifferentiabilityParameters(AbstractFunctionDecl *AFD, + GenericEnvironment *derivativeGenEnv); public: /// Require that the library intrinsics for working with Optional diff --git a/test/AutoDiff/Parse/differentiable_attr_parse.swift b/test/AutoDiff/Parse/differentiable_attr_parse.swift index 354d385ba3604..eb94ff59728ab 100644 --- a/test/AutoDiff/Parse/differentiable_attr_parse.swift +++ b/test/AutoDiff/Parse/differentiable_attr_parse.swift @@ -1,5 +1,7 @@ // RUN: %target-swift-frontend -parse -verify %s +// TODO(TF-1021): Remove "deprecated 'jvp:' and 'vjp:' argument" warnings. + /// Good struct Foo { @@ -7,21 +9,31 @@ struct Foo { var x: Float } +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} @differentiable(vjp: foo(_:_:)) // okay func bar(_ x: Float, _: Float) -> Float { return 1 + x } +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +@differentiable(vjp: foo(_:_:)) // okay +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} @differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay func bar(_ x: T, _: T) -> T { return 1 + x } +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +@differentiable(vjp: foo(_:_:)) // okay +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} @differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay func bar(_ x: Float, _ y: Float) -> Float { return 1 + x } +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +@differentiable(vjp: foo(_:_:)) // okay +// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} @differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay func bar(_ x: Float, _ y: Float) -> Float { return 1 + x @@ -55,6 +67,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float { } @_transparent +// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} @differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay public func squareRoot() -> Self { var lhs = self @@ -99,82 +112,104 @@ func two(x: Float, y: Float) -> Float { /// Bad -@differentiable(3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(3) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(foo(_:_:)) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(vjp: foo(_:_:), 3) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(wrt: (x), foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: (x), foo(_:_:)) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(wrt: x, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: x, y) func bar(_ x: Float, _ y: Float) -> Float { return 1 + x } -@differentiable(wrt: 0, 1) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: 0, 1) func two(x: Float, y: Float) -> Float { return x + y } -@differentiable(wrt: 0, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: 0, y) func two(x: Float, y: Float) -> Float { return x + y } -@differentiable(wrt: 0,) // expected-error {{unexpected ',' separator}} +// expected-error @+1 {{unexpected ',' separator}} +@differentiable(wrt: 0,) func two(x: Float, y: Float) -> Float { return x + y } -@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{expected ')' in 'differentiable' attribute}} +@differentiable(vjp: foo(_:_:) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(vjp: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{expected ':' or '==' to indicate a conformance or same-type requirement}} +@differentiable(vjp: foo(_:_:) where T) func bar(_ x: T, _: T) -> T { return 1 + x } -@differentiable(,) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(,) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(vjp: foo(_:_:),) // expected-error {{unexpected ',' separator}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{unexpected ',' separator}} +@differentiable(vjp: foo(_:_:),) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -@differentiable(vjp: foo(_:_:), where T) // expected-error {{unexpected ',' separator}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{unexpected ',' separator}} +@differentiable(vjp: foo(_:_:), where T) func bar(_ x: T, _: T) -> T { return 1 + x } -@differentiable(wrt: x, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: x, linear) func slope4(_ x: Float) -> Float { return 4 * x } -@differentiable(wrt: x, linear, vjp: const5) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: x, linear, vjp: const5) func slope5(_ x: Float) -> Float { return 5 * x } -@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} +// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +@differentiable(wrt: x, vjp: const6, linear) func slope5(_ x: Float) -> Float { return 6 * x }