<a href="https://colab.research.google.com/github/jarredou/Music-Source-Separation-Training-Colab-Inference/blob/main/Music_Source_Separation_Training_(Colab_Inference)_CustomModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab inference for ZFTurbo's [Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/)


<font size=1>*made by [jarredou](https://github.com/jarredou) & deton</font>  
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q811R5YI)  

In [None]:
#@markdown #Gdrive connection
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
import base64
#@markdown # Install

%cd /content
!git clone -b colab-inference https://github.com/jarredou/Music-Source-Separation-Training

#requirements fix by santilli_
req_text = """
mutagen==1.47.0
ml_collections==1.1.0
numpy>=1.26.0
pandas==2.2.2
scipy
tqdm
segmentation_models_pytorch==0.3.3
timm
audiomentations
pedalboard
omegaconf
beartype
rotary_embedding_torch==0.3.5
einops
# librosa==0.11.0
demucs #==4.0.0
# transformers==4.35.0
torchmetrics==0.11.4
spafe==0.3.2
protobuf
torch_audiomentations
asteroid==0.7.0
auraloss
torchseg
"""

with open("Music-Source-Separation-Training/requirements.txt", "w") as f:
    f.write(req_text)

!mkdir '/content/Music-Source-Separation-Training/ckpts'

print('Installing the dependencies... This will take a few minutes')
!pip install -r 'Music-Source-Separation-Training/requirements.txt' &> /dev/null

print('Installation is done !')

In [None]:
%cd '/content/Music-Source-Separation-Training/'
import os
import torch
import yaml
from urllib.parse import quote

class IndentDumper(yaml.Dumper):
    def increase_indent(self, flow=False, indentless=False):
        return super(IndentDumper, self).increase_indent(flow, False)


def tuple_constructor(loader, node):
    # Load the sequence of values from the YAML node
    values = loader.construct_sequence(node)
    # Return a tuple constructed from the sequence
    return tuple(values)

# Register the constructor with PyYAML
yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',
tuple_constructor)



def conf_edit(config_path, chunk_size, overlap):
    with open(config_path, 'r') as f:
        data = yaml.load(f, Loader=yaml.SafeLoader)

    # handle cases where 'use_amp' is missing from config:
    if 'use_amp' not in data.keys():
      data['training']['use_amp'] = True

    data['audio']['chunk_size'] = chunk_size
    data['inference']['num_overlap'] = overlap

    if data['inference']['batch_size'] == 1:
      data['inference']['batch_size'] = 2

    print("Using custom overlap and chunk_size values for roformer model:")
    print(f"overlap = {data['inference']['num_overlap']}")
    print(f"chunk_size = {data['audio']['chunk_size']}")
    print(f"batch_size = {data['inference']['batch_size']}")


    with open(config_path, 'w') as f:
        yaml.dump(data, f, default_flow_style=False, sort_keys=False, Dumper=IndentDumper, allow_unicode=True)

def download_file(url):
    # Encode the URL to handle spaces and special characters
    encoded_url = quote(url, safe=':/')

    path = 'ckpts'
    os.makedirs(path, exist_ok=True)
    filename = os.path.basename(encoded_url)
    file_path = os.path.join(path, filename)

    if os.path.exists(file_path):
        print(f"File '{filename}' already exists at '{path}'.")
        return

    try:
        response = torch.hub.download_url_to_file(encoded_url, file_path)
        print(f"File '{filename}' downloaded successfully")
    except Exception as e:
        print(f"Error downloading file '{filename}' from '{url}': {e}")





#@markdown # Separation
#@markdown #### Model config:
config_url = 'https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml' #@param {type:"string"}
ckpt_url = 'https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt' #@param {type:"string"}
model_type = 'mel_band_roformer' #@param ['mdx23c','bs_roformer', 'mel_band_roformer', 'bandit', 'bandit_v2', 'scnet', 'apollo', 'htdemucs', 'segm_models', 'torchseg', 'bs_mamba2']
#@markdown ---
#@markdown #### Separation config:
input_folder = '/content/drive/MyDrive/input' #@param {type:"string"}
output_folder = '/content/drive/MyDrive/output' #@param {type:"string"}
extract_instrumental = True #@param {type:"boolean"}
export_format = 'flac PCM_16' #@param ['wav FLOAT', 'flac PCM_16', 'flac PCM_24']
use_tta = False #@param {type:"boolean"}
#@markdown ---
#@markdown *Roformers custom config:*
overlap = 2 #@param {type:"slider", min:2, max:40, step:1}
chunk_size = "485100" #@param [352800, 485100] {allow-input: true}

if export_format.startswith('flac'):
    flac_file = True
    pcm_type = export_format.split(' ')[1]
else:
    flac_file = False
    pcm_type = None


if config_url != '' and ckpt_url != '':
    config_filename = os.path.basename(config_url)
    ckpt_filename = os.path.basename(ckpt_url)
    print(config_filename, ckpt_filename)
    config_path = f'ckpts/{config_filename}'
    start_check_point = f'ckpts/{ckpt_filename}'
    download_file(config_url)
    download_file(ckpt_url)
    if "roformer" in model_type:
      conf_edit(config_path, int(chunk_size), overlap)

!python inference.py \
    --model_type {model_type} \
    --config_path '{config_path}' \
    --start_check_point '{start_check_point}' \
    --input_folder '{input_folder}' \
    --store_dir '{output_folder}' \
    {('--extract_instrumental' if extract_instrumental else '')} \
    {('--flac_file' if flac_file else '')} \
    {('--use_tta' if use_tta else '')} \
    {('--pcm_type ' + pcm_type if pcm_type else '')}

**INST-Mel-Roformer v1/1x/2** has switched output file names - files labelled as vocals are instrumentals (if you uncheck extract_instrumentals for v1e model, only one stem caled "other" will be rendered, and it will be instrumental.<br><br>
**TTA** - results in longer separation time, "it gives a little better SDR score but hard to tell if it's really audible in most cases". <br> it “means "test time augmentation", (...) it will do 3 passes on the audio file instead of 1. 1 pass with be with original audio. 1 will be with inverted stereo (L becomes R, R become L). 1 will be with phase inverted and then results are averaged for final output. ” - jarredou
<br><br>
**Overlap** - higher means longer separation time. 4 is already balanced value, 2 is fast and some people still won't notice any difference. Normally there's not point going over 8.<br><br>
If your separation can't start and "Total files found: 0" is shown, be aware that: <br>1) Input must be a path to a folder containing audio files, not direct path to an audio file<br> 2) The Colab is case aware - e.g. call your folder "input" not "Input".<br> 3) Check if your Google Drive mounting was executed correctly. Open file manager on the left to check if your drive folder is not empty. If it's the case, force remount with the following line:

In [None]:
drive.mount("/content/drive", force_remount=True)