Skip to content

Commit

Permalink
Add sample rate config option to gr.Audio() (#6826)
Browse files Browse the repository at this point in the history
* Fix a bug that caused the sample rate of audio to be 8000 Hz after trimming and a bug that caused volume amplification and clipping each time trimming was performed

* Fix format

* add changeset

* add sample_rate param to waveform_options

* add changeset

* set WaveformOptions defaults

* formatting

* formatting

* add changeset

* audio

* changes

* add changeset

* tweak sample rate logic + docstring

* Tweak docstring

* formatting

* linting

* type tweak

* remove redundant None check

* tweak waveform lifecycle

* fix test

---------

Co-authored-by: tsukumi <tsukumijima@users.noreply.github.com>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
5 people committed Jan 22, 2024
1 parent 44c53d9 commit e8b2d8b
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .changeset/weak-streets-check.md
@@ -0,0 +1,6 @@
---
"@gradio/audio": minor
"gradio": minor
---

fix:Add sample rate config option to `gr.Audio()`
11 changes: 6 additions & 5 deletions gradio/components/audio.py
Expand Up @@ -30,13 +30,15 @@ class WaveformOptions:
show_recording_waveform: Whether to show the waveform when recording audio. Defaults to True.
show_controls: Whether to show the standard HTML audio player below the waveform when recording audio or playing recorded audio. Defaults to False.
skip_length: The percentage (between 0 and 100) of the audio to skip when clicking on the skip forward / skip backward buttons. Defaults to 5.
sample_rate: The output sample rate (in Hz) of the audio after editing. Defaults to 44100.
"""

waveform_color: str = "#9ca3af"
waveform_progress_color: str = "#f97316"
show_recording_waveform: bool = True
show_controls: bool = False
skip_length: int | float = 5
sample_rate: int = 44100


@document()
Expand Down Expand Up @@ -161,11 +163,10 @@ def __init__(
self.editable = editable
if waveform_options is None:
self.waveform_options = WaveformOptions()
self.waveform_options = (
WaveformOptions(**waveform_options)
if isinstance(waveform_options, dict)
else waveform_options
)
elif isinstance(waveform_options, dict):
self.waveform_options = WaveformOptions(**waveform_options)
else:
self.waveform_options = waveform_options
self.min_length = min_length
self.max_length = max_length
super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion js/audio/Index.svelte
Expand Up @@ -104,7 +104,8 @@
dragToSeek: true,
normalize: true,
minPxPerSec: 20,
mediaControls: waveform_options.show_controls
mediaControls: waveform_options.show_controls,
sampleRate: waveform_options.sample_rate || 44100
};
const trim_region_settings = {
Expand Down
17 changes: 10 additions & 7 deletions js/audio/player/AudioPlayer.svelte
Expand Up @@ -112,13 +112,16 @@
mode = "";
const decodedData = waveform?.getDecodedData();
if (decodedData)
await process_audio(decodedData, start, end).then(
async (trimmedBlob: Uint8Array) => {
await dispatch_blob([trimmedBlob], "change");
waveform?.destroy();
create_waveform();
}
);
await process_audio(
decodedData,
start,
end,
waveform_settings.sampleRate
).then(async (trimmedBlob: Uint8Array) => {
await dispatch_blob([trimmedBlob], "change");
waveform?.destroy();
container.innerHTML = "";
});
dispatch("edit");
};
Expand Down
4 changes: 3 additions & 1 deletion js/audio/recorder/AudioRecorder.svelte
Expand Up @@ -82,7 +82,9 @@
timing = false;
clearInterval(interval);
const array_buffer = await blob.arrayBuffer();
const context = new AudioContext();
const context = new AudioContext({
sampleRate: waveform_settings.sampleRate
});
const audio_buffer = await context.decodeAudioData(array_buffer);
if (audio_buffer)
Expand Down
4 changes: 3 additions & 1 deletion js/audio/shared/audioBufferToWav.ts
Expand Up @@ -47,7 +47,9 @@ export function audioBufferToWav(audioBuffer: AudioBuffer): Uint8Array {
for (let i = 0; i < audioBuffer.numberOfChannels; i++) {
const channel = audioBuffer.getChannelData(i);
for (let j = 0; j < channel.length; j++) {
view.setInt16(offset, channel[j] * 0xffff, true);
// Scaling Float32 to Int16
const sample = Math.max(-1, Math.min(1, channel[j]));
view.setInt16(offset, sample * 0x7fff, true);
offset += 2;
}
}
Expand Down
1 change: 1 addition & 0 deletions js/audio/shared/types.ts
Expand Up @@ -5,4 +5,5 @@ export type WaveformOptions = {
skip_length?: number;
trim_region_color?: string;
show_recording_waveform?: boolean;
sample_rate?: number;
};
10 changes: 6 additions & 4 deletions js/audio/shared/utils.ts
@@ -1,5 +1,4 @@
import type WaveSurfer from "wavesurfer.js";
import Regions from "wavesurfer.js/dist/plugins/regions.js";
import { audioBufferToWav } from "./audioBufferToWav";

export interface LoadedParams {
Expand All @@ -18,11 +17,14 @@ export function blob_to_data_url(blob: Blob): Promise<string> {
export const process_audio = async (
audioBuffer: AudioBuffer,
start?: number,
end?: number
end?: number,
waveform_sample_rate?: number
): Promise<Uint8Array> => {
const audioContext = new AudioContext();
const audioContext = new AudioContext({
sampleRate: waveform_sample_rate || audioBuffer.sampleRate
});
const numberOfChannels = audioBuffer.numberOfChannels;
const sampleRate = audioBuffer.sampleRate;
const sampleRate = waveform_sample_rate || audioBuffer.sampleRate;

let trimmedLength = audioBuffer.length;
let startOffset = 0;
Expand Down
18 changes: 16 additions & 2 deletions test/test_components.py
Expand Up @@ -836,7 +836,14 @@ def test_component_functions(self, gradio_temp_dir):
"streamable": False,
"max_length": None,
"min_length": None,
"waveform_options": None,
"waveform_options": {
"sample_rate": 44100,
"show_controls": False,
"show_recording_waveform": True,
"skip_length": 5,
"waveform_color": "#9ca3af",
"waveform_progress_color": "#f97316",
},
"_selectable": False,
}
assert audio_input.preprocess(None) is None
Expand Down Expand Up @@ -881,7 +888,14 @@ def test_component_functions(self, gradio_temp_dir):
"format": "wav",
"streamable": False,
"sources": ["upload", "microphone"],
"waveform_options": None,
"waveform_options": {
"sample_rate": 44100,
"show_controls": False,
"show_recording_waveform": True,
"skip_length": 5,
"waveform_color": "#9ca3af",
"waveform_progress_color": "#f97316",
},
"_selectable": False,
}

Expand Down

0 comments on commit e8b2d8b

Please sign in to comment.