In [None]:
import tifffile
import numpy as np
import matplotlib
matplotlib.use('TkAgg')

from matplotlib.widgets import Slider
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt

In [None]:
path = "Z:/Adam-Lab-Shared/Data/Michal_Rubin/Dendrites/AceM-neon/AcAx3/28-10-2024-acax3-s4-awake/fov1/vol_001/vol.tif"
clean_path = "clean_movie.tif"
SAMPLING_RATE = 1000

#### Load Data

In [None]:
print("Loading data...")
movie = tifffile.imread(path)
movie = -movie

#### Cleaning Functions

In [None]:
def low_pass_filter(signal, low_freq_to_filter):
    smooth_time = 1 / low_freq_to_filter
    window_length = int(smooth_time * SAMPLING_RATE)
    window_length = int(window_length if window_length % 2 == 1 else window_length + 1)
    signal_smoothed = savgol_filter(signal, window_length=window_length, polyorder=2)
    signal_high_pass = signal - signal_smoothed

    return signal_high_pass

In [None]:
def compute_intensity(movie):
    """ Compute mean intensity of each frame throughout the movie """
    return np.mean(movie, axis=(1, 2))

In [None]:
def detect_intensity_drops(intensity, window_length=11, polyorder=3, bad_frames_threshold=-3.5):
    """ Detect sudden drops in brightness using a highpass filter on the intensity array """
    smoothed_intensity = savgol_filter(intensity, window_length=window_length, polyorder=polyorder)
    high_pass_intensity = intensity - smoothed_intensity
    bad_frames = np.where(high_pass_intensity < bad_frames_threshold)[0]
    return bad_frames

def clean_intensity_drops(movie):
    """ Remove noise and filter the data """
    # Compute mean intensity of each frame throughout the movie
    intensity = compute_intensity(movie)
    t = np.arange(len(intensity))

    fig, ax = plt.subplots(1, 3, figsize=(20, 4))

    ax[0].plot(t, intensity, 'k-')
    ax[0].set_xlabel('Frame')
    ax[0].set_ylabel('Mean intensity')
    ax[0].set_title('Mean intensity over time')

    # Sudden drops in brightness --> noise
    bad_frames = detect_intensity_drops(intensity)

    ax[1].plot(t, intensity, 'k-')
    ax[1].plot(bad_frames, intensity[bad_frames], 'r*', label='Bad frames')
    ax[1].set_xlabel('Frame')
    ax[1].set_ylabel('Mean intensity')
    ax[1].set_title("Bad frame detection")
    ax[1].legend()

    movie_clean = np.delete(movie, bad_frames, axis=0)

    # Recheck noise
    intensity = compute_intensity(movie_clean)
    t = np.arange(len(intensity))
    bad_frames = detect_intensity_drops(intensity)

    ax[2].plot(t, intensity, 'k-')
    ax[2].plot(bad_frames, intensity[bad_frames], 'r*', label='Bad frames')
    ax[2].set_xlabel('Frame')
    ax[2].set_ylabel('Mean intensity')
    ax[2].set_title('Mean intensity over time')

    plt.tight_layout()
    plt.show()

    # TODO dynamically play with -3.5 threshold
    bad_frames = detect_intensity_drops(intensity)
    if len(bad_frames) > 0:
        print(f"Warning: {len(bad_frames)} bad frames remain after cleaning")

    return movie_clean

In [None]:
def compute_psd(signal):
    signal_fft = np.fft.fft(signal)
    signal_psd = np.abs(signal_fft) ** 2

    return signal_psd

In [None]:
def regress_out_poly2(movie, intensity=None, scale_images=False):
    """" Regress out linear drift and quadratic drift from each pixel's time trace
         Optional: Regress out intensity as well
    """
    n_frames, n_row, n_col = movie.shape
    t = np.arange(n_frames)

    # Create regressors
    t_centered = (t - np.mean(t)) / np.std(t)
    t_quad = t_centered ** 2
    t_quad = (t_quad - np.mean(t_quad)) / np.std(t_quad)

    regressors = [t_centered, t_quad] if intensity is None else [intensity, t_centered, t_quad]
    X = np.vstack(regressors).T
    movie_flat = movie.reshape(n_frames, -1)

    # Fit linear regression to each pixel’s trace
    regression_coeff_per_pixel = np.linalg.pinv(X) @ movie_flat

    # Remove fitted trends from each trace
    residuals = movie_flat - X @ regression_coeff_per_pixel
    movie_clean = residuals.reshape(movie.shape)

    if scale_images:
        n_regressors = X.shape[1]
        scale_images = regression_coeff_per_pixel.reshape(n_regressors, n_row, n_col)

        fig, axes = plt.subplots(1, n_regressors, figsize=(12, 4))
        for i in range(n_regressors):
            ax = axes[i]
            im = ax.imshow(scale_images[i], cmap='gray')
            ax.set_title(f"Regressor {i + 1}")
            fig.colorbar(im, ax=ax)

        fig.suptitle("Scale Images")
        plt.show()

    return movie_clean

In [None]:
def choose_low_freq_to_filter(movie, initial_cutoff=3.0):
    intensity = compute_intensity(movie)
    t = np.arange(len(intensity))
    freq = np.fft.fftfreq(len(intensity), d=1/SAMPLING_RATE)
    freq_half = freq[:len(freq)//2]

    # Set up figure and axes
    fig, axs = plt.subplots(2, 1, figsize=(14, 8))
    plt.subplots_adjust(bottom=0.25)

    ax_psd = axs[0]
    ax_time = axs[1]

    # Initial filtered signal and PSD
    intensity_high_pass = low_pass_filter(intensity, low_freq_to_filter=initial_cutoff)
    intensity_psd = compute_psd(intensity)
    noise_psd = compute_psd(intensity_high_pass)

    # Plot PSD
    l1, = ax_psd.semilogy(freq_half, intensity_psd[:len(freq)//2], label="Raw")
    l2, = ax_psd.semilogy(freq_half, noise_psd[:len(freq)//2], label=f"Filtered @ {initial_cutoff}Hz")
    ax_psd.set_title('Intensity Power Spectrum')
    ax_psd.set_xlabel("Frequency (Hz)")
    ax_psd.set_ylabel("Power")
    ax_psd.legend()

    # Plot intensity traces
    l3, = ax_time.plot(t, intensity, 'k-', label='Original Intensity')
    l4, = ax_time.plot(t, intensity_high_pass, 'b-', label=f'High Pass Intensity')
    ax_time.set_title('Mean Intensity over time')
    ax_time.set_xlabel('Frame')
    ax_time.set_ylabel('Mean Intensity')
    ax_time.legend()

    # Slider setup
    ax_slider = plt.axes((0.25, 0.05, 0.5, 0.03))
    slider = Slider(ax_slider, 'Low freq cutoff (Hz)', 0.1, 400.0, valinit=initial_cutoff, valstep=0.1)

    def update(val):
        low_freq = slider.val
        filtered = low_pass_filter(intensity, low_freq_to_filter=low_freq)
        new_psd = compute_psd(filtered)

        l2.set_ydata(new_psd[:len(freq)//2])
        l2.set_label(f"Filtered @ {low_freq:.1f}Hz")
        ax_psd.legend()

        l4.set_ydata(filtered)
        l4.set_label(f'High Pass Intensity @ {low_freq:.1f}Hz')
        ax_time.legend()

        fig.canvas.draw_idle()

    slider.on_changed(update)
    plt.show()

    return slider.val

In [None]:
def clean_power_spectrum_noise(movie):
    low_freq_cutoff = choose_low_freq_to_filter(movie)
    intensity = compute_intensity(movie)
    intensity_high_pass = low_pass_filter(intensity, low_freq_to_filter=low_freq_cutoff)

    # Remove the first few frames where the filter doesn't work
    drop_frames = 10
    movie_clean = movie[drop_frames:]
    intensity_high_pass = intensity_high_pass[drop_frames:]

    # Regress out intensity, linear drift, and quadratic drift from each pixel's time trace
    movie_clean = regress_out_poly2(movie_clean, intensity=intensity_high_pass, scale_images=False)

    return movie_clean

In [None]:
def clean_movie_pipeline(movie_raw):
    movie_clean = clean_intensity_drops(movie_raw)
    movie_clean = clean_power_spectrum_noise(movie_clean)

    intensity_raw = compute_intensity(movie_raw)
    intensity_clean = compute_intensity(movie_clean)

    plt.figure(figsize=(12, 4))
    plt.plot(intensity_raw, label="Raw")
    plt.plot(intensity_clean, label="Drift Removed")
    plt.title("Drift Removal from Mean Intensity")
    plt.xlabel("Frame")
    plt.ylabel("Mean Intensity")
    plt.legend()
    plt.tight_layout()
    plt.show()

    std_intensity = np.std(intensity_clean)
    print(f"Standard deviation of corrected trace: {std_intensity:.4f}")

    return movie_clean

#### Clean Data

In [None]:
movie = clean_movie_pipeline(movie)