In [None]:
import param_estimator
import matplotlib.pyplot as plt
import logging
import numpy as np
import scipy.stats

plt.style.use("dark_background")
logging.basicConfig(level=logging.INFO)
logging.getLogger("activation_learner").setLevel(logging.DEBUG)

In [None]:
# make synthetic data
tau = np.linspace(0, 1, 1000)
timeremap = 2 * tau + 8
timeremap += (np.random.rand(len(tau)) - 0.5) * 0.1
timeremap += 0.5 * np.sin(tau * 5)

volume = np.zeros_like(tau)
volume[500:800] = 1
volume[500:600] = np.linspace(0, 1, 100)
volume[800:850] = np.linspace(1, 0, 50)
volume += 0.8 * np.sin(tau * 10)
volume += (np.random.rand(len(tau)) - 0.5) * 0.5

plt.plot(tau, timeremap)
plt.plot(tau, volume)

In [None]:
def dynamic_threshold_moving_average(signal, window_size):
    threshold = np.convolve(
        np.abs(signal), np.ones(window_size) / window_size, mode="same"
    )
    return np.where(signal > threshold, 1, 0)


def estimate_highparams(tau, volume, timeremap, filter_size=0.1, plot=False):
    # normalize
    volume_norm = volume / volume.max()
    
    # median filter
    kernel_size = int(filter_size / (tau[1] - tau[0]))
    if kernel_size % 2 == 0:
        kernel_size += 1 # must be odd
    volume_filt = scipy.signal.medfilt(volume_norm, kernel_size=kernel_size)

    # threshold and find contiguous playing slices
    thresholded_volume_mask = np.ma.masked_array(volume_filt)
    thresholded_volume_mask[volume_filt > 0.5] = np.ma.masked
    playing_slices = np.ma.clump_masked(thresholded_volume_mask)

    # get the longest one
    longest_slice = max(playing_slices, key=lambda i: i.stop - i.start)

    # pad
    longest_slice_len = longest_slice.stop - longest_slice.start
    rough_start_idx = max(longest_slice.start - longest_slice_len // 2, 0)
    rough_stop_idx = min(longest_slice.stop + longest_slice_len // 2, len(volume) - 1)

    # define ideal gain curve function
    def ideal_gain(tau, tau0, tau1, tau2, tau3):
        return np.piecewise(
            tau,
            [
                tau < tau0,
                (tau >= tau0) & (tau < tau1),
                (tau >= tau1) & (tau < tau2),
                (tau >= tau2) & (tau < tau3),
                tau >= tau3,
            ],
            [
                lambda tau: 0,
                lambda tau: (1 - 0) / (tau1 - tau0) * (tau - tau0) + 0,
                lambda tau: 1,
                lambda tau: (0 - 1) / (tau3 - tau2) * (tau - tau2) + 1,
                lambda tau: 0,
            ],
        )

    # fit ideal gain to signal
    tstart, tstop = tau[rough_start_idx], tau[rough_stop_idx]
    p, e = scipy.optimize.curve_fit(
        ideal_gain,
        tau[longest_slice],
        volume[longest_slice],
        p0=[
            tstart,
            tstart + (tstop - tstart) * 1 / 3,
            tstart + (tstop - tstart) * 2 / 3,
            tstop,
        ],
        bounds=(tstart, tstop),
    )
    print(p)

    # calculate fade bounds and slopes
    fadein_start, fadein_stop, fadeout_start, fadeout_stop = p
    fadein_slope = -1 / (fadein_start - fadein_stop)
    fadeout_slope = 1 / (fadeout_start - fadeout_stop)

    print(f"{fadein_start=:.2f} {fadein_stop=:.2f} {fadein_slope=:.2f}")
    print(f"{fadeout_start=:.2f} {fadeout_stop=:.2f} {fadeout_slope=:.2f}")

    # use only the part of time remap where is fully playing
    mask = np.logical_and(tau > fadein_stop, tau < fadeout_start)
    timeremap_slope, timeremap_intercept, _, _, _ = scipy.stats.linregress(
        tau[mask], timeremap[mask]
    )

    # calculate track start time
    track_start = -timeremap_intercept / timeremap_slope
    print(f"{track_start=:.2f}")

    if plot:
        fig, axes = plt.subplots(2, 1)
        axes[0].plot(tau, volume_norm, label="norm", alpha=0.5)
        axes[0].plot(tau, volume_filt, label="filt")
        axes[0].plot(tau, ideal_gain(tau, *p), label="fit")
        axes[0].set_title("volume")
        axes[0].legend()

        axes[1].plot(tau, timeremap, label="input", alpha=0.5)
        axes[1].plot(tau[mask], timeremap[mask], label="masked")
        axes[1].plot(tau, timeremap_slope * tau + timeremap_intercept, label="fit")
        axes[1].set_title("time remap")
        axes[1].legend()
    else:
        fig = None

    return (
        track_start,
        fadein_start,
        fadein_stop,
        fadein_slope,
        fadeout_start,
        fadeout_stop,
        fadeout_slope,
        fig,
    )


(
    track_start,
    fadein_start,
    fadein_stop,
    fadein_slope,
    fadeout_start,
    fadeout_stop,
    fadeout_slope,
    fig,
) = estimate_highparams(tau, volume, timeremap, plot=True)
fig.show()