Skip to content

Commit

Permalink
Support UnstructuredMesh for IndependentSource (#2949)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Shimwell <drshimwell@gmail.com>
Co-authored-by: Paul Romano <paul.k.romano@gmail.com>
  • Loading branch information
3 people committed Apr 12, 2024
1 parent 4ba053c commit e77a524
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 19 deletions.
8 changes: 4 additions & 4 deletions openmc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,10 +864,10 @@ def mesh(self, mesh):
cv.check_type('filter mesh', mesh, openmc.MeshBase)
self._mesh = mesh
if isinstance(mesh, openmc.UnstructuredMesh):
if mesh.volumes is None:
self.bins = []
else:
if mesh.has_statepoint_data:
self.bins = list(range(len(mesh.volumes)))
else:
self.bins = []
else:
self.bins = list(mesh.indices)

Expand Down Expand Up @@ -982,7 +982,7 @@ def from_xml_element(cls, elem, **kwargs):
if translation:
out.translation = [float(x) for x in translation.split()]
return out


class MeshBornFilter(MeshFilter):
"""Filter events by the mesh cell a particle originated from.
Expand Down
42 changes: 41 additions & 1 deletion openmc/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import wraps
from math import pi, sqrt, atan2
from numbers import Integral, Real
from pathlib import Path
Expand Down Expand Up @@ -39,7 +40,8 @@ class MeshBase(IDManagerMixin, ABC):
bounding_box : openmc.BoundingBox
Axis-aligned bounding box of the mesh as defined by the upper-right and
lower-left coordinates.
indices : Iterable of tuple
An iterable of mesh indices for each mesh element, e.g. [(1, 1, 1), (2, 1, 1), ...]
"""

next_id = 1
Expand All @@ -66,6 +68,11 @@ def name(self, name: str):
def bounding_box(self) -> openmc.BoundingBox:
return openmc.BoundingBox(self.lower_left, self.upper_right)

@property
@abstractmethod
def indices(self):
pass

def __repr__(self):
string = type(self).__name__ + '\n'
string += '{0: <16}{1}{2}\n'.format('\tID', '=\t', self._id)
Expand Down Expand Up @@ -1914,6 +1921,17 @@ def _convert_to_cartesian(arr, origin: Sequence[float]):
return arr


def require_statepoint_data(func):
@wraps(func)
def wrapper(self: UnstructuredMesh, *args, **kwargs):
if not self._has_statepoint_data:
raise AttributeError(f'The "{func.__name__}" property requires '
'information about this mesh to be loaded '
'from a statepoint file.')
return func(self, *args, **kwargs)
return wrapper


class UnstructuredMesh(MeshBase):
"""A 3D unstructured mesh
Expand Down Expand Up @@ -1990,6 +2008,7 @@ def __init__(self, filename: PathLike, library: str, mesh_id: Optional[int] = No
self.library = library
self._output = False
self.length_multiplier = length_multiplier
self._has_statepoint_data = False

@property
def filename(self):
Expand All @@ -2010,6 +2029,7 @@ def library(self, lib: str):
self._library = lib

@property
@require_statepoint_data
def size(self):
return self._size

Expand All @@ -2028,6 +2048,7 @@ def output(self, val: bool):
self._output = val

@property
@require_statepoint_data
def volumes(self):
"""Return Volumes for every mesh cell if
populated by a StatePoint file
Expand All @@ -2046,26 +2067,32 @@ def volumes(self, volumes: typing.Iterable[Real]):
self._volumes = volumes

@property
@require_statepoint_data
def total_volume(self):
return np.sum(self.volumes)

@property
@require_statepoint_data
def vertices(self):
return self._vertices

@property
@require_statepoint_data
def connectivity(self):
return self._connectivity

@property
@require_statepoint_data
def element_types(self):
return self._element_types

@property
@require_statepoint_data
def centroids(self):
return np.array([self.centroid(i) for i in range(self.n_elements)])

@property
@require_statepoint_data
def n_elements(self):
if self._n_elements is None:
raise RuntimeError("No information about this mesh has "
Expand Down Expand Up @@ -2096,6 +2123,15 @@ def dimension(self):
def n_dimension(self):
return 3

@property
@require_statepoint_data
def indices(self):
return [(i,) for i in range(self.n_elements)]

@property
def has_statepoint_data(self) -> bool:
return self._has_statepoint_data

def __repr__(self):
string = super().__repr__()
string += '{: <16}=\t{}\n'.format('\tFilename', self.filename)
Expand All @@ -2106,13 +2142,16 @@ def __repr__(self):
return string

@property
@require_statepoint_data
def lower_left(self):
return self.vertices.min(axis=0)

@property
@require_statepoint_data
def upper_right(self):
return self.vertices.max(axis=0)

@require_statepoint_data
def centroid(self, bin: int):
"""Return the vertex averaged centroid of an element
Expand Down Expand Up @@ -2257,6 +2296,7 @@ def from_hdf5(cls, group: h5py.Group):
library = group['library'][()].decode()

mesh = cls(filename=filename, library=library, mesh_id=mesh_id)
mesh._has_statepoint_data = True
vol_data = group['volumes'][()]
mesh.volumes = np.reshape(vol_data, (vol_data.shape[0],))
mesh.n_elements = mesh.volumes.size
Expand Down
10 changes: 7 additions & 3 deletions openmc/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,13 @@ def populate_xml_element(self, elem: ET.Element):
elem.set("mesh", str(self.mesh.id))

# write in the order of mesh indices
for idx in self.mesh.indices:
idx = tuple(i - 1 for i in idx)
elem.append(self.sources[idx].to_xml_element())
if isinstance(self.mesh, openmc.UnstructuredMesh):
for s in self.sources:
elem.append(s.to_xml_element())
else:
for idx in self.mesh.indices:
idx = tuple(i - 1 for i in idx)
elem.append(self.sources[idx].to_xml_element())

@classmethod
def from_xml_element(cls, elem: ET.Element, meshes) -> openmc.MeshSource:
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,22 @@ def test_CylindricalMesh_get_indices_at_coords():
assert mesh.get_indices_at_coords([98, 200.1, 299]) == (0, 1, 0) # second angle quadrant
assert mesh.get_indices_at_coords([98, 199.9, 299]) == (0, 2, 0) # third angle quadrant
assert mesh.get_indices_at_coords([102, 199.1, 299]) == (0, 3, 0) # forth angle quadrant

def test_umesh_roundtrip(run_in_tmpdir, request):
umesh = openmc.UnstructuredMesh(request.path.parent / 'test_mesh_tets.e', 'moab')
umesh.output = True

# create a tally using this mesh
mf = openmc.MeshFilter(umesh)
tally = openmc.Tally()
tally.filters = [mf]
tally.scores = ['flux']

tallies = openmc.Tallies([tally])
tallies.export_to_xml()

xml_tallies = openmc.Tallies.from_xml()
xml_tally = xml_tallies[0]
xml_mesh = xml_tally.filters[0].mesh

assert umesh.id == xml_mesh.id
63 changes: 52 additions & 11 deletions tests/unit_tests/test_source_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,29 @@ def test_roundtrip(run_in_tmpdir, model, request):
###################
# MeshSource tests
###################
@pytest.mark.parametrize('mesh_type', ('rectangular', 'cylindrical'))
def test_mesh_source_independent(run_in_tmpdir, mesh_type):
@pytest.fixture
def void_model():
"""
A void model containing a single box
"""
min, max = -10, 10
box = openmc.model.RectangularParallelepiped(
min, max, min, max, min, max, boundary_type='vacuum')
model = openmc.Model()

geometry = openmc.Geometry([openmc.Cell(region=-box)])
box = openmc.model.RectangularParallelepiped(*[-10, 10]*3, boundary_type='vacuum')
model.geometry = openmc.Geometry([openmc.Cell(region=-box)])

settings = openmc.Settings()
settings.particles = 100
settings.batches = 10
settings.run_mode = 'fixed source'
model.settings.particles = 100
model.settings.batches = 10
model.settings.run_mode = 'fixed source'

return model

model = openmc.Model(geometry=geometry, settings=settings)

@pytest.mark.parametrize('mesh_type', ('rectangular', 'cylindrical'))
def test_mesh_source_independent(run_in_tmpdir, void_model, mesh_type):
"""
A void model containing a single box
"""
model = void_model

# define a 2 x 2 x 2 mesh
if mesh_type == 'rectangular':
Expand Down Expand Up @@ -310,6 +316,41 @@ def test_mesh_source_independent(run_in_tmpdir, mesh_type):
assert mesh_source.strength == 1.0


@pytest.mark.parametrize("library", ('moab', 'libmesh'))
def test_umesh_source_independent(run_in_tmpdir, request, void_model, library):
import openmc.lib
# skip the test if the library is not enabled
if library == 'moab' and not openmc.lib._dagmc_enabled():
pytest.skip("DAGMC (and MOAB) mesh not enabled in this build.")

if library == 'libmesh' and not openmc.lib._libmesh_enabled():
pytest.skip("LibMesh is not enabled in this build.")

model = void_model

mesh_filename = Path(request.fspath).parent / "test_mesh_tets.e"
uscd_mesh = openmc.UnstructuredMesh(mesh_filename, library)
ind_source = openmc.IndependentSource()
n_elements = 12_000
model.settings.source = openmc.MeshSource(uscd_mesh, n_elements*[ind_source])
model.export_to_model_xml()
try:
openmc.lib.init()
openmc.lib.simulation_init()
sites = openmc.lib.sample_external_source(10)
openmc.lib.statepoint_write('statepoint.h5')
finally:
openmc.lib.finalize()

with openmc.StatePoint('statepoint.h5') as sp:
uscd_mesh = sp.meshes[uscd_mesh.id]

# ensure at least that all sites are inside the mesh
bounding_box = uscd_mesh.bounding_box
for site in sites:
assert site.r in bounding_box


def test_mesh_source_file(run_in_tmpdir):
# Creating a source file with a single particle
source_particle = openmc.SourceParticle(time=10.0)
Expand Down

0 comments on commit e77a524

Please sign in to comment.