Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions pina/model/layers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from .stride import Stride
from .utils_convolution import optimizing

import warnings

class PODBlock(torch.nn.Module):
"""
Expand Down Expand Up @@ -85,15 +85,15 @@ def scale_coefficients(self):
"""
return self.__scale_coefficients

def fit(self, X):
def fit(self, X, randomized=True):
"""
Set the POD basis by performing the singular value decomposition of the
given tensor. If `self.scale_coefficients` is True, the coefficients
are scaled after the projection to have zero mean and unit variance.

:param torch.Tensor X: The tensor to be reduced.
"""
self._fit_pod(X)
self._fit_pod(X, randomized)

if self.__scale_coefficients:
self._fit_scaler(torch.matmul(self._basis, X.T))
Expand All @@ -112,16 +112,24 @@ def _fit_scaler(self, coeffs):
"mean": torch.mean(coeffs, dim=1),
}

def _fit_pod(self, X):
def _fit_pod(self, X, randomized):
"""
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.

:param torch.Tensor X: The tensor to be reduced.
"""
if X.device.type == "mps": # svd_lowrank not arailable for mps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here would be something like:

        if X.device.type == "mps":  #  svd_lowrank not arailable for mps
            self._basis = torch.svd(X.T)[0].T
        else:
            self._basis = self.svd(X.T, q=X.shape[0])[0].T

We should put a warning on the doc for MPS users maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing!

In my opinion, if we want to keep the svd_lowrank to speed up the computations, then we should set up a seed or give a Warning saying that the method is randomized and that the computation of the discretized basis may differ in different runs. Also what you suggested with the randomized kwarg is ok (I would include the Warning anyway, maybe).

Then, maybe for MPS users the Warning can be The POD is computed using the standard SVD approach and this may slow down the computation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then let's do it like this. We can put a warning at initialization, and if an mps user is using the layer we raise another waning like the one you described :)

warnings.warn(
"svd_lowrank not available for mps, using svd instead."
"This may slow down computations.", ResourceWarning
)
self._basis = torch.svd(X.T)[0].T
else:
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
if randomized:
warnings.warn("Considering a randomized algorithm to compute the POD basis")
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
else:
self._basis = torch.svd(X.T)[0].T

def forward(self, X):
"""
Expand Down
15 changes: 9 additions & 6 deletions tests/test_layers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def test_fit(rank, scale):

@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_fit(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_fit(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
n_snap = toy_snapshots.shape[0]
dof = toy_snapshots.shape[1]
assert pod.basis.shape == (rank, dof)
Expand Down Expand Up @@ -65,18 +66,20 @@ def test_forward():

@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_expand(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_expand(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
c = pod(toy_snapshots)
torch.testing.assert_close(pod.expand(c), toy_snapshots)
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))

@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.parametrize("rank", [1, 2, 10])
def test_reduce_expand(rank, scale):
@pytest.mark.parametrize("randomized", [True, False])
def test_reduce_expand(rank, scale, randomized):
pod = PODBlock(rank, scale)
pod.fit(toy_snapshots)
pod.fit(toy_snapshots, randomized)
torch.testing.assert_close(
pod.expand(pod.reduce(toy_snapshots)),
toy_snapshots)
Expand Down
Loading