In [None]:
import pims_nd2
import numpy as np
from matplotlib import pyplot as plt

In [None]:
from nd2_to_caiman import np_arr_from_nd2

In [None]:
nd2_fpath = 'D:/PhD/Data/T386_MatlabTest/T386_20211202_green.nd2'

In [None]:
%matplotlib qt

In [None]:
# now nd2_to_caiman.py 
def np_arr_from_nd2(nd2_fpath: str):
    # set iter_axes to "t"
    # then: create nd array with sizes matching frame size,
    with pims_nd2.ND2_Reader(nd2_fpath) as nikon_file:  # todo: get metadata too?
        
        sizes_dict = nikon_file.sizes
        sizes = (sizes_dict['t'], sizes_dict['x'], sizes_dict['y'])
        
        # dtype would be float32 by default...
        frames_arr = np.zeros(sizes, dtype=nikon_file.pixel_type)
        
        # TODO: probably it is not even necessary to export an np.array, as nikon_file is an iterable of
        #  subclasses of np array... not sure what caiman needs
        for i_frame, frame in enumerate(nikon_file):
            frames_arr[i_frame] = np.array(nikon_file[0], dtype=nikon_file.pixel_type)  # not sure if dtype needed here
        return frames_arr

In [None]:
nd2_data = np_arr_from_nd2(nd2_fpath)

In [None]:
fig = plt.figure(figsize=(18,18))  # figsize does not work...
plt.pcolormesh(nd2_data[0])
plt.show()

In [None]:
nd2_data.dtype

# Create one image of fft2 matrix

In [None]:
freq_matrix = np.fft.fftshift(np.fft.fft2(nd2_data[0,:,:])) 

In [None]:
fig = plt.figure(figsize=(18,18))  # figsize does not work...
plt.matshow(np.log(np.abs(freq_matrix)))
plt.colorbar()
plt.show()

In [None]:
amplitude_image = np.log(np.abs(freq_matrix))

In [None]:
win = 40

In [None]:
bright_spikes = amplitude_image > 10.8 # default amplitude threshold
rectangle_filter_boundary = np.zeros(amplitude_image.shape)

In [None]:
# mark the rectangle 
end_y = amplitude_image.shape[0]  # todo: maybe switched?
end_x = amplitude_image.shape[1]
rectangle_filter_boundary[round(end_x/2 - win):round(end_x/2 + win), round(end_y/2 - win)] = 1
rectangle_filter_boundary[round(end_x/2 - win):round(end_x/2 + win), round(end_y/2 + win)] = 1
rectangle_filter_boundary[round(end_x/2 - win), round(end_y/2 - win):round(end_y/2 + win)] = 1
rectangle_filter_boundary[round(end_x/2 + win), round(end_y/2 - win):round(end_y/2 + win)] = 1

In [None]:
filtered_spikes = np.copy(bright_spikes)
filtered_spikes[round(end_x/2-win):round(end_x/2+win),round(end_y/2-win):round(end_y/2+win)] = 0;

In [None]:
# filter out whole fft spectrum, show raw and filtered
freq_filtered = np.copy(freq_matrix)
freq_filtered[bright_spikes] = 0.0

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(20,20))
axs[0, 0].pcolormesh(amplitude_image)
axs[0, 1].pcolormesh(bright_spikes + rectangle_filter_boundary)
axs[0, 2].pcolormesh(np.abs(freq_filtered))
axs[1, 0].pcolormesh(filtered_spikes)
for i_row in range(len(amplitude_image)):
    axs[1,1].plot(amplitude_image[i_row, :])
#axs[0,0].colorbar()
axs[1,2].pcolormesh(np.log(np.abs(freq_matrix - freq_filtered)))
plt.show()

In [None]:
nd2_data.shape

In [None]:
amplitude_threshold = 10.8
win = 40
end_x = nd2_data.shape[1]
end_y = nd2_data.shape[2]

In [None]:
filtered_data = np.zeros(nd2_data.shape)

In [None]:
for i_frame in range(nd2_data.shape[0]):
    gray_image = nd2_data[0, :, :]
    freq_image = np.fft.fftshift(np.fft.fft2(gray_image))
    ampl_image = np.log(np.abs(freq_image))
    bright_spikes = ampl_image > amplitude_threshold
    bright_spikes[round(end_x/2-win):round(end_x/2+win),round(end_y/2-win):round(end_y/2+win)] = 0
    freq_image[bright_spikes] = 0
    filt_image =  np.fft.ifft2(freq_image)
    filtered_data[i_frame] = np.abs(filt_image)
    

# Test functions

In [None]:
import RippleNoiseRemoval as rnr

In [None]:
#filt = rnr.rnr_par()

In [None]:
import pims_nd2
import numpy as np
from matplotlib import pyplot as plt
from nd2_to_caiman import np_arr_from_nd2
from multiprocessing import Pool
import time

In [None]:
nd2_fpath = 'D:/PhD/Data/T386_MatlabTest/T386_20211202_green.nd2'
#nd2_fpath = 'D:/T301/T301_base_d1/T301_base_d1.180820.1614.nd2'
win = 40
amplitude_threshold = 10.8

nd2_data = np_arr_from_nd2(nd2_fpath)  #np.uint16

In [None]:
filtered_data_par = np.empty(nd2_data.shape, dtype=np.float64)  # marginally faster than np.zeros. Need to fill all values!
filtered_data = np.empty(nd2_data.shape, dtype=np.float64)  # marginally faster than np.zeros. Need to fill all values!

In [None]:
def rnr_frame(frame):  # endx and endy should be
    import numpy as np
    win = 40
    amplitude_threshold = 10.8
    freq_image = np.fft.fftshift(np.fft.fft2(frame))  # make FFT
    # get log amplitude to detect spikes in fft
    ampl_image = np.log(np.abs(freq_image))
    end_x = ampl_image.shape[0]
    end_y = ampl_image.shape[1]
    bright_spikes = ampl_image > amplitude_threshold
    bright_spikes[round(end_x/2-win):round(end_x/2+win),
                  round(end_y/2-win):round(end_y/2+win)] = 0
    freq_image[bright_spikes] = 0
    filt_image = np.fft.ifft2(freq_image)
    return np.abs(filt_image)  # returns dtype=np.float64 array

In [None]:
from pathos.multiprocessing import ProcessingPool as Pool

In [None]:
p = Pool(16)  # Make sure the computer has this number of threads!

In [None]:
t0_par = time.time()
res_par = p.map(rnr_frame, nd2_data)
for i_frame, frame in enumerate(res_par):
    filtered_data_par[i_frame] = frame
t1_par = time.time()

In [None]:
t1_par - t0_par

In [None]:
t0 = time.time()
for i_frame in range(nd2_data.shape[0]):
        filtered_data[i_frame] = rnr.rnr_frame(nd2_data[i_frame, :, :], 40, 10.8)
t1 = time.time()

In [None]:
print(f"Parallel pool: {t1_par - t0_par}. Solo: {t1 - t0}")

# Final(?) version using RNR class

In [1]:
from RippleNoiseRemoval import RNR
from labrotation.file_handling import open_file
from time import time

In [2]:
t0_init = time()
rnr = RNR(40, 10.8, 4)  # parameters: win, amplitude_threshold, n_cores
print(time() - t0_init)

0.0


In [3]:
t0_open = time()
nd2_fpath = open_file("Open .nd2 file for RNR!")
rnr.open_recording(nd2_fpath)
print(time() - t0_open)

  warn("Please call FramesSequenceND.__init__() at the start of the"


Opened recording 512x512, 577 frames. Initialized empty results array.
18.570338010787964


In [4]:
t0_par = time()
fd_par = rnr.rnr_parallel()
t1_par = time()
print(t1_par - t0_par)

RNR with 4 threads completed.
37.609785079956055


In [5]:
t0_single = time()
fd_sin = rnr.rnr_singlethread()
t1_single = time()
print(t1_single - t0_single)

RNR completed.
8.265014410018921


In [14]:
# check if the single and parallel results are the same
same_lis = [fd_par[i] == fd_sin[i] for i in range(len(fd_par))]
count_mismatch = 0
for frame in same_lis:
    for row in frame:
        for element in row:
            if not(element):
                count_mismatch += 1  # increase count if a pixel is not the same for the two results
print(count_mismatch)  # 0 means two methods are same, as it should be.

0
