Skip to content

Commit

Permalink
refactor: move combine_masks to components
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Apr 3, 2023
1 parent 2d2cad2 commit 31dff51
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
18 changes: 18 additions & 0 deletions stimupy/components/__init__.py
Expand Up @@ -184,6 +184,24 @@ def mask_regions(
}


def combine_masks(*masks):
# Initialize
combined_mask = np.zeros_like(masks[0])
for mask in masks:
# Check that masks have the same shape
if not mask.shape == combined_mask.shape:
raise ValueError("Not all masks have the same shape")

# Check that masks don't overlap
if (combined_mask & mask).any():
raise ValueError("Masks overlap")

# Combine: increase `mask`-idc by adding the current highest idx in combined_mask
combined_mask = np.where(mask, mask + combined_mask.max(), combined_mask)

return combined_mask


def draw_regions(mask, intensities, intensity_background=0.5):
"""Draw regions defined by mask, with given intensities
Expand Down
19 changes: 1 addition & 18 deletions stimupy/utils/utils.py
@@ -1,4 +1,5 @@
import copy

import numpy as np
import scipy.special as sp

Expand Down Expand Up @@ -434,21 +435,3 @@ def strip_dict(
if name in dct.keys():
new_dict[name] = dct[name]
return new_dict


def combine_masks(*masks):
# Initialize
combined_mask = np.zeros_like(masks[0])
for mask in masks:
# Check that masks have the same shape
if not mask.shape == combined_mask.shape:
raise ValueError("Not all masks have the same shape")

# Check that masks don't overlap
if (combined_mask & mask).any():
raise ValueError("Masks overlap")

# Combine: increase `mask`-idc by adding the current highest idx in combined_mask
combined_mask = np.where(mask, mask + combined_mask.max(), combined_mask)

return combined_mask

0 comments on commit 31dff51

Please sign in to comment.