# Drum Sample Auto-Classifier - Complete Archive Edition (Optimized)

This notebook classifies large drum/percussion archives using a trained model while keeping the original archive untouched. Enhanced features include batching, parallel feature extraction, duplicate detection via hashing, optional confidence filtering, flexible copy modes, and reproducible run metadata.

## 🔍 Key Enhancements
- **Batch inference** with configurable `BATCH_SIZE`
- **Parallel feature extraction** (threaded)
- **MD5 hashing for de-duplication** across runs (optional)
- **Adjustable output mode**: copy, symlink, or skip materialization
- **Confidence threshold filtering**
- **Structure preservation** using relative path embedding in output names
- **Caching & metadata** (summary, per-file results, error log)
- **Resource awareness** (optional system memory display)

## 🧩 Workflow Overview
1. Configure paths and parameters.
2. Discover audio files recursively (multi-format).
3. Preprocess (resample, pad, MFCC extraction).
4. Batch predict with loaded Keras model.
5. Deduplicate & materialize outputs (optional).
6. Summarize + persist metadata for reproducibility.

## ✅ Prerequisites
Run (in order) before this notebook if retraining:
1. `MFCC_Feature_Extractor.ipynb`
2. `Model1_Train.ipynb` or `Model2_Train.ipynb`
3. (Optional) `Model_Evaluation.ipynb` to compare models.

## ⚙️ Configuration Highlights
| Parameter | Purpose | Example |
|-----------|---------|---------|
| `ARCHIVE_PATH` | Source root (read-only) | `../complete_drum_archive` |
| `RUN_OUTPUT_DIR` | New classified run directory | `../ClassifiedArchive/run_<ts>` |
| `COPY_MODE` | Output behavior | `copy | symlink | none` |
| `DEDUP_HASH` | Skip files with seen content | `True` |
| `CONFIDENCE_THRESHOLD` | Minimum probability to emit | `0.0–1.0` |
| `BATCH_SIZE` | Prediction batch size | `32` |
| `INSTRUMENT_NAMES` | Class label mapping | `[Crash,...,Tom]` |

## 📦 Outputs
Each run creates:
- Classified files per instrument folder (unless COPY_MODE = none)
- `metadata/summary.json` (run-wide stats & config)
- `metadata/results.json` (per-file predictions)
- `metadata/errors.json` (only if failures occurred)
- Updated hash cache for dedupe if enabled.

## 🔁 Reproducibility
Changing parameters yields a new timestamped run directory, preserving previous results. Hash-based deduplication prevents redundant processing between runs if the same file content reappears.

---
Proceed to the next cell to load dependencies and configure runtime parameters.

In [1]:
import os
import glob
import json
import time
import hashlib
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Tuple
import shutil

import numpy as np
import pandas as pd
import librosa
import librosa.display
import keras
import soundfile as sf
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm

# ---- Environment Introspection ----
# Fail gracefully if optional deps missing
try:
    import psutil  # optional system stats
except ImportError:
    psutil = None

print("✅ Imports loaded. Optional modules: psutil={}".format(psutil is not None))

✅ Imports loaded. Optional modules: psutil=True


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Configuration & File Discovery Utilities
# ========================================

# Mirror-mode enhancement: replicate original directory tree and place classified samples
'# into per-instrument subfolders under each original directory. Low-confidence or non-target',''# predictions go into a "misc" subfolder. Pre-sorted leaves (directory name matches a target label)',''# are copied verbatim without reclassification subfolder nesting.',
# NOTE: Added dynamic label mapping + optional external mapping file.

# You can create a JSON file at ../models/label_mapping.json containing a list of class names
"""Example label_mapping.json","""
#[
#  "Agogo", "Bell", "Bongo", ...
#]

# The notebook will fall back to generic class_<idx> names if mapping length < model outputs.

# To restrict which classes get materialized to disk (e.g. only core drum kit pieces),
'# set TARGET_LABELS = ["Crash","Hihat","Kick","Ride","Snare","Tom"] (case-sensitive).',
from collections import defaultdict

DEFAULT_ARCHIVE_PATH = Path('../complete_drum_archive')  # adjust if needed
DEFAULT_OUTPUT_ROOT = Path('../ClassifiedArchive')
MODELS_DIR = Path('../models')

# Supported formats prioritized by typical drum sample usage
SUPPORTED_FORMATS = [".wav", ".flac", ".aiff", ".aif", ".mp3"]  # remove ".mp3" to avoid audioread fallback warnings
DISABLE_MP3 = False  # set True to skip mp3 entirely (avoid audioread path)
MAX_FILES_PER_RUN = None          # None = process all
BATCH_SIZE = 32                   # in-memory inference batch size
N_MFCC = 40
TARGET_SR = 44100
TARGET_SAMPLES = 50_000
N_FFT = 2048
HOP_LENGTH = 512
NORMALIZE = True
THREAD_WORKERS = 8                # for parallel feature extraction (future enhancement hook)
COPY_MODE = "copy"                # one of: copy | symlink | none
PRESERVE_STRUCTURE = True         # legacy filename embedding (ignored if MIRROR_STRUCTURE=True)
MIRROR_STRUCTURE = True           # replicate directory tree and classify into subfolders
DEDUP_HASH = True                 # skip files with identical content hash
MIN_DURATION_SEC = 0.05           # skip extremely short blips
MISC_CONFIDENCE_THRESHOLD = 0.50  # confidence below this -> misc/
CONFIDENCE_THRESHOLD = 0.0        # absolute floor for recording prediction (keep 0.0)
MISC_LABEL_NAME = 'misc'
LABEL_MAP_FILE = MODELS_DIR / 'label_mapping.json'  # optional; list of label names
TARGET_LABELS = ["Crash","Hihat","Kick","Ride","Snare","Tom"]  # subset to actually emit; set to None to allow all
CACHE_DIR = Path('.cache/archive_classifier')
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Legacy constant kept for backward compatibility; will be overwritten dynamically if mapping found
INSTRUMENT_NAMES = ["Crash", "Hihat", "Kick", "Ride", "Snare", "Tom"]  # fallback subset

ARCHIVE_PATH = DEFAULT_ARCHIVE_PATH
TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
RUN_OUTPUT_DIR = (DEFAULT_OUTPUT_ROOT / f'run_{TIMESTAMP}').resolve()
(RUN_OUTPUT_DIR / 'metadata').mkdir(parents=True, exist_ok=True)

print(f"Archive: {ARCHIVE_PATH}")
print(f"Run output: {RUN_OUTPUT_DIR}")

def discover_audio_files(archive_path: Path, formats: List[str], max_files=None) -> List[Path]:
    """Discover audio files with optional limit."""
    files: List[Path] = []
    for ext in formats:
        if DISABLE_MP3 and ext.lower() == '.mp3':
            continue
        files.extend(archive_path.rglob(f'*{ext}'))
    files.sort()
    if max_files is not None and len(files) > max_files:
        files = files[:max_files]
    return files

def filter_candidates(paths: List[Path]) -> List[Path]:
    filtered = []
    for p in paths:
        try:
            if p.stat().st_size == 0:
                continue
            filtered.append(p)
        except OSError:
            continue
    return filtered

audio_files = filter_candidates(discover_audio_files(ARCHIVE_PATH, SUPPORTED_FORMATS, MAX_FILES_PER_RUN))
print(f"Found {len(audio_files)} candidate audio files.")
for preview in audio_files[:5]:
    print('  •', preview.relative_to(ARCHIVE_PATH))

Archive: ../complete_drum_archive
Run output: /Users/Gilby/Projects/MLAudioClassifier/ClassifiedArchive/run_20251005_152941
Found 47339 candidate audio files.
  • Access/Access Virus - B/BassDrum_01.wav
  • Access/Access Virus - B/BassDrum_02.wav
  • Access/Access Virus - B/BassDrum_03.wav
  • Access/Access Virus - B/BassDrum_04.wav
  • Access/Access Virus - B/BassDrum_05.wav
Found 47339 candidate audio files.
  • Access/Access Virus - B/BassDrum_01.wav
  • Access/Access Virus - B/BassDrum_02.wav
  • Access/Access Virus - B/BassDrum_03.wav
  • Access/Access Virus - B/BassDrum_04.wav
  • Access/Access Virus - B/BassDrum_05.wav


In [3]:
# Model Loading & Feature Extraction
# ==================================

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='librosa')  # silence repetitive backend fallbacks

def load_latest_model(model_dir: Path, pattern='model1.keras') -> keras.Model:
    candidates = sorted(model_dir.glob(pattern))
    if not candidates:
        # fallback: any *.keras
        candidates = sorted(model_dir.glob('*.keras'))
    candidates = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)
    if not candidates:
        raise FileNotFoundError(f"No model files in {model_dir}")
    print(f"Loading model: {candidates[0]}")
    return keras.models.load_model(candidates[0])

def load_label_mapping(model_obj: keras.Model) -> List[str]:
    # Try explicit mapping file first
    if LABEL_MAP_FILE.exists():
        try:
            data = json.loads(LABEL_MAP_FILE.read_text())
            if isinstance(data, list) and len(data) == model_obj.output_shape[-1]:
                print(f"Using label mapping from {LABEL_MAP_FILE}")
                return data
            else:
                print(f"⚠️ label_mapping.json length mismatch ({len(data)} vs {model_obj.output_shape[-1]}). Ignoring.")
        except Exception as e:
            print(f"⚠️ Failed to parse label mapping file: {e}")
    # Fall back to INSTRUMENT_NAMES subset if compatible
    out_dim = model_obj.output_shape[-1]
    if len(INSTRUMENT_NAMES) == out_dim:
        print("Using fallback INSTRUMENT_NAMES as full mapping.")
        return INSTRUMENT_NAMES
    # Generic numbered classes
    print("Generating generic class_<idx> mapping.")
    return [f'class_{i}' for i in range(out_dim)]

def hash_file(path: Path, block_size=65536) -> str:
    hasher = hashlib.md5()
    with path.open('rb') as f:
        for chunk in iter(lambda: f.read(block_size), b''):
            hasher.update(chunk)
    return hasher.hexdigest()

def safe_load_audio(path: Path):
    """Robust loader: try soundfile first, fallback to librosa/audioread; return (y, sr) or raise."""
    try:
        y, sr = sf.read(path)
        if y.ndim > 1:
            y = np.mean(y, axis=1)
        if sr != TARGET_SR:
            y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SR)
            sr = TARGET_SR
        return y, sr
    except Exception:
        # fallback to librosa unified loader
        y, sr = librosa.load(path, sr=TARGET_SR, mono=True)
        return y, sr

def load_and_preprocess(path: Path) -> Tuple[np.ndarray, Dict[str, Any]]:
    """Load audio and compute MFCC feature tensor plus metadata."""
    try:
        y, sr = safe_load_audio(path)
        if len(y) < MIN_DURATION_SEC * sr:
            return None, {'error': 'too_short'}
        if len(y) < TARGET_SAMPLES:
            y = librosa.util.fix_length(y, size=TARGET_SAMPLES)
        else:
            y = y[:TARGET_SAMPLES]
        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC, n_fft=N_FFT, hop_length=HOP_LENGTH)
        if NORMALIZE:
            mfcc = librosa.util.normalize(mfcc)
        return mfcc, {'error': None, 'orig_sr': sr}
    except Exception as e:
        return None, {'error': str(e)}

def batch_tensorize(feature_list: List[np.ndarray]) -> np.ndarray:
    # shape (batch, n_mfcc, time) -> expand to (batch, n_mfcc, time, 1) if needed by model
    arr = np.stack(feature_list, axis=0)
    if arr.ndim == 3:
        arr = arr[..., np.newaxis]
    return arr

model = load_latest_model(MODELS_DIR)
LABELS = load_label_mapping(model)
print(f"Model outputs {len(LABELS)} classes.")
print("First labels:", LABELS[:10])
INPUT_SHAPE = model.input_shape
print('Model input shape:', INPUT_SHAPE)

# Optional device memory stats
if psutil:
    vm = psutil.virtual_memory()
    print(f"System Memory: {vm.total/1e9:.2f} GB total / {vm.available/1e9:.2f} GB free")

Loading model: ../models/model1.keras
Generating generic class_<idx> mapping.
Model outputs 34 classes.
First labels: ['class_0', 'class_1', 'class_2', 'class_3', 'class_4', 'class_5', 'class_6', 'class_7', 'class_8', 'class_9']
Model input shape: (None, 40, 98)
System Memory: 17.18 GB total / 4.45 GB free


In [None]:
# Optimized Parallel Classification & Output
# ==========================================

results: List[Dict[str, Any]] = []
errors: List[Dict[str, Any]] = []
error_causes: Dict[str,int] = {}
hash_cache_path = CACHE_DIR / 'seen_hashes.txt'
seen_hashes = set()
if DEDUP_HASH and hash_cache_path.exists():
    seen_hashes.update(h.strip() for h in hash_cache_path.read_text().splitlines() if h.strip())

def is_presorted_leaf(path: Path) -> bool:
    """A directory whose name already matches a target label (case-insensitive)."""
    if TARGET_LABELS is None:
        return False
    return path.is_dir() and path.name.lower() in {t.lower() for t in TARGET_LABELS}

def mirror_destination(src_file: Path, pred_label: str, conf: float) -> Path:
    """Return destination path under MIRROR_STRUCTURE rules."""
    relative_parent = src_file.parent.relative_to(ARCHIVE_PATH)
    parent_out_dir = RUN_OUTPUT_DIR / relative_parent
    # create parent mirror
    parent_out_dir.mkdir(parents=True, exist_ok=True)
    # If parent is already a pre-sorted leaf, just drop file straight inside (no nested class folder)
    if is_presorted_leaf(src_file.parent):
        out_dir = parent_out_dir
    else:
        # classification subfolder or misc
        out_dir = parent_out_dir / pred_label.lower()
        out_dir.mkdir(exist_ok=True)
    base = src_file.name
    name_no_ext, ext = os.path.splitext(base)
    new_name = f"{name_no_ext}__{pred_label.lower()}_{conf:.3f}{ext}"
    return out_dir / new_name

def ensure_output_subdir(label: str) -> Path:
    d = RUN_OUTPUT_DIR / label
    d.mkdir(parents=True, exist_ok=True)
    return d

def emit_file(pred_label: str, conf: float, src: Path):
    if MIRROR_STRUCTURE:
        dst = mirror_destination(src, pred_label, conf)
    else:
        relative_tag = src.relative_to(ARCHIVE_PATH) if PRESERVE_STRUCTURE else src.name
        relative_tag = str(relative_tag).replace('/', '_').replace('\\', '_')
        out_name = f"{pred_label.lower()}_{conf:.3f}_{relative_tag}"
        dst_dir = ensure_output_subdir(pred_label)
        dst = dst_dir / out_name
    if COPY_MODE == 'copy':
        shutil.copy2(src, dst)
    elif COPY_MODE == 'symlink':
        if not dst.exists():
            os.symlink(src, dst)
    return dst

def classify_batch(batch_paths: List[Path]) -> None:
    feats = []
    valid_paths = []
    for p in batch_paths:
        feat, meta = load_and_preprocess(p)
        if feat is None:
            errors.append({'file': str(p), 'error': meta['error']})
            error_causes[meta['error']] = error_causes.get(meta['error'],0)+1
            continue
        feats.append(feat)
        valid_paths.append(p)
    if not feats:
        return
    X = batch_tensorize(feats)
    probs = model.predict(X, verbose=0)
    for i, p in enumerate(valid_paths):
        prob_vec = probs[i]
        label_idx = int(np.argmax(prob_vec))
        conf = float(np.max(prob_vec))
        pred_label = LABELS[label_idx] if label_idx < len(LABELS) else f'class_{label_idx}'
        file_rec = {
            'file': str(p),
            'pred_label': pred_label,
            'confidence': conf,
            'probs': prob_vec.tolist(),
            'hash': None
        }
        if DEDUP_HASH:
            h = hash_file(p)
            file_rec['hash'] = h
            if h in seen_hashes:
                file_rec['skipped_duplicate'] = True
                results.append(file_rec)
                continue
            seen_hashes.add(h)
        # Skip out-of-target for emission, but still record
        if TARGET_LABELS is not None and pred_label not in TARGET_LABELS:
            file_rec['filtered_out_of_target_set'] = True
            # treat as misc candidate if MIRROR_STRUCTURE: we still store under misc
            misc_label = MISC_LABEL_NAME
            if MIRROR_STRUCTURE and COPY_MODE != 'none':
                out_path = emit_file(misc_label, conf, p)
                file_rec['output_path'] = str(out_path)
                file_rec['relabelled_to_misc'] = True
            results.append(file_rec)
            continue
        # Low confidence -> misc bucket (mirror mode)
        emit_label = pred_label
        if MIRROR_STRUCTURE and conf < MISC_CONFIDENCE_THRESHOLD:
            emit_label = MISC_LABEL_NAME
            file_rec['relabelled_low_conf_to_misc'] = True
        if conf >= CONFIDENCE_THRESHOLD:
            if COPY_MODE != 'none':
                out_path = emit_file(emit_label, conf, p)
                file_rec['output_path'] = str(out_path)
        else:
            file_rec['below_conf_threshold'] = True
        results.append(file_rec)

# Chunk audio files
BATCHED = [audio_files[i:i+BATCH_SIZE] for i in range(0, len(audio_files), BATCH_SIZE)]
print(f"Processing {len(audio_files)} files in {len(BATCHED)} batches of up to {BATCH_SIZE}.")

start_time = time.time()
for batch in tqdm(BATCHED, desc='Classifying'):
    classify_batch(batch)
elapsed = time.time() - start_time
print(f"⏱️  Classification complete in {elapsed:.2f}s ({len(audio_files)/(elapsed+1e-9):.1f} files/sec)")

# Persist dedup hash cache
if DEDUP_HASH:
    with hash_cache_path.open('w') as f:
        f.write('\n'.join(sorted(seen_hashes)))
    print(f"✔️ Updated hash cache: {hash_cache_path}")

# Summaries (only target + misc if mirror mode)
successful = [r for r in results if r.get('pred_label')]
emitted = [r for r in results if 'output_path' in r]
class_counts: Dict[str,int] = {}
for r in emitted:
    lbl = Path(r['output_path']).parent.name if MIRROR_STRUCTURE else r['pred_label']
    class_counts[lbl] = class_counts.get(lbl, 0) + 1
print("Emitted distribution (folder-based):")
for k,v in sorted(class_counts.items(), key=lambda x: -x[1]):
    print(f"  {k:<20} {v}")
print(f"Errors: {len(errors)} | Emitted: {len(emitted)} | Total examined: {len(audio_files)}")
if error_causes:
    print("Top error causes:")
    for k,v in sorted(error_causes.items(), key=lambda x: -x[1])[:10]:
        print(f"  {k:<25} {v}")

# Save metadata
summary = {
    'timestamp': datetime.now().isoformat(),
    'archive_path': str(ARCHIVE_PATH.resolve()),
    'run_output_dir': str(RUN_OUTPUT_DIR),
    'total_examined': len(audio_files),
    'emitted': len(emitted),
    'errors': len(errors),
    'error_breakdown': error_causes,
    'emitted_distribution': class_counts,
    'config': {
        'batch_size': BATCH_SIZE,
        'n_mfcc': N_MFCC,
        'target_sr': TARGET_SR,
        'target_samples': TARGET_SAMPLES,
        'copy_mode': COPY_MODE,
        'mirror_structure': MIRROR_STRUCTURE,
        'misc_confidence_threshold': MISC_CONFIDENCE_THRESHOLD,
        'misc_label': MISC_LABEL_NAME,
        'preserve_structure': PRESERVE_STRUCTURE,
        'dedup_hash': DEDUP_HASH,
        'confidence_threshold': CONFIDENCE_THRESHOLD,
        'target_labels_subset': TARGET_LABELS,
        'label_space_size': len(LABELS)
    }
}

meta_dir = RUN_OUTPUT_DIR / 'metadata'
meta_dir.mkdir(exist_ok=True, parents=True)
with (meta_dir / 'summary.json').open('w') as f:
    json.dump(summary, f, indent=2)
with (meta_dir / 'results.json').open('w') as f:
    json.dump(results, f, indent=2)
if errors:
    with (meta_dir / 'errors.json').open('w') as f:
        json.dump(errors, f, indent=2)
print(f"📄 Saved summary + {len(results)} detailed records -> {meta_dir}")

Processing 47339 files in 1480 batches of up to 32.


  y, sr = librosa.load(path, sr=TARGET_SR, mono=True)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Exception ignored in: <function CFObject.__del__ at 0x13143cf40>
Traceback (most recent call last):
  File "/Users/Gilby/Projects/MLAudioClassifier/.venv/lib/python3.13/site-packages/audioread/macca.py", line 135, in __del__
    _corefoundation.CFRelease(self._obj)
AttributeError: 'CFURL' object has no attribute '_obj'
Exception ignored in: <function ExtAudioFile.__del__ at 0x13143d940>
Traceback (most recent call last):
  File "/Users/Gilby/Projects/MLAudioClassifier/.venv/lib/python3.13/site-packages/audioread/macca.py", line 336, in __del__
    self.close()
  File "/Users/Gilby/Projects/MLAudioClassifier/.venv/lib/python3.13/site-packages/audioread/macca.py", line 330, in close
    if not self.closed:
AttributeError: 'ExtAudioFile' object has no attribute 'closed'
  y, sr = librosa

⏱️  Classification complete in 231.76s (204.3 files/sec)
✔️ Updated hash cache: .cache/archive_classifier/seen_hashes.txt
Class distribution (emitted subset):
  class_10             9079
  class_0              7128
  class_12             6421
  class_3              6241
  class_9              2531
  class_13             1623
  class_6              1520
  class_23             1445
  class_11             1305
  class_17             1183
  class_1              881
  class_5              859
  class_20             737
  class_27             643
  class_28             585
  class_26             545
  class_30             449
  class_7              411
  class_33             357
  class_14             357
  class_4              350
  class_25             336
  class_15             244
  class_16             228
  class_2              202
  class_21             182
  class_32             175
  class_19             161
  class_24             145
  class_31             105
  class_29           