Skip to content

Commit

Permalink
Add C API function for getting mesh bins for rasterized plot (#2854)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulromano committed Jan 23, 2024
1 parent fca4da4 commit e6a36ff
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 2 deletions.
48 changes: 47 additions & 1 deletion openmc/lib/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ctypes import (c_int, c_int32, c_char_p, c_double, POINTER, Structure,
create_string_buffer, c_uint64, c_size_t)
from random import getrandbits
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Sequence
from weakref import WeakValueDictionary

import numpy as np
Expand All @@ -13,6 +13,7 @@
from .core import _FortranObjectWithID
from .error import _error_handler
from .material import Material
from .plot import _Position

__all__ = ['RegularMesh', 'RectilinearMesh', 'CylindricalMesh', 'SphericalMesh', 'UnstructuredMesh', 'meshes']

Expand Down Expand Up @@ -43,6 +44,11 @@ class _MaterialVolume(Structure):
POINTER(c_int), POINTER(c_uint64)]
_dll.openmc_mesh_material_volumes.restype = c_int
_dll.openmc_mesh_material_volumes.errcheck = _error_handler
_dll.openmc_mesh_get_plot_bins.argtypes = [
c_int32, _Position, _Position, c_int, POINTER(c_int), POINTER(c_int32)
]
_dll.openmc_mesh_get_plot_bins.restype = c_int
_dll.openmc_mesh_get_plot_bins.errcheck = _error_handler
_dll.openmc_get_mesh_index.argtypes = [c_int32, POINTER(c_int32)]
_dll.openmc_get_mesh_index.restype = c_int
_dll.openmc_get_mesh_index.errcheck = _error_handler
Expand Down Expand Up @@ -203,6 +209,46 @@ def material_volumes(
])
return volumes

def get_plot_bins(
self,
origin: Sequence[float],
width: Sequence[float],
basis: str,
pixels: Sequence[int]
) -> np.ndarray:
"""Get mesh bin indices for a rasterized plot.
.. versionadded:: 0.14.1
Parameters
----------
origin : iterable of float
Origin of the plotting view. Should have length 3.
width : iterable of float
Width of the plotting view. Should have length 2.
basis : {'xy', 'xz', 'yz'}
Plotting basis.
pixels : iterable of int
Number of pixels in each direction. Should have length 2.
Returns
-------
2D numpy array with mesh bin indices corresponding to each pixel within
the plotting view.
"""
origin = _Position(*origin)
width = _Position(*width)
basis = {'xy': 1, 'xz': 2, 'yz': 3}[basis]
pixel_array = (c_int*2)(*pixels)
img_data = np.zeros((pixels[1], pixels[0]), dtype=np.dtype('int32'))

_dll.openmc_mesh_get_plot_bins(
self._index, origin, width, basis, pixel_array,
img_data.ctypes.data_as(POINTER(c_int32))
)
return img_data


class RegularMesh(Mesh):
"""RegularMesh stored internally.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
# Dependencies
'python_requires': '>=3.7',
'install_requires': [
'numpy>=1.9', 'h5py', 'scipy', 'ipython', 'matplotlib',
'numpy>=1.9', 'h5py', 'scipy<1.12', 'ipython', 'matplotlib',
'pandas', 'lxml', 'uncertainties'
],
'extras_require': {
Expand Down
58 changes: 58 additions & 0 deletions src/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "openmc/message_passing.h"
#include "openmc/openmp_interface.h"
#include "openmc/particle_data.h"
#include "openmc/plot.h"
#include "openmc/random_dist.h"
#include "openmc/search.h"
#include "openmc/settings.h"
Expand Down Expand Up @@ -1898,6 +1899,63 @@ extern "C" int openmc_mesh_material_volumes(int32_t index, int n_sample,
return (n == -1) ? OPENMC_E_ALLOCATE : 0;
}

extern "C" int openmc_mesh_get_plot_bins(int32_t index, Position origin,
Position width, int basis, int* pixels, int32_t* data)
{
if (int err = check_mesh(index))
return err;
const auto& mesh = model::meshes[index].get();

int pixel_width = pixels[0];
int pixel_height = pixels[1];

// get pixel size
double in_pixel = (width[0]) / static_cast<double>(pixel_width);
double out_pixel = (width[1]) / static_cast<double>(pixel_height);

// setup basis indices and initial position centered on pixel
int in_i, out_i;
Position xyz = origin;
enum class PlotBasis { xy = 1, xz = 2, yz = 3 };
PlotBasis basis_enum = static_cast<PlotBasis>(basis);
switch (basis_enum) {
case PlotBasis::xy:
in_i = 0;
out_i = 1;
break;
case PlotBasis::xz:
in_i = 0;
out_i = 2;
break;
case PlotBasis::yz:
in_i = 1;
out_i = 2;
break;
default:
UNREACHABLE();
}

// set initial position
xyz[in_i] = origin[in_i] - width[0] / 2. + in_pixel / 2.;
xyz[out_i] = origin[out_i] + width[1] / 2. - out_pixel / 2.;

#pragma omp parallel
{
Position r = xyz;

#pragma omp for
for (int y = 0; y < pixel_height; y++) {
r[out_i] = xyz[out_i] - out_pixel * y;
for (int x = 0; x < pixel_width; x++) {
r[in_i] = xyz[in_i] + in_pixel * x;
data[pixel_width * y + x] = mesh->get_bin(r);
}
}
}

return 0;
}

//! Get the dimension of a regular mesh
extern "C" int openmc_regular_mesh_get_dimension(
int32_t index, int** dims, int* n)
Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,31 @@ def test_regular_mesh(lib_init):
assert sum(f[1] for f in elem_vols) == pytest.approx(1.26 * 1.26, 1e-2)


def test_regular_mesh_get_plot_bins(lib_init):
mesh: openmc.lib.RegularMesh = openmc.lib.meshes[2]
mesh.dimension = (2, 2, 1)
mesh.set_parameters(lower_left=(-1.0, -1.0, -0.5),
upper_right=(1.0, 1.0, 0.5))

# Get bins for a plot view covering only a single mesh bin
mesh_bins = mesh.get_plot_bins((-0.5, -0.5, 0.), (0.1, 0.1), 'xy', (20, 20))
assert (mesh_bins == 0).all()
mesh_bins = mesh.get_plot_bins((0.5, 0.5, 0.), (0.1, 0.1), 'xy', (20, 20))
assert (mesh_bins == 3).all()

# Get bins for a plot view covering all mesh bins. Note that the y direction
# (first dimension) is flipped for plotting purposes
mesh_bins = mesh.get_plot_bins((0., 0., 0.), (2., 2.), 'xy', (20, 20))
assert (mesh_bins[:10, :10] == 2).all()
assert (mesh_bins[:10, 10:] == 3).all()
assert (mesh_bins[10:, :10] == 0).all()
assert (mesh_bins[10:, 10:] == 1).all()

# Get bins for a plot view outside of the mesh
mesh_bins = mesh.get_plot_bins((100., 100., 0.), (2., 2.), 'xy', (20, 20))
assert (mesh_bins == -1).all()


def test_rectilinear_mesh(lib_init):
mesh = openmc.lib.RectilinearMesh()
x_grid = [-10., 0., 10.]
Expand Down

0 comments on commit e6a36ff

Please sign in to comment.