**Description of Code, this code is AI-generated**
Demucs audio seperation was the main bottlneck of this project, it required a lot of GPU hours-, this code takes the raw audio clips and separates them into vocals and non vocals stem and zips them out to GDrive. It includes heavy multiprocessing, multiple batches processsing at the same time, this is the code for negative splitting a similar code was used for the positive clips.


In [None]:
pip install -U demucs torchaudio

In [None]:
*

# ================= üì¶ IMPORTS =================
import os
import shutil
import subprocess
import glob
import zipfile
from tqdm import tqdm
import time
import json
from datetime import datetime
import sys
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

# ================= ‚öôÔ∏è CONFIGURATION =================

INPUT_BASE_FOLDER = "/content/drive/MyDrive/negative_dataset_48k"
OUTPUT_BASE_FOLDER = "/content/drive/MyDrive/separated_audio_negative"
OUTPUT_BATCH_SIZE = 1000
DEMUCS_BATCH_SIZE = 1000
NUM_PARALLEL_BATCHES = 10
NUM_JOBS = 2

# Local folders
LOCAL_STAGING_BASE = "/content/staging_negative"
CHECKPOINT_FILE = None

# Model settings
MODEL_NAME = "htdemucs"
DEVICE = "cuda"

# Game folder mapping
GAME_PROCESSING_ORDER = ["Valorant", "CS2", "Apex"]
GAME_FOLDERS = {
    "valorant_set": "Valorant",
    "cs2_set": "CS2",
    "apex_set": "Apex"
}

# Timeout
DEMUCS_TIMEOUT_PER_FILE = 30

# Lock for thread-safe operations
checkpoint_lock = threading.Lock()
staging_lock = threading.Lock()

# Known error patterns
BAD_FILE_ERRORS = [
    "AssertionError", "pad1d", "RuntimeError", "Invalid data",
    "corrupted", "Could not find a format", "Invalid audio",
    "Error opening", "No such file",
]

# ================= üîß SETUP =================
def setup_colab():
    print("üîå Setting up...")
    if not os.path.exists("/content/drive"):
        from google.colab import drive
        drive.mount('/content/drive')

    if not os.path.exists(INPUT_BASE_FOLDER):
        print(f"‚ùå ERROR: Input folder not found: {INPUT_BASE_FOLDER}")
        return False

    global CHECKPOINT_FILE
    CHECKPOINT_FILE = os.path.join(OUTPUT_BASE_FOLDER, "negative_parallel_checkpoint.json")
    print("‚úÖ Ready!")
    return True

def verify_gpu():
    result = subprocess.run([sys.executable, "-c", """
import torch
if torch.cuda.is_available():
    print(f"OK:{torch.cuda.get_device_name(0)}:{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}")
else:
    print("NO_GPU")
"""], capture_output=True, text=True)

    if "OK:" in result.stdout:
        parts = result.stdout.strip().split(":")
        print(f"‚úÖ GPU: {parts[1]} ({parts[2]} GB)")
        print(f"‚ö° Running {NUM_PARALLEL_BATCHES} parallel batches!")
        return True
    print("‚ö†Ô∏è No GPU - using CPU")
    return False

def setup_folders():
    os.makedirs(OUTPUT_BASE_FOLDER, exist_ok=True)
    for game in GAME_PROCESSING_ORDER:
        os.makedirs(os.path.join(OUTPUT_BASE_FOLDER, f"{game}_Separated"), exist_ok=True)

# ================= üíæ CHECKPOINT =================
def load_checkpoint():
    if CHECKPOINT_FILE and os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, 'r') as f:
                cp = json.load(f)
            if isinstance(cp.get("processed_clips"), list):
                cp["processed_clips"] = set(cp["processed_clips"])
            if isinstance(cp.get("bad_files"), list):
                cp["bad_files"] = set(cp["bad_files"])
            print(f"üì• Checkpoint: {cp.get('total_processed', 0)} done, {len(cp.get('bad_files', []))} bad")
            return cp
        except:
            pass
    return {
        "processed_zips": [],
        "processed_clips": set(),
        "bad_files": set(),
        "batch_numbers": {"Valorant": 1, "CS2": 1, "Apex": 1},
        "total_processed": 0,
        "total_skipped": 0,
    }

def save_checkpoint(cp):
    if not CHECKPOINT_FILE:
        return
    with checkpoint_lock:
        cp_save = cp.copy()
        cp_save["processed_clips"] = list(cp.get("processed_clips", set()))
        cp_save["bad_files"] = list(cp.get("bad_files", set()))
        cp_save["last_save"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        try:
            os.makedirs(os.path.dirname(CHECKPOINT_FILE), exist_ok=True)
            with open(CHECKPOINT_FILE, 'w') as f:
                json.dump(cp_save, f)
        except:
            pass

# ================= üöÄ PARALLEL BATCH PROCESSING =================
def process_single_batch(batch_id, file_list, staging_folder, pbar):
    """Process a single batch - called in parallel."""
    if not file_list:
        return [], [], batch_id

    temp_output = f"/content/demucs_p{batch_id}_{int(time.time())}"
    timeout = DEMUCS_TIMEOUT_PER_FILE * len(file_list) + 120

    cmd = [
        sys.executable, "-m", "demucs",
        "-d", DEVICE,
        "-j", str(NUM_JOBS),
        "--two-stems=vocals",
        "-n", MODEL_NAME,
        "--mp3", "--mp3-bitrate", "320",
        "-o", temp_output
    ]
    cmd.extend(file_list)

    successful = []
    bad_files = []

    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)

        if result.returncode == 0:
            # Collect outputs
            for file_path in file_list:
                name = os.path.splitext(os.path.basename(file_path))[0]
                vocals = os.path.join(temp_output, MODEL_NAME, name, "vocals.mp3")
                no_vocals = os.path.join(temp_output, MODEL_NAME, name, "no_vocals.mp3")

                if os.path.exists(vocals) and os.path.exists(no_vocals):
                    with staging_lock:
                        clip_folder = os.path.join(staging_folder, name)
                        os.makedirs(clip_folder, exist_ok=True)
                        shutil.copy2(vocals, os.path.join(clip_folder, "vocals.mp3"))
                        shutil.copy2(no_vocals, os.path.join(clip_folder, "no_vocals.mp3"))
                    successful.append(name)
                    pbar.update(1)
                else:
                    bad_files.append(name)
                    pbar.update(1)
        else:
            # Check for bad file errors
            stderr = result.stderr or ""
            if any(err.lower() in stderr.lower() for err in BAD_FILE_ERRORS):
                # Process one by one to find bad file
                for file_path in file_list:
                    name = os.path.splitext(os.path.basename(file_path))[0]
                    single_out = f"/content/demucs_single_{batch_id}_{int(time.time())}"

                    single_cmd = [
                        sys.executable, "-m", "demucs",
                        "-d", DEVICE, "-j", "1",
                        "--two-stems=vocals", "-n", MODEL_NAME,
                        "--mp3", "-o", single_out, file_path
                    ]

                    try:
                        single_result = subprocess.run(single_cmd, capture_output=True,
                                                       text=True, timeout=60)
                        if single_result.returncode == 0:
                            vocals = os.path.join(single_out, MODEL_NAME, name, "vocals.mp3")
                            no_vocals = os.path.join(single_out, MODEL_NAME, name, "no_vocals.mp3")
                            if os.path.exists(vocals) and os.path.exists(no_vocals):
                                with staging_lock:
                                    clip_folder = os.path.join(staging_folder, name)
                                    os.makedirs(clip_folder, exist_ok=True)
                                    shutil.copy2(vocals, os.path.join(clip_folder, "vocals.mp3"))
                                    shutil.copy2(no_vocals, os.path.join(clip_folder, "no_vocals.mp3"))
                                successful.append(name)
                            else:
                                bad_files.append(name)
                        else:
                            bad_files.append(name)
                    except:
                        bad_files.append(name)
                    finally:
                        if os.path.exists(single_out):
                            shutil.rmtree(single_out, ignore_errors=True)
                    pbar.update(1)
            else:
                for f in file_list:
                    bad_files.append(os.path.splitext(os.path.basename(f))[0])
                    pbar.update(1)

    except subprocess.TimeoutExpired:
        for f in file_list:
            bad_files.append(os.path.splitext(os.path.basename(f))[0])
            pbar.update(1)
    except Exception as e:
        for f in file_list:
            bad_files.append(os.path.splitext(os.path.basename(f))[0])
            pbar.update(1)
    finally:
        if os.path.exists(temp_output):
            shutil.rmtree(temp_output, ignore_errors=True)

    return successful, bad_files, batch_id


def process_parallel_batches(file_list, staging_folder, checkpoint):
    """Process multiple batches in parallel using ThreadPoolExecutor."""
    if not file_list:
        return [], []

    all_successful = []
    all_bad = []

    # Split into chunks for parallel processing
    chunks = []
    for i in range(0, len(file_list), DEMUCS_BATCH_SIZE):
        chunks.append(file_list[i:i + DEMUCS_BATCH_SIZE])

    # Create progress bar for all files in super-batch
    with tqdm(total=len(file_list), desc="   ‚ö° Processing", leave=False,
              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:

        # Process chunks in parallel
        with ThreadPoolExecutor(max_workers=NUM_PARALLEL_BATCHES) as executor:
            futures = {}

            for idx, chunk in enumerate(chunks):
                future = executor.submit(process_single_batch, idx, chunk, staging_folder, pbar)
                futures[future] = idx

            for future in as_completed(futures):
                try:
                    successful, bad_files, batch_id = future.result()
                    all_successful.extend(successful)
                    all_bad.extend(bad_files)

                    # Update checkpoint with bad files
                    for bad in bad_files:
                        checkpoint["bad_files"].add(bad)

                except Exception as e:
                    tqdm.write(f"   ‚ùå Batch error: {e}")

    return all_successful, all_bad

    return all_successful, all_bad


def count_staged(staging_folder):
    if not os.path.exists(staging_folder):
        return 0
    count = 0
    for item in os.listdir(staging_folder):
        p = os.path.join(staging_folder, item)
        if os.path.isdir(p) and os.path.exists(os.path.join(p, "vocals.mp3")):
            count += 1
    return count


def flush_to_drive(game, batch_num, staging_folder, checkpoint):
    if not os.path.exists(staging_folder):
        return batch_num

    clips = [d for d in os.listdir(staging_folder)
             if os.path.isdir(os.path.join(staging_folder, d)) and
             os.path.exists(os.path.join(staging_folder, d, "vocals.mp3"))]

    if not clips:
        return batch_num

    print(f"\nüì§ [{game}] Uploading batch {batch_num} ({len(clips)} clips)...")

    game_output = os.path.join(OUTPUT_BASE_FOLDER, f"{game}_Separated")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    zip_path = os.path.join(game_output, f"{game}_Negative_Batch_{batch_num}_{len(clips)}clips_{timestamp}.zip")

    try:
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for clip in clips:
                clip_path = os.path.join(staging_folder, clip)
                zf.write(os.path.join(clip_path, "vocals.mp3"), f"{clip}/vocals.mp3")
                zf.write(os.path.join(clip_path, "no_vocals.mp3"), f"{clip}/no_vocals.mp3")
                checkpoint["processed_clips"].add(clip)

        size_mb = os.path.getsize(zip_path) / (1024**2)
        print(f"   ‚úÖ Saved: {os.path.basename(zip_path)} ({size_mb:.1f} MB)")

        shutil.rmtree(staging_folder)
        os.makedirs(staging_folder)

        checkpoint["batch_numbers"][game] = batch_num + 1
        save_checkpoint(checkpoint)
        gc.collect()

        return batch_num + 1
    except Exception as e:
        print(f"   ‚ùå Error: {e}")
        return batch_num


# ================= üéÆ MAIN =================
def run_pipeline():
    print("=" * 70)
    print("üöÄ‚ö° ULTRA-FAST PARALLEL DEMUCS PIPELINE")
    print("=" * 70)
    print(f"‚è∞ Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"‚ö° {NUM_PARALLEL_BATCHES} parallel batches √ó {DEMUCS_BATCH_SIZE} files = {NUM_PARALLEL_BATCHES * DEMUCS_BATCH_SIZE} files at once!")
    print(f"üéØ Expected speed: ~0.8-1.2 sec per file (5x faster!)")
    print("=" * 70)

    if not setup_colab():
        return

    global DEVICE
    if not verify_gpu():
        DEVICE = "cpu"

    setup_folders()
    checkpoint = load_checkpoint()

    if not isinstance(checkpoint.get("processed_clips"), set):
        checkpoint["processed_clips"] = set(checkpoint.get("processed_clips", []))
    if not isinstance(checkpoint.get("bad_files"), set):
        checkpoint["bad_files"] = set(checkpoint.get("bad_files", []))

    os.makedirs(LOCAL_STAGING_BASE, exist_ok=True)

    # Scan game folders
    game_zips = {g: [] for g in GAME_PROCESSING_ORDER}

    for folder_name, game_name in GAME_FOLDERS.items():
        folder_path = os.path.join(INPUT_BASE_FOLDER, folder_name)
        if os.path.exists(folder_path):
            zips = sorted(glob.glob(os.path.join(folder_path, "*.zip")))
            game_zips[game_name] = zips
            print(f"üì¶ {game_name}: {len(zips)} zips")

    print("=" * 70)

    total_processed = checkpoint.get("total_processed", 0)
    total_skipped = checkpoint.get("total_skipped", 0)

    for game in GAME_PROCESSING_ORDER:
        zips = game_zips[game]
        if not zips:
            continue

        print(f"\n{'=' * 70}")
        print(f"üéÆ {game.upper()} ({len(zips)} zips)")
        print(f"{'=' * 70}")

        staging = os.path.join(LOCAL_STAGING_BASE, game)
        os.makedirs(staging, exist_ok=True)

        batch_num = checkpoint["batch_numbers"].get(game, 1)

        for src_zip in tqdm(zips, desc=f"[{game}]"):
            zip_name = os.path.basename(src_zip)

            if src_zip in checkpoint.get("processed_zips", []):
                continue

            extract_dir = f"/content/temp_neg_{game}"
            if os.path.exists(extract_dir):
                shutil.rmtree(extract_dir)
            os.makedirs(extract_dir)

            try:
                tqdm.write(f"\nüìÇ {zip_name}")
                with zipfile.ZipFile(src_zip, 'r') as zf:
                    zf.extractall(extract_dir)

                # Find audio files
                audio_files = []
                for root, _, files in os.walk(extract_dir):
                    for f in sorted(files):
                        if f.lower().endswith(('.m4a', '.mp3', '.wav', '.flac', '.ogg')):
                            audio_files.append(os.path.join(root, f))

                # Filter
                files_to_process = [
                    f for f in audio_files
                    if os.path.splitext(os.path.basename(f))[0] not in checkpoint["processed_clips"]
                    and os.path.splitext(os.path.basename(f))[0] not in checkpoint["bad_files"]
                ]

                skipped = len(audio_files) - len(files_to_process)
                tqdm.write(f"   üìÑ {len(audio_files)} files ({skipped} skipped)")

                if not files_to_process:
                    checkpoint["processed_zips"].append(src_zip)
                    continue

                # Process in parallel super-batches
                super_batch_size = DEMUCS_BATCH_SIZE * NUM_PARALLEL_BATCHES

                for i in range(0, len(files_to_process), super_batch_size):
                    super_batch = files_to_process[i:i + super_batch_size]
                    sb_idx = (i // super_batch_size) + 1
                    total_sb = (len(files_to_process) + super_batch_size - 1) // super_batch_size

                    tqdm.write(f"   üöÄ Super-batch {sb_idx}/{total_sb} ({len(super_batch)} files in {NUM_PARALLEL_BATCHES} parallel)")

                    start_time = time.time()
                    successful, bad = process_parallel_batches(super_batch, staging, checkpoint)
                    elapsed = time.time() - start_time

                    per_file = elapsed / len(super_batch) if super_batch else 0

                    total_processed += len(successful)
                    total_skipped += len(bad)

                    for clip in successful:
                        checkpoint["processed_clips"].add(clip)

                    tqdm.write(f"   ‚úÖ {len(successful)} done ({per_file:.2f}s/file)" +
                              (f", ‚ùå {len(bad)} bad" if bad else ""))

                    # Flush if staging is full
                    if count_staged(staging) >= OUTPUT_BATCH_SIZE:
                        batch_num = flush_to_drive(game, batch_num, staging, checkpoint)

                    checkpoint["total_processed"] = total_processed
                    checkpoint["total_skipped"] = total_skipped
                    save_checkpoint(checkpoint)

                checkpoint["processed_zips"].append(src_zip)
                save_checkpoint(checkpoint)

            except zipfile.BadZipFile:
                tqdm.write(f"   ‚ùå Corrupted zip")
            except Exception as e:
                tqdm.write(f"   ‚ùå Error: {e}")
            finally:
                if os.path.exists(extract_dir):
                    shutil.rmtree(extract_dir, ignore_errors=True)

        if count_staged(staging) > 0:
            flush_to_drive(game, batch_num, staging, checkpoint)

        print(f"‚úÖ [{game}] Done!")

    print("\n" + "=" * 70)
    print("üéâ COMPLETE!")
    print("=" * 70)
    print(f"‚úÖ Processed: {total_processed}")
    print(f"‚ùå Skipped: {total_skipped}")
    print(f"‚è∞ Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    for g in GAME_PROCESSING_ORDER:
        out = os.path.join(OUTPUT_BASE_FOLDER, f"{g}_Separated")
        if os.path.exists(out):
            zips = glob.glob(os.path.join(out, "*.zip"))
            mb = sum(os.path.getsize(z) for z in zips) / (1024**2)
            print(f"   üéÆ {g}: {len(zips)} batches ({mb:.1f} MB)")

# ================= üèÅ RUN =================
if __name__ == "__main__":
    run_pipeline()