<a href="https://colab.research.google.com/github/kth0522/AI_news/blob/main/vocal_separation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**라이브러리**

In [10]:
import os
import torch
import torchaudio
import torchaudio.transforms as T
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS

**디바이스 설정 및 모델 선언**

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

bundle = HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
model.to(device)
sample_rate = bundle.sample_rate

**Helper Functions**

In [13]:
def separate_sources(
    model,
    mix,
    segment=10.,
    overlap=0.1,
    device=None,
):
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = T.Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

def separate_and_save_sources(wav_path, output_dir):
    waveform, sr = torchaudio.load(wav_path)
    waveform = waveform.to(device)
    mixture = waveform

    segment = 10
    overlap = 0.1

    ref = waveform.mean(0)
    waveform = (waveform - ref.mean()) / ref.std()

    sources = separate_sources(
        model,
        waveform[None],
        device=device,
        segment=segment,
        overlap=overlap,
    )[0]
    sources = sources * ref.std() + ref.mean()

    sources = sources.cpu()

    file_name = os.path.splitext(os.path.basename(wav_path))[0]

    for i, source in enumerate(model.sources):
        if source == "vocals":
            print(file_name)
            output_path = os.path.join(output_dir, f"{file_name}.wav")
            torchaudio.save(output_path, sources[i], sample_rate)



def process_all_wav_in_folder(input_folder_path, output_folder_path):
    if not os.path.exists(output_folder_path):
        os.makedirs(output_folder_path)

    for filename in os.listdir(input_folder_path):
        if filename.endswith(".wav"):
            wav_path = os.path.join(input_folder_path, filename)
            separate_and_save_sources(wav_path, output_folder_path)


**실행** \
입력 폴더에 wav 파일들 넣고 실행

In [14]:
INPUT_DIRS = './audios'
OUTPUT_DIRS = './outputs'

In [15]:
process_all_wav_in_folder(input_folder_path=INPUT_DIRS, output_folder_path=OUTPUT_DIRS)

mbn_ai_1
