Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for audio transcription pipelines in transformers #8464

Merged
merged 9 commits into from
May 20, 2023

Conversation

BenWilson2
Copy link
Member

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Add support for the AutomaticSpeechRecognitionPipeline type in mlflow.transformers

How is this patch tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests (describe details, including test results, below)

Manual testing in progress on 13.x runtime.

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly in the documentation preview.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Added support for AutomaticSpeechRecognitionPipelines (i.e., Whisper audio transcription) to the transformers flavor and added native support for the bytes type as input to pyfunc signature enforcement.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@pytest.fixture()
def sound_file_for_test():
url = "https://www.nasa.gov/62282main_countdown_launch.wav"
response = requests.get(url)
Copy link
Member

@harupy harupy May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
response = requests.get(url)
response = requests.get(url)
response.raise_for_status()

to avoid encoutering an unclear error when the request fails and the content attribute is not audio bytes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the requests logic as we're just going to load from relative path parsing to the datasets folder to retrieve the static .wav file

@@ -316,6 +320,28 @@ def image_for_test():
return dataset["test"]["image"][0]


@pytest.fixture()
def sound_file_for_test():
url = "https://www.nasa.gov/62282main_countdown_launch.wav"
Copy link
Member

@harupy harupy May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How large is this file? If it's small, can we include it in the repository?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a few MB. I'll add the raw bytes into tests/datasets and we can just use that instead of calling up NASA ;) Good call.

@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs. labels May 19, 2023
@mlflow-automation
Copy link
Collaborator

mlflow-automation commented May 19, 2023

Documentation preview for 7d023f2 will be available here when this CircleCI job completes successfully.

More info

Comment on lines 2236 to 2237
encoded_sound_file = list(data[0].values())[0]
return decode_sound_file(encoded_sound_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoded_sound_file = list(data[0].values())[0]
return decode_sound_file(encoded_sound_file)
encoded_audio = list(data[0].values())[0]
return decode_sound_file(encoded_audio)

Can we rename this variable because it's not a file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point. Changed!



# Acquire an audio file
audio_file = requests.get("https://www.nasa.gov/62283main_landing.wav").content
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
audio_file = requests.get("https://www.nasa.gov/62283main_landing.wav").content
audio = requests.get("https://www.nasa.gov/62283main_landing.wav").content

nit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed :)

except binascii.Error:
return False

def decode_sound_file(encoded):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace sound with audio for consistency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored for consistency

Comment on lines 327 to 328
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio:
tmp_audio.write(response.content)
Copy link
Member

@harupy harupy May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use pathlib.Path.write_bytes and the tmp_path fixture here so the temp file is deleted after running tests?

Suggested change
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio:
tmp_audio.write(response.content)
tmp_audio = tmp_path / "audio.wav"
tmp_audio.write_bytes(response.content)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed all of this logic and switched to pathlib.Path().read_bytes()

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@@ -2443,6 +2444,8 @@ expected input to the model to ensure your inference request can be read properl
\**** The mask syntax for the model that you've chosen is going to be specific to that model's implementation. Some are '[MASK]', while others are '<mask>'. Verify the expected syntax to
avoid failed inference requests.

\***** If using MLServer for realtime inference, a raw audio file in bytes format must be base64 encoded prior to submitting to the endpoint.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only apply to MLServer (https://pypi.org/project/mlserver-mlflow/ - Seldon), or does it apply to the MLflow Model Server in general?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhhh whoops. Serving in general. chalk this up to writing this docstring while in a meeting that someone brought up MLServer.

Comment on lines 44 to 52
print(transcription)

# Load the pipeline as a pyfunc with the audio file being encoded as base64
pyfunc_transcriber = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)

pyfunc_transcription = pyfunc_transcriber.predict(base64.b64encode(audio).decode("ascii"))

# Note: the pyfunc return type if `return_timestamps` is set is a JSON encoded string.
print(pyfunc_transcription)
Copy link
Collaborator

@dbczumar dbczumar May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we add some text in the print() before these transcriptions like "Whisper transcription" and "Pyfunc transcription"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea!

Comment on lines 490 to 492
"accelerate",
"librosa",
"ffmpeg",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, can we add some inline comments explaining which libraries are required for which functionalities?

"`base64.b64encode(<audio data bytes>).decode('ascii')`"
) from e

if isinstance(data, list) and all(isinstance(element, dict) for element in data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we insert an inline comment displaying the structure of the data that falls into this case?

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @BenWilson2 !

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@BenWilson2 BenWilson2 enabled auto-merge (squash) May 20, 2023 01:23
@BenWilson2 BenWilson2 merged commit 205babd into mlflow:master May 20, 2023
35 checks passed
@BenWilson2 BenWilson2 deleted the transformers-audio branch May 20, 2023 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants