In [None]:
import numpy as np
from dask import delayed

import dask.array as da
from numba import jit
import nd2
import matplotlib.pyplot as plt
import pyfftw
# Local imports
from ddm.processing import ddm_fftw, ddm_numpy, get_diff_images, radial_profile_jit

In [None]:
# from dask.distributed import Client
# client = Client()

In [None]:
# client.cluster

In [None]:
# client.close()

In [None]:
# Load dataset
# def load_data(file: str = "../data/A1_s3001.nd2", num_frames: str = 500):
#     data = nd2.imread(file, xarray=True, dask=False)
#     data = data[:num_frames,:,:]
#     return data.astype("float64")

# def load_data_delayed(file: str = "../data/A1_s3001.nd2", num_frames: str = 500):
#     data = nd2.imread(file, xarray=False, dask=True)
#     data = data[:num_frames,:,:]
#     return data.astype("float64")

def load_data(file: str = "../data/A1_s3001.nd2", num_frames: int = 5000):
    f = nd2.ND2File(file)
    data = f.to_xarray(delayed=True)
    # data = data.chunk({"T": "auto"})
    data = data[:num_frames,:,:]
    return data.astype("float64")

In [None]:
# Helper functions
@delayed
def calc_fft(data, tau: int):
    """_summary_

    Parameters
    ----------
    data : _type_
        _description_
    tau : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    num_frames, height, width = data.shape
    image_fft_squared = np.zeros((width, height))
    for jj in range(num_frames - tau):
        image_diff = data[jj+tau] - data[jj]
        image_fft = np.fft.fft2(image_diff)
        image_fft_squared += np.abs(image_fft)**2
    
    fft_shift = np.fft.fftshift(image_fft_squared)
    return fft_shift

@delayed
def calc_fft_pyfftw(data, tau: int):
    """_summary_

    Parameters
    ----------
    data : _type_
        _description_
    tau : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    num_frames, height, width = data.shape
    image_fft_squared = np.zeros((width, height))

    pyfftw.interfaces.cache.enable()
    image_diff = pyfftw.empty_aligned((height,width), dtype='complex64')
    fft_object = pyfftw.builders.fft2(image_diff, threads=8, overwrite_input=True, planner_effort='FFTW_ESTIMATE')

    for jj in range(num_frames - tau):
        image_diff = data[jj+tau] - data[jj]
        image_fft_squared += np.abs(fft_object(image_diff))**2
    
    fft_shift = np.fft.fftshift(image_fft_squared)
    return fft_shift




@delayed
def calc_correlation(fft_shift: np.ndarray, tau: int):
    """_summary_

    Parameters
    ----------
    image_diff_squared : _type_
        _description_
    num_frames : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    # num_frames, height, width = image_fft_squared.shape
    # fft_shift = np.fft.fftshift(image_fft_squared)
    gTau = fft_shift/(num_frames-tau)
    # gTau_radial = radial_profile(gTau, (width/2., height/2.))
    return gTau


In [None]:
def ddm(data, tau: int):
    num_frames, height, width = data.shape
    fft_shift = calc_fft_pyfftw(data, tau=tau)
    gTau = fft_shift/(num_frames-tau)
    x, y = np.indices((width, height))
    gTau_radial = delayed(radial_profile_jit)(gTau, (width/2., height/2.), x, y)
    return gTau_radial

## Benchmark without optimization

In [None]:
data = load_data()

In [None]:
%%time
gtau1 = ddm_numpy(data, 2)
r1 = gtau1.compute()
# plt.plot(r1)

In [None]:
%%time
gtau2 = ddm_fftw(data, 2)
r2 = gtau2.compute()
# plt.plot(r2)

In [None]:
np.testing.assert_allclose(r1, r2, rtol=1e-6)

## Refactoring of functions

In [None]:
%%time
tau = 2
num_frames, height, width = data.shape
# fft_shift = calc_fft(data, tau=tau)
fft_shift = calc_fft_pyfftw(data, tau=tau)
gTau = fft_shift/(num_frames-tau)

x, y = np.indices((width, height))

gTau_radial = delayed(radial_profile_jit)(gTau, (width/2., height/2.), x, y)
result = gTau_radial.compute()

In [None]:
np.testing.assert_allclose(r1, result, rtol=1e-6)

In [None]:
%%time
out = ddm(data, 2)
r4 = out.compute()

In [None]:
np.testing.assert_allclose(r1, r4, rtol=1e-6)

## Explore map blocks for pytorch compatibility

In [None]:
from torchvision import transforms
import torch

@delayed
def transform(img):
    array = np.asarray(img)
    trn = transforms.Compose([
            transforms.ToTensor(),
    ])
    return trn(array)

@delayed
def calc_fft_torch(data, tau: int):
    """_summary_

    Parameters
    ----------
    data : _type_
        _description_
    tau : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    num_frames, height, width = data.shape
    image_fft_squared = torch.zeros(width, height)
    tensor = transform(data)
    for jj in range(num_frames - tau):
        image_diff = tensor[:,jj+tau,:] - tensor[:,jj,:]
        image_fft = torch.fft.fft2(image_diff)
        image_fft_squared += torch.abs(image_fft)**2
    
    fft_shift = torch.fft.fftshift(image_fft_squared)
    return fft_shift


In [None]:
@jit(nopython=True, nogil=True)
def radial_profile_tensor(data, centre, x, y):
    # x, y = np.indices((data.shape))
    r = torch.sqrt((x-centre[0])**2 + (y-centre[1])**2)
    r = r.astype(np.int64)
    tbin = tensor.bincount(r.ravel(),data.ravel())
    nr = tensor.bincount(r.ravel())
    radialprofile = tbin/nr
    return radialprofile

In [None]:
%%time
tensor = transform(data)


In [None]:
tensor_0 = np.asarray(tensor[:,0,:])
tensor_0.shape

In [None]:
tensor.shape

In [None]:
%%time
tau = 2
num_frames, height, width = data.shape
fft_shift = calc_fft_torch(data, tau=tau)
gTau = fft_shift/(num_frames-tau)
x, y = np.indices((width, height))


gTau = np.asarray(gTau.compute())
gTau_radial = radial_profile_jit(gTau, (width/2., height/2.), x, y)
# result = gTau_radial.compute()
# fft_shift = fft_shift.compute()

In [None]:
np.testing.assert_allclose(r1, gTau_radial, rtol=1e-6)