In [1]:
# Python
import os
import csv
# Third party
import torch
import torch.nn as nn
import torch.optim as optim

import torchaudio
from torchaudio.transforms import Fade

import nussl
from nussl.ml import SeparationModel

from mir_eval import separation

# Self
from models.UNet import UNetSpect

SoX could not be found!

    If you do not have SoX, proceed here:
     - - - http://sox.sourceforge.net/ - - -

    If you do (or think that you should) have SoX, double-check your
    path variables.
    


In [120]:
def get_source_dirs(musdb18_directory):
    """Takes in a MUSDB18 directory and returns a list of strings of all the folders in the directory.
    If MUSDB18 is broken down into subfolders (e.g. test, train, and validation), the list of directories will go one layer deep.
    Useful for looping MUSDB18/<song_name>/mixture.wav or MUSDB18/test/<song_name>/mixture.wav.

    Args:
        directory (str): path to MUSDB18 dataset.

    Returns:
        list_of_dirs (list): A list of strings containing the relative path of each song in MUSDB18.
    """
    list_of_dirs = []
    if 'train' in os.listdir(musdb18_directory):
        for sub_folder in os.listdir(musdb18_directory):
            for song_name in os.listdir(musdb18_directory + sub_folder):
                if song_name not in ['bass','drums','vocals','other']:
                    list_of_dirs.append(musdb18_directory + sub_folder + '/' + song_name +'/')
    else:
        for song_name in os.listdir(musdb18_directory):
            list_of_dirs.append(musdb18_directory + song_name +'/')

    return list_of_dirs

In [121]:
def same_device(waveform):
    '''Moves waveform to cpu for bss_eval_sources.  Both waveforms need to be on the same device.
    Args:
        waveform (torch.Tensor): A torch.Tensor waveform.

    Returns:
        waveform (torch.Tensor): A torch.Tensor waveform.
    '''
    if waveform.get_device() == 0:
        waveform = waveform.cpu()
    return waveform

In [122]:
def bss_eval(reference_source, estimated_source):
    """Computes SDR, SIR, and SAR for a reference source and an estimated source.

    Args:
        reference_source (np.ndarray): A np.ndarray of the reference source.
        estimated_source (np.ndarray):  A np.ndarray of the estimated source.

    Returns:
        (sdr,sir,sar) (tuple): A tuple containing sdr, sir, and sar.
    """
    _eval = separation.bss_eval_sources(reference_source, estimated_source)
    sdr = _eval[0].mean()
    sir = _eval[1].mean()
    sar = _eval[2].mean()
    return (sdr, sir, sar)

In [144]:
def separate_sources(
        model,
        mix,
        segment=10.,
        overlap=0.1,
        device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

In [175]:
def evaluate_model(model,list_of_mixtures,segment=10.,overlap=0.1,device=None,):
    model_name = model.__class__.__name__
    print(model_name)

    
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(device)
    
    model = model.to(device)
    print(device)
    for idx, _dir in enumerate(list_of_mixtures):
        print(idx, _dir)
        reference_mixture, reference_sample_rate = torchaudio.load(_dir + 'mixture.wav')
        reference_mixture = reference_mixture.to(device)
        
        reference_mixture_mean = reference_mixture.mean(0)
        reference_mixture = (reference_mixture - reference_mixture_mean.mean()) / reference_mixture_mean.std()  # normalization
        
        sources = separate_sources(
            model,
            reference_mixture[None],
            device=device,
            segment=segment,
            overlap=overlap,
            )[0]
        sources = sources * reference_mixture_mean.std() + reference_mixture_mean.mean()
        sources_list = model.sources
        sources = list(sources)
        audios = dict(zip(sources_list, sources))

        for source in sources_list:
            print(source)

            original_mixture, original_sample_rate = torchaudio.load(_dir + source + '.wav')
            original_mixture = original_mixture.to(device)
            sdr, sir, sar = bss_eval(same_device(audios[source]),same_device(original_mixture))
            print(sdr, sir, sar)



    
    return None


In [176]:
model.sources

['drums', 'bass', 'other', 'vocals']

In [170]:
bundle = HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print('')

In [177]:
list_of_dirs = get_source_dirs('data/MUSDB18HQ/')
evaluate_model(model,list_of_dirs[0:1],segment=10.,overlap=0.1,device='cpu')

HDemucs
cpu
0 data/MUSDB18HQ/test/Al James - Schoolboy Facination/
drums


TypeError: sum() received an invalid combination of arguments - got (axis=tuple, out=NoneType, ), but expected one of:
 * (*, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: axis, out
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)


In [None]:
# mixture = 'data/MUSDB18HQ/train/A Classic Education - NightOwl/mixture.wav'
# waveform, sample_rate = torchaudio.load(mixture)  # replace SAMPLE_SONG with desired path for different song
# waveform = waveform.to(device)

# ref = waveform.mean(0)
# waveform = (waveform - ref.mean()) / ref.std()  # normalization
# model(waveform)

In [None]:
# model = UNetSpect().build()
# checkpoint = torch.load("checkpoints/best.model.pth")
# model = SeparationModel(checkpoint["config"])
# model.load_state_dict(checkpoint["state_dict"])
# model(audio_signal)

In [89]:
# mixture = 'data/MUSDB18HQ/train/A Classic Education - NightOwl\mixture.wav'
# waveform, sample_rate = torchaudio.load(mixture)  # replace SAMPLE_SONG with desired path for different song
# waveform = waveform.to(device)

# ref = waveform.mean(0)
# waveform = (waveform - ref.mean()) / ref.std()  # normalization


# separate_sources(model,
#         waveform[None],
#         segment=10,
#         overlap=0.1,
#         device=None,
#         sample_rate=sample_rate
# )

# sources = sources * ref.std() + ref.mean()

# sources_list = model.sources
# sources = list(sources)

# audios = dict(zip(sources_list, sources))