In [None]:
import numpy as np
import warnings
from IPython.display import Audio, display
from scipy.io import wavfile
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import windows
import librosa, numpy as np
hann = signal.windows.hann

grain_size_ms = 250
spray_ms = 10
feedback = 0.35
mix = 1.0

og_len = 5000 # 5 seconds
channels = 2  # Stereo audio
sr = 44100 # stream rate

audio_path = "../../resources/vocals/withers1.wav"

warnings.simplefilter("ignore", wavfile.WavFileWarning)
sr_loaded, y = wavfile.read(audio_path)

# Convert to float32 and shape to (channels, samples)
waveform = y.T.astype(np.float32) / np.max(np.abs(y))  # normalize
num_samples = waveform.shape[1]

sr = sr_loaded
grain_size = int(sr * grain_size_ms/1000)
launch_interval = int(grain_size * 0.25)
spray = int(sr * spray_ms/1000)
window = hann(grain_size)

In [None]:
def find_optimal_grain_start_autocorr(x, target_pos, grain_size, search_range, min_period=20, max_period=500):
    """
    Find optimal grain start position using autocorrelation.
    Aligns grain boundaries with signal periodicity to reduce clicks/artifacts.
    
    Args:
        x: Input signal (1D array, single channel)
        target_pos: Target position to search around
        grain_size: Size of the grain
        search_range: How far to search (±samples)
        min_period: Minimum period to detect (samples)
        max_period: Maximum period to detect (samples)
    
    Returns:
        Optimal start position
    """
    N = len(x)
    
    # Clamp search bounds
    search_start = max(0, target_pos - search_range)
    search_end = min(N - grain_size - 1, target_pos + search_range)
    
    if search_start >= search_end:
        return max(0, min(target_pos, N - grain_size - 1))
    
    # Extract analysis region
    analysis_start = max(0, search_start - max_period)
    analysis_end = min(N, search_end + grain_size + max_period)
    segment = x[analysis_start:analysis_end]
    
    if len(segment) < max_period * 2:
        return target_pos
    
    # Compute autocorrelation to find local period
    autocorr = np.correlate(segment[:max_period*2], segment[:max_period*2], mode='full')
    autocorr = autocorr[len(autocorr)//2:]  # Take positive lags only
    
    # Find first peak after min_period (fundamental period)
    if len(autocorr) > max_period:
        autocorr_search = autocorr[min_period:max_period]
        if len(autocorr_search) > 0:
            local_period = min_period + np.argmax(autocorr_search)
        else:
            local_period = (min_period + max_period) // 2
    else:
        local_period = (min_period + max_period) // 2
    
    # Search for zero-crossing or peak near target that aligns with period
    best_pos = target_pos
    best_score = -np.inf
    
    for pos in range(search_start, search_end):
        if pos + grain_size >= N:
            break
            
        # Score based on:
        # 1. Proximity to zero crossing (reduces clicks)
        # 2. Alignment with detected period
        # 3. Smooth envelope at grain boundaries
        
        # Zero-crossing score (prefer starting near zero)
        zero_score = 1.0 / (1.0 + abs(x[pos]))
        
        # Period alignment score
        period_offset = pos % local_period
        period_score = 1.0 - (min(period_offset, local_period - period_offset) / (local_period / 2))
        
        # Boundary smoothness (low derivative at start/end)
        if pos > 0 and pos + grain_size < N - 1:
            start_deriv = abs(x[pos] - x[pos - 1])
            end_deriv = abs(x[pos + grain_size] - x[pos + grain_size - 1])
            smooth_score = 1.0 / (1.0 + start_deriv + end_deriv)
        else:
            smooth_score = 0.5
        
        # Combined score
        score = zero_score * 0.3 + period_score * 0.5 + smooth_score * 0.2
        
        if score > best_score:
            best_score = score
            best_pos = pos
    
    return best_pos

def best_offset(grainA, grainB):
    """
    Compute cross-correlation-based optimal offset between grainA and grainB.
    Both grains are 1D numpy arrays (float32).
    """
    # step 1: ensure equal length
    N = len(grainA)
    assert len(grainB) == N

    # step 2: rectangular window (just leave as-is)
    # grainA_win = grainA
    # grainB_win = grainB

    # step 3: zero-pad both grains to length 2N (or more)
    padN = 2 * N
    A_padded = np.pad(grainA, (0, padN - N))
    B_padded = np.pad(grainB, (0, padN - N))

    # step 4: FFT
    FFT_A = np.fft.fft(A_padded)
    FFT_B = np.fft.fft(B_padded)

    # step 5: multiply FFT(A) * conj(FFT(B))
    cross_spec = FFT_A * np.conj(FFT_B)

    # step 6: inverse FFT
    corr = np.fft.ifft(cross_spec).real

    # step 7: find index of maximum correlation
    # (this is the optimal alignment offset)
    offset = np.argmax(corr)

    # convert offset into the range -N/2 .. +N/2 for convenience
    if offset > N:
        offset = offset - padN

    return offset

def grain_delay_stereo_coherent(
    x,
    sr,
    grain_ms=100,
    spray_ms=20,
    delay_ms=250,
    feedback=0.35,
    mix=1.0,
    density_hz=30,
    pitch_jitter_semitones=4.0,
    reverse_prob=0.2,
    stereo_spread=0.4,
    lp_cutoff_hz=10000.0,
    autocorr_search_ms=10.0       # Autocorrelation search range in ms
):
    """
    Phase-coherent granular delay with autocorrelation-based grain selection.
    
    Improvements over basic grain_delay_stereo:
    - Autocorrelation-based grain boundary selection reduces clicks/artifacts
    - Sinc interpolation for higher quality pitch shifting
    - Phase-coherent grain alignment
    
    Args:
        x: Input audio (channels, samples)
        sr: Sample rate
        grain_ms: Grain size in milliseconds
        spray_ms: Random position jitter
        delay_ms: Base delay time
        feedback: Feedback amount
        mix: Dry/wet mix
        density_hz: Grains per second

        reverse_prob: Probability of reversing a grain
        stereo_spread: Stereo width (0=mono, 1=wide)
        lp_cutoff_hz: Lowpass filter cutoff on wet signal
        use_sinc_resample: Use sinc interpolation (True) or linear (False)
        autocorr_search_ms: Search range for autocorrelation alignment
    """
    ch, N = x.shape
    
    # Convert mono signal for autocorrelation analysis
    x_mono = np.mean(x, axis=0) if ch > 1 else x[0]
    
    # Parameters
    grain_size = int(sr * grain_ms / 1000)
    window = hann(grain_size)
    spray = int(sr * spray_ms / 1000)
    base_delay = int(sr * delay_ms / 1000)
    launch_interval = max(1, int(sr / density_hz))
    autocorr_search = int(sr * autocorr_search_ms / 1000)
    
    # Period detection bounds (for typical audio: 80Hz - 2kHz)
    min_period = int(sr / 2000)  # 2kHz
    max_period = int(sr / 80)    # 80Hz
    
    # Delay buffer
    delay = np.zeros((ch, N + grain_size * 4), dtype=np.float32)
    
    # Lowpass filter
    if lp_cutoff_hz is not None:
        b, a = signal.butter(1, lp_cutoff_hz / (sr * 0.5), 'low')
    else:
        b, a = np.array([1.0]), np.array([1.0])
    
    rng = np.random.default_rng()
    
    # Grain loop
    for n in range(0, N - grain_size, launch_interval):
        out_start = n + base_delay
        out_end = out_start + grain_size
        
        if out_start >= delay.shape[1]:
            break
        
        # Clamp grain end to buffer
        usable = grain_size
        if out_end > delay.shape[1]:
            usable = delay.shape[1] - out_start
            out_end = delay.shape[1]
        
        # Target grain position with spray
        target_pos = np.clip(
            n + rng.integers(-spray, spray + 1),
            0, N - grain_size - 1
        )
        
        # === AUTOCORRELATION-BASED GRAIN SELECTION ===
        # Find optimal grain start aligned with signal periodicity
        gstart = find_optimal_grain_start_autocorr(
            x_mono, 
            target_pos, 
            grain_size, 
            autocorr_search,
            min_period,
            max_period
        )
        gend = gstart + grain_size
        
        # Extract grain
        grain = x[:, gstart:gend].copy()

        # ===
        # Find best location for window on grainA, grainB
        # Only align if delay buffer has content (not silent)
        ref_start = out_start
        ref_end = min(ref_start + grain_size, delay.shape[1])
        
        if ref_end - ref_start == grain_size:
            delay_ref = delay[:, ref_start:ref_end].mean(axis=0)
            grain_mono = grain.mean(axis=0)
            
            # Only correlate if delay buffer has actual content (not silence)
            if np.max(np.abs(delay_ref)) > 1e-6:
                offset = best_offset(delay_ref, grain_mono)
                
                # roll grain by offset for alignment
                if offset != 0:
                    grain = np.roll(grain, offset, axis=1)
        # ===
        
        # Pitch modulation
        if pitch_jitter_semitones != 0:
            semis = rng.uniform(-pitch_jitter_semitones, pitch_jitter_semitones)
            pitch_ratio = 2.0 ** (semis / 12.0)
        else:
            pitch_ratio = 1.0
        
        # === PHASE-COHERENT RESAMPLING ===
        if pitch_ratio != 1.0:
            # Fallback to linear interpolation
            idx = np.arange(usable) * pitch_ratio
            idx = np.clip(idx, 0, grain_size - 2)
            idx0 = idx.astype(np.int32)
            frac = idx - idx0
            pitched = (1 - frac)[None, :] * grain[:, idx0] + frac[None, :] * grain[:, idx0 + 1]
        else:
            pitched = grain[:, :usable]
        
        # Reverse grain
        if reverse_prob > 0 and rng.random() < reverse_prob:
            pitched = pitched[:, ::-1].copy()
        
        # Stereo pan
        if stereo_spread > 0 and ch == 2:
            pan = rng.uniform(-stereo_spread, stereo_spread)
            angle = (pan + 1) * 0.25 * np.pi
            L, R = np.cos(angle), np.sin(angle)
            pitched[0] *= L
            pitched[1] *= R
        
        # Window and add to delay buffer
        windowed = pitched[:, :usable] * window[:usable]
        delay[:, out_start:out_end] += windowed
        delay[:, out_start:out_end] += feedback * delay[:, out_start:out_end]
    
    # Filter and mix
    wet = signal.lfilter(b, a, delay[:, :N], axis=1).astype(np.float32)
    out = (1 - mix) * x + mix * wet
    
    # Normalize
    peak = np.max(np.abs(out))
    if peak > 1:
        out /= peak
    
    return out.astype(np.float32)

#=============
# Beam Coherent example

def fast_coherent_grain_start(
    x_mono,
    target_pos,
    grain_size,
    search_radius,
    small_win=32,
    corr_threshold=0.1
):
    """
    Fast, real-time-friendly coherent grain alignment:
    - NO FFT
    - Searches only +/- search_radius (e.g. 8 samples)
    - Uses a small local window (32 samples)
    - Vectorized correlation
    """
    N = len(x_mono)

    # clamp window
    small_win = min(small_win, grain_size)
    base = target_pos
    if base + small_win >= N:
        return target_pos

    # reference window: the first part of the grain we WANT to use
    ref = x_mono[base:base + small_win]
    ref = ref - ref.mean()

    # candidate offsets
    search = np.arange(-search_radius, search_radius + 1)
    cand_positions = base + search

    # clamp to bounds
    cand_positions = cand_positions[
        (cand_positions >= 0) & (cand_positions + small_win < N)
    ]

    if len(cand_positions) == 0:
        return target_pos

    # Extract all candidate windows at once (vectorized)
    # shape: (num_candidates, small_win)
    cands = np.stack([x_mono[p:p + small_win] for p in cand_positions], axis=0)
    cands = cands - cands.mean(axis=1, keepdims=True)

    # Dot-product similarity
    # (Equivalent to correlation for zero-mean windows)
    scores = np.sum(ref * cands, axis=1)

    # normalize by candidate window energy (optional)
    energies = np.sum(cands * cands, axis=1) + 1e-9
    scores /= energies

    best_idx = np.argmax(scores)
    best_score = scores[best_idx]

    # Guard rail: only accept alignment if score is high enough
    if best_score < corr_threshold:
        return target_pos

    # Return adjusted grain start
    return int(cand_positions[best_idx])

def grain_delay_stereo_beam_coherent(
    x,
    sr,
    grain_ms=100,
    spray_ms=20,
    delay_ms=250,
    feedback=0.35,
    mix=1.0,
    density_hz=30,
    pitch_jitter_semitones=4.0,
    reverse_prob=0.2,
    stereo_spread=0.4,
    lp_cutoff_hz=10000.0,
    autocorr_search_ms=6.0,    # small search window around target
    coherence_strength=1.0,    # 0..1, scales search radius
):
    """
    BEAM-style 'coherent' granular delay:
      - Grain positions are nudged toward locally periodic regions
      - Uses autocorrelation on the *input* (mono) only
      - Guard rails prevent crazy shifts on noisy/chaotic material

    Args:
      x: np.ndarray (channels, N)
      sr: sample rate
      grain_ms: grain size in ms
      spray_ms: random position jitter (before coherence)
      delay_ms: base delay
      feedback: feedback amount
      mix: dry/wet
      density_hz: grains per second
      pitch_jitter_semitones: random pitch per grain
      reverse_prob: probability of reversing a grain
      stereo_spread: stereo width
      lp_cutoff_hz: lowpass on wet output
      autocorr_search_ms: +/- range around target to search for coherence
      autocorr_threshold: min correlation peak to accept adjusted start
      coherence_strength: 0..1 scaling of search radius (0 = off)
    """
    ch, N = x.shape

    # mono analysis signal
    x_mono = x.mean(axis=0) if ch > 1 else x[0]

    # core params
    grain_size = int(sr * grain_ms / 1000.0)
    grain_size = max(8, grain_size)
    window = np.hanning(grain_size).astype(np.float32)

    spray = int(sr * spray_ms / 1000.0)
    base_delay = int(sr * delay_ms / 1000.0)
    launch_interval = max(1, int(sr / density_hz))

    # coherent search radius in samples
    max_search = int(sr * autocorr_search_ms / 1000.0)
    search_radius = int(max_search * float(np.clip(coherence_strength, 0.0, 1.0)))

    # delay buffer
    delay = np.zeros((ch, N + grain_size * 4), dtype=np.float32)

    # wet LPF
    if lp_cutoff_hz is not None:
        b, a = signal.butter(1, lp_cutoff_hz / (sr * 0.5), 'low')
    else:
        b, a = np.array([1.0]), np.array([1.0])

    rng = np.random.default_rng()

    for n in range(0, N - grain_size, launch_interval):
        out_start = n + base_delay
        out_end = out_start + grain_size

        if out_start >= delay.shape[1]:
            break

        usable = grain_size
        if out_end > delay.shape[1]:
            usable = delay.shape[1] - out_start
            out_end = delay.shape[1]

        # base position + spray
        target_pos = np.clip(
            n + (rng.integers(-spray, spray + 1) if spray > 0 else 0),
            0, N - grain_size - 1
        )

        # === COHERENT START (WITH GUARD RAILS) ===
        if search_radius > 0:
            gstart = fast_coherent_grain_start(
                x_mono,
                target_pos,
                grain_size,
                search_radius=8,       # VERY small, very fast
                small_win=32,          # tiny local correlation window
                corr_threshold=0.12,   # reject weak matches
            )
        else:
            gstart = target_pos

        gend = gstart + grain_size

        # extract grain
        grain = x[:, gstart:gend].astype(np.float32)

        # pitch jitter
        if pitch_jitter_semitones != 0.0:
            semis = rng.uniform(-pitch_jitter_semitones, pitch_jitter_semitones)
            pitch_ratio = 2.0 ** (semis / 12.0)
        else:
            pitch_ratio = 1.0

        if pitch_ratio != 1.0:
            idx = np.arange(usable) * pitch_ratio
            idx = np.clip(idx, 0, grain_size - 2)
            idx0 = idx.astype(np.int32)
            frac = idx - idx0
            pitched = (1.0 - frac)[None, :] * grain[:, idx0] + frac[None, :] * grain[:, idx0 + 1]
        else:
            pitched = grain[:, :usable]

        # reverse
        if reverse_prob > 0.0 and rng.random() < reverse_prob:
            pitched = pitched[:, ::-1].copy()

        # stereo spread
        if stereo_spread > 0.0 and ch == 2:
            pan = rng.uniform(-stereo_spread, stereo_spread)
            angle = (pan + 1.0) * 0.25 * np.pi
            L, R = np.cos(angle), np.sin(angle)
            pitched[0] *= L
            pitched[1] *= R

        # apply window
        windowed = pitched[:, :usable] * window[:usable]

        # accumulate into delay with feedback
        delay[:, out_start:out_end] += windowed
        delay[:, out_start:out_end] += feedback * delay[:, out_start:out_end]

    wet = signal.lfilter(b, a, delay[:, :N], axis=1).astype(np.float32)
    out = (1.0 - mix) * x + mix * wet

    peak = np.max(np.abs(out))
    if peak > 1.0:
        out /= peak

    return out.astype(np.float32)

In [None]:
print(f"Original:")
display(Audio(waveform, rate=sr))

processed_coherent = grain_delay_stereo_coherent(
    waveform,
    sr,
    grain_ms=25,
    spray_ms=5,
    delay_ms=150,
    feedback=0.5,
    mix=1.0,
    density_hz=150,
    reverse_prob=0.1,
    pitch_jitter_semitones=0.3,  # Now with high-quality pitch shifting
    stereo_spread=0.9,
    lp_cutoff_hz=9000,
    autocorr_search_ms=6.0      # Search ±10ms for optimal grain boundaries
)
print("Coherent with similar window search:")
display(Audio(processed_coherent, rate=sr))

processed_beam_coherent = grain_delay_stereo_beam_coherent(
    waveform,
    sr,
    grain_ms=25,
    spray_ms=5,
    delay_ms=150,
    feedback=0.5,
    mix=1.0,
    density_hz=150,
    pitch_jitter_semitones=0.3,
    reverse_prob=0.1,
    stereo_spread=0.9,
    lp_cutoff_hz=9000,
    autocorr_search_ms=6.0,
    coherence_strength=0.7,
)

print("Beam Coherent:")
display(Audio(processed_beam_coherent, rate=sr))
