In [None]:
import numpy as np
import warnings
from IPython.display import Audio, display
from scipy.io import wavfile
import matplotlib.pyplot as plt
from scipy.signal import iirfilter, sosfilt, sosfilt_zi, butter

def plot_waveform(waveform, sr, other=None, channel=0, start=0, end=2000, show_fft=False):
    """
    Plot a waveform (and optionally a second one for comparison).
    
    Args:
        waveform (np.ndarray): Array of shape (channels, samples)
        sr (int): Sample rate
        other (np.ndarray, optional): Second waveform to compare, same shape as waveform
        channel (int): Which channel to plot (default=0)
        start (int): Start sample index
        end (int): End sample index
        show_fft (bool): Whether to also plot frequency spectra
    """
    # Time axis in seconds
    samples = waveform.shape[1]
    end = min(end, samples)
    t = np.arange(start, end) / sr

    plt.figure(figsize=(12, 5 if not show_fft else 10))

    # --- Time-domain plot ---
    plt.subplot(2 if show_fft else 1, 1, 1)
    plt.plot(t, waveform[channel, start:end], label="Original")
    if other is not None:
        plt.plot(t, other[channel, start:end], label="Processed", alpha=0.8)
    plt.title(f"Waveform (Channel {channel})")
    plt.xlabel("Time [s]")
    plt.ylabel("Amplitude")
    plt.legend()

    # --- Frequency-domain plot ---
    if show_fft:
        fft_orig = np.fft.rfft(waveform[channel])
        freqs = np.fft.rfftfreq(samples, 1/sr)
        plt.subplot(2, 1, 2)
        plt.semilogy(freqs, np.abs(fft_orig), label="Original")
        if other is not None:
            fft_proc = np.fft.rfft(other[channel])
            plt.semilogy(freqs, np.abs(fft_proc), label="Processed", alpha=0.5)
        plt.title("Frequency Spectrum")
        plt.xlabel("Frequency [Hz]")
        plt.ylabel("Magnitude")
        plt.legend()

    plt.tight_layout()
    plt.show()

# Original Source

In [None]:
og_len = 5000 # 5 seconds
channels = 2  # Stereo audio
sr = 44100 # stream rate
audio_path = "../resources/audio1.wav"

warnings.simplefilter("ignore", wavfile.WavFileWarning)
sr_loaded, y = wavfile.read(audio_path)  # y has shape (samples, channels) if stereo

# Convert to float32 and shape to (channels, samples)
waveform = y.T.astype(np.float32) / np.max(np.abs(y))  # normalize
num_samples = waveform.shape[1]
print(f"Original length: {num_samples} samples, Sample rate: {sr_loaded} Hz")
display(Audio(waveform, rate=sr))

plot_waveform(waveform, sr, other=waveform, channel=0, show_fft=True)


# Distortion with output LPF

In [None]:
distort_amount = 4  # Distortion amount

# Apply distortion via tanh function
waveform_distorted = np.tanh(distort_amount * waveform)

# TODO: add lpf

# Normalize to -1 < 0 < -1 to prevent clipping
max_val = np.max(np.abs(waveform_distorted))
if max_val > 1.0:
    waveform_distorted = waveform_distorted / max_val

display(Audio(waveform_distorted, rate=sr))
plot_waveform(waveform, sr, other=waveform_distorted, channel=0, show_fft=True)


# Distortion with pre/post filters

In [None]:
distort_amount = 2.0  # Distortion amount

cutoff_freq = 8000  # Cutoff frequency of filters
db_boost = distort_amount * 2  # Boost in dB
db_cut = -distort_amount * 2  # Cut in dB

# Convert dB to linear gain
boost = 10 ** (db_boost / 20)  
cut = 10 ** (db_cut / 20)

# high-shelf
def high_shelf_curve(freqs, cutoff, gain):
    return 1 + (gain - 1) / (1 + np.exp(-0.001*(freqs - cutoff)))

# FFT
fft_wave = np.fft.rfft(waveform, axis=-1)
freqs = np.fft.rfftfreq(waveform.shape[-1], 1/sr)

pre_shelf = high_shelf_curve(freqs, cutoff_freq, boost)
post_shelf = high_shelf_curve(freqs, cutoff_freq, cut)

# Apply pre-shelf to FFT?
fft_pre = fft_wave * pre_shelf[np.newaxis, :]
wave_pre = np.fft.irfft(fft_pre, n=waveform.shape[-1])

# Apply distortion via tanh function
waveform_distorted = np.tanh(distort_amount * wave_pre)

# Apply post-shelf to FFT?
fft_post = np.fft.rfft(waveform_distorted, axis=-1)
fft_post = fft_post * post_shelf[np.newaxis, :]
waveform_post = np.fft.irfft(fft_post, n=waveform.shape[-1])

# Normalize to -1 < 0 < -1 to prevent clipping
max_val = np.max(np.abs(waveform_post))
if max_val > 1.0:
    waveform_lshelf = waveform_post / max_val

display(Audio(waveform_post, rate=sr))
plot_waveform(waveform, sr, other=waveform_post, channel=0, show_fft=True)