This Notebook is for example training speaker embedding
I used kaggle GPU for training this model. 
The dataset can be download at https://www.openslr.org/12
or access at: https://www.kaggle.com/datasets/hieugiaosu/librispeech

In [1]:
import numpy as np
import pandas as pd
import os
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader,Dataset
import torch.optim as optim
import matplotlib.pyplot as plt
import librosa

In [2]:
len(os.listdir('/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100'))

251

In [3]:
train_df = pd.DataFrame([],columns=['speaker','audio_file'])
addition_data_from_test = [1089,1188,121,1221,1284]
test_list = list([])
speaker_map = {int(speaker):idx for idx,speaker in enumerate(os.listdir('/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100'))}
for i in addition_data_from_test:
    speaker_map[i] = len(speaker_map.keys())
for speaker in os.listdir('/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100'):
    audio_file = list([])
    for chapter in os.listdir(f'/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100/{speaker}'):
        audio_file_list = list(filter(lambda x: 'txt' not in x,os.listdir(f'/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100/{speaker}/{chapter}')))
        audio_file_list = list(map(lambda x: (speaker_map[int(speaker)],f"/kaggle/input/librispeech/train-clean-100/LibriSpeech/train-clean-100/{speaker}/{chapter}/{x}"),audio_file_list))
        audio_file = audio_file + audio_file_list
    new_rows = pd.DataFrame(audio_file[:-3],columns=['speaker','audio_file'])
    train_df = pd.concat([train_df,new_rows],axis=0)
    test_list = test_list + audio_file[-3:]
for speaker in addition_data_from_test:
    audio_file = list([])
    for chapter in os.listdir(f'/kaggle/input/librispeech/test-clean/LibriSpeech/test-clean/{speaker}'):
        audio_file_list = list(filter(lambda x: 'txt' not in x,os.listdir(f'/kaggle/input/librispeech/test-clean/LibriSpeech/test-clean/{speaker}/{chapter}')))
        audio_file_list = list(map(lambda x: (speaker_map[int(speaker)],f"/kaggle/input/librispeech/test-clean/LibriSpeech/test-clean/{speaker}/{chapter}/{x}"),audio_file_list))
        audio_file = audio_file + audio_file_list
    new_rows = pd.DataFrame(audio_file[:-3],columns=['speaker','audio_file'])
    train_df = pd.concat([train_df,new_rows],axis=0)
    test_list = test_list + audio_file[-3:]
#     test_list.append(audio_file[-1])
test_df = pd.DataFrame(test_list,columns=['speaker','audio_file'])

In [4]:
def preprocess(sample_rate=16000,n_fft=400,hop_length=160,f_min=80,f_max=4000,n_mels=64):
    f = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                             n_fft=n_fft,hop_length=hop_length,
                                             f_min=f_min,f_max=f_max,
                                             n_mels=n_mels)
    def proc(audio):
        with torch.no_grad():
            mel = f(audio)
            mel = torch.log(1+mel)
            mel = torch.tanh(mel)
        return mel
    return proc

In [6]:
class SpeakerEmbeddingDataset(Dataset):
    def __init__(self,data,sample_rate=16000,window_length=0.1,max_chunk_num = 50):
        super().__init__()
        self.data = data
        self.sample_rate = sample_rate
        self.window_length = int(window_length*sample_rate)
        self.max_chunk_num = max_chunk_num
        self.max_audio_length = 16000*5
        self.proc = preprocess()
        
    def __len__(self): return len(self.data)*2
    def __getitem__(self,idx):
        i = idx//2
        r = idx%2
        row = self.data.iloc[i]
        audio,rate = torchaudio.load(row['audio_file'])
        audio = audio.squeeze()
        if r==0:
            audio = audio[:audio.shape[0]//2]
        else:
            audio = audio[audio.shape[0]//2:]
        if audio.shape[0] < self.max_audio_length:
            padding_size = self.max_audio_length - audio.shape[0]
            padding = torch.zeros(padding_size).float()
            audio = torch.cat([padding,audio],dim = 0)
        elif audio.shape[0]> self.max_audio_length:
            audio = audio[:self.max_audio_length]
        audio = self.proc(audio)
        speaker = torch.tensor([int(row['speaker'])])
        return audio,speaker

In [7]:
train_ds = SpeakerEmbeddingDataset(train_df)
test_ds = SpeakerEmbeddingDataset(test_df)
TESTSIZE =len(test_ds)
print(len(train_ds),len(test_ds))

56092 1536


In [8]:
class SpeakerEmbedding(nn.Module):
    def __init__(self,input_dim,hidden_state_dim,embedding_dim):
        super().__init__()
        self.input_transform = nn.Sequential(
            nn.Linear(input_dim,hidden_state_dim),
            nn.ELU(),
            nn.Linear(hidden_state_dim,hidden_state_dim),
            nn.ELU()
        )
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=2*hidden_state_dim, nhead=2, activation='relu',batch_first=True),
            num_layers = 2
        )
        
        self.fc = nn.Linear(2*hidden_state_dim,embedding_dim)
       
        self.hidden_state_dim = hidden_state_dim
        self.embedding_dim = embedding_dim
        self.input_dim = input_dim
    def forward(self,x):
        batch_size = x.shape[0]
        seq_len = x.shape[2]
        i = torch.transpose(x,1,2)
        i = self.input_transform(i)
        pos = self.get_sinusoidal_positional_encoding(seq_len,self.hidden_state_dim,i.device).expand(batch_size,-1,-1)
        q = torch.cat([i,pos],dim = -1)
        att = self.transformer(q)
        o = self.fc(att[:,-1,:])
        return o
        
    def get_sinusoidal_positional_encoding(self, max_len, d_model,device=None):
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))

        pos_encoding = torch.zeros((max_len, d_model))
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        pos_encoding = pos_encoding.to(device)
        return pos_encoding
        

In [9]:
class ArcFace(nn.Module):
    """
This is my reimplement of the ArcFace loss function
    """
    def __init__(self,numClasses,embeddingSize,margin,scale,eps=1e-6):
        super().__init__()
        self.numClasses = numClasses
        self.embeddingSize = embeddingSize
        self.m = margin
        self.s = scale
        self.eps = eps
        self.W = nn.Parameter(torch.Tensor(numClasses, embeddingSize))
        nn.init.xavier_normal_(self.W)
    def forward(self,embeddings,labels=None):
        if labels is not None:
            batch_size = labels.size(0)
            cos = F.linear(F.normalize(embeddings), F.normalize(self.W))
            one_hot_encoding_labels = torch.zeros(batch_size, self.numClasses, device=labels.device)
            one_hot_encoding_labels.scatter_(1, labels.unsqueeze(-1), 1)
            cos_target_classes = cos[one_hot_encoding_labels==1]
            theta =  torch.acos(torch.clamp(cos_target_classes, -1 + self.eps, 1 - self.eps))
            cos_with_margin = torch.cos(theta+self.m)
            diff = (cos_with_margin-cos_target_classes).unsqueeze(1)
            logits = cos + one_hot_encoding_labels*diff
            logits = self.s*logits
            return logits
        else:
            cos = F.linear(F.normalize(embeddings), F.normalize(self.W))
            return cos

In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [11]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = SpeakerEmbedding(64,128,256)
        self.arcFace = ArcFace(256,256,0.4,64)
    def forward(self,x,labels=None):
        y = self.emb(x)
        logits = self.arcFace(y,labels)
        return logits

In [12]:
model =  Classifier()

In [13]:
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
  fig, axs = plt.subplots(1, 1)
  axs.set_title(title or 'Spectrogram (db)')
  axs.set_ylabel(ylabel)
  axs.set_xlabel('frame')
  im = axs.imshow(torchaudio.functional.amplitude_to_DB(spec,10.,1e-10,0), origin='lower', aspect=aspect)
  if xmax:
    axs.set_xlim((0, xmax))
  fig.colorbar(im, ax=axs)
  plt.show(block=False)

In [14]:
train_loader = DataLoader(train_ds,batch_size=256,drop_last=True,shuffle=True)
test_loader = DataLoader(test_ds,batch_size=256,shuffle=False,drop_last=False)

In [15]:
def train(model,epochs,train_loader,test_loader,multi_gpu=None):
    global TESTSIZE
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if multi_gpu is None or multi_gpu is True:
        multi_gpu = True if torch.cuda.device_count()>1 else False
    else:
        multi_gpu = False
        print('do not use multi gpu')
    optimizer = optim.Adam(model.parameters(),lr=5e-4)
    lossfn = nn.CrossEntropyLoss()
    acc = -1
    not_decrease = 0
    if multi_gpu:
        model = nn.DataParallel(model)
    model.to(device)
    for epoch in range(epochs):
        if multi_gpu:
            model.module.train()
        else:
            model.train()
        batch = 0
        for x,y in train_loader:
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)
            yHat = model(x,y.squeeze())
            loss = lossfn(yHat,y.squeeze()).mean()
            loss.backward()
            optimizer.step()
            print(f"epoch: {epoch} batch {batch} loss {loss.detach().item()}")
            batch+=1
        if multi_gpu:
            model.module.eval()
        else:
            model.eval()
        
        
        with torch.no_grad():
            count = 0
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                yHat = model(x)
                accuracy = (yHat.argmax(dim=1).long() == y.squeeze()).float().sum().cpu().item()
                count += accuracy
            accuracy = count/TESTSIZE
            print(f"test accuracy: {accuracy*100}%")
            if accuracy >= acc:
                acc = accuracy
                not_decrease = 0
                if multi_gpu:
                    torch.save(model.module.state_dict(),"vad.pth")
                else:
                    torch.save(model.state_dict(),"vad.pth")
            else:
                not_decrease +=1 
        if not_decrease == 5:
            print('early stopping')
            break

In [17]:
train(model,50,train_loader,test_loader,False)