Skip to content

Commit

Permalink
fixed bug and added pad_masks when padding dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnSchmittwilken committed Feb 22, 2023
1 parent 1f217f3 commit c0e8cec
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions stimupy/utils/pad.py
Expand Up @@ -207,6 +207,10 @@ def pad_dict_by_visual_size(dct, padding, ppd, pad_value=0.0, keys=("img", "*mas
if key in keys:
img = dct[key]
if isinstance(img, np.ndarray):
# Add mask which indicates padded region
new_dict["pad_mask"] = pad_by_visual_size(np.zeros(img.shape), padding,
ppd, 1).astype(int)

if key.endswith("mask"):
img = pad_by_visual_size(img, padding, ppd, 0)
img = img.astype(int)
Expand All @@ -215,8 +219,8 @@ def pad_dict_by_visual_size(dct, padding, ppd, pad_value=0.0, keys=("img", "*mas
new_dict[key] = img

# Update visual_size and shape-keys
dct["visual_size"] = resolution.visual_size_from_shape_ppd(img.shape, ppd)
dct["shape"] = resolution.validate_shape(img.shape)
new_dict["visual_size"] = resolution.visual_size_from_shape_ppd(img.shape, ppd)
new_dict["shape"] = resolution.validate_shape(img.shape)
return new_dict


Expand Down Expand Up @@ -294,6 +298,10 @@ def pad_dict_by_shape(dct, padding, pad_value=0, keys=("img", "*mask")):
if key in keys:
img = dct[key]
if isinstance(img, np.ndarray):
# Add mask which indicates padded region
new_dict["pad_mask"] = np.pad(np.zeros(img.shape), padding, mode="constant",
constant_values=1).astype(int)

if key.endswith("mask"):
img = np.pad(img, padding, mode="constant", constant_values=0)
img = img.astype(int)
Expand All @@ -302,9 +310,9 @@ def pad_dict_by_shape(dct, padding, pad_value=0, keys=("img", "*mask")):
new_dict[key] = img

# Update visual_size and shape-keys
dct["shape"] = resolution.validate_shape(img.shape)
new_dict["shape"] = resolution.validate_shape(img.shape)
if "ppd" in dct.keys():
dct["visual_size"] = resolution.visual_size_from_shape_ppd(img.shape, dct["ppd"])
new_dict["visual_size"] = resolution.visual_size_from_shape_ppd(img.shape, dct["ppd"])
return new_dict


Expand Down Expand Up @@ -358,14 +366,18 @@ def pad_dict_to_shape(dct, shape, pad_value=0, keys=("img", "*mask")):
padding_before = padding_per_axis // 2
padding_after = padding_per_axis - padding_before
padding = np.stack([padding_before, padding_after]).T

# Add mask which indicates padded region
new_dict["pad_mask"] = pad_by_shape(np.zeros(img.shape), padding=padding,
pad_value=1).astype(int)

if key.endswith("mask"):
img = pad_by_shape(img, padding=padding, pad_value=0)
img = img.astype(int)
else:
img = pad_by_shape(img, padding=padding, pad_value=pad_value)
new_dict[key] = img

# Update visual_size and shape-keys
new_dict["shape"] = resolution.validate_shape(shape)
if "ppd" in dct.keys():
Expand Down

0 comments on commit c0e8cec

Please sign in to comment.