In [1]:
from IPython.display import Audio
import IPython

import soundfile as sf
import torch
import torchaudio
import librosa

from pathlib import Path
import museval
import numpy as np
import pandas as pd
from mir_eval.separation import bss_eval_sources
import scipy.signal
import tqdm
import warnings
warnings.filterwarnings("ignore")


N_PART = 4
N_FFT = 2047
SAMPLING_RATE = 22050

from u_net import UNet, padding

input_wav = '/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/data/convertMUSDB/test/Al James - Schoolboy Facination.stem.wav'
output_dir = '/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/U-net/new_unet/separates/source'
model_path = '/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/U-net/new_unet/outputs/model-475.pth'

cuda_check = torch.cuda.is_available()
device = torch.device(f'cuda:1' if cuda_check else 'cpu')

print(device)

def median_nan(a):
    return np.median(a[~np.isnan(a)])

cuda:1


In [2]:
def separate(input_wav):
    with torch.no_grad():
        sound, sr = torchaudio.load(input_wav)
        if sr > SAMPLING_RATE:
            sound = torchaudio.functional.resample(sound, orig_freq=44100, new_freq=SAMPLING_RATE) # sampling rate fixing
        sound = sound[[0], :].to(device)

        window = torch.hann_window(N_FFT, device=device)

        # Convert it to power spectrogram, and pad it to make the number of
        # time frames to a multiple of 64 to be fed into U-NET
        sound_stft = torch.stft(sound, N_FFT, window=window, return_complex=False)
        sound_spec = sound_stft.pow(2).sum(-1).sqrt()
        sound_spec, (left, right) = padding(sound_spec)

        # Load the model
        model = UNet(N_PART)
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        model.eval()

        right = sound_spec.size(2) - right
        mask = model(sound_spec).squeeze(0)[:, :, left:right]
        separated = mask.unsqueeze(3) * sound_stft
        
        #print(separated.type(torch.complex64)[:,:,:,0].shape)
        #separated = librosa.istft(separated.cpu().numpy(), n_fft=N_FFT, window='hann', length=sound.size(-1))
        # istft requires complex tensor // forced dtype transform
        separated = torch.istft(separated.type(torch.complex64)[:,:,:,0], N_FFT, window=window, length=sound.size(-1))
        separated = separated.cpu().numpy()
    
    return separated

In [9]:
tracks = []
result = pd.DataFrame(columns=['track','SDR','ISR', 'SIR', 'SAR'])

p = Path('/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/data/MUSDB', 'test')
reference_dir = Path('/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/data/MUSDB', 'test')
for track_path in tqdm.tqdm(p.iterdir(), disable=True):
    tracks.append(track_path)

for track in tqdm.tqdm(tracks):
    # seaparate
    input_file = str(Path(track, 'mixture.wav'))
    separated = separate(input_file)

    output_path = Path('/userHome/userhome2/dahyun/voice/Singing_Voice_Synthesis/U-net/new_unet/separates', 'estimates', Path(input_file).parent.name)
    output_path.mkdir(exist_ok=True, parents=True)

    # save to wav
    for i in range(4):
        sf.write(str(output_path) + '/source' + str(i)+ '.wav', separated.T[:,i], SAMPLING_RATE)

    # evaluation
    estdir = output_path
    refdir = Path(reference_dir, estdir.name)
    if refdir.exists():
        ref, sr = sf.read(str(Path(refdir, 'vocals' + '.wav')), always_2d=True)
        est, sr = sf.read(str(Path(estdir, 'source3' + '.wav')), always_2d=True)

        ref = ref[:,0][None, ...]
        est = est[None, ...]

        SDR, ISR, SIR, SAR = museval.evaluate(ref, est, win=sr, hop=sr)
        values = {
                'track':estdir.name,
                "SDR": median_nan(SDR[0]),
                "ISR": median_nan(ISR[0]),
                "SIR": median_nan(SIR[0]),
                "SAR": median_nan(SAR[0])
            }
        result.loc[result.shape[0]] = values

values = {
        'track':'sum',
        "SDR": result['SDR'].median(),
        "ISR": result['ISR'].median(),
        "SIR": result['SIR'].median(),
        "SAR": result['SAR'].median()
}
result.loc[result.shape[0]] = values
print(list((result.loc[result.shape[0] - 1])[1:]))
result.to_csv(str(output_dir)+'.csv',index=0)

100%|██████████| 50/50 [06:25<00:00,  7.70s/it]

[0.3391093366341078, 1.2319471705340366, inf, -0.6625178621691593]



