In [None]:
!wget -c https://zenodo.org/records/3338373/files/musdb18hq.zip

In [None]:
!unzip musdb18hq.zip -d musdb18hq

In [1]:
import os
import torch
import torchaudio
from glob import glob
from datasets import Dataset, Features, Audio, Value
from IPython.display import Audio
from einops import rearrange

In [2]:
N = (2**21)
tracks = [path.split("/")[2] for path in glob("musdb18hq/train/*/*.wav")][::5]

In [20]:
for i_track, track in enumerate(tracks):
    # load the mixture and vocal tracks
    mix_path = f"musdb18hq/train/{track}/mixture.wav"
    vocal_path = f"musdb18hq/train/{track}/vocals.wav"
    x,fs = torchaudio.load(mix_path, normalize=False)
    v,fs = torchaudio.load(vocal_path, normalize=False)
    C, L = v.shape
    assert(x.shape == v.shape)
    assert(fs==44100)
    assert(C==2)
    assert(x.dtype == v.dtype == torch.int16)
    if L<=N//2:
        continue

    if (L%N)/N > 0.5:
        # pad
        B = L//N + 1
        pad_length = B * N - L
        x = torch.nn.functional.pad(x, (0, pad_length))
        v = torch.nn.functional.pad(v, (0, pad_length))
    else:
        # drop last segment
        B = L//N
        x = x[:,:(B*N)]
        v = v[:,:(B*N)]
        
    # Split the file into non-overlapping 48-second chunks
    x = rearrange(x, 'C (B N) -> B C N', B=B, N=N)
    v = rearrange(v, 'C (B N) -> B C N', B=B, N=N)
    
    # remove segments that don't have enough vocals
    p = v.to(torch.float).norm(dim=1).mean(dim=1)
    x = x[p>200]
    v = v[p>200]
    B = x.shape[0]
    
    # TODO: save each of the
    for i_seg in range(B):
        mix_file = f"musdb_vss/{track}/vocals/{i_seg}.wav"
        vocal_file = f"musdb_vss/{track}/mixture/{i_seg}.wav"
        mix_dir = os.path.dirname(mix_file)
        vocal_dir = os.path.dirname(vocal_file)
        os.makedirs(mix_dir, exist_ok=True)
        os.makedirs(vocal_dir, exist_ok=True)
        torchaudio.save(
            uri = mix_file,
            src = v[i_seg],
            sample_rate = fs,
        )
        torchaudio.save(
            uri = vocal_file,
            src = x[i_seg],
            sample_rate = fs,
        )