Skip to content

Commit

Permalink
Fix several issues in rd_region.py
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Apr 22, 2024
1 parent 6acf43f commit df0f185
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/resdata/grid/rd_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,13 @@ class ResdataRegion(BaseCClass):
_set_name = ResdataPrototype("void rd_region_set_name( rd_region , char*)")
_get_name = ResdataPrototype("char* rd_region_get_name( rd_region )")
_contains_ijk = ResdataPrototype(
"void rd_region_contains_ijk( rd_region , int , int , int)"
"bool rd_region_contains_ijk( rd_region , int , int , int)"
)
_contains_global = ResdataPrototype(
"void rd_region_contains_global( rd_region, int )"
"bool rd_region_contains_global( rd_region, int )"
)
_contains_active = ResdataPrototype(
"void rd_region_contains_active( rd_region , int )"
"bool rd_region_contains_active( rd_region , int )"
)
_equal = ResdataPrototype("bool rd_region_equal( rd_region , rd_region )")
_select_true = ResdataPrototype("void rd_region_select_true( rd_region , rd_kw)")
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def scale_kw(self, rd_kw, scale, force_active=False):
def imul_kw(self, target_kw, other, force_active=False):
if isinstance(other, ResdataKW):
if target_kw.assert_binary(other):
self._imul_kw(target_kw, other)
self._imul_kw(target_kw, other, force_active)
else:
raise TypeError("Type mismatch")
else:
Expand All @@ -1063,7 +1063,7 @@ def imul_kw(self, target_kw, other, force_active=False):
def idiv_kw(self, target_kw, other, force_active=False):
if isinstance(other, ResdataKW):
if target_kw.assert_binary(other):
self._idiv_kw(target_kw, other)
self._idiv_kw(target_kw, other, force_active)
else:
raise TypeError("Type mismatch")
else:
Expand Down
73 changes: 73 additions & 0 deletions python/tests/rd_tests/test_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from resdata.grid import Grid, ResdataRegion
from resdata.util.util import IntVector
from tests import ResdataTest
from resdata.grid.faults import Layer


class RegionTest(ResdataTest):
Expand Down Expand Up @@ -417,3 +418,75 @@ def test_select_false(empty_region, all_true_kw):
false_region = empty_region.copy()
false_region.select_false(all_true_kw)
assert empty_region == false_region

def test_select_layer(empty_region, full_region):
layer = Layer(10,10)
empty_region.select_from_layer(layer, 0, 0)
assert empty_region == full_region

def test_iadd_kw_empty(empty_region, poro):
poro_copy = poro.copy()
poro.add(poro, mask=empty_region)
assert poro == poro_copy

def test_iadd_kw_full(full_region, poro):
poro_copy = poro.copy()
poro.add(poro, mask=full_region)
assert poro == poro_copy * 2

def test_iadd_kw_full(full_region, poro):
poro_copy = poro.copy()
poro.add(2.0, mask=full_region)
assert poro == poro_copy + 2

def test_isub_kw_full(full_region, poro):
poro_copy = poro.copy()
poro.sub(2.0, mask=full_region)
assert poro == poro_copy - 2

def test_imul_kw_full(full_region, poro):
poro_copy = poro.copy()
poro += 1.0
poro.mul(2.0, mask=full_region)
assert poro == (poro_copy + 1.0) * 2.0


def test_idiv_kw_full_scalar(full_region, poro):
poro_copy = poro.copy()
poro += 1.0
poro.div(2.0, mask=full_region)
assert poro == (poro_copy + 1.0) * 0.5

def test_idiv_kw_full(full_region, poro):
poro += 1.0
poro.div(poro, mask=full_region)
assert list(poro) == [1.0]*len(poro)

def test_mul_kw_full(full_region, poro):
poro += 1.0
poro.mul(poro, mask=full_region)
assert list(poro) == [4.0]*len(poro)

def test_copy_kw(full_region, empty_region, poro, grid):
poro_copy = grid.create_kw(
np.zeros((grid.nx, grid.ny, grid.nz), dtype=np.float32), "PORO", True
)
full_region.copy_kw(poro_copy, poro)
assert poro_copy == poro

def test_get_active_list(full_region, active_region):
assert full_region.get_active_list() == active_region.get_active_list()

def test_contains_ijk(full_region):
assert full_region.contains_ijk(0,0,0)

def test_contains_global(full_region):
assert full_region.contains_global(0)

def test_contains_global(full_region):
assert full_region.contains_active(0)

def test_get_set_name(full_region):
full_region.set_name("full")
assert full_region.get_name() == full_region.name
assert full_region.get_name() == "full"

0 comments on commit df0f185

Please sign in to comment.