In [4]:
import os, sys
import numpy as np
import torch
import torchaudio
import matplotlib.pyplot as plt
import IPython.display as ipd
from asteroid.metrics import get_metrics
from pprint import pprint
import time
import pickle
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '3'
torch.cuda.empty_cache()

# Get the pretrained models
print("Pre-trained models available:")
for model_name in os.listdir('../../pretrained_models'):
    print(model_name)
    
def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
    mean = wav_tensor.mean(-1, keepdim=True)
    if std is None:
        std = wav_tensor.std(-1, keepdim=True)
    return (wav_tensor - mean) / (std + eps)
    
anechoic_model_p = '../../pretrained_models/GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt'
anechoic_model_p = '../../pretrained_models/Improved_Sudormrf_U16_Bases512_WSJ02mix.pt'
anechoic_model_p = '../../pretrained_models/Improved_Sudormrf_U36_Bases2048_WSJ02mix.pt'
noisy_reverberant_model_p = '../../pretrained_models/Improved_Sudormrf_U16_Bases2048_WHAMRexclmark.pt'
noisy_reverberant_model_p = '../../pretrained_models/Improved_Sudormrf_U36_Bases4096_WHAMRexclmark.pt'

# Load the appropriate class modules
sys.path.append("../../")
import sudo_rm_rf.dnn.models.improved_sudormrf as improved_sudormrf
import sudo_rm_rf.dnn.models.groupcomm_sudormrf_v2 as sudormrf_gc_v2
import sudo_rm_rf.dnn.models.sepformer as sepformer
from speechbrain.pretrained import SepformerSeparation as sep_former_separator


Pre-trained models available:
Improved_Sudormrf_U16_Bases2048_WHAMRexclmark.pt
improved_sudo_epoch_500
GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt
Improved_Sudormrf_U16_Bases512_WSJ02mix.pt
Improved_Sudormrf_U36_Bases4096_WHAMRexclmark.pt
Improved_Sudormrf_U36_Bases2048_WSJ02mix.pt
.gitattributes


In [5]:
# get all files for wham or whamr!
whamr_test_folder_path = '/mnt/data/whamr/wav8k/min/tt'
wsj02mix_test_file_names = os.listdir(os.path.join(whamr_test_folder_path, 'mix_clean_anechoic')) 
whamrexcl_test_file_names = os.listdir(os.path.join(whamr_test_folder_path, 'mix_both_reverb')) 
wsj02mix_test_file_names = [os.path.join(whamr_test_folder_path, 'mix_clean_anechoic',name)
                            for name in wsj02mix_test_file_names]
whamrexcl_test_file_names = [os.path.join(whamr_test_folder_path, 'mix_both_reverb',name)
                             for name in wsj02mix_test_file_names]

def get_tensors_for_chosen_file(chosen_mixture_path):
    mixture, _ = torchaudio.load(chosen_mixture_path)
    chosen_filename = os.path.basename(chosen_mixture_path)
    ground_truth_sources = torch.tensor(np.array([
        torchaudio.load(os.path.join(whamr_test_folder_path,
                                     's1_anechoic', chosen_filename))[0].detach().numpy()[0],
        torchaudio.load(os.path.join(whamr_test_folder_path,
                                     's2_anechoic', chosen_filename))[0].detach().numpy()[0]
    ]))
    
    return mixture[:, :56000], ground_truth_sources[:, :56000]
    
chosen_file = '446o030h_0.13806_444c020w_-0.13806.wav'
chosen_mixture_path = os.path.join(whamr_test_folder_path, 'mix_clean_anechoic', chosen_file)
print(get_tensors_for_chosen_file(chosen_mixture_path)[0].shape,
get_tensors_for_chosen_file(chosen_mixture_path)[1].shape)

torch.Size([1, 39001]) torch.Size([2, 39001])


In [6]:
models_to_eval = [
#     {
#         'model_path': '../../pretrained_models/GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt',
#         'is_sudo_model': True,
#         'test_dataset': "WSJ02mix",
#     },
#     {
#         'model_path': '../../pretrained_models/Improved_Sudormrf_U16_Bases512_WSJ02mix.pt',
#         'is_sudo_model': True,
#         'test_dataset': "WSJ02mix",
#     },
#     {
#         'model_path': '../../pretrained_models/Improved_Sudormrf_U36_Bases2048_WSJ02mix.pt',
#         'is_sudo_model': True,
#         'test_dataset': "WSJ02mix",
#     },
#     {
#         'model_path': '../../pretrained_models/Improved_Sudormrf_U16_Bases2048_WHAMRexclmark.pt',
#         'is_sudo_model': True,
#         'test_dataset': "WHAMR!",
#     },
#     {
#         'model_path': '../../pretrained_models/Improved_Sudormrf_U36_Bases4096_WHAMRexclmark.pt',
#         'is_sudo_model': True,
#         'test_dataset': "WHAMR!",
#     },
    {
        'model_path': None,
        'is_sudo_model': False,
        'test_dataset': "WSJ02mix",
    }
]

def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
    mean = wav_tensor.mean(-1, keepdim=True)
    if std is None:
        std = wav_tensor.std(-1, keepdim=True)
    return (wav_tensor - mean) / (std + eps)

def get_model(model_info, is_gpu=False):
    if model_info['model_path'] is None:
        model = sep_former_separator.from_hparams(source="speechbrain/sepformer-wsj02mix", 
                                       savedir='pretrained_models/sepformer-wsj02mix',
                                       run_opts={"device":"cuda"})
        model_name = "Sepformer"
    else:
        model = torch.load(model_info['model_path'])
        model_name = model_info['model_path'].split('/')[-1]
    return model, model_name

results_dic = {}

for model_info in models_to_eval:
    model, model_name = get_model(model_info, is_gpu=True)
    model.cuda()
    print("======================")
    print(f"Evaluating model: {model_name}")
    
    results_dic[model_name] = {}
    
    if model_info['test_dataset'] == "WSJ02mix":
        mixture_paths = wsj02mix_test_file_names
    else:
        mixture_paths = whamrexcl_test_file_names
    
    for chosen_mixture_path in tqdm(mixture_paths):
        input_mix, gt_sources = get_tensors_for_chosen_file(chosen_mixture_path)
        input_mix = input_mix.cuda()
        if model_info['is_sudo_model']:
            input_mix_std = input_mix.std(-1, keepdim=True)
            input_mix_mean = input_mix.mean(-1, keepdim=True)
            input_mix = (input_mix - input_mix_mean) / (input_mix_std + 1e-9)
            est_sources = model(input_mix.unsqueeze(1))
        else:
            est_sources = model(input_mix).permute(0, 2, 1)
            
        try:
            all_metrics_dic = get_metrics(
                input_mix.detach().cpu().numpy(),
                normalize_tensor_wav(gt_sources).detach().cpu().numpy(),
                normalize_tensor_wav(est_sources[0]).detach().cpu().numpy(),
                compute_permutation=True, sample_rate=8000, metrics_list='all')
        except Exception as e:
            print(e)
            continue
            
        for k, v in all_metrics_dic.items():
            if k not in results_dic[model_name]:
                results_dic[model_name][k] = [v]
            else:
                results_dic[model_name][k].append(v)
        if 'sisdri' in results_dic[model_name]:
            results_dic[model_name]['sisdri'].append(all_metrics_dic['si_sdr'] - all_metrics_dic['input_si_sdr'])
        else:
            results_dic[model_name]['sisdri'] = [all_metrics_dic['si_sdr'] - all_metrics_dic['input_si_sdr']]
            
    with open(f'{model_name}_sep_perf_models.pickle', 'wb') as handle:
        pickle.dump(results_dic, handle, protocol=pickle.HIGHEST_PROTOCOL)
    del model
    
pprint(results_dic)

  0%|          | 0/3000 [00:00<?, ?it/s]

Evaluating model: Sepformer


100%|██████████| 3000/3000 [52:28<00:00,  1.05s/it]


{'Sepformer': {'input_pesq': [1.6766878366470337,
                              1.7483115196228027,
                              1.562520682811737,
                              1.9775598645210266,
                              1.843642234802246,
                              1.6365641355514526,
                              1.5278839468955994,
                              1.7601330280303955,
                              1.5053852796554565,
                              1.5987656116485596,
                              1.6923589706420898,
                              1.5942891240119934,
                              1.6900184750556946,
                              1.5953938961029053,
                              1.5112732648849487,
                              1.5537854433059692,
                              1.6360238790512085,
                              1.664713442325592,
                              1.7904396653175354,
                              1.7848764061927795,
   

                              1.5911918878555298,
                              1.465947151184082,
                              1.6643107533454895,
                              1.6661397218704224,
                              1.7789828181266785,
                              1.7276665568351746,
                              1.798797607421875,
                              1.757010042667389,
                              1.5390795469284058,
                              1.6023471355438232,
                              1.9343315362930298,
                              1.7190608978271484,
                              1.6926793456077576,
                              1.818836748600006,
                              1.7318881750106812,
                              1.7854704856872559,
                              1.8855445384979248,
                              1.6471934914588928,
                              1.7983675599098206,
                              1.8591959476470947,
    

                             58.25514683341339,
                             48.78255603973078,
                             55.218695619936526,
                             55.02741654215838,
                             58.03688588455261,
                             62.22118910436211,
                             60.42563769559278,
                             73.69087154325311,
                             52.707723909466765,
                             58.02001386302284,
                             68.07843908421998,
                             66.86278627748722,
                             59.93248956574663,
                             67.83172090051647,
                             57.528975594032325,
                             54.35679163550055,
                             57.649191490032194,
                             80.79068763455571,
                             61.2115503552602,
                             71.03111976903412,
                             67.86104

                             0.10979137292226993,
                             0.1646058733235749,
                             0.6031191680743773,
                             0.19569267650086133,
                             0.37495394676678195,
                             0.05853891723712945,
                             -0.05049906330065812,
                             0.3486818895363444,
                             0.19242575018036245,
                             0.07310496007108719,
                             0.1911305583360281,
                             0.2839646459727283,
                             0.16801523071127744,
                             0.10836469044016406,
                             0.11540710952172173,
                             0.1470933547761314,
                             -0.03119034549987565,
                             -0.02721425194322835,
                             0.07389737358187376,
                             0.2757773689525004,
    

                                -0.1153261661529541,
                                -0.025089144706726074,
                                -0.12401127815246582,
                                -0.09855914860963821,
                                0.0469965934753418,
                                0.040952444076538086,
                                -0.07705375552177429,
                                -0.050697386264801025,
                                -0.02238941192626953,
                                -0.034198641777038574,
                                0.003329724073410034,
                                0.01146206259727478,
                                0.022200554609298706,
                                0.08928006887435913,
                                0.08370423316955566,
                                -0.06654071807861328,
                                0.0014142096042633057,
                                0.1855149269104004,
                                

                             0.37089084461614963,
                             0.35628381298626904,
                             0.6991521241968592,
                             0.18231641717660407,
                             0.3636783923046096,
                             0.004727597155003704,
                             0.03982576529778359,
                             -0.052436921802536576,
                             0.09161780248525919,
                             0.04456685344917932,
                             0.10997039495204075,
                             0.25756147403713103,
                             0.08018189156743466,
                             0.21807241547557843,
                             0.06551129264817804,
                             0.2453356624360008,
                             0.03730923508853756,
                             0.017353427713733716,
                             0.0894785627298269,
                             0.1099177107871736,
 

                              0.7804272777786505,
                              0.7366265705888755,
                              0.7758111618363899,
                              0.7156808084371409,
                              0.7206339706828946,
                              0.7146856146254663,
                              0.7062104641613249,
                              0.7372758801210777,
                              0.6960119512631342,
                              0.705209032820522,
                              0.7950079219003081,
                              0.7669969540951611,
                              0.7194966905832125,
                              0.7871603509713918,
                              0.745746839134828,
                              0.7109915754211448,
                              0.7551454455520077,
                              0.7419379532346044,
                              0.735232674660748,
                              0.7218802491298982,
   

                        3.9918991327285767,
                        3.9821527004241943,
                        3.983210802078247,
                        4.095175623893738,
                        4.015233039855957,
                        4.142617225646973,
                        4.028347730636597,
                        4.1027045249938965,
                        4.087995529174805,
                        4.076499700546265,
                        4.076997756958008,
                        4.097927808761597,
                        4.129901885986328,
                        4.029643774032593,
                        4.087713837623596,
                        4.068528175354004,
                        4.016219615936279,
                        4.077316761016846,
                        3.8880486488342285,
                        4.025991201400757,
                        3.9768654108047485,
                        3.9768069982528687,
                        4.169036388397217,
     

                       21.24006091726772,
                       24.705117518448283,
                       21.997931080646254,
                       20.982676664714784,
                       24.112314222528262,
                       24.284313069710294,
                       22.897491539192366,
                       22.744715012777455,
                       22.66217138993614,
                       23.25137047720841,
                       23.01725643034232,
                       22.71488190562235,
                       25.202301422922197,
                       24.098279335875404,
                       26.478368147728496,
                       25.484089544113804,
                       24.79524835075693,
                       25.086218692554418,
                       24.074450140878472,
                       25.061547893074398,
                       24.09109490474194,
                       23.86794747683117,
                       22.384543894270244,
                   

                       21.676146165135528,
                       22.95366020902363,
                       25.032194190581542,
                       22.76603053397033,
                       19.13363077331129,
                       22.35819092809489,
                       23.58000318232655,
                       24.138698141754354,
                       19.097826281646064,
                       22.88437498368247,
                       23.00047433377508,
                       24.254719569285633,
                       21.86818960902083,
                       22.9081355566624,
                       21.086782441097647,
                       25.096058149030505,
                       22.37015065939177,
                       21.632880510227203,
                       23.45821742930859,
                       23.402798906769362,
                       24.05569076655047,
                       22.834301808152027,
                       22.54162799521208,
                       22

                          23.72288227081299,
                          22.476022720336914,
                          23.19374179840088,
                          23.601656913757324,
                          20.243867874145508,
                          23.753806114196777,
                          23.992904663085938,
                          24.727031707763672,
                          22.138955116271973,
                          23.965737342834473,
                          23.86266326904297,
                          21.911300659179688,
                          23.42865562438965,
                          23.271307945251465,
                          23.278160095214844,
                          23.066304206848145,
                          24.044997215270996,
                          23.019139289855957,
                          24.420137405395508,
                          23.73466205596924,
                          23.54279613494873,
                          22.00745773315

                       23.69252762299811,
                       35.27565781433988,
                       34.46688876808203,
                       31.742690477751555,
                       35.92258847199304,
                       33.32846385055387,
                       33.03790515607262,
                       32.626685761105875,
                       36.72032702720795,
                       35.14524519400502,
                       19.880819877123116,
                       34.19240151314088,
                       36.738015766884814,
                       36.59728276153763,
                       34.78641290005184,
                       37.73697189895415,
                       28.62349540465251,
                       35.86289629176028,
                       35.18242853081189,
                       34.60595026996991,
                       53.00879261563289,
                       35.32839453511549,
                       35.53695021846114,
                       33.7878

                          23.125616177916527,
                          24.40843641757965,
                          20.593679904937744,
                          21.742018938064575,
                          21.115801334381104,
                          23.416221976280212,
                          24.491293907165527,
                          24.395819187164307,
                          23.12392497062683,
                          21.9000985622406,
                          20.42524242401123,
                          24.07942485809326,
                          22.156871914863586,
                          20.34332275390625,
                          21.747224904596806,
                          24.442603776231408,
                          21.98066806793213,
                          22.806352138519287,
                          21.723811149597168,
                          22.485957823693752,
                          22.931173130869865,
                          23.9934250339865

                        0.987150089352407,
                        0.9667626366032382,
                        0.9434615561308088,
                        0.9921281149148636,
                        0.9898853819243827,
                        0.9896895890572085,
                        0.9890787639931272,
                        0.9899957735571845,
                        0.9905436530525132,
                        0.9970672279006145,
                        0.9875598003793247,
                        0.9466117066015354,
                        0.9946073341076687,
                        0.9933128127447561,
                        0.9848613869110745,
                        0.9935928419465845,
                        0.9937901792620601,
                        0.9893619474512025,
                        0.9711173508883427,
                        0.9923990412831301,
                        0.9910227823197226,
                        0.9942587838849916,
                        0.9940009

In [23]:
# Perform the analysis for all the files
from glob2 import glob
for result_file_p in glob("*.pickle"):
    model_name = os.path.basename(result_file_p)
    print("--======= \n", model_name)
    with open(result_file_p, 'rb') as handle:
        this_res_dic = pickle.load(handle)
        
    for k, v in this_res_dic.items():
        acc = []
        for metric_name in ['sisdri', 'pesq', 'stoi']:
            metric_values = v[metric_name]
            if metric_name == "stoi":
                acc.append(round(np.array(metric_values).mean(), 3))
            if metric_name == "pesq":
                acc.append(round(np.array(metric_values).mean(), 2))
            if metric_name == "sisdri":
                acc.append(round(np.array(metric_values).mean(), 1))
        print(" | ".join([str(y) for y in acc]))
#         for metric_name, metric_values in v.items():
#             if metric_name == "stoi":
#                 print(metric_name, round(np.array(metric_values).mean(), 3))
#             if metric_name == "pesq":
#                 print(metric_name, round(np.array(metric_values).mean(), 2))
#             if metric_name == "sisdri":
#                 print(metric_name, round(np.array(metric_values).mean(), 1))

 Sepformer_sep_perf_models.pickle
22.5 | 4.0 | 0.988
 Improved_Sudormrf_U16_Bases512_WSJ02mix.pt_sep_perf_models.pickle
17.3 | 3.43 | 0.969
 GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt_sep_perf_models.pickle
13.1 | 2.8 | 0.937
 Improved_Sudormrf_U36_Bases4096_WHAMRexclmark.pt_sep_perf_models.pickle
13.6 | 3.04 | 0.949
 Improved_Sudormrf_U36_Bases2048_WSJ02mix.pt_sep_perf_models.pickle
19.5 | 3.73 | 0.979
 Improved_Sudormrf_U16_Bases2048_WHAMRexclmark.pt_sep_perf_models.pickle
12.1 | 2.79 | 0.933
