Skip to content

Commit

Permalink
Merge pull request #25 from bmorris3/caching-improvements
Browse files Browse the repository at this point in the history
Caching improvements
  • Loading branch information
Brett M. Morris committed Nov 24, 2016
2 parents 6b98006 + 6ae229e commit 84093f4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 60 deletions.
83 changes: 29 additions & 54 deletions shampoo/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import warnings
from multiprocessing.dummy import Pool as ThreadPool

Expand Down Expand Up @@ -79,25 +78,6 @@ def shift_peak(arr, shifts_xy):
int(shifts_xy[1]), axis=1)


def make_items_hashable(input_iterable):
"""
Take a list or tuple of objects, convert any items that are lists into
tuples to make them hashable.
Parameters
----------
input_iterable : list or tuple
Items to convert to hashable types
Returns
-------
hashable_tuple : tuple
``input_iterable`` made hashable
"""
return tuple([tuple(i) if isinstance(i, list) or isinstance(i, np.ndarray)
else i for i in input_iterable])


def _load_hologram(hologram_path):
"""
Load a hologram from path ``hologram_path`` using scikit-image and numpy.
Expand Down Expand Up @@ -221,12 +201,13 @@ def from_tif(cls, hologram_path, **kwargs):
return cls(hologram, **kwargs)

def reconstruct(self, propagation_distance,
plot_aberration_correction=False,
plot_fourier_peak=False,
cache=False, digital_phase_mask=None):
plot_aberration_correction=False, plot_fourier_peak=False,
cache=False):
"""
Wrapper around `~shampoo.reconstruction.Hologram.reconstruct_wave` for
caching.
Reconstruct the wave at ``propagation_distance``.
If ``cache`` is `True`, the reconstructed wave will be cached onto the
`~shampoo.reconstruction.Hologram` object for quick retrieval.
Parameters
----------
Expand All @@ -239,39 +220,37 @@ def reconstruct(self, propagation_distance,
of the hologram? Default is False.
cache : bool
Cache reconstructions onto the hologram object? Default is False.
digital_phase_mask : `~numpy.ndarray`
Digital phase mask, if you have one precomputed. Default is None.
Returns
-------
reconstructed_wave : `~shampoo.reconstruction.ReconstructedWave`
The reconstructed wave.
"""

if cache:
cache_key = make_items_hashable((propagation_distance,
self.wavelength, self.dx, self.dy))
# Cache dictionary is accessible by keys = propagation distances
cache_key = propagation_distance

# If this reconstruction is cached, get it.
if cache and cache_key in self.reconstructions:
reconstructed_wave = self.reconstructions[cache_key]
# If this reconstruction is cached, get it.
if cache_key in self.reconstructions:
reconstructed_wave = self.reconstructions[cache_key]

# If this reconstruction is not in the cache,
# or if the cache is turned off, do the reconstruction
elif (cache and cache_key not in self.reconstructions) or not cache:
reconstructed_wave = self.reconstruct_wave(propagation_distance, digital_phase_mask,
# If this reconstruction is not cached, calculate it and cache it
else:
reconstructed_wave = self._reconstruct(propagation_distance,
plot_aberration_correction=plot_aberration_correction,
plot_fourier_peak=plot_fourier_peak)
self.reconstructions[cache_key] = reconstructed_wave

# If this reconstruction should be cached and it is not:
if cache and cache_key not in self.reconstructions:
self.reconstructions[cache_key] = ReconstructedWave(reconstructed_wave)
else:
reconstructed_wave = self._reconstruct(propagation_distance,
plot_aberration_correction=plot_aberration_correction,
plot_fourier_peak=plot_fourier_peak)

return ReconstructedWave(reconstructed_wave)

def reconstruct_wave(self, propagation_distance, digital_phase_mask=None,
plot_aberration_correction=False,
plot_fourier_peak=False):
def _reconstruct(self, propagation_distance,
plot_aberration_correction=False,
plot_fourier_peak=False):
"""
Reconstruct wave from hologram stored in file ``hologram_path`` at
propagation distance ``propagation_distance``.
Expand All @@ -280,8 +259,6 @@ def reconstruct_wave(self, propagation_distance, digital_phase_mask=None,
----------
propagation_distance : float
Propagation distance [m]
digital_phase_mask : `~numpy.ndarray`
Use pre-calculated digital phase mask. Default is None.
plot_aberration_correction : bool
Plot the abberation correction visualization? Default is False.
plot_fourier_peak : bool
Expand Down Expand Up @@ -315,16 +292,14 @@ def reconstruct_wave(self, propagation_distance, digital_phase_mask=None,
# Calculate Fourier transform of impulse response function
G = self.fourier_trans_of_impulse_resp_func(propagation_distance)

# if digital_phase_mask is None, calculate one
if digital_phase_mask is None:
# Center the spectral peak
shifted_F_hologram = shift_peak(F_hologram * mask,
[self.n/2-x_peak, self.n/2-y_peak])
# Now calculate digital phase mask. First center the spectral peak:
shifted_F_hologram = shift_peak(F_hologram * mask,
[self.n/2-x_peak, self.n/2-y_peak])

# Apodize the result
psi = self.apodize(shifted_F_hologram * G)
digital_phase_mask = self.get_digital_phase_mask(psi,
plots=plot_aberration_correction)
# Apodize the result
psi = self.apodize(shifted_F_hologram * G)
digital_phase_mask = self.get_digital_phase_mask(psi,
plots=plot_aberration_correction)

# Reconstruct the image
psi = G * shift_peak(fft2(apodized_hologram * digital_phase_mask) * mask,
Expand Down
21 changes: 15 additions & 6 deletions shampoo/tests/test_hologram.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,28 @@ def test_multiple_reconstructions():
At commit cc730bd and earlier, the Hologram.apodize function modified
the Hologram.hologram array every time Hologram.reconstruct was called.
This tests that that should not happen anymore.
Also test that the caching machinery is working.
"""

propagation_distances = [0.5, 0.8]
holo = Hologram(_example_hologram())
h_raw = holo.hologram.copy()
w1 = holo.reconstruct(0.5)
h_apodized1 = holo.hologram.copy()
w2 = holo.reconstruct(0.8)
h_apodized2 = holo.hologram.copy()
holograms = []

for d in propagation_distances:
w = holo.reconstruct(d, cache=True)
holograms.append(holo.hologram)

# check hologram doesn't get modified in place first time
assert np.all(h_raw == h_apodized1)
assert np.all(h_raw == holograms[0])

# check hologram doesn't get modified again
assert np.all(h_apodized1 == h_apodized2)
assert np.all(holograms[0] == holograms[1])

# check that the cached reconstructions exist
for d in propagation_distances:
assert d in holo.reconstructions


def test_nonsquare_hologram():
Expand Down

0 comments on commit 84093f4

Please sign in to comment.