Skip to content

Commit

Permalink
fix: test serialization roundtrip with dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Dec 17, 2023
1 parent 5bcdf5c commit ff9a1be
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import numpy as np
import xarray as xr
import pytest
Expand Down Expand Up @@ -518,3 +520,41 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataarray(dim, mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = EOF()
model.fit(mock_data_array, dim)
dt = model.serialize()
rebuilt_model = EOF.deserialize(dt)
assert np.allclose(
model.transform(mock_data_array), rebuilt_model.transform(mock_data_array)
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataset(dim, mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
model = EOF()
model.fit(mock_dataset, dim)
dt = model.serialize()
rebuilt_model = EOF.deserialize(dt)
assert np.allclose(
model.transform(mock_dataset), rebuilt_model.transform(mock_dataset)
)
42 changes: 42 additions & 0 deletions tests/models/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,45 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataarray(dim, mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = EOF()
model.fit(mock_data_array, dim)
rotator = EOFRotator()
rotator.fit(model)
dt = rotator.serialize()
rebuilt_rotator = EOFRotator.deserialize(dt)
assert np.allclose(
rotator.transform(mock_data_array), rebuilt_rotator.transform(mock_data_array)
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataset(dim, mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
model = EOF()
model.fit(mock_dataset, dim)
rotator = EOFRotator()
rotator.fit(model)
dt = rotator.serialize()
rebuilt_rotator = EOFRotator.deserialize(dt)
assert np.allclose(
rotator.transform(mock_dataset), rebuilt_rotator.transform(mock_dataset)
)
38 changes: 38 additions & 0 deletions tests/models/test_mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,41 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
original.inverse_transform(*original.scores()),
loaded.inverse_transform(*loaded.scores()),
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataarray(dim, mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = MCA()
model.fit(mock_data_array, mock_data_array, dim)
dt = model.serialize()
rebuilt_model = MCA.deserialize(dt)
assert np.allclose(
model.transform(mock_data_array), rebuilt_model.transform(mock_data_array)
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataset(dim, mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
model = MCA()
model.fit(mock_dataset, mock_dataset, dim)
dt = model.serialize()
rebuilt_model = MCA.deserialize(dt)
assert np.allclose(
model.transform(mock_dataset), rebuilt_model.transform(mock_dataset)
)
42 changes: 42 additions & 0 deletions tests/models/test_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,45 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
original.inverse_transform(*original.scores()),
loaded.inverse_transform(*loaded.scores()),
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataarray(dim, mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = MCA()
model.fit(mock_data_array, mock_data_array, dim)
rotator = MCARotator()
rotator.fit(model)
dt = rotator.serialize()
rebuilt_rotator = MCARotator.deserialize(dt)
assert np.allclose(
rotator.transform(mock_data_array), rebuilt_rotator.transform(mock_data_array)
)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_serialize_deserialize_dataset(dim, mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
model = MCA()
model.fit(mock_dataset, mock_dataset, dim)
rotator = MCARotator()
rotator.fit(model)
dt = rotator.serialize()
rebuilt_rotator = MCARotator.deserialize(dt)
assert np.allclose(
rotator.transform(mock_dataset), rebuilt_rotator.transform(mock_dataset)
)

0 comments on commit ff9a1be

Please sign in to comment.