Run this to generate mixture dataset

In [1]:
# !python3 mixture.py --vox2_dir data/vox2 --output_dir data

Speaker Seperation and Enhancement using SepFormer

Use the pre-trained SepFormer model to perform speaker separation and speech enhancement of each speaker on the created test set by analyzing these metrics: Signal to Interference Ratio (SIR), Signal to Artefacts Ratio (SAR), Signal to Distortion Ratio (SDR) and Perceptual Evaluation of Speech Quality (PESQ). 


In [1]:
import pandas as pd
import os
from sepformer import *
from data import *
import random
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [allow_tf32, disable_jit_profiling]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []


In [2]:
sepformer = get_sepformer_model('cuda')

INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/sepformer-whamr' if not cached
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/sepformer-whamr' if not cached
INFO:speechbrain.utils.fetching:Fetch masknet.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-whamr' if not cached
INFO:speechbrain.utils.fetching:Fetch encoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-whamr' if not cached
INFO:speechbrain.utils.fetching:Fetch decoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-whamr' if not cached
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: masknet, encoder, decoder


In [3]:
vox_mix = VoxCelebmix(data_dir='data/vox2_mix')

In [4]:
n_samples = 1_000

In [5]:
metrics_dir = {}

for i in tqdm(range(n_samples)):
    sample_idx = random.randint(0, len(vox_mix) - 1)
    wav_mix, wav_s1, wav_s2, sp1, sp2, _ = vox_mix[sample_idx]
    
    min_sz = min(wav_s1.shape[-1], wav_s2.shape[-1])
    
    if wav_s1.shape[-1] > min_sz:
        wav_s1 = wav_s1[:, :min_sz]
    if wav_s2.shape[-1] > min_sz:
        wav_s2 = wav_s2[:, :min_sz]

    pred_s1, pred_s2 = speaker_separation(sepformer, wav_mix)

    
    gt = torch.cat([wav_s1, wav_s2], dim=0).numpy()
    pred = torch.cat([pred_s1, pred_s2], dim=0).numpy()
    
    min_sz = min(gt.shape[-1], pred.shape[-1])
    
    if gt.shape[-1] > min_sz:
        gt = gt[:, :min_sz]
    if pred.shape[-1] > min_sz:
        pred = pred[:, :min_sz]

    sdr, sir, sar, pesq_scores = compute_separation_metrics(gt, pred, sample_rate=16_000)
    
    if sp1 not in metrics_dir.keys():
        metrics_dir[sp1] = [[sdr[0]], [sir[0]], [sar[0]], [pesq_scores[0]]]
    else:
        metrics_dir[sp1][0].append(sdr[0])
        metrics_dir[sp1][1].append(sir[0])
        metrics_dir[sp1][2].append(sar[0])
        metrics_dir[sp1][3].append(pesq_scores[0])
        
    if sp2 not in metrics_dir.keys():
        metrics_dir[sp2] = [[sdr[1]], [sir[1]], [sar[1]], [pesq_scores[1]]]
    else:
        metrics_dir[sp2][0].append(sdr[1])
        metrics_dir[sp2][1].append(sir[1])
        metrics_dir[sp2][2].append(sar[1])
        metrics_dir[sp2][3].append(pesq_scores[1])

100%|██████████| 1000/1000 [15:14<00:00,  1.09it/s]


In [11]:
df_data = []

for k, v in metrics_dir.items():
    row = {
        "Key": k,
        "SIR_MEAN": np.mean(v[1]), "SIR_VARIANCE": np.var(v[1]),
        "SAR_MEAN": np.mean(v[2]), "SAR_VARIANCE": np.var(v[2]),
        "SDR_MEAN": np.mean(v[0]), "SDR_VARIANCE": np.var(v[0]),
        "PESQ_MEAN": np.mean(v[3]), "PESQ_VARIANCE": np.var(v[3]),
    }
    df_data.append(row)

df = pd.DataFrame(df_data)

df.set_index("Key", inplace=True)


In [12]:
df.head()

Unnamed: 0_level_0,SIR_MEAN,SIR_VARIANCE,SAR_MEAN,SAR_VARIANCE,SDR_MEAN,SDR_VARIANCE,PESQ_MEAN,PESQ_VARIANCE
Key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
id00154,10.52138,227.802958,-4.142282,18.451845,-8.231168,81.194939,1.110132,0.004038
id00866,6.617601,343.37111,-3.835985,26.908087,-10.55737,118.509879,1.121674,0.040492
id01066,8.324409,319.070619,-3.091958,26.17335,-8.877471,108.790636,1.117321,0.008584
id00419,0.870727,330.560557,-5.923366,48.3268,-14.635148,157.250794,1.071582,0.002424
id01892,-1.81669,254.610565,-8.194641,59.926393,-17.643697,88.353395,1.079743,0.008471


In [13]:
df

Unnamed: 0_level_0,SIR_MEAN,SIR_VARIANCE,SAR_MEAN,SAR_VARIANCE,SDR_MEAN,SDR_VARIANCE,PESQ_MEAN,PESQ_VARIANCE
Key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
id00154,10.52138,227.802958,-4.142282,18.451845,-8.231168,81.194939,1.110132,0.004038
id00866,6.617601,343.37111,-3.835985,26.908087,-10.55737,118.509879,1.121674,0.040492
id01066,8.324409,319.070619,-3.091958,26.17335,-8.877471,108.790636,1.117321,0.008584
id00419,0.870727,330.560557,-5.923366,48.3268,-14.635148,157.250794,1.071582,0.002424
id01892,-1.81669,254.610565,-8.194641,59.926393,-17.643697,88.353395,1.079743,0.008471
id02577,0.987114,287.076045,-5.062195,42.158743,-13.372956,122.273652,1.122451,0.011619
id02019,0.602653,278.713769,-5.613182,22.46698,-13.967106,124.655204,1.079537,0.003499
id01593,5.8114,244.162415,-6.320831,28.311933,-11.891333,95.821173,1.108855,0.007911
id02542,-2.104947,321.684747,-4.849337,12.144909,-15.626547,100.231684,1.114204,0.007847
id03178,7.786317,282.144113,-3.075381,14.572132,-8.942457,89.394456,1.077445,0.001368


Further, use the above pre-trained and finetuned speaker identification model (obtained in II) to identify which enhanced speech corresponds to which speaker after speaker separation. Report the Rank-1 identification accuracy on both models. 


In [6]:
from lora_finetune import *
from evaluation import *

pretrained

In [7]:
model = init_model('models/checkpoints/wavlm_base_plus_nofinetune.pth')

Using cache found in /home/raid/.cache/torch/hub/s3prl_s3prl_main
INFO:s3prl.util.download:Requesting URL: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base_plus.pt
INFO:s3prl.util.download:Using URL's local file: /home/raid/.cache/s3prl/download/72cb34edf8a3724c720467cf40b77ad20b1b714b5f694e9db57f521467f9006b.wavlm_base_plus.pt
INFO:s3prl.upstream.wavlm.WavLM:WavLM Config: {'extractor_mode': 'default', 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': 'gelu', 'layer_norm_first': False, 'conv_feature_layers': '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2', 'conv_bias': False, 'feature_grad_mult': 0.1, 'normalize': False, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1, 'mask_length': 10, 'mask_prob': 0.8, 'mask_selection': 'static', 'mask_other': 0.0, 'no_mask_overlap': False, 'mask_min_space': 

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
model.to(device)

ECAPA_TDNN(
  (feature_extract): UpstreamExpert(
    (model): WavLM(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU(approximate='none')
          )
          (1-4): 4 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
          (5-6): 2 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
        )
      )
      (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
      (dropout_input): Dropout(p=0.1, inplace=False)

In [9]:
speaker_dict = {}
for i in tqdm(range(len(vox_mix))):
    _, wav_1, wav_2, sp1, sp2, _ = vox_mix[i]
    
    if sp1 not in speaker_dict.keys():
        speaker_dict[sp1] = wav_1

    if sp2 not in speaker_dict.keys():
        speaker_dict[sp2] = wav_2

100%|██████████| 1000/1000 [00:03<00:00, 325.56it/s]


In [10]:
print(speaker_dict.keys())

dict_keys(['id00812', 'id01618', 'id01822', 'id00061', 'id03524', 'id01567', 'id02181', 'id00419', 'id01892', 'id00866', 'id02577', 'id02542', 'id02086', 'id01460', 'id02445', 'id01041', 'id00081', 'id01298', 'id00926', 'id02317', 'id00154', 'id01333', 'id02745', 'id02548', 'id01106', 'id00562', 'id02725', 'id03127', 'id02019', 'id02685', 'id01224', 'id01000', 'id01437', 'id01593', 'id00017', 'id01989', 'id03382', 'id03178', 'id02057', 'id02576', 'id02286', 'id03030', 'id00817', 'id01509', 'id02465', 'id03347', 'id01541', 'id03041', 'id01066', 'id01228'])


In [11]:
speaker_list = list(speaker_dict.keys())

In [12]:
embed_matrix = torch.zeros((len(speaker_dict), 256))

In [13]:
speaker_id = {}

In [14]:
for idx, kv in enumerate(speaker_dict.items()):
    k, wav = kv
    
    wav = torch.stack([wav, wav], dim=0)
    wav = wav.to(device)
    wav = wav.squeeze_(1)
    
    print(wav)
    embed = model(wav)[0, ...]
    
    print(embed[:10])
    
    print("!!!!!")
    embed_matrix[idx, :] = embed
    
    speaker_id[idx] = speaker_list.index(k)

tensor([[1.9165e-02, 1.9135e-02, 1.7029e-02,  ..., 6.1035e-05, 6.1035e-05,
         6.1035e-05],
        [1.9165e-02, 1.9135e-02, 1.7029e-02,  ..., 6.1035e-05, 6.1035e-05,
         6.1035e-05]], device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.8885e-04, -2.3666e-03,
        -7.4487e-04,  1.5126e-03,  2.2123e-03,  7.0878e-03, -8.3986e-05],
       device='cuda:0', grad_fn=<SliceBackward0>)
!!!!!
tensor([[-0.0110, -0.0226, -0.0211,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0110, -0.0226, -0.0211,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.8885e-04, -2.3666e-03,
        -7.4487e-04,  1.5126e-03,  2.2123e-03,  7.0878e-03, -8.3986e-05],
       device='cuda:0', grad_fn=<SliceBackward0>)
!!!!!
tensor([[ 0.0388,  0.0301, -0.0053,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0388,  0.0301, -0.0053,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.88

In [15]:
embed_matrix

tensor([[-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        ...,
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019]],
       grad_fn=<CopySlices>)

In [17]:
labels = []
pred_labels = []

In [None]:
for i in tqdm(range(n_samples)):
    sample_idx = random.randint(0, len(vox_mix) - 1)
    wav_mix, wav_s1, wav_s2, sp1, sp2, _ = vox_mix[sample_idx]
    
    min_sz = min(wav_s1.shape[-1], wav_s2.shape[-1])
    
    if wav_s1.shape[-1] > min_sz:
        wav_s1 = wav_s1[:, :min_sz]
    if wav_s2.shape[-1] > min_sz:
        wav_s2 = wav_s2[:, :min_sz]

    pred_s1, pred_s2 = speaker_separation(sepformer, wav_mix)

    
    gt = torch.cat([wav_s1, wav_s2], dim=0).numpy()
    pred = torch.cat([pred_s1, pred_s2], dim=0).numpy()
    
    min_sz = min(gt.shape[-1], pred.shape[-1])
    
    if gt.shape[-1] > min_sz:
        gt = gt[:, :min_sz]
    if pred.shape[-1] > min_sz:
        pred = pred[:, :min_sz]
        
    
    embeds = model(pred)
    
    cos_sim = F.cosine_similarity(embeds, embed_matrix, dim=1)
    
    cos_sim = cos_sim.cpu().numpy()
    
    idx = cos_sim.argmax()
    
    labels.append(speaker_id[idx])
    pred_labels.append(sp1)
    


    

In [24]:
print(acc(model, [labels, pred_labels]))

0.7204


finetuned

In [10]:
replace_linear_with_lora(model, device)

In [11]:
model.load_state_dict(torch.load('models/wavlm_base_lora_finetune.pth'), strict=False)

_IncompatibleKeys(missing_keys=['feature_weight', 'feature_extract.model.mask_emb', 'feature_extract.model.feature_extractor.conv_layers.0.0.weight', 'feature_extract.model.feature_extractor.conv_layers.0.2.weight', 'feature_extract.model.feature_extractor.conv_layers.0.2.bias', 'feature_extract.model.feature_extractor.conv_layers.1.0.weight', 'feature_extract.model.feature_extractor.conv_layers.2.0.weight', 'feature_extract.model.feature_extractor.conv_layers.3.0.weight', 'feature_extract.model.feature_extractor.conv_layers.4.0.weight', 'feature_extract.model.feature_extractor.conv_layers.5.0.weight', 'feature_extract.model.feature_extractor.conv_layers.6.0.weight', 'feature_extract.model.post_extract_proj.weight', 'feature_extract.model.post_extract_proj.bias', 'feature_extract.model.encoder.pos_conv.0.bias', 'feature_extract.model.encoder.pos_conv.0.weight_g', 'feature_extract.model.encoder.pos_conv.0.weight_v', 'feature_extract.model.encoder.layers.0.self_attn.grep_a', 'feature_ext

In [21]:
for idx, kv in enumerate(speaker_dict.items()):
    k, wav = kv
    
    wav = torch.stack([wav, wav], dim=0)
    wav = wav.to(device)
    wav = wav.squeeze_(1)
    
    print(wav)
    embed = model(wav)[0, ...]
    
    print(embed[:10])
    
    print("!!!!!")
    embed_matrix[idx, :] = embed
    
    speaker_id[idx] = speaker_list.index(k)

tensor([[1.9165e-02, 1.9135e-02, 1.7029e-02,  ..., 6.1035e-05, 6.1035e-05,
         6.1035e-05],
        [1.9165e-02, 1.9135e-02, 1.7029e-02,  ..., 6.1035e-05, 6.1035e-05,
         6.1035e-05]], device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.8885e-04, -2.3666e-03,
        -7.4487e-04,  1.5126e-03,  2.2123e-03,  7.0878e-03, -8.3986e-05],
       device='cuda:0', grad_fn=<SliceBackward0>)
!!!!!
tensor([[-0.0110, -0.0226, -0.0211,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0110, -0.0226, -0.0211,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.8885e-04, -2.3666e-03,
        -7.4487e-04,  1.5126e-03,  2.2123e-03,  7.0878e-03, -8.3986e-05],
       device='cuda:0', grad_fn=<SliceBackward0>)
!!!!!
tensor([[ 0.0388,  0.0301, -0.0053,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0388,  0.0301, -0.0053,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
tensor([-3.1662e-03,  9.9634e-04,  2.0860e-03,  4.88

In [22]:
embed_matrix

tensor([[-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        ...,
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019],
        [-0.0032,  0.0010,  0.0021,  ..., -0.0016, -0.0004, -0.0019]],
       grad_fn=<CopySlices>)

In [23]:
labels = []
pred_labels = []

In [None]:
for i in tqdm(range(n_samples)):
    sample_idx = random.randint(0, len(vox_mix) - 1)
    wav_mix, wav_s1, wav_s2, sp1, sp2, _ = vox_mix[sample_idx]
    
    min_sz = min(wav_s1.shape[-1], wav_s2.shape[-1])
    
    if wav_s1.shape[-1] > min_sz:
        wav_s1 = wav_s1[:, :min_sz]
    if wav_s2.shape[-1] > min_sz:
        wav_s2 = wav_s2[:, :min_sz]

    pred_s1, pred_s2 = speaker_separation(sepformer, wav_mix)

    
    gt = torch.cat([wav_s1, wav_s2], dim=0).numpy()
    pred = torch.cat([pred_s1, pred_s2], dim=0).numpy()
    
    min_sz = min(gt.shape[-1], pred.shape[-1])
    
    if gt.shape[-1] > min_sz:
        gt = gt[:, :min_sz]
    if pred.shape[-1] > min_sz:
        pred = pred[:, :min_sz]
        
    
    embeds = model(pred)
    
    cos_sim = F.cosine_similarity(embeds, embed_matrix, dim=1)
    
    cos_sim = cos_sim.cpu().numpy()
    
    idx = cos_sim.argmax()
    
    labels.append(speaker_id[idx])
    pred_labels.append(sp1)
    


    

In [None]:
print(acc(model, [labels, pred_labels]))

0.7961
