[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/likelian/source-separation/blob/main/Wave-U-Net-Pytorch/predict.ipynb)

In [None]:
!python -V
!nvcc -V

In [None]:
!git clone https://github.com/likelian/source-separation.git

import os
os.chdir("source-separation/Wave-U-Net-Pytorch")

In [None]:
# !git pull
!python -m pip install -r requirements.txt

In [12]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
import argparse
import os

import data.utils
import model.utils as model_utils

from test import predict_song
from model.waveunet import Waveunet

In [9]:
class Args:
    def __init__(self):
        self.instruments = ["bass", "drums", "other", "vocals"]
        self.cuda = False
        self.features = 32
        self.load_model = "checkpoints/waveunet/model"
        self.batch_size = 4
        self.levels = 6
        self.depth = 1
        self.sr = 44100
        self.channels = 2
        self.kernel_size = 5
        self.output_size = 2.0
        self.strides = 4
        self.conv_type = "gn"
        self.res = "fixed"
        self.separate = 1
        self.feature_growth = "double"
        self.input = os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3")
        self.output = None
        
    def set_cuda(self, cuda_cond):
        self.cuda = cuda_cond
    
    def set_model_path(self, model_path):
        self.load_model = model_path
    
    def set_input_path(self, input_path):
        self.input = input_path

    def set_output_path(self, output_path):
        self.output = output_path


In [10]:
def main(args):
    # MODEL
    num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
                     target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                     conv_type=args.conv_type, res=args.res, separate=args.separate)

    if args.cuda:
        model = model_utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print("Loading model from checkpoint " + str(args.load_model))
    state = model_utils.load_model(model, None, args.load_model, args.cuda)
    print('Step', state['step'])

    preds = predict_song(args, args.input, model)

    output_folder = os.path.dirname(args.input) if args.output is None else args.output
    for inst in preds.keys():
        data.utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr)


In [14]:
if __name__ == '__main__':
    args = Args()

    model_path = "/content/drive/MyDrive/musdb18hq/pretrained_model_waveunet/model"

    # args.set_cuda(True)
    args.set_model_path(model_path)
    args.set_input_path("./audio_examples/Cristina Vane - So Easy/mix.mp3")
    # args.set_output_path("./output")

    main(args)


Using valid convolutions with 97961 inputs and 88409 outputs
Loading model from checkpoint /content/drive/MyDrive/musdb18hq/pretrained_model_waveunet/model
Step 132065


