Skip to content

Commit

Permalink
[unittests] Add coverage for CSV SolutionArray.save
Browse files Browse the repository at this point in the history
  • Loading branch information
ischoegl committed Jun 22, 2023
1 parent f1c795e commit 9c60e9e
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 24 deletions.
131 changes: 114 additions & 17 deletions test/python/test_composite.py
Expand Up @@ -461,7 +461,23 @@ def test_import_no_norm_data(self):
self.assertArrayNear(states.P, b.P)
self.assertArrayNear(states.X, b.X)

def test_write_csv(self):
def check_arrays(self, a, b, rtol=1e-8):
self.assertArrayNear(a.T, b.T, rtol=rtol)
self.assertArrayNear(a.P, b.P, rtol=rtol)
self.assertArrayNear(a.X, b.X, rtol=rtol)
for key in a.extra:
value = getattr(a, key)
if isinstance(value[0], str):
assert (getattr(b, key) == value).all()
else:
self.assertArrayNear(getattr(b, key), value, rtol=rtol)
if b.meta:
# not all output formats preserve metadata
for key, value in a.meta.items():
assert b.meta[key] == value

@pytest.mark.usefixtures("allow_deprecated")
def test_write_csv_legacy(self):
states = ct.SolutionArray(self.gas, 7)
states.TPX = np.linspace(300, 1000, 7), 2e5, 'H2:0.5, O2:0.4'
states.equilibrate('HP')
Expand All @@ -476,10 +492,9 @@ def test_write_csv(self):

b = ct.SolutionArray(self.gas)
b.read_csv(outfile)
self.assertArrayNear(states.T, b.T)
self.assertArrayNear(states.P, b.P)
self.assertArrayNear(states.X, b.X)
self.check_arrays(states, b)

@pytest.mark.usefixtures("allow_deprecated")
def test_write_csv_single_row(self):
gas = ct.Solution("gri30.yaml")
states = ct.SolutionArray(gas)
Expand All @@ -491,10 +506,9 @@ def test_write_csv_single_row(self):

b = ct.SolutionArray(gas)
b.read_csv(outfile)
self.assertArrayNear(states.T, b.T)
self.assertArrayNear(states.P, b.P)
self.assertArrayNear(states.X, b.X)
self.check_arrays(states, b)

@pytest.mark.usefixtures("allow_deprecated")
def test_write_csv_str_column(self):
states = ct.SolutionArray(self.gas, 3, extra={'spam': 'eggs'})

Expand All @@ -504,14 +518,103 @@ def test_write_csv_str_column(self):
b = ct.SolutionArray(self.gas, extra={'spam'})
b.read_csv(outfile)
self.assertEqual(list(states.spam), list(b.spam))
self.check_arrays(states, b)

@pytest.mark.usefixtures("allow_deprecated")
def test_write_csv_multidim_column(self):
states = ct.SolutionArray(self.gas, 3, extra={'spam': np.zeros((3, 5,))})

outfile = self.test_work_path / "solutionarray.csv"
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
states.write_csv(outfile)

def test_write_csv(self):
outfile = self.test_work_path / "solutionarray_new.csv"
outfile.unlink(missing_ok=True)

arr = ct.SolutionArray(self.gas, 7)
arr.TPX = np.linspace(300, 1000, 7), 2e5, "H2:0.5, O2:0.4"
arr.equilibrate("HP")
arr.save(outfile, basis="mole")

with open(outfile, "r") as fid:
header = fid.readline()
assert "X_H2" in header.split(",")

b = ct.SolutionArray(self.gas)
b.read_csv(outfile)
self.check_arrays(arr, b)

with pytest.raises(ct.CanteraError, match="already exists"):
arr.save(outfile)

def test_write_csv_fancy(self):
outfile = self.test_work_path / "solutionarray_fancy.csv"
outfile.unlink(missing_ok=True)

extra = {"foo": range(7), "bar": range(7), "spam": "eggs"}
arr = ct.SolutionArray(self.gas, 7, extra=extra)
arr.TPX = np.linspace(300, 1000, 7), 2e5, "H2:0.5, O2:0.4"
arr.equilibrate("HP")
arr.save(outfile)

with open(outfile, "r") as fid:
header = fid.readline()
assert "Y_H2" in header.split(",")

b = ct.SolutionArray(self.gas)
b.read_csv(outfile)
self.check_arrays(arr, b)

def test_write_csv_escaped(self):
outfile = self.test_work_path / "solutionarray_escaped.csv"
outfile.unlink(missing_ok=True)

extra = {"foo": range(7), "bar": range(7), "spam,eggs": "a,b,"}
arr = ct.SolutionArray(self.gas, 7, extra=extra)
arr.TPX = np.linspace(300, 1000, 7), 2e5, "H2:0.5, O2:0.4"
arr.equilibrate("HP")
with pytest.warns(UserWarning, match="escaped"):
arr.save(outfile, basis="mass")

with open(outfile, "r") as fid:
header = fid.readline()
assert "Y_H2" in header.split(",")

b = ct.SolutionArray(self.gas)
if _pandas is None:
with pytest.raises(ValueError):
# np.genfromtxt does not support escaped characters
b.read_csv(outfile)
return

b.read_csv(outfile)
self.check_arrays(arr, b)

df = _pandas.read_csv(outfile)
b.from_pandas(df)
self.check_arrays(arr, b)

def test_write_csv_exceptions(self):
outfile = self.test_work_path / f"solutionarray_invalid.csv"
outfile.unlink(missing_ok=True)

arr = ct.SolutionArray(self.gas, (2, 5))
with pytest.raises(ct.CanteraError, match="only works for 1D SolutionArray"):
arr.save(outfile)

arr = ct.SolutionArray(self.gas, 10, extra={'spam"eggs': "foo"})
with pytest.raises(NotImplementedError, match="double quotes or line feeds"):
arr.save(outfile)

arr = ct.SolutionArray(self.gas, 10, extra={"foo": 'spam\neggs'})
with pytest.raises(NotImplementedError, match="double quotes or line feeds"):
arr.save(outfile)

arr = ct.SolutionArray(self.gas, 10)
with pytest.raises(ct.CanteraError, match="Invalid species basis"):
arr.save(outfile, basis="foo")

@utilities.unittest.skipIf(_pandas is None, "pandas is not installed")
def test_to_pandas(self):
states = ct.SolutionArray(self.gas, 7, extra={"props": range(7)})
Expand All @@ -524,7 +627,6 @@ def test_to_pandas(self):

@pytest.mark.skipif("native" not in ct.hdf_support(),
reason="Cantera compiled without HDF support")
@utilities.unittest.skipIf(_h5py is None, "h5py is not installed")
def test_write_hdf(self):
outfile = self.test_work_path / "solutionarray_fancy.h5"
outfile.unlink(missing_ok=True)
Expand All @@ -539,13 +641,7 @@ def test_write_hdf(self):

b = ct.SolutionArray(self.gas)
attr = b.restore(outfile, "group0")
self.assertArrayNear(states.T, b.T)
self.assertArrayNear(states.P, b.P)
self.assertArrayNear(states.X, b.X)
self.assertArrayNear(states.foo, b.foo)
self.assertArrayNear(states.bar, b.bar)
self.assertEqual(b.meta['spam'], 'eggs')
self.assertEqual(b.meta['hello'], 'world')
self.check_arrays(states, b)

@pytest.mark.skipif("native" not in ct.hdf_support(),
reason="Cantera compiled without HDF support")
Expand All @@ -564,7 +660,7 @@ def run_write_str_column(self, mode):

b = ct.SolutionArray(self.gas, extra={'spam'})
b.restore(outfile, "arr")
self.assertEqual(list(states.spam), list(b.spam))
self.check_arrays(states, b)

@pytest.mark.skipif("native" not in ct.hdf_support(),
reason="Cantera compiled without HDF support")
Expand All @@ -583,7 +679,7 @@ def run_write_multidim_column(self, mode):

b = ct.SolutionArray(self.gas, extra={'spam'})
b.restore(outfile, "arr")
self.assertArrayNear(states.spam, b.spam)
self.check_arrays(states, b)

@pytest.mark.skipif("native" not in ct.hdf_support(),
reason="Cantera compiled without HDF support")
Expand All @@ -604,6 +700,7 @@ def run_write_2d(self, mode):
b.restore(outfile, "arr")
assert b.shape == states.shape


class TestLegacyHDF(utilities.CanteraTest):
# Test SolutionArray legacy HDF file input
#
Expand Down
26 changes: 19 additions & 7 deletions test/python/test_onedim.py
Expand Up @@ -335,7 +335,7 @@ def run_restart(self, mode):
data = self.test_work_path / f"freeflame_restart.{mode}"
data.unlink(missing_ok=True)
if mode == "csv":
self.sim.write_csv(data)
self.sim.save(data, basis="mole")
else:
self.sim.save(data, group)

Expand Down Expand Up @@ -761,7 +761,8 @@ def test_save_restore_remove_species_yaml(self):
k1 = gas1.species_index(species)
self.assertArrayNear(Y1[k1], Y2[k2])

def test_write_csv(self):
@pytest.mark.usefixtures("allow_deprecated")
def test_write_csv_legacy(self):
filename = self.test_work_path / "onedim-write_csv.csv"
# In Python >= 3.8, this can be replaced by the missing_ok argument
if filename.is_file():
Expand All @@ -776,6 +777,19 @@ def test_write_csv(self):
k = self.gas.species_index('H2')
self.assertArrayNear(data.X[:, k], self.sim.X[k, :])

def test_write_csv(self):
filename = self.test_work_path / "onedim-save.csv"
filename.unlink(missing_ok=True)

self.create_sim(2e5, 350, 'H2:1.0, O2:2.0', mech="h2o2.yaml")
self.sim.save(filename, basis="mole")
data = ct.SolutionArray(self.gas)
data.read_csv(filename)
self.assertArrayNear(data.grid, self.sim.grid)
self.assertArrayNear(data.T, self.sim.T)
k = self.gas.species_index('H2')
self.assertArrayNear(data.X[:, k], self.sim.X[k, :])

@pytest.mark.usefixtures("allow_deprecated")
@utilities.unittest.skipIf("h5py" not in ct.hdf_support(), "h5py not installed")
def test_restore_legacy_hdf_h5py(self):
Expand Down Expand Up @@ -1139,7 +1153,7 @@ def test_mixture_averaged_rad(self, saveReference=False):
if filename.is_file():
filename.unlink()

self.sim.write_csv(filename) # check output
self.sim.save(filename, basis="mole") # check output
self.assertTrue(filename.is_file())
csv_data = np.genfromtxt(filename, dtype=float, delimiter=',', names=True)
self.assertIn('radiativeheatloss', csv_data.dtype.names)
Expand Down Expand Up @@ -1245,11 +1259,9 @@ def test_mixture_averaged(self, saveReference=False):
self.assertFalse(bad, bad)

filename = self.test_work_path / "CounterflowPremixedFlame-h2-mix.csv"
# In Python >= 3.8, this can be replaced by the missing_ok argument
if filename.is_file():
filename.unlink()
filename.unlink(missing_ok=True)

sim.write_csv(filename) # check output
sim.save(filename) # check output
self.assertTrue(filename.is_file())
csv_data = np.genfromtxt(filename, dtype=float, delimiter=',', names=True)
self.assertNotIn('qdot', csv_data.dtype.names)
Expand Down

0 comments on commit 9c60e9e

Please sign in to comment.