In [2]:
import os
import torchaudio

os.makedirs("./data", exist_ok=True)

dataset = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
print(len(dataset), "fichiers audio")

100.0%


28539 fichiers audio


In [3]:
# Filtrer 10 premiers locuteurs et leurs 100 fichiers chacun (~10 h)
locuteurs = sorted({speaker for (_, _, _, speaker, _, _) in dataset})[:10]
subset = [(wave, sr, _, spk, _, _) for (wave, sr, _, spk, _, _) in dataset if spk in locuteurs][:1000]
print(len(subset), "extraits audio (~10h)")

1000 extraits audio (~10h)


In [4]:
import torchaudio.transforms as T
mfcc_transform = T.MFCC(sample_rate=16000, n_mfcc=40)
waveform, sr, _, speaker_id, _, _ = subset[0]
mfcc = mfcc_transform(waveform)
print(mfcc.shape)

torch.Size([1, 40, 158])




In [5]:
import torch
import torch.nn as nn

class SpeakerNet(nn.Module):
    def __init__(self, num_speakers):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 10 * 20, 256),
            nn.ReLU(),
            nn.Linear(256, num_speakers)
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

In [12]:
import torch
from torch.nn.utils.rnn import pad_sequence

# batch : liste de tuples (waveform [1,T], sr, ..., speaker, ...)
def collate_fn(batch):
    # On sépare vagues et labels
    waves = [item[0].squeeze(0).t() for item in batch]  # chaque wave [T]
    speakers = torch.tensor([item[3] for item in batch])
    # Padding : on aligne sur la plus grande longueur T_max
    padded_waves = pad_sequence(waves, batch_first=True)  # [B, T_max]
    # Retour au format [B,1,T_max]
    padded_waves = padded_waves.unsqueeze(1)
    # Extraction MFCC sur l'ensemble
    mfccs = mfcc_transform(padded_waves)  # [B, n_mfcc, T_feat]
    # On transpose si nécessaire selon l'architecture
    mfccs = mfccs.unsqueeze(1)  # [B,1,n_mfcc,T_feat]
    return mfccs, speakers

In [None]:
# Déterminer le nombre de locuteurs dans `dataset`
speaker_ids = {speaker for (_, _, _, speaker, _, _) in dataset}
num_speakers = len(speaker_ids)

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
model = SpeakerNet(num_speakers=num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    for wave, sr, _, speaker, _, _ in dataloader:
        # Extraction MFCC
        mfcc = mfcc_transform(wave).unsqueeze(1)  # [B,1,40,T]
        logits = model(mfcc)
        loss = criterion(logits, speaker)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} – loss: {loss.item():.4f}")

ValueError: not enough values to unpack (expected 6, got 2)

In [None]:
speaker_ids = {speaker for (_, _, _, speaker, _, _) in dataset}
num_speakers = len(speaker_ids)

from torch.utils.data import DataLoader

# On réutilise la collate_fn définie en §2.4 pour le padding et extraction MFCC

dataloader = DataLoader(dataset,
                        batch_size=32,
                        shuffle=True,
                        collate_fn=collate_fn)

model = SpeakerNet(num_speakers=num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Entraînement avec batch de MFCCs déjà préparés
def train(model, dataloader):
    model.train()
    for epoch in range(10):
        total_loss = 0.0
        for mfccs, speakers in dataloader:
            logits = model(mfccs)             # mfccs: [B,1,n_mfcc,T]
            loss = criterion(logits, speakers)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch} | avg loss: {avg_loss:.4f}")

# Lancer l'entraînement
train(model, dataloader)

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 1, 1, 40, 1314]

In [None]:
# Sauvegarde
torch.save(model.state_dict(), "pretrained_speaker.pth")

# Pour recharger ultérieurement
model = SpeakerNet(num_speakers=num_speakers)
model.load_state_dict(torch.load("pretrained_speaker.pth"))
model.eval()

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import torch

# Split 80/20 train/val
from torch.utils.data import random_split
train_len = int(0.8 * len(dataset))
val_len = len(dataset) - train_len
train_ds, val_ds = random_split(dataset, [train_len, val_len])
val_loader = DataLoader(val_ds, batch_size=32)

# Prédictions et vérité
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for wave, sr, _, speaker, _, _ in val_loader:
        mfcc = mfcc_transform(wave).unsqueeze(1)
        logits = model(mfcc)
        preds = logits.argmax(dim=1)
        y_true.extend(speaker.tolist())
        y_pred.extend(preds.tolist())

# Matrice de confusion
cm = confusion_matrix(y_true, y_pred)
print("Matrice de confusion :")
print(cm)

# Rapport précision / rappel
report = classification_report(y_true, y_pred)
print("Rapport de classification :")
print(report)

In [None]:
import itertools
import matplotlib.pyplot as plt

# Affichage graphique de la matrice de confusion
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', aspect='auto')
plt.title('Matrice de confusion')
plt.colorbar()

# Etiquettes des axes (speaker IDs)
labels = sorted(set(y_true))
plt.xticks(range(len(labels)), labels, rotation=45)
plt.yticks(range(len(labels)), labels)

# Annotations
thresh = cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment='center',
             color='white' if cm[i, j] > thresh else 'black')

plt.ylabel('Vérité')
plt.xlabel('Prédiction')
plt.tight_layout()
plt.show()