In [30]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time

import tqdm

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes

from scipy.io import wavfile
from scipy.signal import stft,istft

In [2]:
#Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


<font color="red" size=6><b>Meta paramètres</b></font>

Comme calculer les spectrogrammes prend du temps, nous pouvons les calculer une fois pour toute puis les sauvegarder sur le disque. Toutefois, <font color="red"><b>ceci triple l'espace occupé sur disque</b></font>: passant de 6.3 Go à 17.2 Go. Veuillez donc préciser le paramètre `SAVE_SPECTROGRAMS` selon si vous pouvez utiliser cet espace.

In [3]:
DIRECTORY = Path("source_separation")
SAVE_SPECTROGRAMS = True

# Chargement des signaux, Visualisation et Dataset

## Chargement des signaux et spectrogrammes

In [4]:
train_folder = DIRECTORY / "train"
train_small_folder = DIRECTORY / "train_small"
test_folder = DIRECTORY / "test"
get_path = lambda folder,i : folder / ("000"+str(i))[-4:]

datasets_sizes = {
    train_small_folder : 50,
    train_folder : 5000,
    test_folder : 2000
}

Tous les signaux ont la même fréquence d'échantillonage, même longueur; et donc les mêmes f et t échantillonés pour le Spectrogramme.

In [5]:
def check_all_same(folder=train_small_folder):
    f_ref = None
    t_ref = None
    for i in range(datasets_sizes[train_small_folder]):
        fe,signal = wavfile.read(get_path(train_small_folder,i) / "voice.wav")
        len_signal = len(signal)
        f_spec,t_spec,spec = stft(
            signal,fs=fe,
            nperseg=400,nfft=512,noverlap=100)
        if f_ref is None: f_ref = f_spec ; t_ref = t_spec
        assert np.allclose(f_spec,f_ref)
        assert np.allclose(t_spec,t_ref)
    return fe,f_ref,t_ref

fe,f_ref,t_ref = check_all_same()
get_spectrogram = lambda signal : stft(signal,fs=fe,nperseg=400,nfft=512,noverlap=100)[2]

In [12]:
def load_signal_folder(folder: Path,
        load_signals=True,
        load_spectrograms=True) -> dict[str,dict]:
    """
    Return a dictionary with 3 sub dicts: "voice", "noise" and "mix"; and an "SNR" key.
    Each sub dict has 3 keys: "filename", "signal" and "spectrogram" 
    (except if load_signals or load_spectrograms are set to False)
    """
    keys = ["voice","noise","mix"]
    res = dict((k,dict()) for k in keys)
    for f in folder.iterdir():
        assert f.is_file()
        if "voice" in f.name: key = "voice"
        elif "noise" in f.name: key = "noise"
        else: 
            key = "mix"
            if f.suffix == ".wav":
                res["SNR"] = f.name.removesuffix(".wav").split("_")[-1]
        if f.suffix == ".wav" and load_signals:
            fe,signal = wavfile.read(f)
            res[key]["filename"] = f.name
            res[key]["signal"] = signal
        elif f.suffix == ".pt" and SAVE_SPECTROGRAMS and load_spectrograms: 
            # when SAVE_SPECTROGRAMS is False, we shouldn't be able to load them 
            # to save time, otherwise it's cheating.
            res[key]["spectrogram"] = torch.load(f)
    # Create missing spectrograms
    if load_spectrograms:
        for key in keys:
            if "spectrogram" not in res[key]:
                assert load_signals
                spec = get_spectrogram(res[key]["signal"])
                res[key]["spectrogram"] = spec
                if SAVE_SPECTROGRAMS:
                    torch.save(torch.tensor(spec),folder / f"{key}_spectrogram.pt")
    return res


def remove_all_spectrograms():
    for folder in [train_folder,train_small_folder,test_folder]:
        for i in range(datasets_sizes[folder]):
            folder_i: Path = get_path(folder,i)
            for f in folder_i.iterdir():
                if "spectrogram" in f.name:
                    f.unlink()

Comme calculer les spectrogrammes prend du temps, nous pouvons les calculer une fois pour toute, en sauvegardant tous les spectrogrammes sur disque. Attention, <font color="red"><b>ceci triple l'espace occupé sur disque</b></font>: passant de 6.3 Go à 17.2 Go.
Pour les retirer utiliser: `remove_all_spectrograms()`.

In [7]:
def create_spectrograms(folder: Path):
    if SAVE_SPECTROGRAMS and not (folder/"0000"/"voice_spectrogram.pt").exists():
        for i in range(datasets_sizes[folder]):
            load_signal_folder(get_path(folder,i))

create_spectrograms(train_small_folder)
create_spectrograms(test_folder)
create_spectrograms(train_folder)

In [9]:
def compute_time(folder):
    time_load_signal = 0
    time_load_spec = 0
    time_spec = 0
    start = time.perf_counter()
    for i in range(datasets_sizes[folder]):
        fe,signal = wavfile.read(get_path(folder,i) / "voice.wav")
        time_load_signal += time.perf_counter() - start ; start = time.perf_counter()
        path_spec: Path = get_path(folder,i) / "voice_spectrogram.pt"
        if path_spec.exists():
            _ = torch.load(path_spec)
        time_load_spec += time.perf_counter() - start ; start = time.perf_counter()
        _ = get_spectrogram(signal)
        time_spec += time.perf_counter() - start ; start = time.perf_counter()

    print(f"Temps total pour load {datasets_sizes[folder]} signaux: {time_load_signal}")
    print(f"Temps total pour en calculer les spectrogrammes: {time_spec}")
    print(f"Comparé au temps pour charger les spectrogrammes: {time_load_spec}")

print("Première fois:")
compute_time(test_folder)
print("\nDeuxième fois:")
compute_time(test_folder)

Première fois:
Temps total pour load 2000 signaux: 7.638150982558727
Temps total pour en calculer les spectrogrammes: 9.375799618661404
Comparé au temps pour charger les spectrogrammes: 9.05385460332036

Deuxième fois:
Temps total pour load 2000 signaux: 0.47687792032957077
Temps total pour en calculer les spectrogrammes: 3.56374404206872
Comparé au temps pour charger les spectrogrammes: 1.2208664305508137


On voit que charger les spectrogrammes pré-calculés est plus rapide que de les calculer à chaque fois. On note aussi une grosse différence entre la première fois qu'un fichier est chargé et la seconde. Nous pensons que le système place les derniers fichiers chargés dans le cache (recharger le notebook n'y change rien, donc la différence n'apparait que la toute première fois).

## Dataset

In [10]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,
            folder: Path,
            load_signals=True,
            load_spectrograms=False):
        self.folder = folder
        self.load_signals = load_signals
        self.load_spectrograms = load_spectrograms
            
    def __len__(self):
        return datasets_sizes[self.folder]
    
    def __getitem__(self,i):
        d = load_signal_folder(
            get_path(self.folder,i),
            load_signals=self.load_signals,
            load_spectrograms=self.load_spectrograms)
        ret = []
        for name in ["voice","noise","mix"]:
            if self.load_signals:
                ret.append(d[name]["signal"])
            if self.load_spectrograms:
                ret.append(d[name]["spectrogram"])
        ret.append(d["SNR"])
        return ret

Le Dataset peut contenir les signaux et/ou les spectrogrammes. De sorte à ne charger que le nécessaire. Exemple si on veut tout charger:

In [34]:
train_dataset = MyDataset(train_folder,load_signals=True,load_spectrograms=True)
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)
print("Dataset length:",len(train_dataset))

Dataset length: 5000


In [None]:
for voice_signal,voice_spec,noise_signal,noise_spec,mix_signal,mix_spec,snr in train_dataloader:
    print("Signal's shape: ",voice_signal.shape)
    print("Spectrogram's shape: ",voice_spec.shape)
    print("SNRs :",snr)
    break

Signal's shape:  torch.Size([32, 80000])
Spectrogram's shape:  torch.Size([32, 257, 268])
SNRs : ('2', '-1', '4', '4', '2', '0', '0', '1', '2', '0', '-4', '1', '0', '2', '-4', '2', '4', '0', '1', '-2', '2', '-3', '1', '0', '-1', '1', '-4', '-4', '1', '4', '-2', '1')


# Conv-tasnet: surpassing ideal time–frequency magnitude masking for speech separation.
(Inspired from Luo et al., 2019 and from this popular Github repository : https://github.com/JusperLee/Conv-TasNet/tree/master/Conv_TasNet_Pytorch)

In [145]:
train_dataset = MyDataset(
    train_folder,
    load_signals=True,
    load_spectrograms=False)
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)

print("Dataset length:",len(train_dataset))

Dataset length: 5000


In [158]:
for voice_signal,noise_signal,mix_signal,snr in train_dataloader:
    print("Spec's shape:",voice_spec.shape)
    print("SNRs:",snr)
    break

Spec's shape: torch.Size([32, 80000])
SNRs: ('4', '-4', '-3', '0', '-2', '-2', '-4', '-3', '0', '0', '-4', '3', '0', '1', '0', '-2', '-2', '2', '-4', '-3', '-1', '4', '3', '3', '2', '-3', '1', '-3', '-3', '0', '4', '4')


In [159]:
def count_n_param(model):
    return sum([p.numel() for p in model.parameters()])

In [160]:
class ConvTasNet(nn.Module):
    def __init__(self, 
                 num_sources=2, 
                 num_filters=512, 
                 kernel_size=16, 
                 stride=8, 
                 bottleneck_channels=128, 
                 num_blocks=8, 
                 num_repeats=3):
        super(ConvTasNet, self).__init__()

        # Encoder
        self.encoder = nn.Conv1d(32, num_filters, kernel_size, stride, padding=kernel_size//2, bias=False)

        # Separator
        self.separator = TemporalConvNet(num_filters, 
                                         bottleneck_channels, 
                                         num_blocks, 
                                         num_repeats, 
                                         num_sources)

        # Decoder
        self.decoder = nn.ConvTranspose1d(num_filters, 1, kernel_size, stride, padding=kernel_size//2, bias=False)

    def forward(self, mixture, print_shapes=False):
        """
        Args:
            mixture (torch.Tensor): Mixture audio signal, shape (batch, 1, time).

        Returns:
            torch.Tensor: Separated audio signals, shape (batch, num_sources, time).
        """
        # Encoder
        encoded = nn.functional.relu(self.encoder(mixture))

        # Separator
        masks = self.separator(encoded)
        if print_shapes: print(masks.shape)

        # Apply masks
        sources = masks * encoded.unsqueeze(0)

        # Decoder
        decoded_sources = torch.cat([
            self.decoder(sources[:, i]).unsqueeze(1)
            for i in range(sources.size(1))
        ], dim=1)

        if print_shapes: print(decoded_sources.shape)

        return masks, decoded_sources


class TemporalConvNet(nn.Module):
    def __init__(self, 
                 num_filters, 
                 bottleneck_channels, 
                 num_blocks, 
                 num_repeats, 
                 num_sources):
        super(TemporalConvNet, self).__init__()

        self.layers = nn.ModuleList()
        for r in range(num_repeats):
            for b in range(num_blocks):
                dilation = 2 ** b
                self.layers.append(TemporalBlock(num_filters, bottleneck_channels, dilation))

        self.mask_generator = nn.Conv1d(num_filters, num_filters * num_sources, kernel_size=1)
        self.num_sources = num_sources
        self.num_filters = num_filters

    def forward(self, x, print_shapes=False):
        """
        Args:
            x (torch.Tensor): Encoded mixture, shape (batch, num_filters, time).

        Returns:
            torch.Tensor: Masks for each source, shape (batch, num_sources, num_filters, time).
        """
        for layer in self.layers:
            x = layer(x)
            if print_shapes: print(x.shape)

        if print_shapes: print("-")
        masks = self.mask_generator(x)  # shape (batch, num_filters * num_sources, time)
        #masks = masks.view(masks.size(0), self.num_sources, self.num_filters, -1)  # shape (batch, num_sources, num_filters, time)
        masks = masks.view(1, self.num_sources, self.num_filters, masks.size(1))  # shape (batch, num_sources, num_filters, time)
        masks = nn.functional.softmax(masks, dim=1)  # Normalize across sources
        if print_shapes: print(masks.shape)

        return masks


class TemporalBlock(nn.Module):
    def __init__(self, num_filters, bottleneck_channels, dilation):
        super(TemporalBlock, self).__init__()

        self.layer_norm = nn.LayerNorm(num_filters)
        self.conv1x1 = nn.Conv1d(num_filters, bottleneck_channels, kernel_size=1, bias=False)
        
        self.depthwise_conv = nn.Conv1d(
            bottleneck_channels, 
            bottleneck_channels, 
            kernel_size=3, 
            stride=1, 
            padding=dilation, 
            dilation=dilation, 
            groups=bottleneck_channels, 
            bias=False
        )

        self.pointwise_conv = nn.Conv1d(bottleneck_channels, num_filters, kernel_size=1, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x, print_shapes=False):
        """
        Args:
        x (torch.Tensor): Input tensor, shape (batch, num_filters, time).
        
        Returns:
        torch.Tensor: Output tensor, shape (batch, num_filters, time).
        """
        residual = x
        x = x.transpose(0, 1)
        x = self.layer_norm(x)
        x = x.transpose(0, 1)
        x = self.relu(self.conv1x1(x))
        x = self.relu(self.depthwise_conv(x))
        x = self.pointwise_conv(x)
        
        if print_shapes: print(x.shape)
                
        return x + residual


In [161]:
model = ConvTasNet()
print(model)
print("nombre de paramètres:", count_n_param(model))

mask,output = model(voice_signal,print_shapes=True)
mask, output = model(mix_signal,print_shapes=True)


ConvTasNet(
  (encoder): Conv1d(32, 512, kernel_size=(16,), stride=(8,), padding=(8,), bias=False)
  (separator): TemporalConvNet(
    (layers): ModuleList(
      (0): TemporalBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv1x1): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
        (depthwise_conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), groups=128, bias=False)
        (pointwise_conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
        (relu): ReLU()
      )
      (1): TemporalBlock(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (conv1x1): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
        (depthwise_conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), groups=128, bias=False)
        (pointwise_conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
        (relu): ReLU()
      )
      (2

In [129]:
# Evaluation function
def evaluate_model(model, dataloader, criterion, device):
    """
    Evaluate the model on the given dataloader.
    
    Args:
        model (torch.nn.Module): The trained model.
        dataloader (DataLoader): DataLoader for the validation/test dataset.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to use for computation.
    
    Returns:
        float: Average loss over the dataset.
    """

    with torch.no_grad():
        for batch in dataloader:
            # Load data
            voice_signal, noise_signal, mix_signal, snr = batch
            
            # Convert mix_signal to mono-channel if needed
            mixture = mix_signal.unsqueeze(1).to(device)  # Shape: [batch, 1, time]
            
            # Stack voice and noise signals as sources
            sources = torch.stack([voice_signal, noise_signal], dim=1).to(device)  # Shape: [batch, num_sources, time]
            
            # Forward pass
            predicted_sources = model(mixture)
            
            # Compute loss
            loss = criterion(predicted_sources, sources)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss


In [130]:
# Test dataloader
test_dataset = MyDataset(test_folder, load_signals=False, load_spectrograms=True)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [163]:
# Hyperparameters
num_sources = 2
N = 256
L = 20
B = 256
H = 512
P = 3
X = 8
R = 4
num_epochs = 10
learning_rate = 0.001

# Initialize model, loss function, and optimizer
model = ConvTasNet(num_sources=num_sources, num_filters=N, kernel_size=L, stride=X, bottleneck_channels=H, num_blocks=B, num_repeats=R)
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters())

model.to(device) # Move model to the GPU

train_losses = []
validation_losses = []

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    # Iterate over batches
    for batch in progress_bar:
        # Load data
        voice_signal, noise_signal, mix_signal, snr = batch
        voice_signal, noise_signal, mix_signal = voice_signal.to(device), noise_signal.to(device), mix_signal.to(device)

        # Forward pass
        mask, output = model(mix_signal)
        pred_voice_spec = mask * mix_signal
        pred_noise_spec = (1-mask) * mix_signal

        # Compute loss
        loss = (loss_fn(pred_voice_spec,voice_spec)
                + loss_fn(pred_noise_spec,noise_spec))
            
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
        
    # Save the train loss for this epoch
    print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {epoch_loss / len(train_dataloader):.4f}")
    train_losses.append(epoch_loss / len(train_dataloader))

    # Validation phase
    val_loss = evaluate_model(model, test_dataloader, criterion, device)
    validation_losses.append(val_loss)
    print(f"Epoch {epoch + 1}/{num_epochs} - Validation Loss: {val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "conv_tasnet.pth")
print("Training complete! Model saved as 'conv_tasnet.pth'.")

Epoch 1/10:   0%|          | 0/157 [00:02<?, ?it/s]


TypeError: conv1d(): argument 'dilation' must be tuple of ints, but found element of type int at pos 1

In [None]:
# Tracé de la loss au cours des epochs
plt.plot(range(1, num_epochs + 1), train_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()