In [1]:
sample_callback.wavsample_callback.wav

import numpy as np
from scipy.io import wavfile
from scipy.signal import butter, ShortTimeFFT, filtfilt
from sklearn.preprocessing import minmax_scale

import matplotlib.pyplot as plt
plt.style.use('dark_background')  # heck yeah
%matplotlib widget

%load_ext autoreload
%autoreload 1

def normalize(x, range=(-1,1)):
    flattened = minmax_scale(x.flatten(), feature_range=range).astype('float32')
    return flattened.reshape(x.shape)

fs, audio = wavfile.read(wav_path)

In [2]:
class AudioObject:
    def __init__(
        self,
        audio,
        fs,
        b,
        a,
        do_spectrogram=True,
    ):
        """
        b,a: Numerator (b) and denominator (a) polynomials of the IIR filter
        """
        self.audio = audio
        self.fs = fs

        self.audio_filt = filtfilt(b, a, audio)

        if do_spectrogram:
            self.make_spectrogram()

    def make_spectrogram(
        self,
        n=1024,
        overlap=1020,
        normalize_range=(0, 1),
    ):
        from scipy.signal.windows import hamming

        window = hamming(n)
        hop = n - overlap

        self.SFT = ShortTimeFFT(
            window,
            hop,
            fs,
            fft_mode="onesided",
        )

        spx = self.SFT.spectrogram(self.audio_filt)
        spx = np.log10(spx)

        if normalize_range is not None:
            spx = normalize(spx, normalize_range)

        self.spectrogram = spx

    def plot_spectrogram(self, **kwargs):
        return plot_spectrogram(self.spectrogram, self.SFT, **kwargs)


def plot_spectrogram(
    spectrogram: np.ndarray,
    SFT: np.ndarray,
    ax=None,
    cmap="bone",
    plot_kwargs={},
    x_offset_s=0,
):
    import matplotlib.pyplot as plt
    import numpy as np

    if ax is None:
        fig, ax = plt.subplots()

    # extent: times of audio signal (s) & frequencies (Hz). for correct axis labels
    extent = np.array(SFT.extent(SFT.hop * spectrogram.shape[1])).astype("float")
    extent[0:2] += x_offset_s  # offset x axis

    ax.imshow(
        spectrogram,
        origin="lower",
        aspect="auto",
        extent=extent,
        cmap=cmap,
    )

    ax.set(
        xlabel="Time (s)",
        ylabel="Frequency (Hz)",
        **plot_kwargs,
    )

    return ax

In [3]:
# Filter + spectrogram parameters
# 8pole butterpass bandworth filter or wtv the kids are saying these days

f_low, f_high = (500, 15000)
b, a = butter(8, [f_low, f_high], btype="bandpass", fs=fs)

n = 1024  # window length
overlap = 1020

In [4]:
plot_spectrogram_kwargs = {
    "cmap": "magma",
    "plot_kwargs": {
        "ylim": (0, 15000),
    },
}

In [5]:
bird_audio = AudioObject(
    normalize(audio[:, 0]),
    fs,
    b,
    a,
    do_spectrogram=True,
)

In [6]:
speaker_audio = AudioObject(
    normalize(audio[:, 3]),
    fs,
    b,
    a,
    do_spectrogram=True,
)

In [None]:
for x in (bird_audio.spectrogram, speaker_audio.spectrogram):
    print("Min: %.3f | Max %.3f" % (np.min(x), np.max(x)))

In [None]:
ks = [0.8, 0.9, 1]
# subtract = lambda x, y, k: normalize(x.spectrogram - k * y.spectrogram, range=(0, 1))
subtract = lambda x, y, k: np.maximum(x.spectrogram - k * y.spectrogram, 0)

snippet = [26, 30]  # only plot between these times (in seconds)

fig, axs = plt.subplots(len(ks) + 1, 2, figsize=[4 * (len(ks) + 1), 8])
st, en = (np.array(snippet) * fs / (n - overlap)).astype("int")

to_plot = {
    "bird mic": bird_audio.spectrogram,
    "speaker mic": speaker_audio.spectrogram,
}

for k in ks:
    to_plot[r"$B-%.2fS$" % (k)] = subtract(bird_audio, speaker_audio, k)
    to_plot[r"$S-%.2fB$" % (k)] = subtract(speaker_audio, bird_audio, k)

for (t, x), ax in zip(to_plot.items(), axs.ravel()):

    plot_spectrogram(
        x[:, st:en],
        bird_audio.SFT,
        ax=ax,
        x_offset_s=snippet[0],
        **plot_spectrogram_kwargs,
    )

    ax.set(
        title=t,
        xlabel=None,
        ylabel=None,
    )

fig.tight_layout()