In [11]:
import numpy as np
import torch
import tensorflow as tf
import nd2
from dask import delayed, compute

In [2]:
print('PyTorch: {}'.format(torch.__version__))
print('Numpy: {}'.format(np.__version__))
print('TensorFlow: {}'.format(tf.__version__))

PyTorch: 1.11.0
Numpy: 1.22.4
TensorFlow: 2.6.0


In [None]:
from tensorflow.python.client import device_lib

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']
get_available_gpus()

In [None]:
torch.cuda.is_available()

In [3]:
def torch_fft2(data: np.ndarray):
    data2 = torch.fft.fft2(data)    
    return torch.fft.ifft2(data2)

def numpy_fft2(data: np.ndarray):
    data2 = np.fft.fft2(data)
    return np.fft.ifft2(data2)

def tf_fft2(data: np.ndarray):
    data2 = tf.signal.fft2d(data)
    return tf.signal.ifft2d(data2)

In [4]:
# Load dataset
def load_data_xarr(file:str = "../data/A1_s3001.nd2"):
    data = nd2.imread(file, xarray=True, dask=True)
    return data

def load_data(file:str = "../data/A1_s3001.nd2"):
    data = nd2.imread(file, xarray=False, dask=True)
    return data


In [5]:
data = load_data()
data

Unnamed: 0,Array,Chunk
Bytes,2.44 GiB,512.00 kiB
Shape,"(5000, 512, 512)","(1, 512, 512)"
Count,10000 Tasks,5000 Chunks
Type,uint16,numpy.ndarray
"Array Chunk Bytes 2.44 GiB 512.00 kiB Shape (5000, 512, 512) (1, 512, 512) Count 10000 Tasks 5000 Chunks Type uint16 numpy.ndarray",512  512  5000,

Unnamed: 0,Array,Chunk
Bytes,2.44 GiB,512.00 kiB
Shape,"(5000, 512, 512)","(1, 512, 512)"
Count,10000 Tasks,5000 Chunks
Type,uint16,numpy.ndarray


In [13]:
%%time
for image in data:
    np_out = delayed(numpy_fft2)(image)

compute(np_out)

CPU times: total: 2.58 s
Wall time: 2.91 s


(array([[28826.+3.40838469e-14j, 28820.+1.06581410e-13j,
         29725.+1.60094160e-13j, ..., 28283.-3.37160855e-14j,
         28167.+2.01505479e-13j, 28052.-1.70224945e-13j],
        [29245.-1.08579812e-13j, 29225.-2.13162821e-14j,
         29900.+1.55209179e-13j, ..., 28281.-4.57758831e-14j,
         27884.-7.62723218e-14j, 28485.-2.11691775e-13j],
        [29640.-1.40998324e-14j, 29191.-2.26485497e-13j,
         29794.-2.55351296e-14j, ..., 28367.-7.19077575e-14j,
         28463.+1.10911280e-13j, 28659.-1.03528297e-14j],
        ...,
        [30287.-3.21964677e-14j, 29906.-4.97379915e-14j,
         30071.+1.42774681e-13j, ..., 29575.+3.23838178e-14j,
         29295.+1.40443213e-13j, 29081.-2.63511435e-13j],
        [30248.+5.44009282e-15j, 30043.+2.56683563e-13j,
         30278.+2.46247467e-13j, ..., 29081.-3.63806207e-14j,
         28963.-1.25455202e-14j, 29016.-2.35506059e-13j],
        [30250.-2.55351296e-14j, 30241.-9.05941988e-14j,
         30250.+3.62487818e-14j, ..., 29147.+

In [None]:
data[0]

In [None]:
%%time
# tensor = torch.from_numpy(data)
for image in data:
    print(type(image))
    # torch_out = torch_fft2(image)

In [None]:
%%time
# tensor = tf.convert_to_tensor(data)
for image in data:
    tensor = tf.convert_to_tensor(image)
    tf_out = tf_fft2(tensor)

In [None]:
np.testing.assert_allclose(np_out, torch_out.numpy(), rtol=1e-5)
np.testing.assert_allclose(np_out, tf_out.numpy(), rtol=1e-5)