Explanations, tradeoffs, and tips (detailed justification)

Why Demucs (recommended)

Demucs operates primarily in the time domain with learned architectures that model phase implicitly. That gives perceptually better results vs simple magnitude NMF where phase is re-used and causes artifacts. Demucs models (demucs / htdemucs) are trained on many source separation datasets — they generalize well to speech+background-music, music separations, and many real-world mixes. Use GPU.

Why chunking & overlap-add

Demucs runs on entire audio; very long files can exhaust GPU memory. Chunking into overlapping windows + overlap-add preserves the original sample rate and avoids artifacts at chunk boundaries when using a modest overlap (1 s is a good default). Overlap-add with normalization keeps amplitudes consistent.

Avoiding downgrades

We use ffmpeg to convert to 32-bit float WAV at the original sample rate (no resampling). Processing in float prevents quantization/rounding. When exporting we use 32-bit float WAV (subtype='FLOAT') so there is no loss. Only when you or a target system demands 16-bit should you convert back.

Alternative methods

Spleeter is fast and CPU-friendly but uses spectrogram-based NNs and often performs worse than Demucs on complex mixes. Useful if you need quick stems (2/4/5 stems).

Open-Unmix (UMX) and Conv-TasNet are other high-quality options for certain tasks. Demucs is a general strong first choice.

Model selection

Try multiple Demucs variants if you want: demucs, htdemucs, htdemucs_ft (fine-tuned), etc. The pretrained.get_model() call will download available weights. The larger the model, the better but the more GPU/VRAM needed.

Quality vs runtime

If you need best quality, increase CHUNK_SECONDS (so model sees larger context) and use the largest Demucs variant. That increases GPU memory and runtime.

For faster results, use smaller model or Spleeter; but expect more bleed/artifacts.

Labeling & post-filtering

Automatic labeling heuristics are approximate. For production, listen to the stems; then optionally run targeted denoisers (e.g., speech enhancement) on stems you identify as speech to further reduce residual background.

Verifying results

I included waveform overlays and spectrograms so you can visually check separation quality quickly. Listen to the WAV outputs as final verification.

In [None]:
# -----------------------------------------------------------------------------
# 0. Settings (edit these)
# -----------------------------------------------------------------------------
INPUT_FILE = "/content/60s-20.m4a"   # path to your uploaded file in Colab
OUTPUT_DIR = "/content/separated_outputs"  # where separated stems and plots will be saved
USE_DEMUCS = True   # True -> Demucs (recommended). If False, uses Spleeter (fast) fallback.
DEMUC_MODEL = "htdemucs"  # recommended: "demucs" or "htdemucs" (both available). See notes below.
CHUNK_SECONDS = 6   # chunk size for long files (10s is safe; increase if you have GPU/RAM) - Reduced to 6 seconds
CHUNK_OVERLAP = 1.0  # overlap in seconds between chunks to avoid boundary artifacts
PRESERVE_SR = True   # preserve original sample rate (recommended)
# -----------------------------------------------------------------------------

In [None]:
# -----------------------------------------------------------------------------
# 1. Install dependencies (run in Colab cell)
# Explanation:
#  - We install demucs for high-quality separation (uses PyTorch + GPU).
#  - We also install spleeter and librosa for fallback and plotting/analysis.
#  - ffmpeg for format conversion and so we always work with a WAV copy at exact sample rate.
# -----------------------------------------------------------------------------
!pip install -q demucs  # high-quality separator (recommended)
!pip install -q spleeter  # alternative (fast, but lower quality for some sources)
!pip install -q librosa soundfile matplotlib numpy scipy
!apt-get -qq update && apt-get -qq install -y ffmpeg

In [None]:
# ============================================================================
# 0. Clean + Install dependencies for Demucs separation (Colab-safe)
# ============================================================================
!pip install --upgrade pip setuptools wheel
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121  # CUDA 12.1 build (for GPUs)
!pip install demucs==4.0.0 --no-build-isolation
!pip install spleeter librosa soundfile matplotlib numpy scipy tqdm ffmpeg-python
!apt-get -qq update && apt-get -qq install -y ffmpeg


In [None]:
# -----------------------------------------------------------------------------
# 1. Configuration
# -----------------------------------------------------------------------------
INPUT_FILE = "/content/60s-20.m4a"
OUTPUT_DIR = "/content/separated_outputs"
USE_DEMUCS = True
DEMUC_MODEL = "htdemucs"
CHUNK_SECONDS = 6
CHUNK_OVERLAP = 1.0
PRESERVE_SR = True


In [None]:
# -----------------------------------------------------------------------------
# 2. Utilities: safe WAV conversion and helpers
# Explanation:
#  - Always work from a WAV copy converted by ffmpeg at the original sample rate to avoid
#    hidden codec/resampling issues.
#  - We keep the original sample rate and bit depth (ffmpeg parameter).
# -----------------------------------------------------------------------------
import os, subprocess, math, shutil
from pathlib import Path
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt

os.makedirs(OUTPUT_DIR, exist_ok=True)

def to_wav_preserve(input_path, out_wav=None):
    """
    Convert input audio to 32-bit float WAV preserving sample rate.
    Using 32-bit float avoids clipping during processing and keeps fidelity.
    """
    if out_wav is None:
        stem = Path(input_path).stem
        out_wav = f"/content/{stem}_converted.wav"
    # ffmpeg: -vn disable video, -y overwrite; -map 0 selects all audio streams
    # Use -c:a pcm_f32le to output 32-bit float WAV (preserves high fidelity)
    cmd = ["ffmpeg", "-y", "-i", str(input_path), "-vn", "-map", "0:a:0", "-c:a", "pcm_f32le", str(out_wav)]
    print("Running ffmpeg to convert to WAV (32-bit float):", " ".join(cmd))
    subprocess.run(cmd, check=True)
    return out_wav

# convert input to wav (preserve original's sample rate inside a float WAV)
wav_path = to_wav_preserve(INPUT_FILE)
print("Converted WAV path:", wav_path)

# read metadata
data, sr = sf.read(wav_path, dtype='float32')
print(f"Loaded converted WAV: duration={len(data)/sr:.2f}s, sr={sr}, samples={len(data)}")


In [None]:
# -----------------------------------------------------------------------------
# 3a. HIGH QUALITY Separation with Demucs (recommended)
# Explanation / justification:
#  - Demucs (Facebook research) uses time-domain models and often yields best perceptual separation
#    for music and multi-source audio. It preserves sample rate and works in time domain so phase issues
#    are better handled than simple magnitude-based NMF.
#  - We process the file in chunks (CHUNK_SECONDS) with overlap (CHUNK_OVERLAP) to avoid OOM for long audio,
#    then overlap-add outputs. This keeps fidelity (no downsampling) and lets you process arbitrarily long files.
# -----------------------------------------------------------------------------
if USE_DEMUCS:
    # We call demucs via its Python API so we have programmatic control over models and chunking.
    import torch
    from demucs import pretrained
    from demucs.apply import apply_model
    from demucs.audio import AudioFile

    # Select model - "demucs" or "htdemucs" are common names; pick based on your GPU:
    MODEL = DEMUC_MODEL

    print("Loading Demucs model:", MODEL)
    # get_model downloads pretrained weights the first time (colab internet required)
    model = pretrained.get_model(MODEL)
    model.to('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    print("Model loaded on", next(model.parameters()).device)

    # chunking helpers
    def split_into_chunks(signal, sr, chunk_sec=CHUNK_SECONDS, overlap_sec=CHUNK_OVERLAP):
        step = int((chunk_sec - overlap_sec) * sr)
        win = int(chunk_sec * sr)
        if step <= 0:
            raise ValueError("chunk_sec must be > overlap_sec.")
        chunks = []
        idx = 0
        while idx < len(signal):
            chunk = signal[idx: idx + win]
            chunks.append((idx, chunk))
            idx += step
        return chunks, win, step

    def overlap_add_segments(segments, length, step):
        """
        segments: list of numpy arrays (separated audio for a chunk)
        length: full length expected
        step: hop between segments
        We'll reconstruct by summing overlapped areas and normalizing by overlap count.
        """
        # segments is list of arrays shape (n_sources, samples) for each chunk
        n_chunks = len(segments)
        # determine n_sources
        n_sources = segments[0].shape[0]
        out = [np.zeros(length, dtype=np.float32) for _ in range(n_sources)]
        counts = np.zeros(length, dtype=np.float32)
        pos = 0
        for seg in segments:
            seg_len = seg.shape[1]
            for s in range(n_sources):
                end = pos + seg_len
                out[s][pos:end] += seg[s]
            counts[pos:pos+seg_len] += 1.0
            pos += step
        # avoid division by zero
        counts[counts == 0] = 1.0
        for s in range(n_sources):
            out[s] /= counts
        return out

    # prepare mono/stereo handling: demucs expects (channels, samples). We will feed original channels.
    # Use soundfile to read the original converted WAV with original channels.
    full_audio, full_sr = sf.read(wav_path, dtype='float32')  # could be mono or stereo
    if full_audio.ndim == 1:
        channels = 1
        full_audio = full_audio[np.newaxis, :]  # shape (1, samples)
    else:
        # transpose to (channels, samples) for demucs API
        full_audio = full_audio.T

    print("Full audio shape (channels, samples):", full_audio.shape)

    chunks, win, step = split_into_chunks(full_audio[0], full_sr, CHUNK_SECONDS, CHUNK_OVERLAP)
    print(f"Splitting into {len(chunks)} chunks; window={win} samples, step={step} samples")

    separated_chunks = []  # each element: numpy array shape (n_sources, samples)
    for i, (start, chunk0) in enumerate(chunks):
        # slice all channels for this chunk
        chunk_slice = full_audio[:, start:start+win]
        # If chunk shorter than window, pad with zeros (keeps exact sample-rate)
        if chunk_slice.shape[1] < win:
            pad_width = win - chunk_slice.shape[1]
            chunk_slice = np.pad(chunk_slice, ((0,0),(0,pad_width)), mode='constant')

        # demucs apply_model expects a tensor shape (1, channels, samples) (mono/stereo handled)
        with torch.no_grad():
            wav_tensor = torch.from_numpy(chunk_slice).unsqueeze(0)  # (1, channels, samples)
            wav_tensor = wav_tensor.to(next(model.parameters()).device)
            # apply_model returns a dict with 'sources' array shape (batch, sources, channels, samples)
            est = apply_model(model, wav_tensor, device=next(model.parameters()).device, split=False, overlap=0)
            # move to CPU numpy
            est_np = est.cpu().numpy()[0]  # shape (sources, channels, samples)
            # If demucs returns (sources, channels, samples) and we want per-source mixed channels,
            # we can average channels or keep all channels; we'll mix channels to original channel count.
            # Here we will mix channels down to mono if original was mono, otherwise keep stereo by averaging across source channels.
            # For simplicity, sum across source channels to produce one signal per source (mono); to preserve stereo,
            # one could write separate stereo files by keeping the two channels.
            # We'll collapse source channels by averaging across the channel axis:
            est_mono = est_np.mean(axis=1)  # shape (sources, samples)
            separated_chunks.append(est_mono.astype(np.float32))

        print(f"Processed chunk {i+1}/{len(chunks)} (start={start})")

    # overlap-add to reconstruct full-length separated sources (mono per source)
    segments = separated_chunks
    full_length = full_audio.shape[1]
    reconstructed_sources = overlap_add_segments(segments, full_length, step)

    # write each source to WAV (preserve sample rate, 32-bit float)
    out_paths = []
    for idx, src in enumerate(reconstructed_sources, start=1):
        outp = os.path.join(OUTPUT_DIR, f"{Path(INPUT_FILE).stem}_demucs_component_{idx}.wav")
        sf.write(outp, src, full_sr, subtype='FLOAT')
        out_paths.append(outp)
        print("Wrote:", outp)

    print("Demucs separation done. Outputs:", out_paths)

In [None]:
# -----------------------------------------------------------------------------
# 3b. Alternative: Spleeter (fast, but lower quality for some mixes)
# Explanation:
#  - Spleeter is faster and good for common music separations (2/4/5 stems).
#  - Use it if you need quick stems without GPU. It operates with fixed stem labels (vocals, drums, bass, other).
# -----------------------------------------------------------------------------
if not USE_DEMUCS:
    # Example: 4-stem separation (vocals, drums, bass, other)
    from spleeter.separator import Separator
    separator = Separator('spleeter:4stems')  # 2stems/4stems/5stems available
    out_dir = os.path.join(OUTPUT_DIR, "spleeter_out")
    os.makedirs(out_dir, exist_ok=True)
    separator.separate_to_file(wav_path, out_dir)
    print("Spleeter separation completed; check directory:", out_dir)


In [None]:
# -----------------------------------------------------------------------------
# 4. Visualize: overlayed waveforms and spectrogram of separated outputs
# Explanation:
#  - Plot normalized overlayed waveforms so you can visually compare sources without amplitude differences hiding detail.
#  - Also plot spectrograms for each source to inspect frequency/time content.
# -----------------------------------------------------------------------------
import librosa, librosa.display
import matplotlib.pyplot as plt

# Gather output files from the chosen method:
if USE_DEMUCS:
    out_files = out_paths
else:
    # gather spleeter outputs
    out_files = [str(p) for p in Path(OUTPUT_DIR).glob("**/*.wav")]

# load mixture for overlay plot:
mixture, sr = sf.read(wav_path, dtype='float32')
if mixture.ndim > 1:
    mixture_mono = mixture.mean(axis=1)
else:
    mixture_mono = mixture

# plot overlay
plt.figure(figsize=(14, 5))
t = np.arange(len(mixture_mono))/sr
plt.plot(t, mixture_mono / (np.max(np.abs(mixture_mono)) + 1e-9), alpha=0.25, linewidth=0.7, label="mixture (normalized)")
for i, fpath in enumerate(out_files):
    sig, _ = sf.read(fpath, dtype='float32')
    if sig.ndim > 1:
        sig = sig.mean(axis=1)
    # align length
    sig = sig[:len(mixture_mono)]
    sign = sig / (np.max(np.abs(sig)) + 1e-9)
    plt.plot(t, sign + 0.0, linewidth=0.9, label=f"component {i+1}")
plt.xlim(0, len(mixture_mono)/sr)
plt.xlabel("Time (s)")
plt.ylabel("Normalized amplitude")
plt.legend(loc='upper right')
plt.title("Overlayed waveforms (normalized) — mixture + separated components")
plt.show()

# spectrograms
n = len(out_files)
cols = 2
rows = math.ceil((n+1)/cols)
plt.figure(figsize=(14, 3*rows))
# mixture spectrogram
plt.subplot(rows, cols, 1)
D = np.abs(librosa.stft(mixture_mono, n_fft=2048, hop_length=512))
librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), sr=sr, x_axis='time', y_axis='log')
plt.title("Mixture spectrogram")
plt.colorbar(format="%+2.0f dB")

for i, fpath in enumerate(out_files, start=1):
    plt.subplot(rows, cols, i+1)
    sig, _ = sf.read(fpath, dtype='float32')
    if sig.ndim > 1:
        sig = sig.mean(axis=1)
    D = np.abs(librosa.stft(sig, n_fft=2048, hop_length=512))
    librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), sr=sr, x_axis='time', y_axis='log')
    plt.title(f"Component {i} spectrogram")
    plt.colorbar(format="%+2.0f dB")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------------------------------------------------------
# 5. (Optional) Auto-label components by simple heuristics
# Explanation:
#  - Use spectral centroid and energy fraction to guess if a stem is voice / music / hum / noise.
#  - Heuristics are imperfect; consider manual review for critical labeling.
# -----------------------------------------------------------------------------
def guess_label(signal, sr):
    # compute spectral centroid and RMS energy
    import librosa
    if signal.ndim > 1:
        signal = signal.mean(axis=1)
    centroid = librosa.feature.spectral_centroid(y=signal, sr=sr).mean()
    rms = librosa.feature.rms(y=signal).mean()
    # low-frequency hum detection: centroid very low + narrow band energy
    if centroid < 200 and rms > 1e-5:
        return "low-hum/low-frequency noise"
    # speech heuristic: mid centroid and strong fluctuations (energy variance)
    if 200 < centroid < 3000 and rms < 0.02:
        return "possible speech/voice"
    # music: broader bandwidth and higher centroid
    if centroid >= 3000:
        return "music/bright content"
    return "unknown"

for f in out_files:
    s, _ = sf.read(f, dtype='float32')
    if s.ndim > 1:
        s_m = s.mean(axis=1)
    else:
        s_m = s
    label = guess_label(s_m, sr)
    print(Path(f).name, "->", label)


In [None]:
# -----------------------------------------------------------------------------
# 6. Packaging outputs
# -----------------------------------------------------------------------------
# Create a zip with all separated stems and the plots (if you saved plots).
shutil.make_archive(OUTPUT_DIR, 'zip', OUTPUT_DIR)
print("Created archive:", OUTPUT_DIR + ".zip")
