Skip to content

mago-research/Ultra-Sortformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

108 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Ultra-Sortformer: Extending NVIDIA Sortformer to N Speakers

5spk Model on Hugging Face 8spk Model on Hugging Face

This repository records the transition from a fixed 4-speaker cap to a configurable N-speaker Streaming Sortformer (N > 4). The pattern is always the same: grow the speaker head with SVD-based orthogonal initialization, keep base vs. new weights in separate modules, and fine-tune with split learning rates so behavior on 2–4 speaker audio stays stable while the extra dimensions learn. You can target any N supported by your data and VRAM; we publish checkpoints for N = 5 and N = 8 as reference points.

Released models (Hugging Face)


Table of Contents

  1. Background
  2. Architecture Overview
  3. Extension Journey
  4. Benchmark
  5. Synthetic Training Data
  6. Training
  7. Requirements

Background

NVIDIA's diar_streaming_sortformer_4spk-v2.1 is a streaming speaker diarization model based on the Sortformer architecture. It uses a FastConformer encoder (17 layers) followed by a Transformer encoder (18 layers) to produce per-frame speaker activity predictions. The final output layer is a single linear mapping from hidden states to four speaker probabilities.

The model supports real-time streaming diarization with a chunk-based speaker cache.

Problem: The public checkpoint is hard-limited to four simultaneous speakers. Scenes with more talkers are handled poorly once you only relabel data without widening the head.

Goal: Make max_num_of_spks = N with N > 4 a first-class training target, without sacrificing the 2–4 speaker regime. This README walks through the mechanics; concrete runs in the repo use N = 5 (smallest jump) and N = 8 (larger jump from the same base).


Architecture Overview

Audio Input
    │
    ▼
┌─────────────────────────────┐
│  Preprocessor               │
└─────────────────────────────┘
    │
    ▼
┌──────────────────────────────┐
│  FastConformer Encoder       │  ← 17 layers, d_model=512
│  (NEST Encoder)              │    Subsampling factor: 8
└──────────────────────────────┘
    │
    ▼
┌──────────────────────────────┐
│  Transformer Encoder         │  ← 18 layers, d_model=192
└──────────────────────────────┘
    │
    ▼
┌──────────────────────────────┐
│  SortformerModules           │  ← Speaker Cache + Attention
│  (Streaming Speaker Cache)   │
└──────────────────────────────┘
    │
    ▼
┌──────────────────────────────┐
│  single_hidden_to_spks       │  ← Linear(192, N_spk)  ← KEY LAYER
└──────────────────────────────┘
    │
    ▼
  Per-frame speaker activity predictions  [batch, time, N_spk]

Stock: one single_hidden_to_spks (Linear(192, N_spk)). This repo (extended): single_hidden_to_spks_base + single_hidden_to_spks_new for split LR (Step 1–2). 4 → N speakers means N − 4 new rows (or N_new from any N_base).


Extension Journey

Step 1: Output Layer Extension (4 → N)

Script: scripts/extend_output_layer.py

We treat the baseline as N_base = 4 speakers and grow the matrix to N = N_base + N_new. The first shipped milestone is N = 5 (N_new = 1); the N = 8 model adds N_new = 4 rows in one shot with the same procedure.

Random initialization for the new rows destroys accuracy on existing speakers. We instead use SVD-based orthogonal initialization so new logits start orthogonal to the subspace spanned by the original weights.

How it works

Let the existing weight matrix be W ∈ ℝ^{N_base×H} (H = 192):

# SVD decomposition of existing weights
U, S, Vh = torch.linalg.svd(W, full_matrices=True)

# New speaker rows = right singular directions beyond the first N_base
# Vh[N_base], Vh[N_base+1], ... are orthogonal to the row space of W
new_row = Vh[N_base]  # first new speaker; repeat indexing for additional speakers

# Normalize to match typical row norms
avg_norm = W.norm(dim=1).mean()
new_row = new_row * (avg_norm / new_row.norm())

Repeat for each new speaker index until the head reaches the target N.

The extended checkpoint is saved in split form so optimizers can treat base and new rows differently:

single_hidden_to_spks_base  (N_base speakers)   ← frozen / low LR
single_hidden_to_spks_new   (N_new speakers)    ← higher LR

Step 2: Split Learning Rate Training

Key insight: A single learning rate on the whole expanded head tends to erase the old 2–4 speaker solution while the new dimensions still fit.

What we saw early on

  • Synthetic val_2spk–val_4spk quality dropped under uniform LR.
  • The network over-used the new capacity (e.g., predicting the full trained N on 3–4 speaker clips).

Mitigation: differential learning rates

Component Learning rate Role
single_hidden_to_spks_base (speakers 1…N_base) 1e-5 Preserve the pretrained head
single_hidden_to_spks_new (added speakers) 1e-4 Faster adaptation on new dimensions
Rest of the model 1e-5 Standard fine-tuning

Implemented via setup_optimizer_param_groups in sortformer_diar_models.py:

def setup_optimizer_param_groups(self):
    sm = self.sortformer_modules
    n_base = getattr(sm, 'n_base_spks', 0)
    new_lr = self._cfg.get('optim_new_lr', None)

    if n_base > 0 and new_lr is not None and hasattr(sm, 'single_hidden_to_spks_new'):
        new_params = list(sm.single_hidden_to_spks_new.parameters())
        new_param_ids = {id(p) for p in new_params}
        base_params = [p for p in self.parameters() if id(p) not in new_param_ids]
        self._optimizer_param_groups = [
            {"params": base_params},
            {"params": new_params, "lr": new_lr},
        ]

Step 3: Scaling to Larger N (Example: 8 Speakers)

The N = 8 release uses the same pipeline: start from NVIDIA 4-spk, extend the Sortformer head with orthogonal / split weights (N_base = 4, N_new = 4), then fine-tune with ~1e-5 on the bulk of parameters and ~1e-4 on single_hidden_to_spks_new on mixed synthetic + real meeting data.

Other N: keep max_num_of_spks, manifests, and n_base_spks consistent with your checkpoint; NeMo patches are listed under Training.


Synthetic Training Data

scripts/sentence_level_multispeaker_simulator.py subclasses NeMo’s MultiSpeakerSimulator (same idea as multispeaker_simulator.py) and keeps NeMo’s data_simulator.yaml session pipeline. _build_sentence only is overridden: turns use whole manifest utterances, not word-aligned slices.

Source Data

Single-speaker utterances come from the 다화자 음성합성 데이터 (Multi-speaker Speech Synthesis Dataset) on AI-Hub (NIA). It spans 3,400+ Korean speakers (10s–60s), ~10k hours.

Split Approx. #Utterances Language
multispeaker_speech_synthesis_data/Training 8,666,803 Korean
multispeaker_speech_synthesis_data/Validation 1,225,244 Korean

Build a NeMo-style JSON manifest listing audio_filepath, speaker (or compatible id), and optionally text, words, alignments for labels. The simulator groups rows by speaker id to sample per turn.

Synthesis prerequisites

  1. Install NeMo (see Requirements). The simulator imports nemo from your environment; it looks for NeMo/tools/speech_data_simulator/conf/data_simulator.yaml only if a sibling NeMo/ directory exists next to this repo. If you use a pip-only install, pass --config_file pointing to that YAML (e.g. from a checkout or a copied file).
  2. System audio libraries: libsndfile1 and ffmpeg (listed under Requirements) are required for decoding/writing audio in practice.
  3. Generation forces CPU (CUDA_VISIBLE_DEVICES="") for stable runs without a working GPU stack.

How it differs from stock NeMo

Aspect Stock MultiSpeakerSimulator SentenceLevelMultiSpeakerSimulator
Turn content Word-aligned slices; reads audio in chunks up to max_audio_read_sec One or more entire utterances per turn (mono, resampled to sr)
Turn length cap Word-count target from sentence_length_params max_sentences_per_turn: uniform 1…N utterances per turn (CLI default N = 3). If unset in YAML and not overridden, falls back to negative binomial on utterance count (often too long—prefer explicit N)
Optional YAML Same session_params.max_turn_duration_sec caps samples per turn when set

Synthesis configuration

  • Base config: NeMo/tools/speech_data_simulator/conf/data_simulator.yaml (or pass --config_file).
  • Important YAML knobs (not all exposed on CLI):
    • session_config.{num_speakers,num_sessions,session_length} — target speakers per session, session count, nominal duration in seconds.
    • session_params.{mean_silence,mean_overlap,...} — global silence/overlap means (per-session values vary).
    • speaker_enforcement.enforce_num_speakers — if true, NeMo may continue past session_length and pad the waveform until every speaker has spoken; real duration can exceed session_length. Set enforce_num_speakers: false in YAML if you need a hard cap at the cost of possibly missing speakers in a session.
    • sr, outputs.output_filename, augmentors, background noise, etc. — unchanged from NeMo.

Synthesis CLI

Argument Role
--manifest_filepath Input NeMo JSON manifest (single-speaker rows with speaker id).
--output_dir Output directory for .wav, .rttm, .json, params.yaml, etc.
--config_file Optional override YAML (defaults to NeMo’s data_simulator.yaml).
--num_speakers Override session_config.num_speakers.
--num_sessions Override session_config.num_sessions.
--session_length Override nominal session length (seconds).
--mean_silence Session mean silence ratio in [0, 1).
--mean_overlap Session mean overlap ratio in [0, 1); invalid mean_overlap_var is clamped for stability.
--max_sentences_per_turn / --max_sent Max utterances concatenated in one speaker turn; each run draws uniformly from 1…N (default N = 3).

Session length and speaker enforcement

session_length is a target timeline length in samples (session_length × sr). With enforce_num_speakers: true (NeMo default), the generator can extend the buffer so late speakers still get turns. For utterance-level simulation, combine --max_sent (small N) with YAML tuning (enforce_num_speakers, optional max_turn_duration_sec) if you need durations close to the nominal cap.

Synthesis outputs

Per session index i: multispeaker_session_i.wav, multispeaker_session_i.rttm, multispeaker_session_i.json (and CTM if enabled), plus a copied params.yaml under --output_dir. Merge session manifests into NeMo diarization train/val JSON with your own tooling. Under scripts/, this repo includes sentence_level_multispeaker_simulator.py (synthesis) and inference.py (minimal HF Sortformer diarize example).

Synthesis example

From the repository root (with NeMo installed and discoverable as above):

python scripts/sentence_level_multispeaker_simulator.py \
  --manifest_filepath /path/to/manifest.json \
  --output_dir /path/to/synthetic_run \
  --num_speakers 8 \
  --num_sessions 1000 \
  --session_length 180 \
  --mean_silence 0.10 \
  --mean_overlap 0.05 \
  --max_sent 3

Adjust paths, speaker count, session count, and overlap/silence means to match your experiment grid.

Generated datasets (this project)

Synthetic grids for 2–8 speakers used two mean-overlap settings: ov0.05 (~5%) and ov0.15 (~15%, harder). Both used comparable mean silence (~10%); overlap and silence are session means, so per-session values vary.


Benchmark

Tables and evaluation protocol: results/benchmark.md.

Note: More speakers can shift speaker-count behavior on short or low-speaker clips; read Spk_Count_Acc next to DER. Model cards on Hugging Face have more context.


Training

Upstream NeMo does not ship the split head / split LR below—patch your checkout (or use a fork that includes the same edits).

NeMo Modifications

nemo/collections/asr/models/sortformer_diar_models.py

  • Added setup_optimizer_param_groups() override for differential learning rates

nemo/collections/asr/modules/sortformer_modules.py

  • Added n_base_spks parameter to enable split output layers (single_hidden_to_spks_base + single_hidden_to_spks_new)

What we froze (ablation)

We compared three setups using NeMo’s freeze_encoder / freeze_transformer_encoder flags on the Fast Conformer encoder (encoder.*) and Transformer encoder (transformer_encoder.*):

Setup Conformer encoder Transformer encoder
Encoder frozen frozen trainable
Encoder + Transformer frozen frozen frozen
Full fine-tuning trainable trainable

Full fine-tuning gave the smoothest, most reliable training loss decrease, so the released Ultra-Sortformer runs use no freezing (both stacks trainable), together with the split speaker head and split LR above.

Training Configuration

Example keys (see also NeMo examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml):

model:
  max_num_of_spks: 6       # Set to your target N
  lr: 1e-5                 # Base learning rate
  # optim_new_lr: 1e-4     # Higher LR for single_hidden_to_spks_new (split-head fine-tuning)

  sortformer_modules:
    num_spks: ${model.max_num_of_spks}
    # n_base_spks: 4       # Base speaker count when using split output layers

Requirements

System packages (Debian/Ubuntu; use sudo if you are not root):

sudo apt-get update && sudo apt-get install -y libsndfile1 ffmpeg

Python (recommended: a fresh virtual environment). nemo_toolkit[asr] pulls in PyTorch and most ASR dependencies declared by NeMo; install a CUDA-enabled PyTorch wheel first from pytorch.org if you train or run on GPU.

pip install Cython packaging
pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"

Note: pip NeMo from NVIDIA does not include the Training patches (n_base_spks, split head, optimizer groups).

Optional extras:

  • pyannote.metrics — used for DER / benchmark-style evaluation in this project’s docs and scripts.
  • librosa — only if your own preprocessing or tooling uses it.
  • Sibling NeMo/ clone — handy for NeMo/examples/... and default data_simulator.yaml; otherwise pass --config_file to the simulator.

License

Apache License 2.0

About

NeMo-based speaker diarization over 4-speakers

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages