Skip to content

Commit

Permalink
Merge pull request #6627 from markotoplak/pca-eq-2
Browse files Browse the repository at this point in the history
Implement __eq__ and __hash__ for PCA (and family)
  • Loading branch information
janezd committed Nov 10, 2023
2 parents 0719286 + b9489a3 commit df20f2e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 5 deletions.
66 changes: 63 additions & 3 deletions Orange/projection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __setstate__(self, state):

class Projection:
def __init__(self, proj):
self.__dict__.update(proj.__dict__)
if proj is not None:
self.__dict__.update(proj.__dict__)
self.proj = proj

def transform(self, X):
Expand All @@ -119,27 +120,71 @@ def __call__(self, data):
def __repr__(self):
return self.name

def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) \
and self.proj == other.proj

def __hash__(self):
return hash(self.proj)


class TransformDomain:
def __init__(self, projection):
self.projection = projection
self._hash = None

def __call__(self, data):
if data.domain != self.projection.pre_domain:
data = data.transform(self.projection.pre_domain)
return self.projection.transform(data.X)

def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) \
and self.projection == other.projection

def __setstate__(self, state):
self.__dict__.update(state)
self._hash = None

def __getstate__(self):
state = self.__dict__.copy()
del state["_hash"]
return state

def __hash__(self):
if self._hash is None:
self._hash = hash(self.projection)
return self._hash


class ComputeValueProjector(SharedComputeValue):
def __init__(self, projection, feature, transform):
def __init__(self, projection=None, feature=None, transform=None):
super().__init__(transform)
if projection is not None:
warnings.warn("Argument projection is unused and will be removed.",
OrangeDeprecationWarning, stacklevel=2)
self.projection = projection
self.feature = feature
self.transformed = None

def compute(self, data, space):
return space[:, self.feature]

def __eq__(self, other):
if self is other:
return True
return super().__eq__(other) \
and self.projection == other.projection \
and self.feature == other.feature \
and self.transformed == other.transformed

def __hash__(self):
return hash((super().__hash__(), self.projection, self.feature, self.transformed))


class DomainProjection(Projection):
var_prefix = "C"
Expand All @@ -149,7 +194,7 @@ def __init__(self, proj, domain, n_components):

def proj_variable(i, name):
v = Orange.data.ContinuousVariable(
name, compute_value=ComputeValueProjector(self, i, transformer)
name, compute_value=ComputeValueProjector(feature=i, transform=transformer)
)
v.to_sql = LinearCombinationSql(
domain.attributes, self.components_[i, :],
Expand All @@ -176,6 +221,21 @@ def copy(self):
model.name = self.name
return model

def __eq__(self, other):
# see comment in __hash__() about .domain
if self is other:
return True
return super().__eq__(other) \
and self.n_components == other.n_components \
and self.orig_domain == other.orig_domain \
and self.var_prefix == other.var_prefix

def __hash__(self):
# hashing self.domain would cause infinite recursion;
# because it is only constructed from .orig_domain, .n_components
# and .proj (dealt with in the superclass), we do not need it
return hash((super().__hash__(), self.n_components, self.orig_domain, self.var_prefix))


class LinearProjector(Projector):
name = "Linear Projection"
Expand Down
6 changes: 4 additions & 2 deletions Orange/projection/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,11 @@ def __init__(self, embedding: openTSNE.TSNEEmbedding, table: Table,

def proj_variable(i):
return self.embedding.domain[i].copy(
compute_value=ComputeValueProjector(self, i, transformer))
compute_value=ComputeValueProjector(feature=i, transform=transformer))

super().__init__(None)
self.name = "TSNE"

super().__init__(self)
self.embedding_ = embedding
self.embedding = table
self.pre_domain = pre_domain
Expand Down
30 changes: 30 additions & 0 deletions Orange/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,33 @@ def test_max_components(self):
self.assertEqual(len(pca.explained_variance_ratio_), 20)
pca = PCA(n_components=10)(data)
self.assertEqual(len(pca.explained_variance_ratio_), 10)

def test_eq_hash(self):
d = np.random.RandomState(0).rand(20, 20)
data = Table.from_numpy(None, d)
p1 = PCA()(data)
p2 = PCA()(data)
np.testing.assert_equal(p1(data).X, p2(data).X)

# even though results are the same, these transformations
# are different because the PCA object is
self.assertNotEqual(p1, p2)
self.assertNotEqual(p1.domain, p2.domain)
self.assertNotEqual(hash(p1), hash(p2))
self.assertNotEqual(hash(p1.domain), hash(p2.domain))

# copy projection
p2.domain[0].compute_value.compute_shared.projection = \
p1.domain[0].compute_value.compute_shared.projection
p2.proj = p1.proj
# reset hash caches because object were hacked
# pylint: disable=protected-access
p1.domain._hash = None
p2.domain._hash = None
p1.domain[0].compute_value.compute_shared._hash = None
p2.domain[0].compute_value.compute_shared._hash = None

self.assertEqual(p1, p2)
self.assertEqual(p1.domain, p2.domain)
self.assertEqual(hash(p1), hash(p2))
self.assertEqual(hash(p1.domain), hash(p2.domain))

0 comments on commit df20f2e

Please sign in to comment.