MUSDB18
==============

This chapter will present how to make ```torch.utils.data.Dataset``` with MUSDB18 to train a music demixing model.

In [3]:
import random
import numpy as np
from torch.utils.data import Dataset


class RandomMUSDB18TrainDataset(Dataset):

    def __init__(self, musdb, sample_size):
        self.musdb = musdb
        self.sample_size = sample_size

    def __len__(self) -> int:
        return len(self.musdb)

    def get_track(self, i, target=None):
        if target is None:
            return self.musdb[i].audio.T.astype(np.float32)
        else:
            return self.musdb[i].sources[target].audio.T.astype(np.float32)

    def __getitem__(self, i):
        mixture = self.get_track(i)
        length = mixture.shape[1]
        rand_start_pos = random.randint(0, length - self.sample_size - 1)

        mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
        targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
                   for source
                   in self.musdb[i].sources.keys()}

        return mixture, targets

##

In [7]:
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
train_dataset = RandomMUSDB18TrainDataset(mus_train, sample_size=44100*3)

In [11]:
from IPython.display import Audio, display

print('mixture')
mixture, source_dict = train_dataset[0]
display(Audio(mixture, rate=44100))

for source in source_dict.keys():
    print(source)
    display(Audio(source_dict[source], rate=44100))

mixture


vocals


drums


bass


other
