In [1]:
import os
import warnings

import librosa

from tqdm import tqdm

import numpy as np

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchsummary import summary
from torchviz import make_dot

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MUSDBDataset(Dataset):
    def __init__(self, data_dir: str):
        self.crop_size = 284672
        self.data_dir = os.path.join(data_dir, 'data_numpy')
        if not os.path.exists(self.data_dir) or \
            len([name for name in os.listdir(self.data_dir)]) < 500:
            print("Data has not been saved as numpy object. Converting...")
            if not os.path.exists(self.data_dir):
                os.makedirs(self.data_dir)
            self.convert_to_numpy(data_dir, self.data_dir)
        self.music_fulllist = self.get_filenames(self.data_dir)
        self.music_list, self.sep_list = self.separate_source(self.music_fulllist)

    def __len__(self):
        return len(self.music_list)

    def __getitem__(self, idx):
        base_music = self.music_list[idx]
        base_music = np.load(base_music)
        base_music = np.stack([base_music[:self.crop_size]])

        sep_music = self.sep_list[idx*4 : idx*4+4]
        sep_music = np.stack([np.load(idx)[:self.crop_size] for idx in sep_music])
        return base_music, sep_music

    def get_filenames(self, path):
        files_list = list()
        for filename in os.listdir(path):
            if not filename == "data_numpy":
                files_list.append(os.path.join(path, filename))
        return files_list

    def convert_to_numpy(self, music_dir, target_dir):
        warnings.filterwarnings('ignore')
        music_list = self.get_filenames(music_dir)
        for music in tqdm(music_list):
            outfile_name = music.split("/")[-1]
            outfile_name = target_dir + "/" + outfile_name
            arr, _ = librosa.load(music)
            np.save(outfile_name, arr)

    def separate_source(self, mus_list):
        warnings.filterwarnings('ignore')
        music_list = list()
        sep_list = list()
        for music in tqdm(mus_list):
            mus_type = music.split(".")[-3]
            if mus_type == '0':
                music_list.append(music)
            else:
                sep_list.append(music)

        return music_list, sep_list

In [3]:
ds = MUSDBDataset('/mnt/d/createdmusdb18/train')

100%|██████████| 500/500 [00:00<00:00, 1364445.02it/s]


In [4]:
train_dataloader = DataLoader(ds, batch_size=4)

In [10]:

class DownSampling(nn.Module):
    def __init__(self, in_ch=1, out_ch=24, kernel_size=15):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=7),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, x: Tensor):
        x = self.net(x)
        return x

In [11]:
class UpSampling(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=2),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=2),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, x, x_back):
        x = self.upsample(x);
        x = torch.cat([x, x_back], axis=1)
        return self.conv(x)

In [98]:
class WaveUNet(nn.Module):
    def __init__(self, n_level=12, n_source=4):
        super().__init__()
        self.level = n_level
        
        self.layer_to_concat = []
        
        layers=[DownSampling(in_ch=1,out_ch=24,kernel_size=15)]
        
        for i in range(self.level-1):
            layers.append(DownSampling(in_ch=24*(i+1),out_ch=24*(i+2),kernel_size=15))
            
        # layers.append(DownSampling(in_ch=24*(self.level), out_ch=24*(self.level+1), kernel_size=15, decimate=False))
        layers.append(DownSampling(in_ch=24*(self.level), out_ch=24*(self.level+1), kernel_size=15))
            
        for i in range(self.level):
            layers.append(UpSampling(in_ch=24*(self.level+1-i) + 24*(self.level - i), out_ch=24*(self.level-i), kernel_size=5))
            
        self.net = nn.ModuleList(layers)
        self.separation = nn.Sequential(
            nn.Conv1d(25, n_source, kernel_size=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(n_source, n_source, kernel_size=1),
            nn.LeakyReLU(inplace=True),
        )
    
    def forward(self, x: Tensor):
        print("before in ", x.shape)
        self.layer_to_concat.append(x)
        for layer in self.net[0: self.level]:
            x = layer(x)
            print("conv ", x.shape)
            self.layer_to_concat.append(x);
            x = x[:, :, 1::2]
            print("decimate ", x.shape)
        x = self.net[self.level](x)
        print("middle out ", x.shape)
        self.layer_to_concat.append(x); 
        for i, layer in enumerate(self.net[self.level+1:]):
            print("before up ", x.shape)
            x = self.layer_to_concat[-1];
            x = layer(x, self.layer_to_concat[-1-i-1])
            print("after up ", x.shape)
            self.layer_to_concat[-1] = x
            
        x = torch.cat([self.layer_to_concat[0], x], axis=1)
        x = self.separation(x)
        
        return x

In [101]:
test = torch.rand((1, 1, 16384))

In [102]:
model = WaveUNet(n_level=12, n_source=4)

# out = model(torch.Tensor(ds[0][0]))
out = model(test)

before in  torch.Size([1, 1, 16384])
conv  torch.Size([1, 24, 16384])
decimate  torch.Size([1, 24, 8192])
conv  torch.Size([1, 48, 8192])
decimate  torch.Size([1, 48, 4096])
conv  torch.Size([1, 72, 4096])
decimate  torch.Size([1, 72, 2048])
conv  torch.Size([1, 96, 2048])
decimate  torch.Size([1, 96, 1024])
conv  torch.Size([1, 120, 1024])
decimate  torch.Size([1, 120, 512])
conv  torch.Size([1, 144, 512])
decimate  torch.Size([1, 144, 256])
conv  torch.Size([1, 168, 256])
decimate  torch.Size([1, 168, 128])
conv  torch.Size([1, 192, 128])
decimate  torch.Size([1, 192, 64])
conv  torch.Size([1, 216, 64])
decimate  torch.Size([1, 216, 32])
conv  torch.Size([1, 240, 32])
decimate  torch.Size([1, 240, 16])
conv  torch.Size([1, 264, 16])
decimate  torch.Size([1, 264, 8])
conv  torch.Size([1, 288, 8])
decimate  torch.Size([1, 288, 4])
middle out  torch.Size([1, 312, 4])
before up  torch.Size([1, 312, 4])
after up  torch.Size([1, 288, 8])
before up  torch.Size([1, 288, 8])
after up  torch.S

In [103]:
out.shape

torch.Size([1, 4, 16384])