In [91]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import time

%matplotlib ipympl

import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes

from IPython.display import Audio
# Audio(data=signal.T,rate=fe)
from scipy.io import wavfile
from scipy.signal import stft,istft

<font color="red" size=6><b>Meta paramètres</b></font>

Comme calculer les spectrogrammes prend du temps, nous pouvons les calculer une fois pour toute puis les sauvegarder sur le disque. Toutefois, <font color="red"><b>ceci triple l'espace occupé sur disque</b></font>: passant de 6.3 Go à 17.2 Go. Veuillez donc préciser le paramètre `SAVE_SPECTROGRAMS` selon si vous pouvez utiliser cet espace.

In [2]:
DIRECTORY = Path("source_separation")
SAVE_SPECTROGRAMS = True

# Chargement des signaux, Visualisation et Dataset

## Chargement des signaux et spectrogrammes

In [3]:
train_folder = DIRECTORY / "train"
train_small_folder = DIRECTORY / "train_small"
test_folder = DIRECTORY / "test"
get_path = lambda folder,i : folder / ("000"+str(i))[-4:]

datasets_sizes = {
    train_small_folder : 50,
    train_folder : 5000,
    test_folder : 2000
}

Tous les signaux ont la même fréquence d'échantillonage, même longueur; et donc les mêmes f et t échantillonés pour le Spectrogramme.

In [83]:
stft_kwargs = {
    "nperseg":400,
    "nfft":512,
    "noverlap":149 #86,
}
def check_all_same(folder=train_small_folder):
    f_ref = None
    t_ref = None
    for i in range(datasets_sizes[train_small_folder]):
        fe,signal = wavfile.read(get_path(train_small_folder,i) / "voice.wav")
        len_signal = len(signal)
        f_spec,t_spec,spec = stft(
            signal,fs=fe,**stft_kwargs)
        if f_ref is None: f_ref = f_spec ; t_ref = t_spec
        assert np.allclose(f_spec,f_ref)
        assert np.allclose(t_spec,t_ref)
    return fe,f_ref,t_ref

fe,f_ref,t_ref = check_all_same()
get_spectrogram = lambda signal : stft(signal,fs=fe,**stft_kwargs)[2]
reverse_spectrogram = lambda spec : istft(spec,fe,**stft_kwargs)[1][:80000]

# Pour vérifier que le spectrogramme est bien inversible:
def test():
    i = np.random.randint(0,datasets_sizes[train_folder])
    signal = wavfile.read(get_path(train_folder,i)/"voice.wav")[1]
    spec = get_spectrogram(signal)
    resignal = reverse_spectrogram(spec)
    print("Spectrogram shape:",spec.shape)
    print("Original signal",signal.shape)
    print("Reconstructed signal",resignal.shape)
    print("Allclose",np.allclose(signal,resignal))
test()

Spectrogram shape: (257, 320)
Original signal (80000,)
Reconstructed signal (80000,)
Allclose True
(320,)


In [45]:
def torch_load(file):
    # depending on torch.__version__ you may need a 'weights_only=True' argument
    # to avoid a warning (and in old versions this kwarg doesn't exist)
    # I'm not sure since which version this warning exists,
    # so I'm doing a try/except instead of a if version>2.1: ...
    try:
        return torch.load(file,weights_only=True)
    except:
        return torch.load(file)

def load_signal_folder(folder: Path,
        load_signals=True,
        load_spectrograms=True) -> dict[str,dict]:
    """
    Return a dictionary with 3 sub dicts: "voice", "noise" and "mix"; and an "SNR" key.
    Each sub dict has 3 keys: "filename", "signal" and "spectrogram" 
    (except if load_signals or load_spectrograms are set to False)
    """
    keys = ["voice","noise","mix"]
    res = dict((k,dict()) for k in keys)
    for f in folder.iterdir():
        assert f.is_file()
        if "voice" in f.name: key = "voice"
        elif "noise" in f.name: key = "noise"
        else: 
            key = "mix"
            if f.suffix == ".wav":
                res["SNR"] = f.name.removesuffix(".wav").split("_")[-1]
        if f.suffix == ".wav" and load_signals:
            fe,signal = wavfile.read(f)
            res[key]["filename"] = f.name
            res[key]["signal"] = signal
        elif f.suffix == ".pt" and SAVE_SPECTROGRAMS and load_spectrograms: 
            # when SAVE_SPECTROGRAMS is False, we shouldn't be able to load them 
            # to save time, otherwise it's cheating.
            res[key]["spectrogram"] = torch_load(f)
    # Create missing spectrograms
    if load_spectrograms:
        for key in keys:
            if "spectrogram" not in res[key]:
                assert load_signals
                spec_abs = abs(get_spectrogram(res[key]["signal"]))
                res[key]["spectrogram"] = spec_abs
                if SAVE_SPECTROGRAMS:
                    torch.save(torch.tensor(spec_abs),folder / f"{key}_spectrogram.pt")
    return res


def remove_all_spectrograms():
    for folder in [train_folder,train_small_folder,test_folder]:
        for i in range(datasets_sizes[folder]):
            folder_i: Path = get_path(folder,i)
            for f in folder_i.iterdir():
                if "spectrogram" in f.name:
                    f.unlink()

Comme calculer les spectrogrammes prend du temps, nous pouvons les calculer une fois pour toute, en sauvegardant toutes les amplitudes des spectrogrammes sur disque (puisque l'apprentissage travaille avec les amplitudes, pas besoin de sauvegarder la phase). Attention, <font color="red"><b>ceci double l'espace occupé sur disque</b></font>: passant de 6.3 Go à 12.8 Go.
Pour les retirer utiliser: `remove_all_spectrograms()`.

In [46]:
def create_spectrograms(folder: Path):
    if SAVE_SPECTROGRAMS and not (folder/"0000"/"voice_spectrogram.pt").exists():
        for i in range(datasets_sizes[folder]):
            load_signal_folder(get_path(folder,i))

create_spectrograms(train_small_folder)
create_spectrograms(test_folder)
create_spectrograms(train_folder)

In [47]:
def compute_time(folder):
    time_load_signal = 0
    time_load_spec = 0
    time_spec = 0
    start = time.perf_counter()
    for i in range(datasets_sizes[folder]):
        fe,signal = wavfile.read(get_path(folder,i) / "voice.wav")
        time_load_signal += time.perf_counter() - start ; start = time.perf_counter()
        path_spec: Path = get_path(folder,i) / "voice_spectrogram.pt"
        if path_spec.exists():
            _ = torch_load(path_spec)
        time_load_spec += time.perf_counter() - start ; start = time.perf_counter()
        _ = get_spectrogram(signal)
        time_spec += time.perf_counter() - start ; start = time.perf_counter()

    print(f"Temps total pour load {datasets_sizes[folder]} signaux: {time_load_signal}")
    print(f"Temps total pour en calculer les spectrogrammes: {time_spec}")
    print(f"Comparé au temps pour charger les spectrogrammes: {time_load_spec}")

print("Première fois:")
compute_time(test_folder)
print("\nDeuxième fois:")
compute_time(test_folder)

Première fois:
Temps total pour load 2000 signaux: 0.7348379389877664
Temps total pour en calculer les spectrogrammes: 1.9839318960275705
Comparé au temps pour charger les spectrogrammes: 3.389195533981365

Deuxième fois:
Temps total pour load 2000 signaux: 0.20108442701166496
Temps total pour en calculer les spectrogrammes: 1.7352007419658548
Comparé au temps pour charger les spectrogrammes: 0.5324180929983413


On voit que charger les spectrogrammes pré-calculés est plus rapide que de les calculer à chaque fois. On note aussi une grosse différence entre la première fois qu'un fichier est chargé et la seconde, j'imagine que le système place les derniers fichiers chargés dans le cache (recharger le notebook n'y change rien, donc la différence n'apparait que la toute première fois)

## Visualization

<font color="green">Note: les versions très récentes de `ipympl` ont une erreur de frappe dans le code, avec une variable nommé "buttons" au lieu de "button". Ainsi, si vous avez une version instable de `ipympl`, il se peut que d'un coup la cellule interactive ci-dessous écrivent des dizaines d'erreurs à la chaine. Nous n'avons pas trouvé ce qui les déclenche, mais vous pouvez les ignorer (puisqu'elles n'empechent pas la cellule de tourner), sinon vous pouvez simplement ouvrir le fichier d'où vient l'erreur (en cliquant sur la ligne d'erreur qui s'affiche des dizaines de fois), et changé "buttons", par "button".</font>

In [48]:
def visualize_signal(folder):
    signals = load_signal_folder(folder)
    colorbars = []
    for ax,name in zip(axs,["voice","noise","mix"]):
        spec_dB = 10*np.log10(signals[name]["spectrogram"].numpy())
        plt_obj = ax.pcolormesh(
            t_ref,f_ref,spec_dB,
            vmax=np.percentile(spec_dB,99),
            vmin=np.percentile(spec_dB,10))
        colorbars.append(plt.colorbar(plt_obj,ax=ax))
        ax.set_title(signals[name]["filename"])
        ax.tick_params(axis='both', which='major', labelsize=4)
        colorbars[-1].ax.tick_params(axis='both', which='major', labelsize=4)
        ax.set_xlabel("Temps (ms)",fontsize=4)
        ax.set_ylabel("Fréquence (GHz)",fontsize=4)
    return colorbars


folder = train_small_folder
plt.close("viz")
fig_viz = plt.figure("viz",figsize=(7,2.5))
gs = gridspec.GridSpec(2,3) # to get good control of the color bars
ax_slices = [np.s_[:,i] for i in range(3)]
axs = [fig_viz.add_subplot(gs[sli]) for sli in ax_slices]
colorbars = []

@widgets.interact(i=(0,datasets_sizes[folder]-1,1))
def update(i=1):
    for ax in axs: ax.cla()
    global colorbars
    fig_viz = plt.figure("viz")
    plt.title("")
    try: 
        if colorbars != []:
            for colorbar in colorbars:
                fig_viz.delaxes(colorbar.ax)
            gs = gridspec.GridSpec(2,3)
            for ax,sli in zip(axs,ax_slices):
                ax.set_position(gs[sli].get_position(fig_viz))
                ax.set_subplotspec(gs[sli])
    except:
        print('got an error')
        pass
    folder_i = get_path(folder,i)
    colorbars = visualize_signal(folder_i)
    fig_viz.suptitle(f"Spectogrammes pour {Path(folder_i.parent.name)/folder_i.name}")
    fig_viz.tight_layout()
    plt.show()
    


interactive(children=(IntSlider(value=1, description='i', max=49), Output()), _dom_classes=('widget-interact',…

## Dataset

In [49]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,
            folder: Path,
            load_signals=True,
            load_spectrograms=False):
        self.folder = folder
        self.load_signals = load_signals
        self.load_spectrograms = load_spectrograms
            
    def __len__(self):
        return datasets_sizes[self.folder]
    
    def __getitem__(self,i):
        d = load_signal_folder(
            get_path(self.folder,i),
            load_signals=self.load_signals,
            load_spectrograms=self.load_spectrograms)
        ret = []
        for name in ["voice","noise","mix"]:
            if self.load_signals:
                ret.append(d[name]["signal"])
            if self.load_spectrograms:
                ret.append(d[name]["spectrogram"])
        ret.append(d["SNR"])
        return ret

Le Dataset peut contenir les signaux et/ou les spectrogrammes. De sorte à ne charger que le nécessaire. Exemple si on veut tout charger:

In [50]:
train_dataset = MyDataset(train_folder,load_signals=True,load_spectrograms=True)
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)
print("Dataset length:",len(train_dataset))
for voice_signal,voice_spec,noise_signal,noise_spec,mix_signal,mix_spec,snr in train_dataloader:
    print("Signal's shape: ",voice_signal.shape)
    print("Spectrogram's shape: ",voice_spec.shape)
    print("SNRs :",snr)
    break

Dataset length: 5000
Signal's shape:  torch.Size([32, 80000])
Spectrogram's shape:  torch.Size([32, 257, 320])
SNRs : ('0', '1', '-1', '-1', '2', '0', '0', '1', '0', '3', '3', '3', '-4', '2', '-1', '2', '0', '4', '2', '1', '-3', '-3', '2', '-1', '0', '0', '3', '-1', '0', '-4', '0', '1')


# Seq2Seq using a U-Net: Singing Voice Separation With Deep U-Net Convolutional Networks

D'après le [papier de A. Jansson et al](https://openaccess.city.ac.uk/id/eprint/19289/1/7bb8d1600fba70dd79408775cd0c37a4ff62.pdf)

Et pour [l'architecture](https://github.com/phillipi/pix2pix/blob/master/models.lua) (qui vient du papier [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004))

In [51]:
train_dataset = MyDataset(
    train_folder,
    load_signals=False,
    load_spectrograms=True)
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)

In [52]:
for voice_spec,noise_spec,mix_spec,snr in train_dataloader:
    print("Spec's shape:",voice_spec.shape)
    print("SNRs:",snr)
    break

Spec's shape: torch.Size([32, 257, 320])
SNRs: ('-2', '-1', '0', '0', '-4', '2', '1', '-1', '4', '1', '4', '4', '2', '-1', '3', '3', '0', '-4', '0', '-3', '0', '-2', '-2', '0', '0', '-2', '0', '3', '3', '-1', '-4', '-1')


<font color="green"><b>Remarque / Astuce</b></font>:
Le réseau UNet divise la matrice par des puissances de 2, ici $2^6=64$, puis remultiplie d'autant. Ainsi, pour garder la même dimension en sortie, il est nécessaire que la dimension de la matrice d'entrée soit un multiple de 64. C'est pourquoi pour calculer les spectrogrammes nous prenons un `noverlap=149` (et non 100 comme indiqué dans le cours), ce qui donne $320 = 64*5$ temps d'échantillon (la 2nd dimension). (`noverlap=86` donne $256$ mais n'est pas toujours inversible...). Concernant la 1ère dimension, elle est à $257$: nous oublierons la dernière valeur (ce qui ramène à $256 = 64 * 4$), et en sortie le masque prend la même valeur en fréquence 257 qu'en fréquence 256, pas idéal...

In [90]:
def change_dim_257_to_256(x):
    return x[...,:256,:]

def change_dim_256_to_257(y):
    return torch.cat([y,y[...,-1,:].unsqueeze(-2)],-2)

In [149]:
class UNet(torch.nn.Module):
    def __init__(self,n_levels=5):
        super().__init__()
        self.n_levels = n_levels
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.drop_out = nn.Dropout(0.5)
        ngf = 16
        nb_C = [ngf*2**i for i in range(n_levels+1)]

        # Partie Descendante
        self.down_convs = nn.ModuleList()
        self.down_norms = nn.ModuleList()
        self.first_conv = nn.Conv2d(1,ngf,5,2,2)
        for i in range(n_levels):
            self.down_convs.append(nn.Conv2d(
                in_channels=nb_C[i],
                out_channels=nb_C[i+1],
                kernel_size=(5,5),
                stride=2,
                padding=5//2,
            ))
            if i < n_levels-1:
                self.down_norms.append(nn.BatchNorm2d(nb_C[i+1]))

        # Partie Ascendante
        self.up_convs = nn.ModuleList()
        self.up_norms = nn.ModuleList()
        self.last_conv = nn.ConvTranspose2d(2*ngf,1,5,2,5//2,output_padding=1)
        for i in range(n_levels,0,-1):
            self.up_convs.append(nn.ConvTranspose2d(
                in_channels=2*nb_C[i] if i<n_levels else nb_C[i],
                out_channels=nb_C[i-1],
                kernel_size=(5,5),
                stride=2,
                padding=5//2,
                output_padding=1
            ))
            self.up_norms.append(nn.BatchNorm2d(nb_C[i-1]))



    def forward(self,x: torch.Tensor,print_shapes=False):
        # Preprocess:
        mask = change_dim_257_to_256(x) # => Shape (B,256,320)
        mask = mask.unsqueeze(1) # => Shape (B,1,256,320)
        if print_shapes: print(mask.shape)

        # Partie Descendante
        mask = self.first_conv(mask)
        stack = [mask]
        if print_shapes: print(mask.shape)
        for i in range(self.n_levels):
            mask = self.leaky_relu(mask)
            mask = self.down_convs[i](mask)
            if i < self.n_levels-1:
                mask = self.down_norms[i](mask)
                stack.append(mask)
            if print_shapes: print(mask.shape)

        if print_shapes: print("-")
        # Partie Ascendante
        for i in range(self.n_levels,0,-1):
            if i < self.n_levels:
                mask = torch.cat([mask,stack.pop()],dim=1)
            mask = torch.relu(mask)
            mask = self.up_convs[self.n_levels-i](mask)
            mask = self.up_norms[self.n_levels-i](mask)
            if self.n_levels-i < 3:
                mask = self.drop_out(mask)
            if print_shapes: print(mask.shape)
        mask = torch.cat([mask,stack.pop()],dim=1)
        mask = torch.relu(mask)
        mask = self.last_conv(mask)
        if print_shapes: print(mask.shape)

        # Postprocess:
        mask = mask.squeeze(1) # => Shape (B,256,320)
        mask = change_dim_256_to_257(mask) # => Shape (B,257,320)
        output = mask * x
        if print_shapes: print("Output shape:",output.shape)
        return mask, output


        

In [151]:
model = UNet()
mask,output = model(voice_spec,print_shapes=True)

torch.Size([32, 1, 256, 320])
torch.Size([32, 16, 128, 160])
torch.Size([32, 32, 64, 80])
torch.Size([32, 64, 32, 40])
torch.Size([32, 128, 16, 20])
torch.Size([32, 256, 8, 10])
torch.Size([32, 512, 4, 5])
-
torch.Size([32, 256, 8, 10])
torch.Size([32, 128, 16, 20])
torch.Size([32, 64, 32, 40])
torch.Size([32, 32, 64, 80])
torch.Size([32, 16, 128, 160])
torch.Size([32, 1, 256, 320])
Output shape: torch.Size([32, 257, 320])


In [139]:
x = voice_spec
x = change_dim_257_to_256(x)
x = x.unsqueeze(1)
conv = torch.nn.Conv2d(1,1,5,2,5//2)
print("input shape",x.shape)
y = conv(x)
print("Output shape",y.shape)
unconv = torch.nn.ConvTranspose2d(1,1,5,2,2,output_padding=1)
z = unconv(y)
print("z: ",z.shape)

input shape torch.Size([32, 1, 256, 320])
Output shape torch.Size([32, 1, 128, 160])
z:  torch.Size([32, 1, 256, 320])


In [102]:
x = torch.randn(32,1,128)
conv = torch.nn.Conv1d(1,1,3,2,1)
print("input shape",x.shape)
y = conv(x)
print("Output shape",y.shape)

input shape torch.Size([32, 1, 128])
Output shape torch.Size([32, 1, 64])


In [126]:
for i in range(10,0,-1):
    print(i)

10
9
8
7
6
5
4
3
2
1
