Skip to content

Commit

Permalink
Fix sample rate issues (#153)
Browse files Browse the repository at this point in the history
* Add automatic sample rate detection in MicrophoneAudioSource. Fix resampling crash.

* Replace block_size by block_duration in audio source constructors
  • Loading branch information
juanmc2005 committed Oct 28, 2023
1 parent d43cacf commit 8299b70
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 38 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ from diart.inference import StreamingInference
from diart.sinks import RTTMWriter

pipeline = SpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
Expand Down Expand Up @@ -167,7 +167,7 @@ config = SpeakerDiarizationConfig(
embedding=MyEmbeddingModel()
)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic)
prediction = inference()
```
Expand Down Expand Up @@ -241,12 +241,11 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding

segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
sample_rate = segmentation.model.sample_rate
mic = MicrophoneAudioSource(sample_rate)
mic = MicrophoneAudioSource()

stream = mic.stream.pipe(
# Reformat stream to 5s duration and 500ms shift
dops.rearrange_audio_stream(sample_rate=sample_rate),
dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate),
ops.map(lambda wav: (wav, segmentation(wav))),
ops.starmap(embedding)
).subscribe(on_next=lambda emb: print(emb.shape))
Expand Down
3 changes: 0 additions & 3 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
left = utils.get_padding_left(file_duration + right, self.duration)
return left, right

def optimal_block_size(self) -> int:
return int(np.rint(self.step * self.sample_rate))


class Pipeline(ABC):
@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions src/diart/blocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ class Resample:
resample_rate: int
Sample rate of the output
"""
def __init__(self, sample_rate: int, resample_rate: int):
self.resample = T.Resample(sample_rate, resample_rate)
def __init__(self, sample_rate: int, resample_rate: int, device: Optional[torch.device] = None):
self.device = device
if self.device is None:
self.device = torch.device("cpu")
self.resample = T.Resample(sample_rate, resample_rate).to(self.device)
self.formatter = TemporalFeatureFormatter()

def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
wav = self.formatter.cast(waveform) # shape (batch, samples, 1)
wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1)
with torch.no_grad():
resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2)
return self.formatter.restore_type(resampled_wav)
Expand Down
5 changes: 2 additions & 3 deletions src/diart/console/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@

def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):
# Create audio source
block_size = int(np.rint(step * sample_rate))
source_components = source.split(":")
if source_components[0] != "microphone":
audio_source = src.FileAudioSource(source, sample_rate)
audio_source = src.FileAudioSource(source, sample_rate, block_duration=step)
else:
device = int(source_components[1]) if len(source_components) > 1 else None
audio_source = src.MicrophoneAudioSource(sample_rate, block_size, device)
audio_source = src.MicrophoneAudioSource(step, device)

# Encode audio, then send through websocket
audio_source.stream.pipe(
Expand Down
5 changes: 2 additions & 3 deletions src/diart/console/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,17 @@ def run():
pipeline = pipeline_class(config)

# Manage audio source
block_size = config.optimal_block_size()
source_components = args.source.split(":")
if source_components[0] != "microphone":
args.source = Path(args.source).expanduser()
args.output = args.source.parent if args.output is None else Path(args.output)
padding = config.get_file_padding(args.source)
audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, block_size)
audio_source = src.FileAudioSource(args.source, config.sample_rate, padding, config.step)
pipeline.set_timestamp_shift(-padding[0])
else:
args.output = Path("~/").expanduser() if args.output is None else Path(args.output)
device = int(source_components[1]) if len(source_components) > 1 else None
audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device)
audio_source = src.MicrophoneAudioSource(config.step, device)

# Run online inference
inference = StreamingInference(
Expand Down
12 changes: 8 additions & 4 deletions src/diart/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,22 @@ def __init__(

self.stream = self.source.stream

# Rearrange stream to form sliding windows
self.stream = self.stream.pipe(
dops.rearrange_audio_stream(chunk_duration, step_duration, source.sample_rate),
)

# Dynamic resampling if the audio source isn't compatible
if sample_rate != self.source.sample_rate:
msg = f"Audio source has sample rate {self.source.sample_rate}, " \
f"but pipeline's is {sample_rate}. Will resample."
logging.warning(msg)
self.stream = self.stream.pipe(
ops.map(blocks.Resample(self.source.sample_rate, sample_rate))
ops.map(blocks.Resample(self.source.sample_rate, sample_rate, self.pipeline.config.device))
)

# Add rx operators to manage the inputs and outputs of the pipeline
# Form batches
self.stream = self.stream.pipe(
dops.rearrange_audio_stream(chunk_duration, step_duration, sample_rate),
ops.buffer_with_count(count=self.batch_size),
)

Expand Down Expand Up @@ -316,7 +320,7 @@ def run_single(
filepath,
pipeline.config.sample_rate,
padding,
pipeline.config.optimal_block_size(),
pipeline.config.step,
)
pipeline.set_timestamp_shift(-padding[0])
inference = StreamingInference(
Expand Down
44 changes: 27 additions & 17 deletions src/diart/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,23 @@ class FileAudioSource(AudioSource):
padding: (float, float)
Left and right padding to add to the file (in seconds).
Defaults to (0, 0).
block_size: int
Number of samples per chunk emitted.
Defaults to 1000.
block_duration: int
Duration of each emitted chunk in seconds.
Defaults to 0.5 seconds.
"""
def __init__(
self,
file: FilePath,
sample_rate: int,
padding: Tuple[float, float] = (0, 0),
block_size: int = 1000,
block_duration: float = 0.5,
):
super().__init__(Path(file).stem, sample_rate)
self.loader = AudioLoader(self.sample_rate, mono=True)
self._duration = self.loader.get_duration(file)
self.file = file
self.resolution = 1 / self.sample_rate
self.block_size = block_size
self.block_size = int(np.rint(block_duration * self.sample_rate))
self.padding_start, self.padding_end = padding
self.is_closed = False

Expand Down Expand Up @@ -134,11 +134,9 @@ class MicrophoneAudioSource(AudioSource):
Parameters
----------
sample_rate: int
Sample rate for the emitted audio chunks.
block_size: int
Number of samples per chunk emitted.
Defaults to 1000.
block_duration: int
Duration of each emitted chunk in seconds.
Defaults to 0.5 seconds.
device: int | str | (int, str) | None
Device identifier compatible for the sounddevice stream.
If None, use the default device.
Expand All @@ -147,15 +145,27 @@ class MicrophoneAudioSource(AudioSource):

def __init__(
self,
sample_rate: int,
block_size: int = 1000,
block_duration: float = 0.5,
device: Optional[Union[int, Text, Tuple[int, Text]]] = None,
):
super().__init__("live_recording", sample_rate)
self.block_size = block_size
# Use the lowest supported sample rate
sample_rates = [16000, 32000, 44100, 48000]
best_sample_rate = None
for sr in sample_rates:
try:
sd.check_input_settings(device=device, samplerate=sr)
except Exception:
pass
else:
best_sample_rate = sr
break
super().__init__(f"input_device:{device}", best_sample_rate)

# Determine block size in samples and create input stream
self.block_size = int(np.rint(block_duration * self.sample_rate))
self._mic_stream = sd.InputStream(
channels=1,
samplerate=sample_rate,
samplerate=self.sample_rate,
latency=0,
blocksize=self.block_size,
callback=self._read_callback,
Expand Down Expand Up @@ -261,10 +271,10 @@ def __init__(
sample_rate: int,
streamer: StreamReader,
stream_index: Optional[int] = None,
block_size: int = 1000,
block_duration: float = 0.5,
):
super().__init__(uri, sample_rate)
self.block_size = block_size
self.block_size = int(np.rint(block_duration * self.sample_rate))
self._streamer = streamer
self._streamer.add_basic_audio_stream(
frames_per_chunk=self.block_size,
Expand Down

0 comments on commit 8299b70

Please sign in to comment.