Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

specify n_files for split in multisession_registration #633

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions studio/app/common/dataclass/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from studio.app.common.dataclass.base import BaseData
from studio.app.common.dataclass.utils import create_images_list
from studio.app.common.schemas.outputs import PlotMetaData
from studio.app.const import MAX_IMAGE_DATA_PART_SIZE
from studio.app.dir_path import DIRPATH


Expand Down Expand Up @@ -50,17 +49,12 @@ def __init__(
del data
gc.collect()

def split_image(self, output_dir: str):
def split_image(self, output_dir: str, n_files: int = 2):
assert n_files > 1, "n_files should be greater than 1"

image = self.data
size = image.nbytes
frames = image.shape[0]

if size > MAX_IMAGE_DATA_PART_SIZE:
frames_per_part = math.ceil(
frames / math.ceil(size / MAX_IMAGE_DATA_PART_SIZE)
)
else:
frames_per_part = frames // 2
frames_per_part = math.ceil(frames // n_files)

file_name = self.path[0] if isinstance(self.path, list) else self.path
name, ext = os.path.splitext(os.path.basename(file_name))
Expand All @@ -69,13 +63,13 @@ def split_image(self, output_dir: str):
_dir = join_filepath([output_dir, "image_split", name])
create_directory(_dir)

for t in np.arange(0, frames, frames_per_part):
_path = join_filepath([_dir, f"{name}_{t//frames_per_part}{ext}"])
for n in range(n_files):
_path = join_filepath([_dir, f"{name}_{n}{ext}"])
with tifffile.TiffWriter(_path, bigtiff=True) as tif:
if t == frames - 1:
tif.write(image[t:])
if n == n_files - 1:
tif.write(image[n * frames_per_part :])
else:
tif.write(image[t : t + frames_per_part])
tif.write(image[n * frames_per_part : (n + 1) * frames_per_part])
save_paths.append(_path)

return save_paths
Expand Down
2 changes: 0 additions & 2 deletions studio/app/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,3 @@ class FILETYPE:
NOT_DISPLAY_ARGS_LIST = ["params", "output_dir", "nwbfile", "kwargs"]

DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

MAX_IMAGE_DATA_PART_SIZE = 1_000_000_000 # 1GB
5 changes: 4 additions & 1 deletion studio/app/optinist/wrappers/caiman/cnmf_multisession.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ def caiman_cnmf_multisession(

Ain = reshaped_params.pop("Ain", None)
roi_thr = reshaped_params.pop("roi_thr", None)
n_reg_files = reshaped_params.pop("n_reg_files", 2)
if n_reg_files < 2:
raise Exception(f"Set n_reg_files to a integer value gte 2. Now {n_reg_files}.")
reg_file_rate = reshaped_params.pop("reg_file_rate", 1.0)
if reg_file_rate > 1.0:
logger.warn(
f"reg_file_rate {reg_file_rate}, should be lte 1. Using 1.0 instead."
)
reg_file_rate = 1.0

split_image_paths = images.split_image(output_dir)
split_image_paths = images.split_image(output_dir, n_files=n_reg_files)
n_split_images = len(split_image_paths)

logger.info(f"image was split into {n_split_images} parts.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data_params:

init_params:
# Ain: null # TBD: need to support 2D array type.
n_reg_files: 2 # number of files to be split.
reg_file_rate: 1.0 # threshold for the cell is detected in how many files
K: 4
gSig: [4, 4]
Expand Down
Loading