98 changes: 30 additions & 68 deletions llvm/include/llvm/ADT/ImmutableSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator.h"
#include "llvm/Support/Allocator.h"
Expand Down Expand Up @@ -357,6 +358,12 @@ class ImutAVLTree {
}
};

template <typename ImutInfo>
struct IntrusiveRefCntPtrInfo<ImutAVLTree<ImutInfo>> {
static void retain(ImutAVLTree<ImutInfo> *Tree) { Tree->retain(); }
static void release(ImutAVLTree<ImutInfo> *Tree) { Tree->release(); }
};

//===----------------------------------------------------------------------===//
// Immutable AVL-Tree Factory class.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -961,33 +968,14 @@ class ImmutableSet {
using TreeTy = ImutAVLTree<ValInfo>;

private:
TreeTy *Root;
IntrusiveRefCntPtr<TreeTy> Root;

public:
/// Constructs a set from a pointer to a tree root. In general one
/// should use a Factory object to create sets instead of directly
/// invoking the constructor, but there are cases where make this
/// constructor public is useful.
explicit ImmutableSet(TreeTy* R) : Root(R) {
if (Root) { Root->retain(); }
}

ImmutableSet(const ImmutableSet &X) : Root(X.Root) {
if (Root) { Root->retain(); }
}

~ImmutableSet() {
if (Root) { Root->release(); }
}

ImmutableSet &operator=(const ImmutableSet &X) {
if (Root != X.Root) {
if (X.Root) { X.Root->retain(); }
if (Root) { Root->release(); }
Root = X.Root;
}
return *this;
}
explicit ImmutableSet(TreeTy *R) : Root(R) {}

class Factory {
typename TreeTy::Factory F;
Expand Down Expand Up @@ -1016,7 +1004,7 @@ class ImmutableSet {
/// The memory allocated to represent the set is released when the
/// factory object that created the set is destroyed.
LLVM_NODISCARD ImmutableSet add(ImmutableSet Old, value_type_ref V) {
TreeTy *NewT = F.add(Old.Root, V);
TreeTy *NewT = F.add(Old.Root.get(), V);
return ImmutableSet(Canonicalize ? F.getCanonicalTree(NewT) : NewT);
}

Expand All @@ -1028,7 +1016,7 @@ class ImmutableSet {
/// The memory allocated to represent the set is released when the
/// factory object that created the set is destroyed.
LLVM_NODISCARD ImmutableSet remove(ImmutableSet Old, value_type_ref V) {
TreeTy *NewT = F.remove(Old.Root, V);
TreeTy *NewT = F.remove(Old.Root.get(), V);
return ImmutableSet(Canonicalize ? F.getCanonicalTree(NewT) : NewT);
}

Expand All @@ -1047,21 +1035,20 @@ class ImmutableSet {
}

bool operator==(const ImmutableSet &RHS) const {
return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
}

bool operator!=(const ImmutableSet &RHS) const {
return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
: Root != RHS.Root;
}

TreeTy *getRoot() {
if (Root) { Root->retain(); }
return Root;
return Root.get();
}

TreeTy *getRootWithoutRetain() const {
return Root;
}
TreeTy *getRootWithoutRetain() const { return Root.get(); }

/// isEmpty - Return true if the set contains no elements.
bool isEmpty() const { return !Root; }
Expand All @@ -1082,7 +1069,7 @@ class ImmutableSet {

using iterator = ImutAVLValueIterator<ImmutableSet>;

iterator begin() const { return iterator(Root); }
iterator begin() const { return iterator(Root.get()); }
iterator end() const { return iterator(); }

//===--------------------------------------------------===//
Expand All @@ -1092,7 +1079,7 @@ class ImmutableSet {
unsigned getHeight() const { return Root ? Root->getHeight() : 0; }

static void Profile(FoldingSetNodeID &ID, const ImmutableSet &S) {
ID.AddPointer(S.Root);
ID.AddPointer(S.Root.get());
}

void Profile(FoldingSetNodeID &ID) const { return Profile(ID, *this); }
Expand All @@ -1114,50 +1101,26 @@ class ImmutableSetRef {
using FactoryTy = typename TreeTy::Factory;

private:
TreeTy *Root;
IntrusiveRefCntPtr<TreeTy> Root;
FactoryTy *Factory;

public:
/// Constructs a set from a pointer to a tree root. In general one
/// should use a Factory object to create sets instead of directly
/// invoking the constructor, but there are cases where make this
/// constructor public is useful.
explicit ImmutableSetRef(TreeTy* R, FactoryTy *F)
: Root(R),
Factory(F) {
if (Root) { Root->retain(); }
}

ImmutableSetRef(const ImmutableSetRef &X)
: Root(X.Root),
Factory(X.Factory) {
if (Root) { Root->retain(); }
}

~ImmutableSetRef() {
if (Root) { Root->release(); }
}

ImmutableSetRef &operator=(const ImmutableSetRef &X) {
if (Root != X.Root) {
if (X.Root) { X.Root->retain(); }
if (Root) { Root->release(); }
Root = X.Root;
Factory = X.Factory;
}
return *this;
}
ImmutableSetRef(TreeTy *R, FactoryTy *F) : Root(R), Factory(F) {}

static ImmutableSetRef getEmptySet(FactoryTy *F) {
return ImmutableSetRef(0, F);
}

ImmutableSetRef add(value_type_ref V) {
return ImmutableSetRef(Factory->add(Root, V), Factory);
return ImmutableSetRef(Factory->add(Root.get(), V), Factory);
}

ImmutableSetRef remove(value_type_ref V) {
return ImmutableSetRef(Factory->remove(Root, V), Factory);
return ImmutableSetRef(Factory->remove(Root.get(), V), Factory);
}

/// Returns true if the set contains the specified value.
Expand All @@ -1166,20 +1129,19 @@ class ImmutableSetRef {
}

ImmutableSet<ValT> asImmutableSet(bool canonicalize = true) const {
return ImmutableSet<ValT>(canonicalize ?
Factory->getCanonicalTree(Root) : Root);
return ImmutableSet<ValT>(
canonicalize ? Factory->getCanonicalTree(Root.get()) : Root.get());
}

TreeTy *getRootWithoutRetain() const {
return Root;
}
TreeTy *getRootWithoutRetain() const { return Root.get(); }

bool operator==(const ImmutableSetRef &RHS) const {
return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
}

bool operator!=(const ImmutableSetRef &RHS) const {
return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
: Root != RHS.Root;
}

/// isEmpty - Return true if the set contains no elements.
Expand All @@ -1195,7 +1157,7 @@ class ImmutableSetRef {

using iterator = ImutAVLValueIterator<ImmutableSetRef>;

iterator begin() const { return iterator(Root); }
iterator begin() const { return iterator(Root.get()); }
iterator end() const { return iterator(); }

//===--------------------------------------------------===//
Expand All @@ -1205,7 +1167,7 @@ class ImmutableSetRef {
unsigned getHeight() const { return Root ? Root->getHeight() : 0; }

static void Profile(FoldingSetNodeID &ID, const ImmutableSetRef &S) {
ID.AddPointer(S.Root);
ID.AddPointer(S.Root.get());
}

void Profile(FoldingSetNodeID &ID) const { return Profile(ID, *this); }
Expand Down