In [None]:
import numpy as np
import warnings
from IPython.display import Audio, display
from scipy.io import wavfile
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import windows
import librosa, numpy as np
hann = signal.windows.hann

og_len = 5000 # 5 seconds
channels = 2  # Stereo audio
sr = 44100 # stream rate

audio_path = "../../resources/vocals/withers1.wav"

warnings.simplefilter("ignore", wavfile.WavFileWarning)
sr_loaded, y = wavfile.read(audio_path)

# Convert to float32 and shape to (channels, samples)
waveform = y.T.astype(np.float32) / np.max(np.abs(y))  # normalize
num_samples = waveform.shape[1]

print(f"Original:")
display(Audio(waveform, rate=sr))

In [None]:
def find_optimal_grain_start_autocorr(x, target_pos, grain_size, search_range, min_period=20, max_period=500):
    """
    Find optimal grain start position using autocorrelation.
    Aligns grain boundaries with signal periodicity to reduce clicks/artifacts.
    
    Args:
        x: Input signal (1D array, single channel)
        target_pos: Target position to search around
        grain_size: Size of the grain
        search_range: How far to search (±samples)
        min_period: Minimum period to detect (samples)
        max_period: Maximum period to detect (samples)
    
    Returns:
        Optimal start position
    """
    N = len(x)
    
    # Clamp search bounds
    search_start = max(0, target_pos - search_range)
    search_end = min(N - grain_size - 1, target_pos + search_range)
    
    if search_start >= search_end:
        return max(0, min(target_pos, N - grain_size - 1))
    
    # Extract analysis region
    analysis_start = max(0, search_start - max_period)
    analysis_end = min(N, search_end + grain_size + max_period)
    segment = x[analysis_start:analysis_end]
    
    if len(segment) < max_period * 2:
        return target_pos
    
    # Compute autocorrelation to find local period
    autocorr = np.correlate(segment[:max_period*2], segment[:max_period*2], mode='full')
    autocorr = autocorr[len(autocorr)//2:]  # Take positive lags only
    
    # Find first peak after min_period (fundamental period)
    if len(autocorr) > max_period:
        autocorr_search = autocorr[min_period:max_period]
        if len(autocorr_search) > 0:
            local_period = min_period + np.argmax(autocorr_search)
        else:
            local_period = (min_period + max_period) // 2
    else:
        local_period = (min_period + max_period) // 2
    
    # Search for zero-crossing or peak near target that aligns with period
    best_pos = target_pos
    best_score = -np.inf
    
    for pos in range(search_start, search_end):
        if pos + grain_size >= N:
            break
            
        # Score based on:
        # 1. Proximity to zero crossing (reduces clicks)
        # 2. Alignment with detected period
        # 3. Smooth envelope at grain boundaries
        
        # Zero-crossing score (prefer starting near zero)
        zero_score = 1.0 / (1.0 + abs(x[pos]))
        
        # Period alignment score
        period_offset = pos % local_period
        period_score = 1.0 - (min(period_offset, local_period - period_offset) / (local_period / 2))
        
        # Boundary smoothness (low derivative at start/end)
        if pos > 0 and pos + grain_size < N - 1:
            start_deriv = abs(x[pos] - x[pos - 1])
            end_deriv = abs(x[pos + grain_size] - x[pos + grain_size - 1])
            smooth_score = 1.0 / (1.0 + start_deriv + end_deriv)
        else:
            smooth_score = 0.5
        
        # Combined score
        score = zero_score * 0.3 + period_score * 0.5 + smooth_score * 0.2
        
        if score > best_score:
            best_score = score
            best_pos = pos
    
    return best_pos

def grain_delay_stereo_coherent(
    x,
    sr,
    grain_ms=100,
    spray_ms=20,
    delay_ms=250,
    feedback=0.35,
    mix=1.0,
    density_hz=30,
    pitch_jitter_semitones=4.0,
    reverse_prob=0.2,
    stereo_spread=0.4,
    lp_cutoff_hz=10000.0,
    autocorr_search_ms=10.0       # Autocorrelation search range in ms
):
    """
    Phase-coherent granular delay with autocorrelation-based grain selection.
    
    Improvements over basic grain_delay_stereo:
    - Autocorrelation-based grain boundary selection reduces clicks/artifacts
    - Sinc interpolation for higher quality pitch shifting
    - Phase-coherent grain alignment
    
    Args:
        x: Input audio (channels, samples)
        sr: Sample rate
        grain_ms: Grain size in milliseconds
        spray_ms: Random position jitter
        delay_ms: Base delay time
        feedback: Feedback amount
        mix: Dry/wet mix
        density_hz: Grains per second

        reverse_prob: Probability of reversing a grain
        stereo_spread: Stereo width (0=mono, 1=wide)
        lp_cutoff_hz: Lowpass filter cutoff on wet signal
        use_sinc_resample: Use sinc interpolation (True) or linear (False)
        autocorr_search_ms: Search range for autocorrelation alignment
    """
    ch, N = x.shape
    
    # Convert mono signal for autocorrelation analysis
    x_mono = np.mean(x, axis=0) if ch > 1 else x[0]
    
    # Parameters
    grain_size = int(sr * grain_ms / 1000)
    window = hann(grain_size)
    spray = int(sr * spray_ms / 1000)
    base_delay = int(sr * delay_ms / 1000)
    launch_interval = max(1, int(sr / density_hz))
    autocorr_search = int(sr * autocorr_search_ms / 1000)
    
    # Period detection bounds (for typical audio: 80Hz - 2kHz)
    min_period = int(sr / 2000)  # 2kHz
    max_period = int(sr / 80)    # 80Hz
    
    # Delay buffer
    delay = np.zeros((ch, N + grain_size * 4), dtype=np.float32)
    
    # Lowpass filter
    if lp_cutoff_hz is not None:
        b, a = signal.butter(1, lp_cutoff_hz / (sr * 0.5), 'low')
    else:
        b, a = np.array([1.0]), np.array([1.0])
    
    rng = np.random.default_rng()
    
    # Grain loop
    for n in range(0, N - grain_size, launch_interval):
        out_start = n + base_delay
        out_end = out_start + grain_size
        
        if out_start >= delay.shape[1]:
            break
        
        # Clamp grain end to buffer
        usable = grain_size
        if out_end > delay.shape[1]:
            usable = delay.shape[1] - out_start
            out_end = delay.shape[1]
        
        # Target grain position with spray
        target_pos = np.clip(
            n + rng.integers(-spray, spray + 1),
            0, N - grain_size - 1
        )
        
        # === AUTOCORRELATION-BASED GRAIN SELECTION ===
        # Find optimal grain start aligned with signal periodicity
        gstart = find_optimal_grain_start_autocorr(
            x_mono, 
            target_pos, 
            grain_size, 
            autocorr_search,
            min_period,
            max_period
        )
        gend = gstart + grain_size
        
        # Extract grain
        grain = x[:, gstart:gend].copy()
        
        # Pitch modulation
        if pitch_jitter_semitones != 0:
            semis = rng.uniform(-pitch_jitter_semitones, pitch_jitter_semitones)
            pitch_ratio = 2.0 ** (semis / 12.0)
        else:
            pitch_ratio = 1.0
        
        # === PHASE-COHERENT RESAMPLING ===
        if pitch_ratio != 1.0:
            # Fallback to linear interpolation
            idx = np.arange(usable) * pitch_ratio
            idx = np.clip(idx, 0, grain_size - 2)
            idx0 = idx.astype(np.int32)
            frac = idx - idx0
            pitched = (1 - frac)[None, :] * grain[:, idx0] + frac[None, :] * grain[:, idx0 + 1]
        else:
            pitched = grain[:, :usable]
        
        # Reverse grain
        if reverse_prob > 0 and rng.random() < reverse_prob:
            pitched = pitched[:, ::-1].copy()
        
        # Stereo pan
        if stereo_spread > 0 and ch == 2:
            pan = rng.uniform(-stereo_spread, stereo_spread)
            angle = (pan + 1) * 0.25 * np.pi
            L, R = np.cos(angle), np.sin(angle)
            pitched[0] *= L
            pitched[1] *= R
        
        # Window and add to delay buffer
        windowed = pitched[:, :usable] * window[:usable]
        delay[:, out_start:out_end] += windowed
        delay[:, out_start:out_end] += feedback * delay[:, out_start:out_end]
    
    # Filter and mix
    wet = signal.lfilter(b, a, delay[:, :N], axis=1).astype(np.float32)
    out = (1 - mix) * x + mix * wet
    
    # Normalize
    peak = np.max(np.abs(out))
    if peak > 1:
        out /= peak
    
    return out.astype(np.float32)

In [None]:
def spectral_grain_diffusion(
    x,
    sr,
    grain_ms=100,
    hop_ratio=0.25,           # Overlap ratio (0.25 = 75% overlap)
    blur_size=15,             # Magnitude blur kernel size (samples in freq domain)
    blur_iterations=2,        # Multiple blur passes for smoother diffusion
    spectral_tilt_db=0.0,     # Tilt spectrum (+db = brighter, -db = darker)
    freeze_thresh=0.0,        # Freeze bins below this magnitude (0 = off)
    preserve_transients=True, # Detect and preserve transient phase
    transient_thresh=2.0,     # Transient detection threshold (ratio)
    lowcut_hz=60.0,           # Preserve frequencies below this
    highcut_hz=16000.0,       # Preserve frequencies above this
    diffusion_amount=1.0,     # 0-1 blend of diffused vs original spectrum
    mix=1.0                   # Dry/wet mix
):
    """
    Spectral grain processing with BEAM-style magnitude diffusion.
    
    Blurs spectral magnitudes while preserving phase relationships,
    creating a diffuse, ambient texture that maintains temporal coherence.
    
    Args:
        x: Input audio (channels, samples)
        sr: Sample rate
        grain_ms: Analysis/synthesis grain size
        hop_ratio: Hop size as ratio of grain (lower = more overlap)
        blur_size: Frequency-domain blur kernel size
        blur_iterations: Number of blur passes
        spectral_tilt_db: Apply spectral tilt (+bright, -dark)
        freeze_thresh: Freeze low-magnitude bins
        preserve_transients: Keep transient attacks sharp
        transient_thresh: Transient detection sensitivity
        lowcut_hz: Preserve bass below this frequency
        highcut_hz: Preserve highs above this frequency
        diffusion_amount: Blend of diffused vs original spectrum
        mix: Final dry/wet mix
    """
    ch, N = x.shape
    
    # Grain parameters
    grain_size = int(sr * grain_ms / 1000)
    # Ensure even size for FFT
    if grain_size % 2 != 0:
        grain_size += 1
    
    hop_size = int(grain_size * hop_ratio)
    
    # Analysis/synthesis windows (sqrt-Hann for perfect reconstruction)
    window = np.sqrt(np.hanning(grain_size)).astype(np.float32)
    
    # Frequency bin indices for low/high cut
    freq_per_bin = sr / grain_size
    low_bin = int(lowcut_hz / freq_per_bin)
    high_bin = int(highcut_hz / freq_per_bin)
    num_bins = grain_size // 2 + 1
    
    # Spectral tilt curve (linear in dB across spectrum)
    if spectral_tilt_db != 0:
        tilt_curve = np.linspace(0, spectral_tilt_db, num_bins)
        tilt_curve = 10 ** (tilt_curve / 20)  # Convert dB to linear
    else:
        tilt_curve = np.ones(num_bins)
    
    # Create blur kernel
    blur_kernel = np.ones(blur_size, dtype=np.float32) / blur_size
    
    # Pad input for overlap-add
    pad_size = grain_size
    x_padded = np.pad(x, ((0, 0), (pad_size, pad_size)), mode='reflect')
    N_padded = x_padded.shape[1]
    
    # Output buffer
    out = np.zeros_like(x_padded, dtype=np.float32)
    window_sum = np.zeros(N_padded, dtype=np.float32)
    
    # Previous frame magnitude for transient detection
    prev_mag = [np.zeros(num_bins, dtype=np.float32) for _ in range(ch)]
    
    # Process each channel
    for c in range(ch):
        frame_idx = 0
        pos = 0
        
        while pos + grain_size <= N_padded:
            # Extract and window frame
            frame = x_padded[c, pos:pos + grain_size] * window
            
            # FFT
            spec = np.fft.rfft(frame)
            mag = np.abs(spec)
            phase = np.angle(spec)
            
            # Store original for blending
            orig_mag = mag.copy()
            
            # === TRANSIENT DETECTION ===
            is_transient = False
            if preserve_transients and frame_idx > 0:
                # Compare current magnitude to previous
                mag_ratio = np.mean(mag) / (np.mean(prev_mag[c]) + 1e-10)
                if mag_ratio > transient_thresh:
                    is_transient = True
            
            prev_mag[c] = mag.copy()
            
            # === MAGNITUDE BLUR (DIFFUSION) ===
            if not is_transient:
                diffused_mag = mag.copy()
                
                # Apply blur iterations
                for _ in range(blur_iterations):
                    diffused_mag = np.convolve(diffused_mag, blur_kernel, mode='same')
                
                # Preserve low frequencies (bass)
                if low_bin > 0:
                    diffused_mag[:low_bin] = orig_mag[:low_bin]
                
                # Preserve high frequencies
                if high_bin < num_bins:
                    diffused_mag[high_bin:] = orig_mag[high_bin:]
                
                # Blend original and diffused
                mag = orig_mag * (1 - diffusion_amount) + diffused_mag * diffusion_amount
            
            # === FREEZE LOW MAGNITUDES ===
            if freeze_thresh > 0:
                freeze_mask = orig_mag < freeze_thresh
                # Keep frozen bins at threshold level
                mag[freeze_mask] = freeze_thresh
            
            # === SPECTRAL TILT ===
            mag = mag * tilt_curve
            
            # Reconstruct with original phase (phase preservation)
            spec_out = mag * np.exp(1j * phase)
            
            # IFFT
            frame_out = np.fft.irfft(spec_out, n=grain_size)
            
            # Apply synthesis window and overlap-add
            out[c, pos:pos + grain_size] += frame_out * window
            window_sum[pos:pos + grain_size] += window ** 2
            
            pos += hop_size
            frame_idx += 1
    
    # Normalize by window sum
    window_sum[window_sum < 1e-8] = 1.0
    out = out / window_sum
    
    # Remove padding
    out = out[:, pad_size:pad_size + N]
    
    # Mix
    result = (1 - mix) * x + mix * out
    
    # Normalize if clipping
    peak = np.max(np.abs(result))
    if peak > 1.0:
        result /= peak
    
    return result.astype(np.float32)


In [None]:
grain_size_ms = 250
spray_ms = 10
feedback = 0.35
mix = 1.0

sr = sr_loaded
grain_size = int(sr * grain_size_ms/1000)
launch_interval = int(grain_size * 0.25)
spray = int(sr * spray_ms/1000)
window = hann(grain_size)

def grain_delay_stereo(
    x,
    sr,
    grain_ms=100,
    spray_ms=20,
    delay_ms=250,
    feedback=0.35,
    mix=1.0,
    density_hz=30,                     # grains per second
    pitch_jitter_semitones=4.0,        # random ± pitch
    reverse_prob=0.2,                  # 20% of grains reversed
    stereo_spread=0.4,                 # 0=mono, 1=wide random
    lp_cutoff_hz=10000.0               # filter wet signal
):
    ch, N = x.shape

    # grain size according to samplerate
    grain_size = int(sr * grain_ms / 1000)
    window = hann(grain_size)

    spray = int(sr * spray_ms / 1000)
    base_delay = int(sr * delay_ms / 1000)

    # overlap based on density
    launch_interval = max(1, int(sr / density_hz))   # launches/sec

    # delay buffer (bigger than input to allow long delay)
    delay = np.zeros((ch, N + grain_size * 4), dtype=np.float32)

    # lowpass on wet path
    if lp_cutoff_hz is not None:
        b, a = signal.butter(1, lp_cutoff_hz/(sr*0.5), 'low')
    else:
        b, a = np.array([1], float), np.array([1], float)

    rng = np.random.default_rng()

    # Grain loop
    for n in range(0, N - grain_size, launch_interval):

        out_start = n + base_delay
        out_end   = out_start + grain_size

        if out_start >= delay.shape[1]:
            break

        # clamp grain end to buffer end
        usable = grain_size
        if out_end > delay.shape[1]:
            usable = delay.shape[1] - out_start
            out_end = delay.shape[1]

        # Grain start/end
        gstart = np.clip(
            n + rng.integers(-spray, spray + 1),
            0, N - grain_size - 1
        )
        gend = gstart + grain_size

        # extract grain
        grain = x[:, gstart:gend]

        # Pitch modulation
        if pitch_jitter_semitones != 0:
            semis = rng.uniform(-pitch_jitter_semitones, pitch_jitter_semitones)
            pitch_ratio = 2.0 ** (semis / 12.0)
        else:
            pitch_ratio = 1.0

        # resample for pitch (linear interpolation)
        idx = np.arange(usable) * pitch_ratio
        idx = np.clip(idx, 0, grain_size-2)
        idx0 = idx.astype(np.int32)
        frac = idx - idx0

        pitched = (1-frac)[None,:] * grain[:,idx0] + frac[None,:] * grain[:,idx0+1]

        # Reverse grain according to probability
        if reverse_prob > 0 and rng.random() < reverse_prob:
            pitched = pitched[:, ::-1]

        # Stereo pan
        if stereo_spread > 0 and ch == 2:
            pan = rng.uniform(-stereo_spread, stereo_spread)
            angle = (pan + 1) * 0.25 * np.pi
            L = np.cos(angle)
            R = np.sin(angle)
            pitched[0] *= L
            pitched[1] *= R

        # Window
        windowed = pitched[:, :usable] * window[:usable]

        # Delay buffer
        delay[:, out_start:out_end] += windowed
        delay[:, out_start:out_end] += feedback * delay[:, out_start:out_end]

    # Filter and Mix
    wet = signal.lfilter(b, a, delay[:, :N], axis=1).astype(np.float32)
    out = (1 - mix) * x + mix * wet

    # normalize if needed
    peak = np.max(np.abs(out))
    if peak > 1:
        out /= peak

    return out.astype(np.float32)


def spectral_smear(x, sr, 
                   window_ms=15,           # window size
                   blur_amount=15,        # kernel size for blur
                   phase_randomize=0.7,   # 0-1, how much phase chaos
                   mix=1.0):
    
    ch, N = x.shape
    
    win = int(sr * window_ms / 1000)
    hop = win // 4  # 75% overlap
    window = np.hanning(win)
    
    pad = (-(N - win) % hop)
    if pad > 0:
        x = np.pad(x, ((0,0),(0,pad)), mode='constant')
    
    out = np.zeros_like(x, dtype=np.float32)
    rng = np.random.default_rng(42)  # consistent seed
    
    for c in range(ch):
        for i in range(0, x.shape[1] - win, hop):
            frame = x[c, i:i+win] * window
            
            spec = np.fft.rfft(frame)
            mag = np.abs(spec)
            phase = np.angle(spec)
            
            # Blur magnitudes
            if blur_amount > 1:
                blur_kernel = np.ones(blur_amount) / blur_amount
                mag = np.convolve(mag, blur_kernel, mode='same')
            
            # Randomize phase (leave low freqs alone)
            if phase_randomize > 0:
                cutoff_bin = int(0.2 * len(phase))  # keep bass intact
                random_phase = rng.uniform(-np.pi, np.pi, len(phase) - cutoff_bin)
                phase[cutoff_bin:] = (phase[cutoff_bin:] * (1 - phase_randomize) + 
                                      random_phase * phase_randomize)
            
            # Reconstruct
            spec_new = mag * np.exp(1j * phase)
            out_frame = np.fft.irfft(spec_new, n=win)
            
            out[c, i:i+win] += out_frame * window
    
    out = out[:, :N] / 1.5  # normalize for overlap
    
    return (1 - mix) * x[:, :out.shape[1]] + mix * out

def convolution_reverb(x, ir_path, mix=0.5, normalize=True):
    ch, N = x.shape
    
    # Load impulse response
    ir, ir_sr = librosa.load(ir_path, sr=None, mono=False)
    
    # Ensure IR is 2D
    if ir.ndim == 1:
        ir = ir[np.newaxis, :]  # mono IR -> (1, N)
    
    print(f"IR loaded: {ir.shape[0]} channels, {ir.shape[1]} samples ({ir.shape[1]/ir_sr:.2f}s)")
    
    # Resample IR if sample rates don't match
    if ir_sr != sr:
        print(f"Resampling IR from {ir_sr}Hz to {sr}Hz...")
        ir_resampled = []
        for c in range(ir.shape[0]):
            ir_resampled.append(librosa.resample(ir[c], orig_sr=ir_sr, target_sr=sr))
        ir = np.array(ir_resampled)
    
    # Match channel count
    ir_ch = ir.shape[0]
    
    wet = np.zeros((ch, N), dtype=np.float32)
    
    if ir_ch == 1 and ch == 2:
        # Mono IR -> apply to both channels
        for c in range(ch):
            wet[c] = signal.fftconvolve(x[c], ir[0], mode='same')
    
    elif ir_ch == 2 and ch == 2:
        # Stereo IR -> convolve each channel
        for c in range(ch):
            wet[c] = signal.fftconvolve(x[c], ir[c], mode='same')
    
    elif ir_ch == ch:
        # Matching channels
        for c in range(ch):
            wet[c] = signal.fftconvolve(x[c], ir[c], mode='same')
    
    else:
        # Fallback: use first IR channel for all
        print(f"Warning: IR has {ir_ch} channels, audio has {ch}. Using first IR channel.")
        for c in range(ch):
            wet[c] = signal.fftconvolve(x[c], ir[0], mode='same')
    
    # Mix dry/wet
    out = (1 - mix) * x + mix * wet
    
    # Normalize if needed
    if normalize:
        peak = np.max(np.abs(out))
        if peak > 1.0:
            out /= peak
            print(f"Normalized by {peak:.2f}x")
    
    return out.astype(np.float32)

# === PROCESS ===
processed = grain_delay_stereo(
    waveform,
    sr,
    grain_ms=75,
    spray_ms=15,
    delay_ms=75,
    feedback=1.0,
    mix=1.0,
    density_hz=200,

    reverse_prob=0.1,

    pitch_jitter_semitones=0.0,

    stereo_spread=0.0,

    lp_cutoff_hz=8000
)
print("Grains:")
display(Audio(processed, rate=sr))

# Apply to your grains:
smeared = spectral_smear(processed, sr, 
                         window_ms=2,
                         blur_amount=6,      # higher = more smear
                         phase_randomize=0.7, # higher = more chaos
                         mix=1.0)

print("Spectral Smear:")
display(Audio(smeared, rate=sr))

reverbed = convolution_reverb(processed, "../../resources/ir/1st_baptist_nashville_balcony.wav", mix=0.4)
print("Convolution Reverb:")
display(Audio(reverbed, rate=sr))

# === PHASE-COHERENT GRAIN PROCESSING ===
# Compare original vs phase-coherent granular processing

print("Phase-Coherent Grains (autocorrelation + sinc resampling):")
processed_coherent = grain_delay_stereo_coherent(
    waveform,
    sr,
    grain_ms=75,
    spray_ms=25,
    delay_ms=75,
    feedback=1.0,
    mix=1.0,
    density_hz=200,
    reverse_prob=0.1,
    pitch_jitter_semitones=0.0,  # Now with high-quality pitch shifting
    stereo_spread=0.3,
    lp_cutoff_hz=8000,
    autocorr_search_ms=10.0      # Search ±10ms for optimal grain boundaries
)
print("Processed coherent:")
display(Audio(processed_coherent, rate=sr))


diffused = spectral_grain_diffusion(
    waveform,
    sr,
    grain_ms=125,              # Grain size
    hop_ratio=0.25,           # 75% overlap
    blur_size=50,             # Blur kernel size (higher = more diffuse)
    blur_iterations=3,        # Multiple passes for smoother result
    spectral_tilt_db=-2.0,    # Slight darkness
    freeze_thresh=0.0,        # No freeze
    preserve_transients=False, # Keep attacks punchy
    transient_thresh=1.8,     # Transient sensitivity
    lowcut_hz=80.0,           # Preserve bass
    highcut_hz=12000.0,       # Preserve air
    diffusion_amount=0.95,    # Strong diffusion
    mix=1.0                   # Full wet
)
print("spectral_grain_diffusion:")
display(Audio(diffused, rate=sr))