In [1]:
import os, sys 
sys.path.append("../")
import torch 
import datetime
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
from argparse import Namespace
import torch.nn as nn 
import torch.nn.functional as F
import librosa
import argparse
import scipy 
from matplotlib.pyplot import cm
from tqdm import tqdm
import time
import comet_ml
import pandas as pd
from glob2 import glob
import copy
import IPython

import baseline.dataset_loaders.chime as chime
import baseline.dataset_loaders.libri1to3chime as libri1to3chime
import baseline.utils.mixture_consistency as mixture_consistency
import baseline.models.improved_sudormrf as improved_sudormrf
import baseline.metrics.dnnmos_metric as dnnmos_metric
import baseline.metrics.sisdr_metric as sisdr_metric

import pickle

# from __config__ import *
# plt.style.use('science')
# plt.style.use(['science','ieee','no-latex'])
# plt.style.reload_library()
# plt.style.use(['science', 'ieee'])

# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(["0", "1"])
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(["2"])



# Dev and Eval sets of CHiME5 (single-speaker)

In [45]:
batch_size = 1
sample_rate = 16000
timelength = 4.
fixed_n_sources = 1
get_only_active_speakers = False
use_vad = False
random_order = True
n_samples = 250
time_samples = int(sample_rate * timelength)
model_type = 'teacher'

data_loader = chime.Dataset(
    sample_rate=sample_rate, fixed_n_sources=fixed_n_sources,
    timelength=timelength, augment=random_order,
    zero_pad=True, split='dev', get_only_active_speakers=get_only_active_speakers,
    normalize_audio=False, n_samples=n_samples, use_vad=use_vad)

val_chime_gen = data_loader.get_generator(
    batch_size=batch_size, num_workers=1) 

data_loader = chime.Dataset(
    sample_rate=sample_rate, fixed_n_sources=fixed_n_sources,
    timelength=timelength, augment=random_order,
    zero_pad=True, split='eval', get_only_active_speakers=get_only_active_speakers,
    normalize_audio=False, n_samples=n_samples, use_vad=use_vad)

eval_chime_gen = data_loader.get_generator(
    batch_size=batch_size, num_workers=1) 

def get_new_teacher(hparams, depth_growth):
    student = improved_sudormrf.SuDORMRF(
        out_channels=hparams["out_channels"],
        in_channels=hparams["in_channels"],
        num_blocks=int(depth_growth * hparams["num_blocks"]),
        upsampling_depth=hparams["upsampling_depth"],
        enc_kernel_size=hparams["enc_kernel_size"],
        enc_num_basis=hparams["enc_num_basis"],
        num_sources=2,
    )
    return student

hparams = {
    'out_channels': 256,
    'in_channels': 512,
    'num_blocks': 8,
    'upsampling_depth': 7,
    'enc_kernel_size': 81,
    'enc_num_basis': 512,
}

def get_new_student(hparams, depth_growth):
    student = improved_sudormrf.SuDORMRF(
        out_channels=hparams["out_channels"],
        in_channels=hparams["in_channels"],
        num_blocks=int(depth_growth * hparams["num_blocks"]),
        upsampling_depth=hparams["upsampling_depth"],
        enc_kernel_size=hparams["enc_kernel_size"],
        enc_num_basis=hparams["enc_num_basis"],
        num_sources=2,
    )
    return student

hparams = {
    'out_channels': 256,
    'in_channels': 512,
    'num_blocks': 8,
    'upsampling_depth': 7,
    'enc_kernel_size': 81,
    'enc_num_basis': 512,
}

if model_type == 'student':
    w_mix_con_chkpt = '../pretrained_checkpoints/chime_adapted_remixit_student_w_mixconsist.pt'
else:
    w_mix_con_chkpt = "../pretrained_checkpoints/libri1to3mix_supervised_teacher_w_mixconsist.pt"

w_mix_con_model = get_new_teacher(hparams, depth_growth=1)
w_mix_con_model.load_state_dict(torch.load(w_mix_con_chkpt))
#wo_mix_con_model = wo_mix_con_model.cuda()

<All keys matched successfully>

## CHiME (single-speaker)

In [46]:
# Extract proper numbers for val and test for DNS-MOS
eval_res_dic = {'sig_mos': [], 'bak_mos': [], 'ovr_mos': []}
w_mix_con_model.eval()

# Eval set
for cnt, mixture in tqdm(enumerate(eval_chime_gen)):
    input_mix = mixture.unsqueeze(1) #.cuda()
    input_mix_std = input_mix.std(-1, keepdim=True)
    input_mix_mean = input_mix.mean(-1, keepdim=True)
    input_mix = (input_mix - input_mix_mean) / (input_mix_std + 1e-9)

    with torch.no_grad():
        rec_sources_wavs = w_mix_con_model(input_mix) 
        rec_sources_wavs = mixture_consistency.apply(rec_sources_wavs, input_mix)
        
        new_mix = rec_sources_wavs[:, 0:1] + rec_sources_wavs[:, 1:]
        new_mix_std = new_mix.std(-1, keepdim=True)
        new_mix_mean = new_mix.mean(-1, keepdim=True)
        rec_sources_wavs = (rec_sources_wavs - new_mix_mean) / (new_mix_std + 1e-9)
        
#         rec_sources_wavs = (rec_sources_wavs * new_mix_std) + new_mix_mean
        teacher_est_active_speakers = rec_sources_wavs[:, 0:1].detach().cpu()
        teacher_est_noises = rec_sources_wavs[:, 1:].detach().cpu()

    mix, est_s, est_n = input_mix[0, 0].cpu().numpy(), teacher_est_active_speakers[0, 0].cpu().numpy(), teacher_est_noises[0, 0].cpu().numpy()
    
    # Msasure the DNSMOS
    dnsmos_val = dnnmos_metric.compute_dnsmos(est_s, fs=16000)
    for k, v in dnsmos_val.items():
        eval_res_dic[k].append(v)
    
    if cnt > n_samples:
        break

val_res_dic = {'sig_mos': [], 'bak_mos': [], 'ovr_mos': []}
wo_mix_con_model.eval()

# Dev set
for cnt, mixture in tqdm(enumerate(val_chime_gen)):
    input_mix = mixture.unsqueeze(1) #.cuda()
    input_mix_std = input_mix.std(-1, keepdim=True)
    input_mix_mean = input_mix.mean(-1, keepdim=True)
    input_mix = (input_mix - input_mix_mean) / (input_mix_std + 1e-9)

    with torch.no_grad():
        rec_sources_wavs = w_mix_con_model(input_mix)
        rec_sources_wavs = mixture_consistency.apply(rec_sources_wavs, input_mix)
        
        new_mix = rec_sources_wavs[:, 0:1] + rec_sources_wavs[:, 1:]
        new_mix_std = new_mix.std(-1, keepdim=True)
        new_mix_mean = new_mix.mean(-1, keepdim=True)
        rec_sources_wavs = (rec_sources_wavs - new_mix_mean) / (new_mix_std + 1e-9)
        
#         rec_sources_wavs = (rec_sources_wavs * new_mix_std) + new_mix_mean
        teacher_est_active_speakers = rec_sources_wavs[:, 0:1].detach().cpu()
        teacher_est_noises = rec_sources_wavs[:, 1:].detach().cpu()

    mix, est_s, est_n = input_mix[0, 0].cpu().numpy(), teacher_est_active_speakers[0, 0].cpu().numpy(), teacher_est_noises[0, 0].cpu().numpy()
    
    # Msasure the DNSMOS
    dnsmos_val = dnnmos_metric.compute_dnsmos(est_s, fs=16000)
    for k, v in dnsmos_val.items():
        val_res_dic[k].append(v)
    
    if cnt > n_samples:
        break



  tensor_wav = torch.tensor(


1it [00:01,  1.43s/it][A[A

2it [00:02,  1.38s/it][A[A

3it [00:03,  1.34s/it][A[A

4it [00:04,  1.23s/it][A[A

5it [00:06,  1.23s/it][A[A

6it [00:07,  1.23s/it][A[A

7it [00:08,  1.24s/it][A[A

8it [00:09,  1.14s/it][A[A

9it [00:10,  1.16s/it][A[A

10it [00:11,  1.18s/it][A[A

11it [00:13,  1.19s/it][A[A

12it [00:14,  1.20s/it][A[A

13it [00:15,  1.20s/it][A[A

14it [00:16,  1.21s/it][A[A

15it [00:18,  1.21s/it][A[A

16it [00:19,  1.15s/it][A[A

17it [00:20,  1.08s/it][A[A

18it [00:21,  1.14s/it][A[A

19it [00:22,  1.16s/it][A[A

20it [00:23,  1.18s/it][A[A

21it [00:24,  1.19s/it][A[A

22it [00:26,  1.20s/it][A[A

23it [00:27,  1.21s/it][A[A

24it [00:28,  1.19s/it][A[A

25it [00:29,  1.13s/it][A[A

26it [00:30,  1.06s/it][A[A

27it [00:31,  1.11s/it][A[A

28it [00:32,  1.15s/it][A[A

29it [00:34,  1.17s/it][A[A

30it [00:35,  1.10s/it][A[A

31it [00:35,  1.04s/it][A[A

32it [00:36,  1

KeyboardInterrupt: 

In [16]:
for name, this_dic in [('val', val_res_dic), ('eval', eval_res_dic)]:
    print(name)
    for k, v in this_dic.items():
        print(k, np.median(v))

val
sig_mos 3.1403073230397824
bak_mos 0.24783729047514014
ovr_mos 2.0211982832454853
eval
sig_mos 3.0753744058996473
bak_mos 0.2666276682215719
ovr_mos 2.0101867086121867


# Dev and Eval sets of LibriCHiME-5 

In [22]:
batch_size = 1
sample_rate = 16000
timelength = 4.0
fixed_n_sources = -1
split = 'dev'
min_or_max = 'min'
n_speakers_priors = [0.50, 0.25, 0.25]
time_samples = int(sample_rate * timelength)
random_order = True

hparams = {
    "rescale_to_input_mixture": False,
    "apply_mixture_consistency": True,
}

data_loader = libri1to3chime.Dataset(
    sample_rate=sample_rate, fixed_n_sources=fixed_n_sources,
    timelength=timelength, augment=random_order,
    zero_pad=True, min_or_max=min_or_max, split='dev',
    normalize_audio=False, n_samples=-1,
    n_speakers_priors=n_speakers_priors)

dev_libriChime_gen = data_loader.get_generator(
    batch_size=batch_size, num_workers=1)

data_loader = libri1to3chime.Dataset(
    sample_rate=sample_rate, fixed_n_sources=fixed_n_sources,
    timelength=timelength, augment=random_order,
    zero_pad=True, min_or_max=min_or_max, split='test',
    normalize_audio=False, n_samples=-1,
    n_speakers_priors=n_speakers_priors)

eval_libriChime_gen = data_loader.get_generator(
    batch_size=batch_size, num_workers=1)

def apply_output_transform(rec_sources_wavs, input_mix_std,
                           input_mix_mean, input_mom, hparams):
    if hparams["rescale_to_input_mixture"]:
        rec_sources_wavs = (rec_sources_wavs * input_mix_std) + input_mix_mean
    if hparams["apply_mixture_consistency"]:
        rec_sources_wavs = mixture_consistency.apply(rec_sources_wavs, input_mom)
    return rec_sources_wavs

## LibriChime

In [30]:
eval_libriChime_res_dic = {'sig_mos': [], 'bak_mos': [], 'ovr_mos': [], 'si_sdr': [], 'si_sdri': []}
dev_libriChime_res_dic = {'sig_mos': [], 'bak_mos': [], 'ovr_mos': [], 'si_sdr': [], 'si_sdri': []}

w_mix_con_model.eval()

cnt = 0
n_samples = 5

# Eval set
for speakers, noise in tqdm(eval_libriChime_gen):
    gt_speaker_mix = speakers.sum(1, keepdims=True) #.cuda()
    # noise = noise.cuda()

    input_mix = noise + gt_speaker_mix
    input_mix_std = input_mix.std(-1, keepdim=True)
    input_mix_mean = input_mix.mean(-1, keepdim=True)
    input_mix = (input_mix - input_mix_mean) / (input_mix_std + 1e-9)

    with torch.no_grad():
        rec_sources_wavs = w_mix_con_model(input_mix)
        rec_sources_wavs = apply_output_transform(
            rec_sources_wavs, input_mix_std, input_mix_mean, input_mix, hparams)
        teacher_est_active_speakers = rec_sources_wavs[:, 0:1]
        teacher_est_noises = rec_sources_wavs[:, 1:]
    
    sisdr = sisdr_metric.compute_sisdr(
        teacher_est_active_speakers.cpu().numpy(), gt_speaker_mix.cpu().numpy())
    
    mix_sisdr = sisdr - sisdr_metric.compute_sisdr(
        input_mix.cpu().numpy(), gt_speaker_mix.cpu().numpy())
    
    eval_libriChime_res_dic['si_sdr'].append(sisdr)
    eval_libriChime_res_dic['si_sdri'].append(mix_sisdr)
    
    cnt += 1
    
    if cnt > n_samples:
        break
        
cnt = 0

# Dev set
for speakers, noise in tqdm(dev_libriChime_gen):
    gt_speaker_mix = speakers.sum(1, keepdims=True) #.cuda()
    # noise = noise.cuda()

    input_mix = noise + gt_speaker_mix
    input_mix_std = input_mix.std(-1, keepdim=True)
    input_mix_mean = input_mix.mean(-1, keepdim=True)
    input_mix = (input_mix - input_mix_mean) / (input_mix_std + 1e-9)

    with torch.no_grad():
        rec_sources_wavs = w_mix_con_model(input_mix)
        rec_sources_wavs = apply_output_transform(
            rec_sources_wavs, input_mix_std, input_mix_mean, input_mix, hparams)
        teacher_est_active_speakers = rec_sources_wavs[:, 0:1]
        teacher_est_noises = rec_sources_wavs[:, 1:]
    
    sisdr = sisdr_metric.compute_sisdr(
        teacher_est_active_speakers.cpu().numpy(), gt_speaker_mix.cpu().numpy())
    
    mix_sisdr = sisdr - sisdr_metric.compute_sisdr(
        input_mix.cpu().numpy(), gt_speaker_mix.cpu().numpy())
    
    dev_libriChime_res_dic['si_sdr'].append(sisdr)
    dev_libriChime_res_dic['si_sdri'].append(mix_sisdr)
    
    cnt += 1
    
    if cnt > n_samples:
        break



  tensor_wav = torch.tensor(


  0%|                                                                                                                                                                                                                                                                                                                         | 1/3000 [00:00<35:17,  1.42it/s][A[A

  0%|▏                                                                                                                                                                                                                                                                                                                        | 2/3000 [00:00<26:20,  1.90it/s][A[A

  0%|▎                                                                                                                                                                                                                                                   

In [31]:
eval_libriChime_res_dic

{'sig_mos': [],
 'bak_mos': [],
 'ovr_mos': [],
 'si_sdr': [19.673168659210205,
  6.154797673225403,
  19.270875453948975,
  10.007598400115967,
  21.264519691467285,
  5.16119122505188],
 'si_sdri': [5.310730934143066,
  6.618217974901199,
  6.380683183670044,
  13.534106314182281,
  8.796828985214233,
  9.86205279827118]}

In [33]:
dev_libriChime_res_dic

{'sig_mos': [],
 'bak_mos': [],
 'ovr_mos': [],
 'si_sdr': [11.347124576568604,
  21.56254291534424,
  3.950425386428833,
  23.122360706329346,
  7.49070405960083,
  15.223684310913086],
 'si_sdri': [5.2526432275772095,
  7.2285425662994385,
  4.450291655957699,
  10.25989294052124,
  10.8815136551857,
  2.763533592224121]}