Skip to content

Commit d3f13e9

Browse files
committed
[WIP] Make DifferentiableAttributeTypeCheckRequest cached.
Attempt to make `DifferentiableAttributeTypeCheckRequest` cached. Not sure how to implement `DifferentiableAttributeTypeCheckRequest::isCached` in the way that works with attributes in non-primary-files. Cross-module tests fail: Failing Tests (4): Swift(macosx-x86_64) :: AutoDiff/derived_differentiable.swift Swift(macosx-x86_64) :: AutoDiff/differentiable_attr_cross_module/main.swift Swift(macosx-x86_64) :: AutoDiff/sil_differentiability_witness_silgen.swift Swift(macosx-x86_64) :: AutoDiff/tbdgen.swift
1 parent 682958a commit d3f13e9

File tree

8 files changed

+34
-28
lines changed

8 files changed

+34
-28
lines changed

include/swift/AST/ASTTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ SWIFT_TYPEID_NAMED(EnumDecl *, EnumDecl)
3737
SWIFT_TYPEID_NAMED(GenericParamList *, GenericParamList)
3838
SWIFT_TYPEID_NAMED(GenericTypeParamType *, GenericTypeParamType)
3939
SWIFT_TYPEID_NAMED(InfixOperatorDecl *, InfixOperatorDecl)
40+
// SWIFT_ENABLE_TENSORFLOW
41+
SWIFT_TYPEID_NAMED(IndexSubset *, IndexSubset)
42+
// SWIFT_ENABLE_TENSORFLOW END
4043
SWIFT_TYPEID_NAMED(IterableDeclContext *, IterableDeclContext)
4144
SWIFT_TYPEID_NAMED(ModuleDecl *, ModuleDecl)
4245
SWIFT_TYPEID_NAMED(NamedPattern *, NamedPattern)

include/swift/AST/ASTTypeIDs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ enum class FunctionBuilderClosurePreCheck : uint8_t;
3232
class GenericParamList;
3333
class GenericSignature;
3434
class GenericTypeParamType;
35+
// SWIFT_ENABLE_TENSORFLOW
36+
class IndexSubset;
37+
// SWIFT_ENABLE_TENSORFLOW END
3538
class InfixOperatorDecl;
3639
class IterableDeclContext;
3740
class ModuleDecl;

include/swift/AST/Attr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ class DifferentiableAttr final
15411541
private llvm::TrailingObjects<DifferentiableAttr,
15421542
ParsedAutoDiffParameter> {
15431543
friend TrailingObjects;
1544-
friend class DifferentiableAttributeParameterIndicesRequest;
1544+
friend class DifferentiableAttributeTypeCheckRequest;
15451545

15461546
/// The declaration on which the `@differentiable` attribute is declared.
15471547
Decl *OriginalDeclaration = nullptr;

include/swift/AST/TypeCheckRequests.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,10 +1659,10 @@ class CompareDeclSpecializationRequest
16591659
};
16601660

16611661
// SWIFT_ENABLE_TENSORFLOW
1662-
class DifferentiableAttributeParameterIndicesRequest :
1663-
public SimpleRequest<DifferentiableAttributeParameterIndicesRequest,
1664-
IndexSubset *(DifferentiableAttr *, Decl *),
1665-
CacheKind::SeparatelyCached> {
1662+
class DifferentiableAttributeTypeCheckRequest :
1663+
public SimpleRequest<DifferentiableAttributeTypeCheckRequest,
1664+
IndexSubset *(DifferentiableAttr *),
1665+
CacheKind::Cached> {
16661666
public:
16671667
using SimpleRequest::SimpleRequest;
16681668

@@ -1671,14 +1671,14 @@ class DifferentiableAttributeParameterIndicesRequest :
16711671

16721672
// Evaluation.
16731673
llvm::Expected<IndexSubset *>
1674-
evaluate(Evaluator &evaluator, DifferentiableAttr *attr, Decl *decl) const;
1674+
evaluate(Evaluator &evaluator, DifferentiableAttr *attr) const;
16751675

16761676
public:
16771677
// Separate caching.
1678-
bool isCached() const { return true; }
1679-
Optional<IndexSubset *> getCachedResult() const;
1680-
void cacheResult(IndexSubset *value) const;
1678+
bool isCached() const;
16811679
};
1680+
1681+
void simple_display(llvm::raw_ostream &out, const IndexSubset *);
16821682
// SWIFT_ENABLE_TENSORFLOW END
16831683

16841684

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ SWIFT_REQUEST(TypeChecker, CompareDeclSpecializationRequest,
3535
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
3636
Type(AssociatedTypeDecl *), Cached, NoLocationInfo)
3737
// SWIFT_ENABLE_TENSORFLOW
38-
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeParameterIndicesRequest,
39-
IndexSubset *(DifferentiableAttr *, Decl *),
40-
SeparatelyCached, NoLocationInfo)
38+
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
39+
IndexSubset *(DifferentiableAttr *), Cached, NoLocationInfo)
4140
// SWIFT_ENABLE_TENSORFLOW END
4241
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
4342
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,

lib/AST/Attr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,8 +1525,8 @@ IndexSubset *DifferentiableAttr::getParameterIndices() const {
15251525
auto &ctx = getOriginalDeclaration()->getASTContext();
15261526
return evaluateOrDefault(
15271527
ctx.evaluator,
1528-
DifferentiableAttributeParameterIndicesRequest{
1529-
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1528+
DifferentiableAttributeTypeCheckRequest{
1529+
const_cast<DifferentiableAttr *>(this)},
15301530
nullptr);
15311531
}
15321532

@@ -1535,8 +1535,8 @@ void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
15351535
"Original declaration must have been resolved");
15361536
auto &ctx = getOriginalDeclaration()->getASTContext();
15371537
ctx.evaluator.cacheOutput(
1538-
DifferentiableAttributeParameterIndicesRequest{
1539-
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1538+
DifferentiableAttributeTypeCheckRequest{
1539+
const_cast<DifferentiableAttr *>(this)},
15401540
std::move(paramIndices));
15411541
}
15421542

lib/AST/TypeCheckRequests.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,21 +1056,21 @@ void swift::simple_display(llvm::raw_ostream &out,
10561056
}
10571057

10581058
//----------------------------------------------------------------------------//
1059-
// DifferentiableAttributeParameterIndicesRequest computation.
1059+
// DifferentiableAttributeTypeCheckRequest computation.
10601060
//----------------------------------------------------------------------------//
10611061

1062-
Optional<IndexSubset *>
1063-
DifferentiableAttributeParameterIndicesRequest::getCachedResult() const {
1062+
bool DifferentiableAttributeTypeCheckRequest::isCached() const {
1063+
// FIXME: The challenge is to implement `isCached` correctly.
1064+
/*
10641065
auto *attr = std::get<0>(getStorage());
1065-
if (attr->hasComputedParameterIndices())
1066-
return attr->ParameterIndicesAndBit.getPointer();
1067-
return None;
1066+
return attr->hasComputedParameterIndices();
1067+
*/
1068+
return true;
10681069
}
10691070

1070-
void DifferentiableAttributeParameterIndicesRequest::cacheResult(
1071-
IndexSubset *parameterIndices) const {
1072-
auto *attr = std::get<0>(getStorage());
1073-
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
1071+
void swift::simple_display(llvm::raw_ostream &out,
1072+
const IndexSubset *indexSubset) {
1073+
indexSubset->print(out);
10741074
}
10751075

10761076
//----------------------------------------------------------------------------//

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3304,8 +3304,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
33043304
}
33053305

33063306
llvm::Expected<IndexSubset *>
3307-
DifferentiableAttributeParameterIndicesRequest::evaluate(
3308-
Evaluator &evaluator, DifferentiableAttr *attr, Decl *D) const {
3307+
DifferentiableAttributeTypeCheckRequest::evaluate(
3308+
Evaluator &evaluator, DifferentiableAttr *attr) const {
33093309
// Skip checking implicit `@differentiable` attributes. We currently assume
33103310
// that all implicit `@differentiable` attributes are valid.
33113311
// Motivation: some implicit attributes do not contain a where clause, and
@@ -3315,6 +3315,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
33153315
if (attr->isImplicit())
33163316
return nullptr;
33173317

3318+
auto *D = attr->getOriginalDeclaration();
33183319
auto &ctx = D->getASTContext();
33193320
auto &diags = ctx.Diags;
33203321
auto lookupConformance =

0 commit comments

Comments
 (0)