# Inverse Drum Machine: Separation and Synthesis Demo

This notebook demonstrates how to use the models from the Inverse Drum Machine project to perform drum separation. It allows you to:

1.  **Load** different pre-trained models (including our method and baselines).
2.  **Process** tracks from the StemGMD dataset.
3.  **Apply** either direct synthesis or source separation via masking.
4.  **Save** the separated stems as audio files for evaluation and listening.

This code was used to generate the demos from the demo page at [https://bernardo-torres.github.io/projects/inverse-drum-machine/](https://bernardo-torres.github.io/projects/inverse-drum-machine/).

## Prerequisites

Before running this notebook, please ensure you have:

1.  **Installed Dependencies**: Make sure all required packages are installed.
2.  **Downloaded the Dataset**: This demo uses the test tracks from the **StemGMD** dataset. Please download it and place it in the `data/StemGMD` directory.
3.  **Downloaded Pre-trained Models**: To use our trained model, download the checkpoint weights and place them in the appropriate `logs/` directory.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
import librosa
import soundfile as sf
import copy

from idm.baselines.larsnet.larsnet import LarsNet
from idm.utils import get_normalizing_function
from idm.synthesis_conditioning.peak_picking import PeakPicking
from idm.utils import cpu_numpy
from idm.data.dataset import with_audio_settings
from idm.data.datamodule import get_dataset
from idm import eval_class_mapping, drum_kit_map_stemgmd
from idm.feature_extractor.stft import STFT
from idm.inference import estimate_masks, load_model, ROOT_PATH

inverse_kit_map = {v: k for k, v in drum_kit_map_stemgmd.items()}

### Quick Test
With the following code we should be able to load a model and the dataset, and run inference on a single track. 

In [3]:

# MODEL_IDENTIFIER = 'gt'
# MODEL_IDENTIFIER = 'oracle'
# MODEL_IDENTIFIER = 'larsnet_stereo'
# MODEL_IDENTIFIER = 'larsnet_mono'
# MODEL_IDENTIFIER = 'nmfd_case1a'
MODEL_IDENTIFIER = 'idm-44-train-kits' 
# DATASET_SPLIT = 'test_test_kits'
# DATASET_SPLIT = 'test_train_kits'
DATASET_SPLIT = 'eval_session_train_kits'


OUTPUT_DIR = Path("demo/separation_outputs")
EVAL_SR = 44100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, name = load_model(MODEL_IDENTIFIER, DEVICE)
model.eval()
model.to(DEVICE)
model_sr = model.sampling_rate if hasattr(model, "sampling_rate") else EVAL_SR

print(f"Loading dataset split: {DATASET_SPLIT}")
dataset = get_dataset(
    sample_rate_target=EVAL_SR,
    root_path=ROOT_PATH,
    dataset_root=ROOT_PATH / "data" / "StemGMD",
    dataset_split=DATASET_SPLIT,
    version="full",
    normalize=False, # Normalize per-track later
)
class_names = dataset.all_possible_classes
print(f"Dataset has {len(dataset)} tracks and {len(class_names)} classes: {class_names}")
print(dataset[0]["mix"][:4*model_sr].shape)
model(dataset[0]["mix"][:4*model_sr].to(DEVICE)[None, None, :])  

Loading model: idm-44-train-kits
Found checkpoint at: logs/idm-44-train-kits/checkpoints/val-epoch=518-global_step=0.ckpt




Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']
torch.Size([176400])


{'stems': tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 7.1568e-11,  2.0448e-11,  1.4825e-10,  ..., -4.4564e-04,
            9.4507e-05,  8.0401e-04],
          [ 2.8438e-10, -1.2797e-10, -1.7063e-10,  ..., -1.8026e-04,
           -3.0684e-04, -3.7834e-04],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]], device='cuda:0', grad_fn=<MulBackward0>),
 'output': tensor([[-1.0469e-08, -2.8224e-09, -1.4152e-08,  ..., -9.5351e-03,
          -7.5431e-03, -7.8883e-03]], device='cuda:0', grad_fn=<SumBackward1>),
 'samples': tensor([[[-1.1278e-02, -1.1294e-02, -5.3387e-03,  ...,  1.0655e-02,
            1.5195e-02,  1.6045e-02],
          [-2.

## Inference

The `run_all` function encapsulates the entire inference pipeline for multiple models. It iterates through selected tracks in the dataset, performs separation or synthesis, and saves the output audio to disk.
Control over which tracks to save 

In [4]:

normalizing_fn = get_normalizing_function("maxabs")
class_mapping = eval_class_mapping["5_class"]

synth_or_masked = "masked"  # Choose between "synth" or "masked"
alpha = 1.0  # Masking exponent

def run_all(DATASET_SPLIT, MODEL_IDENTIFIER, synth_or_masked="masked", track_indices=None):
    model, name = load_model(MODEL_IDENTIFIER, DEVICE)
    model_sr = model.sampling_rate if hasattr(model, "sampling_rate") else EVAL_SR

    print(f"Loading dataset split: {DATASET_SPLIT}")
    dataset = get_dataset(
        sample_rate_target=EVAL_SR,
        dataset_split=DATASET_SPLIT,
        version="full",
        root_path=ROOT_PATH,
        dataset_root=ROOT_PATH / "data" / "StemGMD",
        gt_sources_path=ROOT_PATH / "data" / "Stem_GMD_single_hits",
        normalize=False, # Normalize per-track later
    )
    class_names = dataset.all_possible_classes
    print(f"Dataset has {len(dataset)} tracks and {len(class_names)} classes: {class_names}")
    if track_indices is None:
        # Some default values for each split
        if DATASET_SPLIT == "test_test_kits":
            TRACK_INDICES = [108, 219, 240, 256, 265, 268]
        elif DATASET_SPLIT == "test_train_kits":
            TRACK_INDICES = [163, 164, 316, 373, 389, 393]
        else:
            TRACK_INDICES = [108, 219, 240, 256, 265, 268]
    else:
        TRACK_INDICES = track_indices

    with torch.no_grad():
        for track_idx in tqdm(TRACK_INDICES):
            batch = dataset[track_idx]
            unique_classes = batch["all_possible_classes"]
            unique_eval_classes = np.unique(list(class_mapping.values()))

            stems = {k: v.to(DEVICE) for k, v in batch["stems"].items()}
            eval_mix = sum(stems.values())
            input_mix = eval_mix

            with with_audio_settings(
            dataset, sample_rate_target=EVAL_SR, mono=model.mono if hasattr(model, "mono") else True
        ):
                _batch = dataset[track_idx]
                _stems = {k: v.to(DEVICE) for k, v in _batch["stems"].items()}
                    # Create a mix from the stems
                input_mix = sum(_stems.values())
                # Delete _batch and _stems
                del _batch, _stems

            # Normalize the ground truth mix and stems for a fair comparison
            norm_factor = torch.max(torch.abs(eval_mix)) / torch.max(torch.abs(normalizing_fn(eval_mix)))
            if norm_factor == 0: norm_factor = 1
            eval_mix_norm = eval_mix / norm_factor
            stems = {k: v / norm_factor for k, v in stems.items()}
            eval_mix = normalizing_fn(eval_mix)
            input_mix = normalizing_fn(input_mix)

            # Prepare model input (resample if necessary)
            if model_sr != EVAL_SR:
                input_mix = librosa.resample(cpu_numpy(input_mix), orig_sr=EVAL_SR, target_sr=model_sr)
                input_mix = torch.from_numpy(input_mix).to(DEVICE)


            # --- Run model ---
            if isinstance(model, str) and "gt" in MODEL_IDENTIFIER:
                tf = False
                map_model_outs = True
                train_classes = list(dataset.all_possible_classes)
                synth_stems = torch.stack([stems[clas] for clas in dataset.all_possible_classes], dim=0).unsqueeze(0)
                synth_or_masked = "synth"
            elif isinstance(model, str) and "oracle" in MODEL_IDENTIFIER:
                tf = False
                map_model_outs = True
                train_classes = list(dataset.all_possible_classes)
                synth_stems = torch.stack([stems[clas] for clas in dataset.all_possible_classes], dim=0).unsqueeze(0) 
                synth_or_masked = "masked"
            elif isinstance(model, LarsNet):
                tf = False
                map_model_outs = False
                train_classes = model.train_classes
                if input_mix.ndim == 1:
                    assert model.mono is True, "Model is stereo but input is mono"
                    # Let's duplicate the mono input to stereo
                    input_mix = input_mix.unsqueeze(0).repeat(2, 1)
                est_stems_dict = model(input_mix.unsqueeze(0))
                synth_stems = torch.stack([est_stems_dict[cls].mean(dim=0) for cls in unique_eval_classes], dim=0).unsqueeze(0)
            elif hasattr(model, "encoder"): # Assumes our trained model structure
                tf = False
                map_model_outs = True
                train_classes = model.train_classes
                encoder_outs = model.encoder(input_mix.unsqueeze(0))
                peak_picking_fn = PeakPicking(activation_rate=encoder_outs['activation_rate'], classes=model.train_classes)
                model.decoder.peak_picking_fn = peak_picking_fn
                outputs = model.decoder(**encoder_outs, extra_returns=["stems"])
                synth_stems = outputs["stems"]
            else:
                tf = True
                synth_or_masked="masked"
                map_model_outs = True
                train_classes = model.train_classes
                sources = batch["gt_sources"]
                onsets = batch["onsets_dict"]
                # Convert onsets to torch tensors
                for key in onsets:
                    onsets[key] = torch.tensor(onsets[key], device=DEVICE)
                sources = torch.stack([sources[inst] for inst in unique_classes], dim=0)
                output = model(
                    input_mix.unsqueeze(0),
                    sources=sources,
                    ext_onsets=onsets,
                    refit=True,
                    return_keys=["spec_output"],
                )
                synth_stems = output["spec_output"]


            # Resample output back to eval sample rate if necessary
            # if model_sr != EVAL_SR:
            #     est_stems = soxr_resample_batched(est_stems.unsqueeze(0), model_sr, EVAL_SR).squeeze(0)

            if not tf:
                synth_stems = synth_stems[..., : eval_mix.shape[-1]] 

            if map_model_outs:
                mapped_synth_stems = torch.zeros(
                    (synth_stems.shape[0], len(unique_eval_classes), *synth_stems.shape[2:]), device=synth_stems.device
                )
            else:
                mapped_synth_stems = synth_stems

            mapped_gt_stems = torch.zeros((len(unique_eval_classes), eval_mix.shape[-1]), device=eval_mix.device)
            for class_name in unique_classes:
                mapped_class = class_mapping[class_name]

                # If any of the stems in the class is active, the class is active
                eval_clas_idx = np.where(unique_eval_classes == mapped_class)[0][0]
                mapped_gt_stems[eval_clas_idx] += stems[class_name]
                if class_name not in train_classes:
                    continue
                idx = train_classes.index(class_name)
                if map_model_outs:  # If we have more train classes than eval classes
                    mapped_synth_stems[:, eval_clas_idx] += synth_stems[:, idx]

            synth_stems = mapped_synth_stems
            stems = mapped_gt_stems

            # Trim to match original length
            if synth_or_masked == "masked":
                transform = STFT(n_fft=1024, hop_length=256, win_length=1024, center=True, magnitude=False) if not tf else model.transform

                if tf: # If the model already outputs TF domain (e.g., NMFD)
                    synth_tf = synth_stems
                    transform = copy.deepcopy(model.transform)
                    transform.magnitude = False
                else:
                    synth_tf = transform(synth_stems)
                mix_tf = transform(eval_mix)
                masked_stems = estimate_masks(mix_tf, synth_tf, masking_type="wiener", alpha=alpha)
                synth_stems = transform.inverse(masked_stems, length=eval_mix.shape[-1])
            
            # --- Save to disk ---
            drum_kit = batch['drum_kit']
            # f"{batch['audio_fn']}".split('.')[0]_f"{inverse_kit_map[drum_kit]}"
            FILE_KIT = f"{batch['audio_fn']}".split('.')[0] + f"_{inverse_kit_map[drum_kit]}"
            save_dir = OUTPUT_DIR / DATASET_SPLIT / FILE_KIT / MODEL_IDENTIFIER 
            save_dir.mkdir(parents=True, exist_ok=True)

            # Let's also save the mix
            sf.write(save_dir.parent / f"mix.wav", eval_mix.cpu().numpy(), EVAL_SR)
            # Now let's save every mix and stem
            for i, cls in enumerate(unique_eval_classes):
                est_stem = synth_stems[:, i].squeeze(0).cpu().numpy()

                sf.write(save_dir / f"{cls}_{synth_or_masked}.wav", est_stem, EVAL_SR)
    print("\n--- Inference Complete ---")
    print(f"Saved separated tracks to: {os.path.relpath(OUTPUT_DIR.resolve(), ROOT_PATH)}")

## Run Inference

Now, we execute the `run_all` function to generate the separated tracks.

The following cell will run inference for multiple models and configurations. You can comment out any run you don't want to execute. The results will be saved in the `demo/separation_outputs` directory. 

If you decide to manually select the tracks, be sure to inspect the dataset to check how many tracks are available in each split. They are sorted by duration, so the last indices correspond to very long tracks (30s+).


In [5]:
DATASET_SPLIT = "eval_session_train_kits"

# Masked-based runs
run_all(DATASET_SPLIT, "idm-44-train-kits", synth_or_masked="masked")
run_all(DATASET_SPLIT, "oracle", synth_or_masked="masked")
run_all(DATASET_SPLIT, "larsnet", synth_or_masked="masked")
run_all(DATASET_SPLIT, "larsnet-mono", synth_or_masked="masked")
run_all(DATASET_SPLIT, "nmfd_case1a", synth_or_masked="masked")

# Direct synthesis runs
run_all(DATASET_SPLIT, "gt", synth_or_masked="synth")
run_all(DATASET_SPLIT, "idm-44-train-kits", synth_or_masked="synth")
run_all(DATASET_SPLIT, "larsnet", synth_or_masked="synth")
run_all(DATASET_SPLIT, "larsnet-mono", synth_or_masked="synth")



Loading model: idm-44-train-kits
Found checkpoint at: logs/idm-44-train-kits/checkpoints/val-epoch=518-global_step=0.ckpt
Loading dataset split: eval_session_train_kits




Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:11<00:00,  1.98s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: oracle
Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:10<00:00,  1.78s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: larsnet
Loading UNet models...


CY pretrained_cymbals_unet: 100%|██████████| 5/5 [00:00<00:00,  6.03it/s]


Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:13<00:00,  2.18s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: larsnet-mono
Loading UNet models...


CY pretrained_cymbals_unet: 100%|██████████| 5/5 [00:00<00:00,  6.10it/s]


Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:12<00:00,  2.08s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: nmfd_case1a
Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


 50%|█████     | 3/6 [00:08<00:07,  2.59s/it][rank: 0] Clamping onset frame 3122 to 3104. This may indicate a mismatch between activation_rate and n_frames.
 67%|██████▋   | 4/6 [00:10<00:05,  2.56s/it][rank: 0] Clamping onset frame 4100 to 4097. This may indicate a mismatch between activation_rate and n_frames.
100%|██████████| 6/6 [00:16<00:00,  2.76s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: gt
Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:09<00:00,  1.65s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: idm-44-train-kits
Found checkpoint at: logs/idm-44-train-kits/checkpoints/val-epoch=518-global_step=0.ckpt
Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:11<00:00,  1.92s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: larsnet
Loading UNet models...


CY pretrained_cymbals_unet: 100%|██████████| 5/5 [00:00<00:00,  5.67it/s]


Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:13<00:00,  2.20s/it]



--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs
Loading model: larsnet-mono
Loading UNet models...


CY pretrained_cymbals_unet: 100%|██████████| 5/5 [00:00<00:00,  6.45it/s]


Loading dataset split: eval_session_train_kits
Dataset has 240 tracks and 9 classes: ['CY_CR' 'CY_RD' 'HH_CHH' 'HH_OHH' 'KD' 'SD' 'TT_HFT' 'TT_HMT' 'TT_LMT']


100%|██████████| 6/6 [00:12<00:00,  2.08s/it]


--- Inference Complete ---
Saved separated tracks to: demo/separation_outputs



