In [None]:
import humanize

from astrocast.preparation import Delta, IO
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
test_path = Path("./quickstart.h5")
assert test_path.exists()

pixels = [(167, 153), (199, 306), (214, 109), (337, 72)]

In [None]:
io = IO()
inf = io.load(test_path, lazy=False, loc="inf/ch0")
inf.shape

In [None]:
delta = Delta(inf)
df = delta.run(window=10, method="dF", compute=True)

In [None]:
fig = delta.plot(pixels=pixels, show_original=True, separate_panels=True, twin_y_axis=True, figsize=(20, 10))

In [None]:
from scipy import signal

In [None]:
import random
import scipy.ndimage as ndi

sigma = 2
radius = 3
inf_ = ndi.gaussian_filter(inf, sigma=sigma, radius=radius)

fig, axx = plt.subplots(2, len(pixels), figsize=(25, 4))
fig_2, axx_2 = plt.subplots(2, len(pixels), figsize=(25, 4))
for ii, arr in enumerate([inf, inf_]):
    for i, (x, y) in enumerate(pixels):
        ax = axx[ii, i]
        ax_2 = axx_2[ii, i]  # specific axis for fig_2.
        xy = arr[:, x, y]
        
        peak_x, res = signal.find_peaks(xy, prominence=0.1, wlen=100, distance=10, width=5, rel_height=0.9)
        ax.plot(xy, linestyle="--", color="black")
        for left, right in list(zip(res["left_ips"], res["right_ips"])):
            random_color = "#" + ''.join([random.choice('0123456789ABCDEF') for _ in range(6)])
            ax.axvspan(left, right, color=random_color, alpha=0.3)
        
        # create a copy of the xy trace and remove the detected peaks
        xy_copy = np.copy(xy)
        for left, right in list(zip(res["left_ips"], res["right_ips"])):
            xy_copy[int(left):int(
                    right)] = np.nan  # If left and right are not integers, they are converted to the nearest valid integer.
        
        # interpolate the removed sections based on the remaining points
        x = np.where(~np.isnan(xy_copy))[0]
        y = xy_copy[x]
        f = np.interp(range(len(xy_copy)), x, y)
        
        # plot the new interpolated xy_interpolated in fig_2
        ax_2.plot(f, linestyle="--", color="black")



In [None]:
from scipy.interpolate import CubicSpline, Rbf
from numpy.polynomial.polynomial import Polynomial
from scipy.interpolate import pade  # works very poorly
import seaborn as sns

interpolations = {
    "Linear Interpolation":                      np.interp,
    "Cubic Spline":                              lambda x, xp, fp: CubicSpline(xp, fp)(x),
    "Polynomial Interpolation":                  lambda x, xp, fp: Polynomial.fit(xp, fp, deg=4)(x),
    # "Padé Approximation":                        lambda x, xp, fp: (
    #         pade(xp, 2)[0](x) / pade(xp, 2)[1](x)),  # adjust degree as needed
    "Radial basis function (RBF) interpolation": lambda x, xp, fp: Rbf(xp, fp)(x)
    }

fig, axx = plt.subplots(len(pixels), 1, figsize=(20, 3 * len(pixels)))
colors = sns.color_palette("hls", len(interpolations))

for i, (x, y) in enumerate(pixels):
    
    ax = axx[i]
    xy = inf_[:, x, y]
    ax.plot(xy, color="black", linestyle="-")
    
    peak_x, res = signal.find_peaks(xy, prominence=0.1, wlen=100, distance=10, width=5, rel_height=0.9)
    
    # create a copy of the xy trace and remove the detected peaks
    xy_copy = np.copy(xy)
    for left, right in zip(res["left_ips"], res["right_ips"]):
        # If left and right are not integers, they are converted to nearest valid integers.
        xy_copy[int(left):int(right)] = np.nan
    
    # get the valid (not nan) points from copy 
    valid_x = np.where(~np.isnan(xy_copy))[0]
    valid_y = xy_copy[valid_x]
    
    for ii, (name, interp) in enumerate(interpolations.items()):
        f_xy = interp(range(len(xy_copy)), valid_x, valid_y)
        
        ax.plot(f_xy, linestyle="--", color=colors[ii], label=name)
    
    ax.legend()

In [None]:
from scipy.interpolate import RBFInterpolator


def get_interpolated(x_obs, y_obs, X, neighbors=None, smoothing=0.0, kernel='thin_plate_spline', epsilon=None,
                     degree=None):
    
    x_obs = np.expand_dims(x_obs, 1)
    y_obs = np.expand_dims(y_obs, 1)
    X = np.expand_dims(X, 1)
    
    Y = RBFInterpolator(x_obs, y_obs, neighbors=neighbors, smoothing=smoothing, kernel=kernel, epsilon=epsilon,
                        degree=degree)(X)
    Y = np.squeeze(Y)
    return Y


fig, axx = plt.subplots(len(pixels), 2, figsize=(20, 3 * len(pixels)))
colors = sns.color_palette("hls", len(interpolations))

for i, (x, y) in enumerate(pixels):
    
    ax0, ax1 = axx[i, :]
    xy = inf[:, x, y]
    xy_smooth = inf_[:, x, y]
    
    # find peaks 
    peak_x, res = signal.find_peaks(xy_smooth, prominence=0.1, wlen=100, distance=10, width=5, rel_height=0.95)
    
    # create a copy of the xy trace and remove the detected peaks
    xy_copy = np.copy(xy_smooth)
    for left, right in zip(res["left_ips"], res["right_ips"]):
        # If left and right are not integers, they are converted to nearest valid integers.
        xy_copy[int(left):int(right)] = np.nan
    
    # get x indices
    X = range(len(xy_copy))
    
    # get the valid (not nan) points from copy 
    valid_x = np.where(~np.isnan(xy_copy))[0]
    valid_y = xy_copy[valid_x]
    
    y_interpolated = get_interpolated(valid_x, valid_y, X, neighbors=50, smoothing=1)
    
    # plot curves
    ax0.plot(xy_smooth, color="black", linestyle="-", alpha=0.6)
    ax0.plot(xy, color="black", linestyle="--", alpha=0.6)
    ax0.plot(y_interpolated, color="red", linestyle="--")
    ax0.grid(False)
    
    # plot delta
    ax1t = ax1.twinx()
    ax1t.grid(False)
    ax1t.plot(xy, linestyle="-", color="black", alpha=0.5)
    
    dxy = xy - y_interpolated
    ax1.plot(dxy, linestyle="--", color="black")
    ax1.grid(False)


In [None]:
# %timeit inf2[:, x, y] = get_subtracted(inf[:, x, y], inf_[:, x, y], neighbors=50, kernel="thin_plate_spline")

In [None]:
# %timeit inf2[:, x, y] = get_subtracted(inf[:, x, y], inf_[:, x, y], neighbors=50, kernel="thin_plate_spline", smoothing=1)

In [None]:
from tqdm import tqdm
from scipy.interpolate import RBFInterpolator


def get_interpolated(x_obs, y_obs, X, neighbors=None, smoothing=0.0, kernel='thin_plate_spline', epsilon=None,
                     degree=None):
    
    x_obs = np.expand_dims(x_obs, 1)
    y_obs = np.expand_dims(y_obs, 1)
    X = np.expand_dims(X, 1)
    
    Y = RBFInterpolator(x_obs, y_obs, neighbors=neighbors, smoothing=smoothing, kernel=kernel, epsilon=epsilon,
                        degree=degree)(X)
    Y = np.squeeze(Y)
    return Y


def get_subtracted(xy, xy_smooth, prominence=0.1, wlen=100, distance=10, width=5, rel_height=0.95,
                   neighbors=50, kernel="thin_plate_spline", **kwargs):
    
    # find peaks 
    peak_x, res = signal.find_peaks(xy_smooth, prominence=prominence, wlen=wlen, distance=distance, width=width,
                                    rel_height=rel_height)
    
    # create a copy of the xy trace and remove the detected peaks
    xy_copy = np.copy(xy_smooth)
    for left, right in zip(res["left_ips"], res["right_ips"]):
        # If left and right are not integers, they are converted to nearest valid integers.
        xy_copy[int(left):int(right)] = np.nan
    
    # get x indices
    X = range(len(xy_copy))
    
    # get the valid (not nan) points from copy 
    valid_x = np.where(~np.isnan(xy_copy))[0]
    valid_y = xy_copy[valid_x]
    
    y_interpolated = get_interpolated(valid_x, valid_y, X, neighbors=neighbors, kernel=kernel, **kwargs)
    dxy = xy - y_interpolated
    
    return dxy


inf2 = np.zeros_like(inf)
dim1, dim2, dim3 = inf.shape
for x in tqdm(range(dim2)):
    for y in range(dim3):
        inf2[:, x, y] = get_subtracted(inf[:, x, y], inf_[:, x, y])

path = "inf_new.h5"
io = IO()
io.save(path, data=inf, loc="inf")
io.save(path, data=inf2, loc="delta")


In [None]:
def remove_peaks(arr, prominence=0.1, wlen=100, distance=10, width=5, rel_height=0.95):
    
    dim1, dim2, dim3 = arr.shape
    for x in tqdm(range(dim2)):
        for y in range(dim3):
            
            xy = arr[:, x, y]
            
            # find peaks 
            peak_x, res = signal.find_peaks(xy, prominence=prominence, wlen=wlen, distance=distance, width=width,
                                            rel_height=rel_height)
            
            # create a copy of the xy trace and remove the detected peaks
            for left, right in zip(res["left_ips"], res["right_ips"]):
                # If left and right are not integers, they are converted to nearest valid integers.
                xy[int(left):int(right)] = np.nan
            
            arr[:, x, y] = xy


inf2 = inf_.copy()
%time remove_peaks(inf2)


In [None]:
path = Path("inf_new.h5")
if path.exists():
    path.unlink()

io = IO()
io.save(path, data=inf2, loc="inf")

In [None]:
inf2 = inf_.copy()

from skimage.transform import rescale


def downsample_video(video_data, scale_factor):
    """
    Downsamples a video represented as a 3D numpy array (time, x, y).

    Args:
        video_data: 3D numpy array representing the video.
        scale_factor: Factor by which to downscale the video frames.

    Returns:
        3D numpy array representing the downsampled video.
    """
    downsampled_video = []
    
    for frame in video_data:
        # Rescale each frame
        rescaled_frame = rescale(frame, scale_factor, anti_aliasing=True)
        downsampled_video.append(rescaled_frame)
    
    return np.array(downsampled_video)


def rescale_3d(data, scale_factor):
    """
    Rescales a 3D data array in all three dimensions.

    Args:
        data: 3D numpy array (could represent video data with time as one of the dimensions).
        scale_factor: Tuple of three scaling factors for each dimension (x, y, z).

    Returns:
        3D numpy array representing the rescaled data.
    """
    # Rescale the data
    rescaled_data = rescale(data, scale_factor, anti_aliasing=True)
    
    return rescaled_data


# Example usage
# video_data is your original 3D array (time, x, y)
# scale_factor is the factor by which you want to downscale (e.g., 0.5 for half size)
inf_small = rescale_3d(inf_.copy(), scale_factor=0.25)
display(inf_small.shape)

inf_small_nan = inf_small.copy()
%time remove_peaks(inf_small_nan)

In [None]:
import logging
import time
from scipy.interpolate import RBFInterpolator
import numpy as np
import humanize


def get_interpolated(xyz_obs, values, xyz_new, neighbors=None, smoothing=None, kernel='thin_plate_spline', epsilon=None,
                     degree=None):
    """
    Interpolates the values at new 3D coordinates using RBF Interpolator.

    Args:
        xyz_obs: 3D coordinates of observed data points (shape: [n_points, 3]).
        values: Values at the observed data points (shape: [n_points]).
        xyz_new: 3D coordinates of points where interpolation is needed (shape: [n_new_points, 3]).
        neighbors, smoothing, kernel, epsilon, degree: Parameters for RBFInterpolator.

    Returns:
        Interpolated values at xyz_new.
    """
    interpolator = RBFInterpolator(xyz_obs, values, neighbors=neighbors, smoothing=smoothing, kernel=kernel,
                                   epsilon=epsilon, degree=degree)
    return interpolator(xyz_new)


# Assuming 'inf_clean' is your 3D dataset with peaks removed
# Flatten the spatial dimensions and create corresponding 3D coordinates
logging.info(f"creating coordinates")
coordinates = np.array(
        np.meshgrid(np.arange(inf_small.shape[0]), np.arange(inf_small.shape[1]), np.arange(inf_small.shape[2]),
                    indexing='ij')).reshape(3, -1).T
values = inf_small_nan.reshape(-1)

# Filter out the NaN values
logging.info(f"find NaN values")
valid_mask = ~np.isnan(values)
xyz_obs = coordinates[valid_mask]
values_obs = values[valid_mask]

# Interpolate over the entire 3D dataset
logging.info(f"interpolate values")
t0 = time.time()
interpolated_values = get_interpolated(xyz_obs, values_obs, coordinates, neighbors=50, smoothing=1,
                                       kernel='thin_plate_spline', epsilon=2)
display(humanize.naturaldelta(time.time() - t0))

# Reshape the interpolated values back to the original shape
logging.info(f"reshaping")
inf_interpolated = interpolated_values.reshape(inf_small_nan.shape)

# inf_interpolated now contains the interpolated data, replacing NaNs from inf_clean

path = Path("inf_new.h5")
if path.exists():
    path.unlink()

io = IO()
io.save(path, data=inf_small_nan, loc="inf")
io.save(path, data=inf_interpolated, loc="inf_int")
io.save(path, data=inf_small, loc="inf_ref")
io.save(path, data=inf_small - inf_interpolated, loc="delta")


In [None]:
path = Path("inf_new.h5")
if path.exists():
    path.unlink()

from skimage.transform import rescale, resize

delta = resize(inf_interpolated, output_shape=inf.shape, anti_aliasing=True)
delta = inf - delta

io = IO()
io.save(path, data=inf, loc="inf")
io.save(path, data=delta, loc="delta")

In [None]:
import importlib
import astrocast.preparation as prep

importlib.reload(prep)

delta_obj = prep.Delta(np.array(1))
subtr = delta_obj._subtract_delta_rbf(inf, max_chunk_size_mb=32)