In [1]:
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm

%matplotlib ipympl

import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes

from IPython.display import Audio
# Audio(data=signal.T,rate=fe)
from scipy.io import wavfile
from scipy.signal import stft,istft

In [2]:
datasets_sizes = {
    "train_small" : 50,
    "train" : 5000,
    "test" : 2000
}

In [12]:
def load_signals(folder: Path) -> dict[str,tuple[str,int,np.ndarray]]:
    """
    Return a dictionary with 3 items. The keys are : "voice", "noise" and "mix"
    The values are triplets: file name, fe and signal (a np.ndarray)
    """
    res = dict()
    for f in folder.iterdir():
        assert f.is_file()
        if "voice" in f.name: key = "voice"
        elif "noise" in f.name: key = "noise"
        else: key = "mix"
        res[key] = (f.name,*wavfile.read(f))
    return res


def visualize_signal(folder):
    signals = load_signals(folder)
    colorbars = []
    for ax,name in zip(axs,["voice","noise","mix"]):
        file_name,fe,signal = signals[name]
        f_spec,t_spec,spec = stft(
            signal,fs=fe,
            nperseg=400,nfft=512,noverlap=100)
        spec_dB = 10*np.log10(abs(spec))
        plt_obj = ax.pcolormesh(
            t_spec,f_spec,spec_dB,
            vmax=np.percentile(spec_dB,99),
            vmin=np.percentile(spec_dB,10))
        colorbars.append(plt.colorbar(plt_obj,ax=ax))
        ax.set_title(file_name)
        ax.tick_params(axis='both', which='major', labelsize=4)
        colorbars[-1].ax.tick_params(axis='both', which='major', labelsize=4)
        ax.set_xlabel("Temps (ms)",fontsize=4)
        ax.set_ylabel("Fréquence (GHz)",fontsize=4)
    return colorbars


folder = "train_small"
plt.close("viz")
fig_viz = plt.figure("viz",figsize=(7,2.5))
gs = gridspec.GridSpec(2,3) # to get good control of the color bars
ax_slices = [np.s_[:,i] for i in range(3)]
axs = [fig_viz.add_subplot(gs[sli]) for sli in ax_slices]
colorbars = []

@widgets.interact(i=(0,datasets_sizes[folder]-1,1))
def update(i=1):
    for ax in axs: ax.cla()
    global colorbars
    fig_viz = plt.figure("viz")
    try: 
        if colorbars != []:
            for colorbar in colorbars:
                fig_viz.delaxes(colorbar.ax)
            gs = gridspec.GridSpec(2,3)
            for ax,sli in zip(axs,ax_slices):
                ax.set_position(gs[sli].get_position(fig_viz))
                ax.set_subplotspec(gs[sli])
    except:
        print('got an error')
        pass
    folder_i = Path(folder)/ ("000"+str(i))[-4:]
    colorbars = visualize_signal(folder_i)
    fig_viz.suptitle(f"Spectogrammes pour {folder_i}")
    fig_viz.tight_layout()
    plt.show()
    


interactive(children=(IntSlider(value=1, description='i', max=49), Output()), _dom_classes=('widget-interact',…