In [None]:
import torch
from torch.utils.data import  DataLoader

# Dataset Construction

The notebook assumes the data uses the following file structure
- Data
    -  train
    -  val
    -  test
    -  all_ir (contains full set of irs used for data augmentation)

The train validation and test folders have the following subfolders
- ir: the irs used to augment the data within the directory
- mix_clean: two speaker mixtures
- mix_noise: two speaker mixtures with additive noise
- s1: first set of source speakers in the mixtures
- s2: second set of source speakers in the mixtures

The mixtures share filenames with the sources

## Impulse Response Datasets

The following functions and commented code can be used to create a set of training and evaluation impulse responses. IRs are copied to the ir sub folders. This method was picked to leverage the torch audio-mentation object ApplyImpulseResponse. impulse responses are randomly selected.

All IRs are from the MIT McDermott survey data https://mcdermottlab.mit.edu/Reverb/IR_Survey.html

Uncomment the following cell block to create directories

In [None]:
#from sourcesep.utilities.utils import split_ir_files, copy_ir_files

# train_ir_dir = './Data/train/ir'
# test_ir_dir = './Data/test/ir'
# val_ir_dir = './Data/val/ir'
# source_ir_dir = './Data/all_ir'

# ir_split = split_ir_files("./Data/all_ir", .15, .15)
# train_ir_names = ir_split['train']
# test_ir_names = ir_split['test']
# val_ir_names = ir_split['val']

# copy_ir_files(source_ir_dir,train_ir_dir,train_ir_names)

# copy_ir_files(source_ir_dir,test_ir_dir,test_ir_names)

# copy_ir_files(source_ir_dir,val_ir_dir,val_ir_names)

## Audio DataLoader

Pytorch dataset used to load the audio clip. Audio files are loaded as batches are called due to limitations in jupyterlab memory. Upon loading, random cropping is applied.

apply_ir = True will apply a random impulse response to the input data prior to fourier transformation.

In [None]:
from sourcesep.dataset import SourceSepDS

# Models

I approached the problem of noisy reverberant speaker speration in two stages. The first network attempts to simultaneously de-reverberate and de-noise an input mixture of two speakers. The second network splits the speaker mixture into channels containing the different speakers.

The aim is to assess whether a speech enhancement network can help the time domain source separator generalize to noisy reverberant audio. For the sake of project scope, the implementations are limited to single channel, two speaker mixtures.

## Complex SkipConvNet

The following network architecture called in the import is based on the U-Net architecture described in Kothapally et al., 2020. The network is a U-net with convolutional skip networks that process the intermediate representations by the encoder.

The primary difference between the paper and this network is that this implementation is a complex-valued U-Net that produces masks that directly apply the real and imaginary parts of the STFT coeficients. Complex batch normalization, convolution, transpose convolution, and activation functions are all implemented as described in Trabelsi et al. (2018)

The model also deviates from the design in the Kothapally paper in the output. The original network learns to synthesize an enhanced spectrogram, while in this implementation, the network produces two masks that are applied to the input STFT to scale the real and imaginary parts of the coefficients. The justification is that this decreases the amount of training that is needed to produce outputs that are meaningful to the source separator. The method for producing complex masks mimics Ephrat et al. (2018 although a tanh activation function is used instead of a sigmoid.

I chose to implement this model as a complex network as opposed to a real-valued network that is applied to spectrograms, as phase does not need to be discarded when training. This technique allows us to avoid the use of additional techniques to learn, reconstruct, or incorporate the phase from the noisy reverberant mixture into the spectrogram that would be produced by a real-valued network.

In [None]:
from sourcesep.enhancer.model import ComplexSkipConvNet

#network parameter sets

#large model with halved channel count/skip block count
enc_params = [{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([1,2]) ,'out_ch':32, 'sk_bl':8}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]) ,'out_ch':64, 'sk_bl':8}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]) ,'out_ch':128,'sk_bl':4}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':256,'sk_bl':4}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':256,'sk_bl':2}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':256,'sk_bl':2}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':256,'sk_bl':1}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':256,'sk_bl':0}]

dec_params = [{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]), 'op':torch.tensor([0,0]) ,'out_ch':256}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':256}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':256}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':256}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':128}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':64}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':32}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([1,0]), 'out_ch':1}]

skip_params = {'k':3, 'p':1, 's':1}

unet_params_large = {'enc':enc_params, 'dec': dec_params,'sk':skip_params}

#small network
enc_params = [{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([1,2]) ,'out_ch':16, 'sk_bl':8}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]) ,'out_ch':32, 'sk_bl':8}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]) ,'out_ch':32,'sk_bl':4}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':64,'sk_bl':4}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':64,'sk_bl':2}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':64,'sk_bl':2}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':128,'sk_bl':1}
              ,{'k':torch.tensor([5,5]), 's':torch.tensor([2,2]), 'p': torch.tensor([2,2]), 'out_ch':128,'sk_bl':0}]

dec_params = [{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]), 'op':torch.tensor([0,0]) ,'out_ch':128}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':64}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':64}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':64}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':32}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':32}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([0,0]), 'out_ch':16}
              ,{'k':torch.tensor([2,2]), 's':torch.tensor([2,2]),'op':torch.tensor([1,0]), 'out_ch':1}]

skip_params = {'k':3, 'p':1, 's':1}

unet_params = {'enc':enc_params, 'dec': dec_params,'sk':skip_params}

## Conv TasNet

The following code called in the import is a non-causal, two speaker specific implementation of the convolational TaS net proposed in (Luo, Mesgarani, 2018)

Network hyperparameters are taken directly from the small versions of the model that were assesed

In [None]:
from sourcesep.separator.model import ConvTasNet

## Joint Model

The two models are combined for into a single object for ease in joint training

In [None]:
from sourcesep.jointmodel import SourceSeparator

## Model Sizes

In [None]:
#  CTN: 1.5 million params , larger models cause cuda issues
# model = ConvTasNet()

# #Default CUNET: 3.4 million params
# model = ComplexSkipConvNet(unet_params)

# #default joint network 5.1 m params
# model = SourceSeparator(unet_params, sep_net=True) 

# Loss

The objective function used in training is the scale-invariant source-to-noise ratio described in Luo 2018. This objective is used to measure loss for both the source separator as well as the U-Net (measured using the sensitivity of the output).

As the scope of the project is limited to two speaker mixtures, the loss calculation can be grouped into two categories. 

The enhancer outputs a single channel, so SI-SNR is simply a measure of the error between the target signal and the predicted signal that accounts for the scale of the signal.

The source separator produces two channels for each input signal, which can lead to complications as the network would also need to learn the order of the target single speaker signals. Permutation invariance is accounted for by considering the parwise SI-SNR values between the predicted channels and the ground truth sources. The permutation of predictions and labels that maximizes the objective is used as the loss.

In [None]:
from sourcesep.loss import si_snr

# Train and Evaluation Loops

Two source separation systems were compared

The first uses both networks, the second system is just the tasnet

The joint system is trained in two stages: general training and then fine tuning. The first phase defines loss as the aggregate of the SI-SNR for the reconstructed signal produced by the u-net and the clean speaker mixture and the SI-SNR for the source separator and the clean speaker channels.

The approach aims to encourage the u-net to learn intermediate representations that are helpful in the speech enhancement task while also encouraging the network to learn the overarching goal of the system.

For fine tuning, all network layers of the u-net besides the output layer are frozen. The loss term is just the SI-SNR between the source separator output and the clean single speaker sources.

The aim of the second phase is to maximize the performance of the final output of the system and allow the u-net to learn high-level changes to its output that benefit overall source separation performance.

A tasnet is also trained as a baseline for the same number of epochs as the joint model to see if the speech enhancement network aids in the source separation task

In [None]:
from sourcesep.train import train, evaluate, train_ss, evaluate_ss

## UNet Only Training

Used for intial Unet testing. Not used in report

In [None]:
batch_size = 8

train_ds = SourceSepDS('./Data/train', sep_net=True, ir_p=1, rand_crop=True)
val_ds = SourceSepDS('./Data/val', sep_net=True, ir_p=1, rand_crop=False, ir_determ=True)

train_dataloader = DataLoader(train_ds, batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SourceSeparator(unet_params, sep_net=False).to(device)

In [None]:
#load file
load_model = False
if load_model:
    model_file_name = 'testnew'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1)
criterion = si_snr

In [None]:
epochs = 10
_,_,_,_ = train(epochs
                  ,model 
                  ,criterion
                  ,optimizer
                  ,scheduler
                  ,train_dataloader
                  ,val_dataloader
                  ,device
                  ,save_dir='./models/testnew.pkl'
                  ,valid_freq=1
                  ,fine_tune=False
                  ,batch_eval=500
                  ,batch_loud=True)

## Joint Training

In [None]:
batch_size = 8

train_ds = SourceSepDS('./Data/train',sep_net=True, ir_p=1, rand_crop=True, ir_determ=False)
val_ds = SourceSepDS('./Data/val',sep_net=True, ir_p=1, rand_crop=False, ir_determ=True)

train_dataloader = DataLoader(train_ds, batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SourceSeparator(unet_params, sep_net=True).to(device)

In [None]:
#load file
load_model = False
if load_model:
    model_file_name = 'joint_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1)
criterion = si_snr

In [None]:
epochs = 4
train_losses, valid_losses, train_si_snris, valid_si_snris = train(epochs
                                                                      ,model 
                                                                      ,criterion
                                                                      ,optimizer
                                                                      ,scheduler
                                                                      ,train_dataloader
                                                                      ,val_dataloader
                                                                      ,device
                                                                      ,save_dir='./models/joint_final.pkl'
                                                                      ,valid_freq=4
                                                                      ,fine_tune=False
                                                                      ,batch_eval=500
                                                                      ,batch_loud=False)

In [None]:
joint_tl = train_losses
joint_vl = valid_losses
joint_tsisnri = train_si_snris
joint_vsisnri = valid_si_snris

## Fine Tuning

In [None]:
batch_size = 8

train_ds = SourceSepDS('./Data/train',sep_net=True, ir_p=1, rand_crop=True, ir_determ=False)
val_ds = SourceSepDS('./Data/val',sep_net=True, ir_p=1, rand_crop=False, ir_determ=True)

train_dataloader = DataLoader(train_ds, batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SourceSeparator(unet_params, sep_net=True).to(device)

In [None]:
#load file
load_model = True
if load_model:
    model_file_name = 'joint_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1)
criterion = si_snr

In [None]:
epochs = 2
train_losses, valid_losses, train_si_snris, valid_si_snris = train(epochs
                                                                  ,model 
                                                                  ,criterion
                                                                  ,optimizer
                                                                  ,scheduler
                                                                  ,train_dataloader
                                                                  ,val_dataloader
                                                                  ,device
                                                                  ,save_dir='./models/joint_final.pkl'
                                                                  ,valid_freq=2
                                                                  ,fine_tune=True
                                                                  ,batch_eval=200
                                                                  ,batch_loud=False)

In [None]:
ft_tl = train_losses
ft_vl = valid_losses
ft_tsisnri = train_si_snris
ft_vsisnri = valid_si_snris

## Baseline Training

Training and evaluation loops for only the source separation network.

In [None]:
batch_size = 8

train_ds = SourceSepDS('./Data/train', sep_net=True, ir_p=1, rand_crop=True, ir_determ=False )
val_ds = SourceSepDS('./Data/val', sep_net=True, ir_p=1, rand_crop=False, ir_determ=True )

train_dataloader = DataLoader(train_ds, batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
bl_model = ConvTasNet().to(device)

In [None]:
load_model = False
if load_model:
    model_file_name = 'baseline_large_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    bl_model.load_state_dict(torch.load(path_to_pkl))

In [None]:
optimizer = torch.optim.Adam(bl_model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1)
criterion = si_snr

In [None]:
epochs = 6
train_losses, valid_losses, train_si_snris, valid_si_snris = train_ss(epochs
                                                                      ,bl_model 
                                                                      ,criterion
                                                                      ,optimizer
                                                                      ,scheduler
                                                                      ,train_dataloader
                                                                      ,val_dataloader
                                                                      ,device
                                                                      ,save_dir='./models/baseline_large_final.pkl'
                                                                      ,valid_freq=2
                                                                      ,batch_eval=500)

In [None]:
bl_tl = train_losses
bl_vl = valid_losses
bl_tsisnri = train_si_snris
bl_vsisnri = valid_si_snris

# Test Set Performance

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = si_snr

In [None]:
#load join net 
model = SourceSeparator(unet_params, sep_net=True).to(device)
load_model = True
if load_model:
    model_file_name = 'joint_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
test_ds = SourceSepDS('./Data/test', sep_net=True, ir_p=1, rand_crop=False, ir_determ=True )
test_dataloader = DataLoader(test_ds, batch_size)

joint_loss_test, joint_si_snri_test, joint_si_snri_dn_test = evaluate(model, test_dataloader, criterion, device, batch_eval=1000, fine_tune=False)

In [None]:
print('Joint Net si snri: {}'.format(joint_si_snri_test))
print('Denoiser si snri: {}'.format(joint_si_snri_dn_test))

In [None]:
#load baseline
bl_model = ConvTasNet().to(device)
load_model = True
if load_model:
    model_file_name = 'baseline_large_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    bl_model.load_state_dict(torch.load(path_to_pkl))

In [None]:
test_ds = SourceSepDS('./Data/test', sep_net=True, ir_p=1, rand_crop=False, ir_determ=True )
test_dataloader = DataLoader(test_ds, batch_size)

bl_loss_test, bl_si_snri_test = evaluate_ss(bl_model, test_dataloader, criterion, device)

In [None]:
print('tasnet si snri: {}'.format(bl_si_snri_test))

Out of curiosity, evaluating the joint network before fine tuning. This model is not included is in the upload as the total submssion would exceed the file size.

In [None]:
#load join net 
model = SourceSeparator(unet_params, sep_net=True).to(device)
load_model = True
if load_model:
    model_file_name = 'joint_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
test_ds = SourceSepDS('./Data/test', sep_net=True, ir_p=1, rand_crop=False, ir_determ=True )
test_dataloader = DataLoader(test_ds, batch_size)

joint_preft_loss_test, joint_preft_si_snri_test, joint_preft_si_snri_dn_test = evaluate(model, test_dataloader, criterion, device, batch_eval=1000, fine_tune=False)

In [None]:
print('Joint Net si snri: {}'.format(joint_preft_si_snri_test))
print('Denoiser si snri: {}'.format(joint_preft_si_snri_dn_test))

# Case Study

In [None]:
from sourcesep.utils import evaluate_file

## Joint Model

In [None]:
#load join net 
model = SourceSeparator(unet_params, sep_net=True).to(device)
load_model = True
if load_model:
    model_file_name = 'joint_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    model.load_state_dict(torch.load(path_to_pkl))

In [None]:
real_filepath = 'normal-charlie.wav'

evaluate_file(real_filepath, model, apply_ir=False, only_tas=False, save_png=True)

## TasNet Only

In [None]:
#load baseline
bl_model = ConvTasNet().to(device)
load_model = True
if load_model:
    model_file_name = 'baseline_large_final'
    path_to_pkl = './models/{}.pkl'.format(model_file_name)
    bl_model.load_state_dict(torch.load(path_to_pkl))

In [None]:
real_filepath = 'normal-charlie.wav'

evaluate_file(real_filepath, bl_model, apply_ir=False, only_tas=True)