In [1]:
import os
import warnings

import librosa

from tqdm import tqdm

import numpy as np

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MUSDBDataset(Dataset):
    def __init__(self, data_dir: str):
        self.crop_size = 284672
        self.data_dir = os.path.join(data_dir, 'data_numpy')
        if not os.path.exists(self.data_dir) or \
            len([name for name in os.listdir(self.data_dir)]) < 500:
            print("Data has not been saved as numpy object. Converting...")
            if not os.path.exists(self.data_dir):
                os.makedirs(self.data_dir)
            self.convert_to_numpy(data_dir, self.data_dir)
        self.music_fulllist = self.get_filenames(self.data_dir)
        self.music_list, self.sep_list = self.separate_source(self.music_fulllist)

    def __len__(self):
        return len(self.music_list)

    def __getitem__(self, idx):
        base_music = self.music_list[idx]
        base_music = np.load(base_music)
        base_music = np.stack([base_music[:self.crop_size]])

        sep_music = self.sep_list[idx*4 : idx*4+4]
        sep_music = np.stack([np.load(idx)[:self.crop_size] for idx in sep_music])
        return base_music, sep_music

    def get_filenames(self, path):
        files_list = list()
        for filename in os.listdir(path):
            if not filename == "data_numpy":
                files_list.append(os.path.join(path, filename))
        return files_list

    def convert_to_numpy(self, music_dir, target_dir):
        warnings.filterwarnings('ignore')
        music_list = self.get_filenames(music_dir)
        for music in tqdm(music_list):
            outfile_name = music.split("/")[-1]
            outfile_name = target_dir + "/" + outfile_name
            arr, _ = librosa.load(music)
            np.save(outfile_name, arr)

    def separate_source(self, mus_list):
        warnings.filterwarnings('ignore')
        music_list = list()
        sep_list = list()
        for music in tqdm(mus_list):
            mus_type = music.split(".")[-3]
            if mus_type == '0':
                music_list.append(music)
            else:
                sep_list.append(music)

        return music_list, sep_list

In [3]:
ds = MUSDBDataset('/mnt/d/createdmusdb18/train')

100%|██████████| 500/500 [00:00<00:00, 588261.43it/s]


In [4]:
train_dataloader = DataLoader(ds, batch_size=4)

In [32]:

class DownSampling(nn.Module):
    def __init__(self, in_ch=1, out_ch=24, kernel_size=15, decimate: bool = True):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
        )
        self.decimate = decimate
        
    def forward(self, x: Tensor):
        x = self.net(x)
        if self.decimate:
            x = x[:, ::2]
        return x

In [None]:
class UpSampling(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
        )

In [80]:
class WaveUNet(nn.Module):
    def __init__(self, n_level=12):
        super().__init__()
        self.level = n_level
        
        self.layer_to_concat = []
        
        layers=[DownSampling(in_ch=1,out_ch=24,kernel_size=15)]
        
        for i in range(self.level-1):
            layers.append(DownSampling(in_ch=24*(i+1),out_ch=24*(i+2),kernel_size=15))
            
        layers.append(DownSampling(in_ch=24*(self.level), out_ch=24*(self.level+1), kernel_size=15, decimate=False))
            
        
        self.net = nn.ModuleList(layers)
    
    def forward(self, x: Tensor):
        for layer in self.net[0: self.level+1]:
            x = layer(x)
            self.layer_to_concat.append(x);

        return x

In [81]:
test = torch.rand((1, 16384))

In [82]:
model = WaveUNet(12)

# out = model(torch.Tensor(ds[0][0]))
out = model(test)

1
2
3
4
5
6
7
8
9
10
11
12
13


In [64]:
out.shape

torch.Size([312, 4])

In [83]:
upsample = nn.ConvTranspose1d(in_channels=312, out_channels=312//2, kernel_size=2, stride=2)

In [84]:
out.shape

torch.Size([312, 4])

In [73]:
lin = nn.functional.interpolate(out[None, :], scale_factor=2, mode='linear')

In [74]:
lin.shape

torch.Size([1, 312, 8])

In [69]:
out2 = upsample(out)

In [70]:
out2.shape

torch.Size([156, 8])

In [25]:
ds[0][0].shape

(1, 284672)

In [9]:
out.shape

torch.Size([24, 142322])