In [None]:
import numpy as np
import warnings
from IPython.display import Audio, display
from scipy.io import wavfile
import matplotlib.pyplot as plt
from scipy.signal import find_peaks


def envelope_follower(signal, fs, attack=0.001, release=0.05):
    env = np.zeros_like(signal)
    alpha_a = np.exp(-1.0 / (fs * attack))
    alpha_r = np.exp(-1.0 / (fs * release))
    for n, s in enumerate(np.abs(signal)):
        if n == 0:
            env[n] = s
        else:
            if s > env[n-1]:
                env[n] = alpha_a * env[n-1] + (1 - alpha_a) * s
            else:
                env[n] = alpha_r * env[n-1] + (1 - alpha_r) * s
    return env

def displace_transients(signal, peaks, shift_samples=200, win_size=400):
    out = np.copy(signal)
    fade = np.hanning(win_size*2)  # symmetric Hann window for crossfade

    for p in peaks:
        start = max(0, p - win_size)
        end = min(len(signal), p + win_size)
        transient = signal[start:end]

        # Apply windowing to soften edges
        window = np.hanning(len(transient))
        transient_win = transient * window

        # Silence original region (crossfade out)
        out[start:end] *= (1 - window)

        # Paste shifted transient (crossfade in)
        new_start = max(0, start + shift_samples)
        new_end = min(len(out), new_start + len(transient_win))
        out[new_start:new_end] += transient_win[:new_end-new_start]

    return out

def mix(dry, wet, wet_percent=0.5):
    return (1 - wet_percent) * dry + wet_percent * wet

og_len = 5000 # 5 seconds
channels = 2  # Stereo audio

audio_path = "../../resources/audio1.wav"
warnings.simplefilter("ignore", wavfile.WavFileWarning)
sr, y = wavfile.read(audio_path)

# Convert to float32 and normalize
y = y.astype(np.float32)
y = y / np.max(np.abs(y))

# Convert to mono if stereo
if y.ndim > 1:
    y = np.mean(y, axis=1)

fs = sr
x = y

envelope = envelope_follower(x, fs)

peaks, _ = find_peaks(
    envelope, 
    distance=int(fs*0.02),
    height=np.max(envelope)*0.05
)

# Try shifting forward (positive) or backward (negative)
y_shifted = displace_transients(x, peaks, shift_samples=400)
y_mixed = mix(x, y_shifted, wet_percent=0.5)

# --- Plot for sanity ---
plt.figure(figsize=(12,4))
plt.plot(x[:20000], label="Original")
plt.plot(y_mixed[:20000], label="Displaced")
plt.legend()
plt.title("Transient Displacer (Proof of Concept)")
plt.show()

# --- Listen ---
print("Original:")
display(Audio(x, rate=fs))
print("Displaced:")
display(Audio(y_mixed, rate=fs))
