Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve some names #47

Merged
merged 3 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,26 @@
<img width="51%" src="/visualization.gif" title="Real-time diarization example" />
</p>

## Demo

You can visualize the real-time speaker diarization of an audio stream with the built-in demo script.
## Getting started

### Stream a recorded conversation

```shell
python -m diart.demo /path/to/audio.wav
python -m diart.stream /path/to/audio.wav
```

### Stream from your microphone

```shell
python -m diart.demo microphone
python -m diart.stream microphone
```

See `python -m diart.demo -h` for more information.
See `python -m diart.stream -h` for more options.

## Build your own pipeline

Diart provides building blocks that can be combined to do speaker diarization on an audio stream.
The streaming implementation is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `functional` module is completely independent.
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.

### Example

Expand All @@ -50,16 +48,16 @@ import rx
import rx.operators as ops
import diart.operators as myops
from diart.sources import MicrophoneAudioSource
import diart.functional as fn
import diart.blocks as blocks

sample_rate = 16000
mic = MicrophoneAudioSource(sample_rate)

# Initialize independent modules
segmentation = fn.FrameWiseModel("pyannote/segmentation")
embedding = fn.ChunkWiseModel("pyannote/embedding")
osp = fn.OverlappedSpeechPenalty(gamma=3, beta=10)
normalization = fn.EmbeddingNormalization(norm=1)
segmentation = blocks.FramewiseModel("pyannote/segmentation")
embedding = blocks.ChunkwiseModel("pyannote/embedding")
osp = blocks.OverlappedSpeechPenalty(gamma=3, beta=10)
normalization = blocks.EmbeddingNormalization(norm=1)

# Reformat microphone stream. Defaults to 5s duration and 500ms shift
regular_stream = mic.stream.pipe(myops.regularize_stream(sample_rate))
Expand Down Expand Up @@ -142,13 +140,13 @@ Its performance is very close to what is reported in the paper (and sometimes ev

To obtain the best results, make sure to use the following hyper-parameters:

Dataset | latency | tau | rho | delta
------------|---------|--------|--------|------
DIHARD III | any | 0.555 | 0.422 | 1.517
AMI | any | 0.507 | 0.006 | 1.057
VoxConverse | any | 0.576 | 0.915 | 0.648
DIHARD II | 1s | 0.619 | 0.326 | 0.997
DIHARD II | 5s | 0.555 | 0.422 | 1.517
| Dataset | latency | tau | rho | delta |
|-------------|---------|--------|--------|-------|
| DIHARD III | any | 0.555 | 0.422 | 1.517 |
| AMI | any | 0.507 | 0.006 | 1.057 |
| VoxConverse | any | 0.576 | 0.915 | 0.648 |
| DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
| DIHARD II | 5s | 0.555 | 0.422 | 1.517 |

`diart.benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:

Expand All @@ -157,7 +155,7 @@ python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422
```

`diart.benchmark` runs a faster inference and evaluation by pre-calculating model outputs in batches.
More options about benchmarking can be found by running `python -m diart.benchmark -h`.
See `python -m diart.benchmark -h` for more options.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.

Expand Down
10 changes: 10 additions & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
RHO = "Speech ratio threshold to decide if centroids are updated with a given speaker. 0 <= RHO <= 1"
DELTA = "Embedding-to-centroid distance threshold to flag a speaker as known or new. 0 <= DELTA <= 2"
GAMMA = "Parameter gamma for overlapped speech penalty"
BETA = "Parameter beta for overlapped speech penalty"
MAX_SPEAKERS = "Maximum number of speakers"
GPU = "Run on GPU"
OUTPUT = "Directory to store the system's output in RTTM format"
27 changes: 14 additions & 13 deletions src/diart/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate

import diart.argdoc as argdoc
import diart.operators as dops
import diart.sources as src
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.sinks import RTTMWriter

# Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files <conversation>.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str, help="Directory with RTTM files <conversation>.rttm")
parser.add_argument("--step", default=0.5, type=float, help="Source sliding window step in seconds. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help="System latency in seconds. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help="Activity threshold tau active in [0,1]. Defaults to 0.5")
parser.add_argument("--rho", default=0.3, type=float, help="Speech ratio threshold rho update in [0,1]. Defaults to 0.3")
parser.add_argument("--delta", default=1, type=float, help="Maximum distance threshold delta new in [0,2]. Defaults to 1")
parser.add_argument("--gamma", default=3, type=float, help="Parameter gamma for overlapped speech penalty. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help="Parameter beta for overlapped speech penalty. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help="Maximum number of identifiable speakers. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If lower than 2, run fully online and estimate real-time latency. Defaults to 32")
parser.add_argument("--output", type=str, help="Output directory to store RTTM files. Defaults to `root`")
parser.add_argument("--gpu", dest="gpu", action="store_true", help="Add this flag to run on GPU")
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32")
parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU)
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()

args.root = Path(args.root)
Expand Down
4 changes: 2 additions & 2 deletions src/diart/functional.py → src/diart/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def resolve_features(features: TemporalFeatures) -> torch.Tensor:
return data.float()


class FrameWiseModel:
class FramewiseModel:
def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
self.model = get_model(model)
self.model.eval()
Expand Down Expand Up @@ -88,7 +88,7 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
return output


class ChunkWiseModel:
class ChunkwiseModel:
def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
self.model = get_model(model)
self.model.eval()
Expand Down
22 changes: 11 additions & 11 deletions src/diart/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pyannote.core import SlidingWindowFeature, SlidingWindow
from tqdm import tqdm

from . import functional as fn
from . import blocks
from . import operators as dops
from . import sources as src

Expand Down Expand Up @@ -77,17 +77,17 @@ def from_model_streams(
if source_duration is not None:
end_time = self.config.last_chunk_end_time(source_duration)
# Initialize clustering and aggregation modules
clustering = fn.OnlineSpeakerClustering(
clustering = blocks.OnlineSpeakerClustering(
self.config.tau_active,
self.config.rho_update,
self.config.delta_new,
"cosine",
self.config.max_speakers,
)
aggregation = fn.DelayedAggregation(
aggregation = blocks.DelayedAggregation(
self.config.step, self.config.latency, strategy="hamming", stream_end=end_time
)
binarize = fn.Binarize(uri, self.config.tau_active)
binarize = blocks.Binarize(uri, self.config.tau_active)

# Join segmentation and embedding streams to update a background clustering model
# while regulating latency and binarizing the output
Expand All @@ -102,7 +102,7 @@ def from_model_streams(
)
# Add corresponding waveform to the output
if audio_chunk_stream is not None:
window_selector = fn.DelayedAggregation(
window_selector = blocks.DelayedAggregation(
self.config.step, self.config.latency, strategy="first", stream_end=end_time
)
waveform_stream = audio_chunk_stream.pipe(
Expand All @@ -117,8 +117,8 @@ def from_model_streams(
class OnlineSpeakerDiarization:
def __init__(self, config: PipelineConfig):
self.config = config
self.segmentation = fn.FrameWiseModel(config.segmentation, self.config.device)
self.embedding = fn.ChunkWiseModel(config.embedding, self.config.device)
self.segmentation = blocks.FramewiseModel(config.segmentation, self.config.device)
self.embedding = blocks.ChunkwiseModel(config.embedding, self.config.device)
self.speaker_tracking = OnlineSpeakerTracking(config)
msg = "Invalid latency requested"
assert self.config.step <= self.config.latency <= self.duration, msg
Expand Down Expand Up @@ -152,11 +152,11 @@ def from_source(
# Branch the stream to calculate chunk segmentation
segmentation_stream = regular_stream.pipe(ops.map(self.segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
osp = fn.OverlappedSpeechPenalty(gamma=self.config.gamma, beta=self.config.beta)
osp = blocks.OverlappedSpeechPenalty(gamma=self.config.gamma, beta=self.config.beta)
embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe(
ops.starmap(lambda wave, seg: (wave, osp(seg))),
ops.starmap(self.embedding),
ops.map(fn.EmbeddingNormalization(norm=1))
ops.map(blocks.EmbeddingNormalization(norm=1))
)
chunk_stream = regular_stream if output_waveform else None
return self.speaker_tracking.from_model_streams(
Expand All @@ -177,8 +177,8 @@ def from_file(
)

# Initialize pipeline modules
osp = fn.OverlappedSpeechPenalty(self.config.gamma, self.config.beta)
emb_norm = fn.EmbeddingNormalization(norm=1)
osp = blocks.OverlappedSpeechPenalty(self.config.gamma, self.config.beta)
emb_norm = blocks.EmbeddingNormalization(norm=1)

# Split audio into chunks
chunks = rearrange(
Expand Down
25 changes: 11 additions & 14 deletions src/diart/demo.py → src/diart/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import rx.operators as ops
import torch

import diart.argdoc as argdoc
import diart.operators as dops
import diart.sources as src
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
Expand All @@ -12,21 +13,17 @@
# Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'")
parser.add_argument("--step", default=0.5, type=float, help="Source sliding window step")
parser.add_argument("--latency", default=0.5, type=float, help="System latency")
parser.add_argument("--tau", default=0.5, type=float, help="Activity threshold tau active")
parser.add_argument("--rho", default=0.3, type=float, help="Speech duration threshold rho update")
parser.add_argument("--delta", default=1, type=float, help="Maximum distance threshold delta new")
parser.add_argument("--gamma", default=3, type=float, help="Parameter gamma for overlapped speech penalty")
parser.add_argument("--beta", default=10, type=float, help="Parameter beta for overlapped speech penalty")
parser.add_argument("--max-speakers", default=20, type=int, help="Maximum number of identifiable speakers")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference")
parser.add_argument("--gpu", dest="gpu", action="store_true", help="Add this flag to run on GPU")
parser.add_argument(
"--output", type=str,
help="Output directory to store the RTTM. Defaults to home directory "
"if source is microphone or parent directory if source is a file"
)
parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU)
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
args = parser.parse_args()

# Define online speaker diarization pipeline
Expand Down