diff --git a/examples/plot_minimal_example.py b/examples/plot_minimal_example.py index 04ca4af..281cb21 100644 --- a/examples/plot_minimal_example.py +++ b/examples/plot_minimal_example.py @@ -29,11 +29,7 @@ my_settings.particles = 5000 my_settings.run_mode = "fixed source" # Create a DT point source -try: - source = openmc.IndependentSource() -except: - # work with older versions of openmc - source = openmc.Source() +source = openmc.IndependentSource() source.space = openmc.stats.Point((100, 0, 0)) source.angle = openmc.stats.Isotropic() source.energy = openmc.stats.Discrete([14e6], [1]) @@ -46,8 +42,9 @@ dimension=[40, 40, 40], ) mesh_filter = openmc.MeshFilter(mesh) +energy_filter = openmc.EnergyFilter([0, 1e6, 2e6]) mesh_tally_1 = openmc.Tally(name="mesh_tally") -mesh_tally_1.filters = [mesh_filter] +mesh_tally_1.filters = [mesh_filter, energy_filter] mesh_tally_1.scores = ["heating"] my_tallies.append(mesh_tally_1) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 6f219f8..b42b905 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -104,6 +104,7 @@ def plot_mesh_tally( if isinstance(tally, typing.Sequence): mesh_ids = [] for one_tally in tally: + _check_tally_for_energy_filters_with_multiple_bins(one_tally) mesh = one_tally.find_filter(filter_type=openmc.MeshFilter).mesh # TODO check the tallies use the same mesh mesh_ids.append(mesh.id) @@ -113,6 +114,7 @@ def plot_mesh_tally( ) else: mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh + _check_tally_for_energy_filters_with_multiple_bins(tally) if isinstance(mesh, openmc.CylindricalMesh): raise NotImplemented( @@ -299,9 +301,26 @@ def get_index_where(self, value: float, basis: str = "xy"): return slice_index +def _check_tally_for_energy_filters_with_multiple_bins(tally): + for current_filter in tally.filters: + if isinstance(current_filter, openmc.EnergyFilter): + if isinstance(current_filter, openmc.EnergyFilter): + if current_filter.num_bins > 1: + msg = ( + "An EnergyFilter was found on the tally with more " + "than a single bin. EnergyFilter.num_bins=" + f"{current_filter.num_bins}. Either reduce the number " + "of energy bins to 1 or remove the EnergyFilter to " + "plot this tally. EnergyFilter with more than 1 energy " + "bin are unsupported" + ) + raise ValueError(msg) + + def _get_tally_data( scaling_factor, mesh, basis, tally, value, volume_normalization, score, slice_index ): + # if score is not specified and tally has a single score then we know which score to use if score is None: if len(tally.scores) == 1: diff --git a/tests/test_units.py b/tests/test_units.py index b441cb8..df33329 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -197,4 +197,48 @@ def test_plot_two_mesh_tallies(model): assert plot.get_ylim() == (-3.0, 3.5) +def test_plot_with_energy_filters(model): + geometry = model.geometry + + mesh = openmc.RegularMesh().from_domain(geometry, dimension=[2, 20, 30]) + mesh_filter = openmc.MeshFilter(mesh) + energy_filter = openmc.EnergyFilter([0, 1e6, 2e6]) + mesh_tally_1 = openmc.Tally(name="mesh_tally") + mesh_tally_1.filters = [mesh_filter, energy_filter] + mesh_tally_1.scores = ["heating"] + + tallies = openmc.Tallies([mesh_tally_1]) + + model.tallies = tallies + + sp_filename = model.run() + with openmc.StatePoint(sp_filename) as statepoint: + tally_result_1 = statepoint.get_tally(name="mesh_tally") + + with pytest.raises(ValueError) as excinfo: + plot_mesh_tally(tally=tally_result_1) + msg = ( + "An EnergyFilter was found on the tally with more " + "than a single bin. EnergyFilter.num_bins=" + "2. Either reduce the number " + "of energy bins to 1 or remove the EnergyFilter to " + "plot this tally. EnergyFilter with more than 1 energy " + "bin are unsupported" + ) + assert str(excinfo.value) == msg + + energy_filter = openmc.EnergyFilter([0, 2e6]) + mesh_tally_1.filters = [mesh_filter, energy_filter] + + tallies = openmc.Tallies([mesh_tally_1]) + + model.tallies = tallies + + sp_filename = model.run() + with openmc.StatePoint(sp_filename) as statepoint: + tally_result_1 = statepoint.get_tally(name="mesh_tally") + + plot_mesh_tally(tally=tally_result_1) + + # todo catch errors when 2d mesh used and 1d axis selected for plotting'