Skip to content

Commit

Permalink
Zoom-crop with the help of an exampole
Browse files Browse the repository at this point in the history
  • Loading branch information
jingxuanlim committed Apr 25, 2021
1 parent 9e507e0 commit 3c007df
Showing 1 changed file with 98 additions and 31 deletions.
129 changes: 98 additions & 31 deletions CircuitSeeker/mocorr.py
Expand Up @@ -212,6 +212,7 @@ def applyTransform(
sitk.sitkNearestNeighbor, 0.0, moving_im.GetPixelID()
)
else:
aligne_to_fixed = False
transformed_im = sitk.Resample(moving_im, moving_im, transform,
sitk.sitkNearestNeighbor, 0.0, moving_im.GetPixelID()
)
Expand All @@ -231,6 +232,8 @@ def applyTransform(
# take the middle elements from zoomed
plane_inds = planes_with_some_data(zoomed, 0.4)
resampled_slice = slice(plane_inds[0],plane_inds[0]+moving_nplanes)
print(zoomed.shape)
print(resampled_slice)

zoomed = zoomed[resampled_slice]

Expand Down Expand Up @@ -303,10 +306,12 @@ def motionCorrect(
fixed, fixed_vox, moving_vox,
write_path=None, dataset_path=None,
distributed_state=None, sigma=7,
transforms_dir=None, folder_slicer=None, pad_fixed=False,
transforms_dir=None, time_stride=1,
folder_slicer=None, pad_fixed=False,
params=None, t_chunksize=False, force_not_chunk=False,
correct_another=None,
t_indices=None,
resample_with_fixed=False,
slice_transformed=(slice(None), slice(None), slice(None)),
resume=False,
**kwargs,
Expand All @@ -333,6 +338,7 @@ def motionCorrect(
correct_another [da.Array]: dask array of another set of frames that the transformations will be applied
t_indices [np.ndarray]: only transform select frames (only works when chunking and starting fresh)
slice_transformed [tuple]: slice the transformed image as early as possible, right after it was transformed in an attempt to save memory
resample_with_fixed [bool]: whether or not to use the fixed image in the resampling step
"""

# set up the distributed environment
Expand All @@ -349,8 +355,9 @@ def motionCorrect(
ds.initializeClient()

# create (lazy) dask bag from all frames
frame_paths = csio.globPaths(folder, prefix, suffix)
nframes = len(frame_paths)
frames = csio.daskBagOfFilePaths(folder, prefix, suffix, slicer=folder_slicer)
nframes = frames.npartitions

# scale cluster carefully
if distributed_state is None:
Expand All @@ -367,9 +374,9 @@ def motionCorrect(
ddataset_path = delayed(dataset_path)

if params is None:
params = alignFramesToReference(frames, dfixed, dfixed_vox, dmoving_vox,
params = alignFramesToReference(frame_paths, dfixed, dfixed_vox, dmoving_vox,
sigma, ddataset_path,
transforms_dir=transforms_dir, resume=resume, pad_fixed=pad_fixed)
transforms_dir=transforms_dir, resume=resume, pad_fixed=pad_fixed, time_stride=time_stride)


# transform frames with params
Expand All @@ -393,7 +400,7 @@ def motionCorrect(


from analysis_toolbox.utils import find_files
actual_write_paths = find_files(write_path + '/', ext='h5', compute_path=True)['path']
actual_write_paths = find_files(write_path + '/', ext='h5', compute_paths=True)['path']

if resume:

Expand Down Expand Up @@ -447,21 +454,33 @@ def motionCorrect(
import zarr
frame_paths = zarr.open(frames_to_correct, mode='r')
else:
frame_paths = np.array(csio.globPaths(frame_to_correct, suffix='.h5', prefix='TM')) ## convert to array for array indexing
elif isinstance(frame_to_correct, da.Array):
frame_paths = np.array(csio.globPaths(frames_to_correct, suffix='.h5', prefix='TM')) ## convert to array for array indexing
elif isinstance(frames_to_correct, da.Array):
fileext = None
pass

resampled_slice_ref_index = 0
_, resampled_slice = applyTransform(frame_paths[resampled_slice_ref_index], moving_vox, params[resampled_slice_ref_index],
fixed=fixed, fixed_vox=fixed_vox, return_resampled_slice=True, dataset_path=dataset_path)
if resample_with_fixed:

resampled_slice_ref_index = 0
print(frame_paths[resampled_slice_ref_index])
resampling_fixed = fixed
_, resampled_slice = applyTransform(frame_paths[resampled_slice_ref_index], moving_vox, params[resampled_slice_ref_index],
fixed=resampling_fixed, fixed_vox=fixed_vox, return_resampled_slice=True, dataset_path=dataset_path, pad_fixed=pad_fixed)
print(f'Resampled slice: {resampled_slice}')

else:

resampling_fixed = None

print(f'Resampling fixed: {resampling_fixed}')

## TODO: replace applyTransformToChunksOfFrames with applyTransformToAChunkOfFrames and use the former wrap everything in this block
transformed = indices.map_partitions(applyTransformToChunksOfFrames,
frame_dir=frames_to_correct, params_path=transforms_dir + '/params.npy',
moving_vox=moving_vox, dataset_path=dataset_path,
resampled_slice=resampled_slice,
slice_transformed=slice_transformed, write_path=write_path,
fixed=dfixed, fixed_vox=dfixed_vox).to_delayed()
fixed=resampling_fixed, fixed_vox=dfixed_vox, pad_fixed=pad_fixed).to_delayed()

if write_path is None:
indices_len = list(indices.map_partitions(lambda x: len(x)).compute())
Expand All @@ -472,6 +491,7 @@ def motionCorrect(
elif isinstance(frames_to_correct, da.Array):
example_image = frames_to_correct[0]

## TODO: the shape changes when pad_fixed is used
shape = example_image.shape
dtype = example_image.dtype

Expand All @@ -496,10 +516,23 @@ def motionCorrect(
else:
frames_to_correct = frames # dask bag of file paths

if resample_with_fixed:

resampled_slice_ref_index = 0
print(frame_paths[resampled_slice_ref_index])
resampling_fixed = dfixed
_, resampled_slice = applyTransform(frame_paths[resampled_slice_ref_index], moving_vox, params[resampled_slice_ref_index],
fixed=resampling_fixed, fixed_vox=dfixed_vox, return_resampled_slice=True, dataset_path=dataset_path, pad_fixed=pad_fixed)
print(f'Resampled slice: {resampled_slice}')

else:

resampling_fixed = None

# work on each frame separately -- better for parallelism
transformed = applyTransformToFrames(frames_to_correct, params, dmoving_vox, ddataset_path,
slice_transformed=slice_transformed, write_path=write_path,
fixed=dfixed, fixed_vox=dfixed_vox)
fixed=dfixed, fixed_vox=dfixed_vox, pad_fixed=pad_fixed)

# release resources
if distributed_state is None:
Expand Down Expand Up @@ -534,8 +567,9 @@ def runAlignFramesToReference(

if folder_slicer is None: folder_slicer = slice(len(files))
# create (lazy) dask bag from all frames
frame_paths = csio.globPaths(folder, prefix, suffix)
nframes = len(frame_paths)
frames = csio.daskBagOfFilePaths(folder, prefix, suffix, slicer=folder_slicer)
nframes = frames.npartitions

# scale cluster carefully
if distributed_state is None:
Expand All @@ -551,7 +585,7 @@ def runAlignFramesToReference(
dmoving_vox = delayed(moving_vox)
ddataset_path = delayed(dataset_path)

params = alignFramesToReference(frames, dfixed, dfixed_vox, dmoving_vox,
params = alignFramesToReference(frame_paths, dfixed, dfixed_vox, dmoving_vox,
sigma, ddataset_path,
transforms_dir=transforms_dir, pad_fixed=pad_fixed)

Expand All @@ -563,66 +597,98 @@ def runAlignFramesToReference(
return params


def alignFramesToReference(frames, dfixed, dfixed_vox, dmoving_vox,
sigma, ddataset_path,
def alignFramesToReference(frame_paths, dfixed, dfixed_vox, dmoving_vox,
sigma, ddataset_path, time_stride=1,
resume=True, transforms_dir=None, pad_fixed=False):
"""
frames [db.Bag]: dask bag of file paths
"""

from scipy.ndimage import percentile_filter, gaussian_filter1d

paths = list(frames)
expected_param_savepaths = [os.path.join(transforms_dir, os.path.splitext(os.path.basename(path))[0] + '_rigid.npy') for path in paths]
## I believe this request workers
expected_param_savepaths = [os.path.join(transforms_dir, os.path.splitext(os.path.basename(path))[0] + '_rigid.npy') for path in frame_paths]

if resume:
from analysis_toolbox.utils import find_files
from tqdm.notebook import tqdm

actual_param_savepaths = find_files(transforms_dir + '/', grep='TM', ext='npy', compute_path=True)['path']
actual_param_savepaths = find_files(transforms_dir + '/', grep='TM', ext='npy', compute_paths=True)['path']

if len(actual_param_savepaths) == 0:

resume = False ## nothing to resume
savepaths = db.from_sequence(expected_param_savepaths, npartitions=len(expected_param_savepaths))
savepaths = expected_param_savepaths
framepaths = frame_paths

else:

missing_param_savepaths = np.setdiff1d(expected_param_savepaths, actual_param_savepaths)
savepaths = db.from_sequence(missing_param_savepaths, npartitions=len(missing_param_savepaths))
savepaths = missing_param_savepaths

missing_indices = np.array([np.where(np.array(expected_param_savepaths)==missing_param_savepath)[0][0] for missing_param_savepath in tqdm(missing_param_savepaths)])
paths = np.array(paths)[missing_indices]
frames = db.from_sequence(paths, npartitions=len(paths))
frame_paths = np.array(frame_paths)[missing_indices]
framepaths = frame_paths

else:

savepaths = db.from_sequence(expected_param_savepaths)

params = frames.map(lambda b,c,d,w,x,y,z: rigidAlignAndSave(w,b,x,y, dataset_path=z, pad_fixed=d, savepath=c),
w=dfixed, x=dfixed_vox, y=dmoving_vox, z=ddataset_path, d=pad_fixed, c=savepaths,
savepaths = expected_param_savepaths
framepaths = frame_paths

## subsample if requested
if not resume and time_stride > 1:
total_nframes = len(framepaths)
## calculate the frames to compute, taking care to include the last frame
## remove the last frame if it is already included in np.arange (i.e. it is repeated after appending)
time_slice = np.unique(np.append(np.arange(0, total_nframes, time_stride), total_nframes-1)).astype('int')
print(time_slice)
savepaths = np.array(savepaths)[time_slice]
framepaths = np.array(framepaths)[time_slice]
compute_nframes = len(framepaths)

## convert paths to dask bags
savepaths_bag = db.from_sequence(savepaths, npartitions=len(savepaths))
framepaths_bag = db.from_sequence(framepaths, npartitions=len(framepaths))

params = framepaths_bag.map(lambda b,c,d,w,x,y,z: rigidAlignAndSave(w,b,x,y, dataset_path=z, pad_fixed=d, savepath=c),
w=dfixed, x=dfixed_vox, y=dmoving_vox, z=ddataset_path, d=pad_fixed, c=savepaths_bag,
).compute()

## if not resuming, then only rely on the params recently computed
if not resume:
params = np.array(list(params))
else:
## reload from files
## reload from files, because the latest run might not have all the params
params = np.stack([np.load(expected_param_savepath, allow_pickle=True) for expected_param_savepath in tqdm(expected_param_savepaths)])

# (weak) outlier removal and smoothing
params = percentile_filter(params, 50, footprint=np.ones((3,1)))
params = gaussian_filter1d(params, sigma, axis=0)

# interpolate
if not resume and time_stride > 1:
x = time_slice.copy()
y = np.mgrid[:6]
z = params.copy().T

from scipy.interpolate import interp2d
f = interp2d(x, y, z, kind='cubic')

new_x = np.arange(0, total_nframes, 1)
params = f(new_x, y).T

# write transforms as matrices
if transforms_dir is not None:
if not os.path.exists(transforms_dir): os.makedirs(transforms_dir)
np.save(transforms_dir + '/params.npy', params)

for ind, p in enumerate(params):
transform = _parametersToRigidMatrix(p)
basename = os.path.splitext(os.path.basename(paths[ind]))[0]
path = os.path.join(transforms_dir, basename) + '_rigid.mat'
np.savetxt(path, transform)
basename = os.path.splitext(os.path.basename(expected_param_savepaths[ind]))[0]
matpath = os.path.join(transforms_dir, basename) + '.mat'
npypath = os.path.join(transforms_dir, basename) + '.npy'
np.savetxt(matpath, transform)
if not os.path.exists(npypath): np.save(npypath, p)

return params

Expand Down Expand Up @@ -664,6 +730,7 @@ def applyTransformToChunksOfFrames(indices, frame_dir, params_path, moving_vox,

## inconsistent between the two conditions
if fileext == '':
from analysis_toolbox.utils import change_root_dir_in_path
from functools import partial
change_root_dir_in_path_for_transformed = partial(change_root_dir_in_path, replace_with=write_path)
save_paths = np.vectorize(change_root_dir_in_path_for_transformed)(frame_paths)
Expand Down

0 comments on commit 3c007df

Please sign in to comment.