-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for audio transcription pipelines in transformers #8464
Conversation
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@pytest.fixture() | ||
def sound_file_for_test(): | ||
url = "https://www.nasa.gov/62282main_countdown_launch.wav" | ||
response = requests.get(url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
response = requests.get(url) | |
response = requests.get(url) | |
response.raise_for_status() |
to avoid encoutering an unclear error when the request fails and the content
attribute is not audio bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the requests logic as we're just going to load from relative path parsing to the datasets folder to retrieve the static .wav file
@@ -316,6 +320,28 @@ def image_for_test(): | |||
return dataset["test"]["image"][0] | |||
|
|||
|
|||
@pytest.fixture() | |||
def sound_file_for_test(): | |||
url = "https://www.nasa.gov/62282main_countdown_launch.wav" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How large is this file? If it's small, can we include it in the repository?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a few MB. I'll add the raw bytes into tests/datasets
and we can just use that instead of calling up NASA ;) Good call.
Documentation preview for 7d023f2 will be available here when this CircleCI job completes successfully. More info
|
mlflow/transformers.py
Outdated
encoded_sound_file = list(data[0].values())[0] | ||
return decode_sound_file(encoded_sound_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encoded_sound_file = list(data[0].values())[0] | |
return decode_sound_file(encoded_sound_file) | |
encoded_audio = list(data[0].values())[0] | |
return decode_sound_file(encoded_audio) |
Can we rename this variable because it's not a file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point. Changed!
examples/transformers/whisper.py
Outdated
|
||
|
||
# Acquire an audio file | ||
audio_file = requests.get("https://www.nasa.gov/62283main_landing.wav").content |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
audio_file = requests.get("https://www.nasa.gov/62283main_landing.wav").content | |
audio = requests.get("https://www.nasa.gov/62283main_landing.wav").content |
nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed :)
mlflow/transformers.py
Outdated
except binascii.Error: | ||
return False | ||
|
||
def decode_sound_file(encoded): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace sound
with audio
for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored for consistency
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio: | ||
tmp_audio.write(response.content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use pathlib.Path.write_bytes
and the tmp_path
fixture here so the temp file is deleted after running tests?
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio: | |
tmp_audio.write(response.content) | |
tmp_audio = tmp_path / "audio.wav" | |
tmp_audio.write_bytes(response.content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed all of this logic and switched to pathlib.Path().read_bytes()
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
docs/source/models.rst
Outdated
@@ -2443,6 +2444,8 @@ expected input to the model to ensure your inference request can be read properl | |||
\**** The mask syntax for the model that you've chosen is going to be specific to that model's implementation. Some are '[MASK]', while others are '<mask>'. Verify the expected syntax to | |||
avoid failed inference requests. | |||
|
|||
\***** If using MLServer for realtime inference, a raw audio file in bytes format must be base64 encoded prior to submitting to the endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only apply to MLServer (https://pypi.org/project/mlserver-mlflow/ - Seldon), or does it apply to the MLflow Model Server in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhhh whoops. Serving in general. chalk this up to writing this docstring while in a meeting that someone brought up MLServer.
examples/transformers/whisper.py
Outdated
print(transcription) | ||
|
||
# Load the pipeline as a pyfunc with the audio file being encoded as base64 | ||
pyfunc_transcriber = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) | ||
|
||
pyfunc_transcription = pyfunc_transcriber.predict(base64.b64encode(audio).decode("ascii")) | ||
|
||
# Note: the pyfunc return type if `return_timestamps` is set is a JSON encoded string. | ||
print(pyfunc_transcription) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we add some text in the print()
before these transcriptions like "Whisper transcription" and "Pyfunc transcription"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea!
mlflow/ml-package-versions.yml
Outdated
"accelerate", | ||
"librosa", | ||
"ffmpeg", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, can we add some inline comments explaining which libraries are required for which functionalities?
"`base64.b64encode(<audio data bytes>).decode('ascii')`" | ||
) from e | ||
|
||
if isinstance(data, list) and all(isinstance(element, dict) for element in data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we insert an inline comment displaying the structure of the data that falls into this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @BenWilson2 !
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Add support for the AutomaticSpeechRecognitionPipeline type in
mlflow.transformers
How is this patch tested?
Manual testing in progress on 13.x runtime.
Does this PR change the documentation?
Release Notes
Is this a user-facing change?
Added support for AutomaticSpeechRecognitionPipelines (i.e., Whisper audio transcription) to the transformers flavor and added native support for the
bytes
type as input topyfunc
signature enforcement.What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes