Skip to content

Commit

Permalink
Merge 77edded into 62d479b
Browse files Browse the repository at this point in the history
  • Loading branch information
prisae committed Dec 8, 2020
2 parents 62d479b + 77edded commit 3457299
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 55 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ recent versions
*latest*
--------

- All data is stored in the ``Survey``, not partly in ``Survey`` and partly
in ``Simulation``.
- Removed ``precision`` from ``skin_depth``, ``wavelength``,
``min_cell_width``; all in ``meshes``. It caused problems for high
frequencies.
- Deprecated ``collect_classes`` in ``io``.
- Expanded the ``what``-parameter in the ``Simulation``-class to include
properties related to the gradient.


*v0.15.2* : Bugfix deploy II
Expand Down
75 changes: 43 additions & 32 deletions emg3d/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# License for the specific language governing permissions and limitations under
# the License.

import warnings
import itertools
from copy import deepcopy

Expand Down Expand Up @@ -327,6 +328,12 @@ def to_dict(self, what='computed', copy=False):
out['gridding'] = self.gridding
out['solver_opts'] = self.solver_opts

# Clean unwanted data if plain.
if what == 'plain':
for key in ['synthetic', 'residual', 'weights']:
if key in out['survey']['data'].keys():
del out['survey']['data'][key]

# Put provided grids back on gridding.
if self.gridding == 'input':
out['gridding_opts'] = self._grid_single
Expand Down Expand Up @@ -354,27 +361,20 @@ def to_dict(self, what='computed', copy=False):

out['_input_nCz'] = self._input_nCz

# Get required properties.
store = []

if what == 'all':
store += ['_dict_grid', '_dict_model', '_dict_sfield',
'_dict_hfield']

# Store wanted dicts.
if what in ['computed', 'all']:
store += ['_dict_efield', '_dict_efield_info']
for name in ['_dict_efield', '_dict_efield_info', '_dict_hfield',
'_dict_bfield', '_dict_bfield_info']:
if hasattr(self, name):
out[name] = getattr(self, name)

# store dicts.
for name in store:
out[name] = getattr(self, name)
if what == 'all':
for name in ['_dict_grid', '_dict_model', '_dict_sfield']:
if hasattr(self, name):
out[name] = getattr(self, name)

# store data.
out['data'] = {}
# Store gradient and misfit.
if what in ['computed', 'results', 'all']:
for name in list(self.data.data_vars):
# These two are stored in the Survey instance.
if name not in ['observed', 'std']:
out['data'][name] = self.data.get(name)
out['gradient'] = self._gradient
out['misfit'] = self._misfit

Expand Down Expand Up @@ -434,7 +434,8 @@ def from_dict(cls, inp):

# Add existing derived/computed properties.
data = ['_dict_grid', '_dict_model', '_dict_sfield',
'_dict_hfield', '_dict_efield', '_dict_efield_info']
'_dict_hfield', '_dict_efield', '_dict_efield_info',
'_dict_bfield', '_dict_bfield_info']
for name in data:
if name in inp.keys():
values = inp.get(name)
Expand All @@ -457,9 +458,13 @@ def from_dict(cls, inp):
if name in inp.keys():
setattr(out, '_'+name, inp.get(name))

# Add stored data (synthetic, residual, etc).
for name in inp['data'].keys():
out.data[name] = out.data.observed*0+inp['data'][name]
# For backwards compatibility < v0.16.0; remove eventually.
if 'data' in inp.keys():
warnings.warn("Simulation-dict is outdated; store with new "
"version of `emg3d`.", FutureWarning)
for name in inp['data'].keys():
out.data[name] = out.data.observed*np.nan
out.data[name][...] = inp['data'][name]

return out

Expand Down Expand Up @@ -941,24 +946,30 @@ def clean(self, what='computed'):
if what not in ['computed', 'keepresults', 'all']:
raise TypeError(f"Unrecognized `what`: {what}")

clean = []

# Clean data/model/sfield-dicts.
if what in ['keepresults', 'all']:
clean += ['_dict_grid', '_dict_model', '_dict_sfield']
for name in ['_dict_grid', '_dict_model', '_dict_sfield']:
delattr(self, name)
setattr(self, name, self._dict_initiate)

# Clean field-dicts.
if what in ['computed', 'keepresults', 'all']:
clean += ['_dict_efield', '_dict_efield_info', '_dict_hfield']

# Clean dicts.
for name in clean:
delattr(self, name)
setattr(self, name, self._dict_initiate)
# These exist always and have to be initiated.
for name in ['_dict_efield', '_dict_efield_info', '_dict_hfield']:
delattr(self, name)
setattr(self, name, self._dict_initiate)

# These only exist with gradient; don't initiate them.
for name in ['_dict_bfield', '_dict_bfield_info']:
if hasattr(self, name):
delattr(self, name)

# Clean data.
if what in ['computed', 'all']:
for name in list(self.data.data_vars):
if name not in ['observed', 'std']:
del self.data[name]
for key in ['residual', 'weight']:
if key in self.data.keys():
del self.data[key]
self.data['synthetic'] = self.data.observed*np.nan
for name in ['_gradient', '_misfit']:
delattr(self, name)
Expand Down
34 changes: 16 additions & 18 deletions emg3d/surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# License for the specific language governing permissions and limitations under
# the License.

import warnings
from copy import deepcopy
from dataclasses import dataclass

Expand Down Expand Up @@ -236,10 +237,10 @@ def to_dict(self, copy=False):
# Add frequencies.
out['frequencies'] = self.frequencies

# Add `observed` and `std`, if it exists.
out['data'] = {'observed': self.data.observed.data}
if 'std' in self.data.keys():
out['data']['std'] = self.data['std'].data
# Add data.
out['data'] = {}
for key in self.data.keys():
out['data'][key] = self.data[key].data

# Add `noise_floor` and `relative error`.
out['noise_floor'] = self.data.noise_floor
Expand Down Expand Up @@ -271,25 +272,22 @@ def from_dict(cls, inp):
"""
try:
# Backwards compatibility (emg3d < 0.13); remove eventually.
if 'observed' in inp.keys():
data = inp['observed']
new_format = False
else:
data = None
new_format = True

# Initiate survey.
out = cls(name=inp['name'], sources=inp['sources'],
receivers=inp['receivers'],
frequencies=inp['frequencies'], data=data,
frequencies=inp['frequencies'],
fixed=bool(inp['fixed']))

# Add all data (includes 'std')!
if new_format:
for key, value in inp['data'].items():
out._data[key] = out.data.observed*np.nan
out._data[key][...] = value
# Backwards compatibility (emg3d < 0.13); remove eventually.
if 'observed' in inp.keys():
inp['data'] = {'observed': inp['observed']}
warnings.warn("Survey-dict is outdated; store with new "
"version of `emg3d`.", FutureWarning)

# Add all data.
for key, value in inp['data'].items():
out._data[key] = out.data.observed*np.nan
out._data[key][...] = value

# v0.14.0 onwards.
if 'noise_floor' in inp.keys():
Expand Down
16 changes: 16 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ def test_input_gradient(self):
# Ensure the gradient has the shape of the model, not of the input.
assert grad.shape == self.model.shape

sim2 = simulation.to_dict(what='all', copy=True)
sim3 = simulation.to_dict(what='plain', copy=True)
assert 'residual' in sim2['survey']['data'].keys()
assert 'residual' not in sim3['survey']['data'].keys()

# Backwards compatibility
with pytest.warns(FutureWarning):
sim4 = simulation.to_dict()
sim4['data'] = {'synthetic': sim4['survey']['data']['observed']}
simulations.Simulation.from_dict(sim4)

simulation.clean('all') # Should remove 'residual', 'bfield-dicts'
sim5 = simulation.to_dict('all')
assert 'residual' not in sim5['survey']['data'].keys()
assert '_dict_bfield' not in sim5.keys()


@pytest.mark.skipif(xarray is None, reason="xarray not installed.")
def test_simulation_automatic():
Expand Down
11 changes: 6 additions & 5 deletions tests/test_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,12 @@ def test_copy(self, tmpdir):
assert_allclose(srvy4.standard_deviation, srvy5.standard_deviation)

# Test backwards compatibility.
srvy5 = srvy5.to_dict()
srvy5['observed'] = srvy5['data']['observed']
del srvy5['data']
srvy6 = surveys.Survey.from_dict(srvy5)
assert_allclose(srvy5['observed'], srvy6.observed)
with pytest.warns(FutureWarning):
srvy5 = srvy5.to_dict()
srvy5['observed'] = srvy5['data']['observed']
del srvy5['data']
srvy6 = surveys.Survey.from_dict(srvy5)
assert_allclose(srvy5['observed'], srvy6.observed)


def test_PointDipole():
Expand Down

0 comments on commit 3457299

Please sign in to comment.