Skip to content

Commit

Permalink
Add method for reduction classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Apr 22, 2021
1 parent 7fe59a3 commit 21d9f6f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
15 changes: 12 additions & 3 deletions ezyrb/pod.py
Expand Up @@ -59,13 +59,22 @@ def singular_values(self):
"""
return self._singular_values

def reduce(self, X):
def fit(self, X):
"""
Reduces the parameter Space by using the specified reduction method (default svd).
Create the reduced space for the given snapshots `X` using the
specified method
:type: numpy.ndarray
:param numpy.ndarray X: the input snapshots matrix (stored by column)
"""
self._modes, self._singular_values = self.__method(X)
return self

def reduce(self, X):
"""
Reduces the given snapshots.
:param numpy.ndarray X: the input snapshots matrix (stored by column).
"""
return self.modes.T.conj().dot(X)

def expand(self, X):
Expand Down
1 change: 1 addition & 0 deletions ezyrb/reducedordermodel.py
Expand Up @@ -21,6 +21,7 @@ def fit(self, *args, **kwargs):
:param \*args: additional parameters to pass to the `fit` method.
:param \**kwargs: additional parameters to pass to the `fit` method.
"""
self.reduction.fit(self.database.snapshots.T)
self.approximation.fit(
self.database.parameters,
self.reduction.reduce(self.database.snapshots.T).T, *args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions ezyrb/reduction.py
Expand Up @@ -6,6 +6,11 @@


class Reduction(ABC):

@abstractmethod
def fit(self):
pass

@abstractmethod
def reduce(self):
pass
Expand Down
31 changes: 15 additions & 16 deletions tests/test_pod.py
Expand Up @@ -13,7 +13,7 @@ def test_constructor_empty(self):
a = POD()

def test_numpysvd(self):
A = POD('svd').reduce(snapshots)
A = POD('svd').fit(snapshots).reduce(snapshots)
assert np.allclose(A, poddb, rtol=1e-03, atol=1e-08) or np.allclose(
A,
-1 * poddb,
Expand All @@ -22,7 +22,7 @@ def test_numpysvd(self):
)

def test_correlation_matirix(self):
A = POD('correlation_matrix').reduce(snapshots)
A = POD('correlation_matrix').fit(snapshots).reduce(snapshots)
assert np.allclose(A, poddb, rtol=1e-03, atol=1e-08) or np.allclose(
A,
-1 * poddb,
Expand All @@ -31,7 +31,7 @@ def test_correlation_matirix(self):
)

def test_correlation_matirix_savemem(self):
A = POD('correlation_matrix', save_memory=True).reduce(snapshots)
A = POD('correlation_matrix', save_memory=True).fit(snapshots).reduce(snapshots)
assert np.allclose(A, poddb, rtol=1e-03, atol=1e-08) or np.allclose(
A,
-1 * poddb,
Expand All @@ -40,15 +40,14 @@ def test_correlation_matirix_savemem(self):
)

def test_randomized_svd(self):
A = POD('randomized_svd').reduce(snapshots)
A = POD('randomized_svd').fit(snapshots).reduce(snapshots)
np.testing.assert_allclose(np.absolute(A),
np.absolute(poddb),
rtol=1e-03,
atol=1e-08)

def test_singlular_values(self):
a = POD('svd')
a.reduce(snapshots)
a = POD('svd').fit(snapshots)
np.testing.assert_allclose(
a.singular_values,
np.array([887.15704, 183.2508, 84.11757, 26.40448]),
Expand All @@ -57,50 +56,50 @@ def test_singlular_values(self):

def test_modes(self):
a = POD('svd')
a.reduce(snapshots)
a.fit(snapshots)
np.testing.assert_allclose(a.modes, modes)

def test_truncation_01(self):
a = POD(method='svd', rank=0)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 1

def test_truncation_02(self):
a = POD(method='randomized_svd', rank=0)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 1

def test_truncation_03(self):
a = POD(method='correlation_matrix', rank=0)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 2

def test_truncation_04(self):
a = POD(method='svd', rank=3)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 3

def test_truncation_05(self):
a = POD(method='randomized_svd', rank=3)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 3

def test_truncation_06(self):
a = POD(method='correlation_matrix', rank=4)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 4

def test_truncation_07(self):
a = POD(method='svd', rank=0.8)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 1

def test_truncation_08(self):
a = POD(method='randomized_svd', rank=0.995)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 3

def test_truncation_09(self):
a = POD(method='correlation_matrix', rank=0.9999)
a.reduce(snapshots)
a.fit(snapshots)
assert a.singular_values.shape[0] == 2

0 comments on commit 21d9f6f

Please sign in to comment.