In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dask.array as da
from dask import compute, delayed
from dask.distributed import Client
from dask.diagnostics import ResourceProfiler
import nd2
from ddm.processing import radial_profile
from typing import Tuple

In [None]:
# client.close()

In [None]:
client = Client(n_workers=4, threads_per_worker=1, memory_limit='2GB', processes=False)
client

In [None]:
# #read in data
from ddm.data_handling import read_file
# # from ddm.data_handling import readLIF
# # dData = readLIF('../tests/data/testData3series.lif')

# # from ddm.data_handling import readTIF
# # dData3 = readTIF('../tests/data/21-03-31_ddm_water_control_sample.tif')

# # from ddm.data_handling import readND2
# # dData2 = readND2('../tests/data/testData10frames.nd2')


# dData = read_file('../tests/data/testData3series.lif', experiment=0)
dData = read_file('../tests/data/testData10frames.nd2')
# dData = read_file('../data/A1_s3001.nd2')
# dData = read_file('../tests/data/21-03-31_ddm_water_control_sample.tif')

#dummy data for testing
# dData = np.random.randint(0, 256, size = (20, 512, 512))
# dData = da.random.randint(0, 256, size=(5000, 512, 512), chunks=(10, 512, 512), dtype='uint8')

In [None]:
dData

In [None]:
# import threading
# from typing import Tuple, cast

# def _dask_block(copy: bool, block_id: Tuple[int]) -> np.ndarray:
#         if isinstance(block_id, np.ndarray):
#             return
#         with threading.RLock():
#             was_closed = self.closed
#             if self.closed:
#                 self.open()
#             try:
#                 ncoords = len(self._coord_shape)
#                 idx = self._seq_index_from_coords(block_id[:ncoords])

#                 if idx == self._NO_IDX:
#                     if any(block_id):
#                         raise ValueError(
#                             f"Cannot get chunk {block_id} for single frame image."
#                         )
#                     idx = 0
#                 data = _get_frame(cast(int, idx))
#                 data = data.copy() if copy else data
#                 return data[(np.newaxis,) * ncoords]
#             finally:
#                 if was_closed:
#                     self.close()

# def _get_frame(self, index: int) -> np.ndarray:
#     frame = self._rdr._read_image(index)
#     frame.shape = self._raw_frame_shape
#     return frame.transpose((2, 0, 1, 3)).squeeze()

# def _seq_index_from_coords(
#         self, coords: Sequence
#     ) -> Union[Sequence[int], SupportsInt]:
#         if not self._coord_shape:
#             return self._NO_IDX
        # return np.ravel_multi_index(coords, self._coord_shape)


In [None]:
def to_dask(f, chunk_size = 10, wrapper = False):
    """Convert data to dask array
    """
    shape = (int(f._coord_shape[0]/chunk_size), )
    chunks = [(chunk_size,) * x for x in shape]
    chunks += [(x,) for x in f._frame_shape]
    dask_arr = da.map_blocks(
        f._dask_block,
        copy=True,
        chunks=chunks,
        dtype=f.dtype,
    )
   
    if wrapper:
        from resource_backed_dask_array import ResourceBackedDaskArray

        # this subtype allows the dask array to re-open the underlying
        # nd2 file on compute.
        return ResourceBackedDaskArray.from_array(dask_arr, f)
    return dask_arr


In [None]:
filename = '../data/A1_s3001.nd2'
# filename = '../tests/data/testData10frames.nd2'
f = nd2.ND2File(filename)
dData = to_dask(f)
# f.close()

In [None]:
dData

### Observations
- Rechunking a loaded nd2 file through the `nd2` package, will result in significant overhead. The array is rechunked prior to executing the any computation. The default behaviour of `nd2` is to create chunks with a single image.

In [None]:
#method 1 - high q for A = 0 - rework to optimise the FFTs???
def findMeanSqFFT(dumData):
    # sqFFT = 2*da.abs(da.fft.fft2(dumData))*np.abs(da.fft.fft2(dumData))
    sqFFT = 2*(da.abs(da.fft.fft2(dumData)))**2
    sqFFT_shift = da.fft.fftshift(sqFFT)
    sqFFTmean = da.mean(sqFFT_shift, axis = 0)
    return sqFFTmean/(np.shape(sqFFTmean)[0]*np.shape(sqFFTmean[1]))

def computeAB(sqFFTmean):
    sqFFTrad = radial_profile(sqFFTmean, (np.shape(sqFFTmean)[0]/2, np.shape(sqFFTmean)[1]/2))
    b = np.mean(sqFFTrad[-100:-50]) #change depending on size of array
    a = sqFFTrad - b
    return (a, b)

In [None]:
meanSQ = findMeanSqFFT(dData)
meanSQ

In [None]:
%%time
out = meanSQ.compute()

In [None]:
a, b = computeAB(out)

In [None]:
plt.plot(a[10:])
plt.show()


In [None]:
plt.plot(a)

In [None]:
# client.restart()
client.close()

In [None]:
#method 2 - delay time is zero then A = real part of squared stuff
# fTFt = np.fft.fftshift(2*np.fft.fft2(dData2)*np.conj(np.fft.fft2(dData2, axes = (0,1))))
# aFull = np.real(np.mean(fTFt, axis = 0))
# c = radial_profile(aFull, (512, 512))
# pl.figure(0)
# #pl.plot(c[10:], color = 'r')
# pl.plot(A[10:], color = 'b')

# SqFFT = 2*np.abs(np.fft.fft2(dData2)**2)
# SqFFT = np.fft.fftshift(SqFFT)
# SqFFTmean = np.mean(SqFFT, axis = 0)
# SqFFTrad = radial_profile(SqFFTmean, (512, 512))

# d = SqFFTrad - c

# pl.figure(1)
# #pl.plot(d[10:], color = 'm')
# pl.axhline(B, color = 'c')