In [1]:
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import torch
import torchaudio
import torch.nn as nn
from IPython.display import Audio
from torchaudio.transforms import Fade

In [3]:
from model.PM_Unet import Model_Unet

In [4]:
def separate_sources(
        model,
        mix,
        segment=10.,
        sample_rate = 44100,
        overlap=0.1,
        device='cpu',
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(['drums', 'bass', 'other', 'vocals']), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


In [5]:
def track(SAMPLE_SONG):
    segment = 7
    overlap = 0.2
    
    waveform, sample_rate = torchaudio.load(SAMPLE_SONG) 
    mixture = waveform[:, sample_rate*30: sample_rate*37]
    
    ref = waveform.mean(0)
    waveform = (waveform - ref.mean()) / ref.std()  # normalization

    sources = separate_sources(
    model,
    mixture[None],
    segment=segment,
    overlap=overlap,
    )
    sources = sources * ref.std() + ref.mean()

    sources_list = ['drums', 'bass', 'other', 'vocals']
    B, S, C, T = sources.shape
    sources = (sources.view(B, S*C,T)/sources.view(B, S*C,T).max(dim=2)[0].unsqueeze(-1)).view(B, S,C,T)
    sources = list(sources)
    
    audios = dict(zip(sources_list, sources[0]))
    audios['original'] = waveform
    return audios

sample_rate = 44100

In [None]:
model = Model_Unet(source=['drums', 'bass', 'other', 'vocals'], depth=4, channel=28)
model.load_state_dict(torch.load('model_weight_LSTM.pt', map_location=torch.device('cpu')))

In [None]:
path_track = '' + '.wav'
audios = track("path_track")

In [None]:
Audio(audios['vocals'], rate=44100)