diff --git a/python/AzureSpeechDetection/acs_speech_component/acs_speech_processor.py b/python/AzureSpeechDetection/acs_speech_component/acs_speech_processor.py index 6e310745..bda7d24b 100644 --- a/python/AzureSpeechDetection/acs_speech_component/acs_speech_processor.py +++ b/python/AzureSpeechDetection/acs_speech_component/acs_speech_processor.py @@ -323,7 +323,7 @@ def process_audio(self, job_config: AzureJobConfig) -> List[mpf.AudioTrack]: code = error_info.get('code') msg = error_info.get('message') raise mpf.DetectionError.DETECTION_FAILED.exception( - f'Transcripton failed with code "{code}" and message "{msg}".') + f'Transcription failed with code "{code}" and message "{msg}".') transcription = self.acs.get_transcription(result) logger.info('Speech-to-text processing complete') diff --git a/python/AzureSpeechDetection/acs_speech_component/azure_connect.py b/python/AzureSpeechDetection/acs_speech_component/azure_connect.py index 7e1c5bbd..d0f4c1a2 100644 --- a/python/AzureSpeechDetection/acs_speech_component/azure_connect.py +++ b/python/AzureSpeechDetection/acs_speech_component/azure_connect.py @@ -66,6 +66,7 @@ class AcsServerInfo(NamedTuple): blob_service_key: str http_retry: mpf_util.HttpRetry http_max_attempts: int + use_sas_auth: bool class AzureConnection(object): @@ -78,6 +79,7 @@ def __init__(self): self.container_client = None self.http_retry = None self.http_max_attempts = None + self.use_sas_auth = False def update_acs(self, server_info: AcsServerInfo): self.url = server_info.url @@ -88,6 +90,7 @@ def update_acs(self, server_info: AcsServerInfo): 'Content-Type': 'application/json', } self.http_retry = server_info.http_retry + self.use_sas_auth = server_info.use_sas_auth logger.info('Retrieving valid transcription locales') req = request.Request( @@ -144,11 +147,10 @@ def upload_file_to_blob(self, audio_bytes, recording_id, blob_access_time): 'Uploading file to blob failed due to: ' + str(e)) from e time_limit = timedelta(minutes=blob_access_time) - return '{url:s}/{recording_id:s}?{sas_url:s}'.format( - url=self.container_client.url, - recording_id=recording_id, - sas_url=self.generate_account_sas(time_limit) - ) + blob_url = f'{self.container_client.url}/{recording_id}' + if self.use_sas_auth: + blob_url += '?' + self.generate_account_sas(time_limit) + return blob_url def submit_batch_transcription(self, recording_url, job_name, diarize, language, expiry): diff --git a/python/AzureSpeechDetection/acs_speech_component/job_parsing.py b/python/AzureSpeechDetection/acs_speech_component/job_parsing.py index 140c9dc5..7d6517c5 100644 --- a/python/AzureSpeechDetection/acs_speech_component/job_parsing.py +++ b/python/AzureSpeechDetection/acs_speech_component/job_parsing.py @@ -82,6 +82,7 @@ def _add_job_properties(self, job_properties: Mapping[str, str]): acs_blob_container_url = self._get_job_property_or_env_value('ACS_BLOB_CONTAINER_URL', job_properties) acs_blob_service_key = self._get_job_property_or_env_value('ACS_BLOB_SERVICE_KEY', job_properties) http_retry = mpf_util.HttpRetry.from_properties(job_properties) + use_sas_auth = mpf_util.get_property(job_properties, 'USE_SAS_AUTH', False) http_max_attempts = mpf_util.get_property( properties=job_properties, @@ -96,7 +97,8 @@ def _add_job_properties(self, job_properties: Mapping[str, str]): blob_container_url=acs_blob_container_url, blob_service_key=acs_blob_service_key, http_retry=http_retry, - http_max_attempts=http_max_attempts + http_max_attempts=http_max_attempts, + use_sas_auth=use_sas_auth ) self.blob_access_time = mpf_util.get_property( diff --git a/python/AzureSpeechDetection/plugin-files/descriptor/descriptor.json b/python/AzureSpeechDetection/plugin-files/descriptor/descriptor.json index c1026990..2bde0933 100644 --- a/python/AzureSpeechDetection/plugin-files/descriptor/descriptor.json +++ b/python/AzureSpeechDetection/plugin-files/descriptor/descriptor.json @@ -91,6 +91,12 @@ "description": "Trigger condition for this component when used as a downstream action, of the form `TRACK_PROPERTY=VALUE`. Example: 'SPEECH_DETECTOR=AZURESPEECH' will result in this action only being run if the feed-forward track has the property SPEECH_DETECTOR with the value AZURESPEECH.", "type": "STRING", "defaultValue": "" + }, + { + "name": "USE_SAS_AUTH", + "description": "When true, a shared access signature (SAS) will be appended the recording URL included in the transcription request.", + "type": "BOOLEAN", + "defaultValue": "FALSE" } ] } diff --git a/python/AzureSpeechDetection/tests/test_acs_speech.py b/python/AzureSpeechDetection/tests/test_acs_speech.py index 96ab47ed..99fb3586 100644 --- a/python/AzureSpeechDetection/tests/test_acs_speech.py +++ b/python/AzureSpeechDetection/tests/test_acs_speech.py @@ -29,7 +29,7 @@ import os import json import unittest -from threading import Thread +import threading from typing import ClassVar from http.server import HTTPServer, SimpleHTTPRequestHandler from unittest.mock import patch @@ -53,6 +53,7 @@ outputs_url = url_prefix + 'outputs' models_url = url_prefix + 'models' container_url = 'https://account_name.blob.core.endpoint.suffix/container_name' +account_sas = '[sas_url]' def get_test_properties(**extra_properties): @@ -83,6 +84,7 @@ def tearDownClass(cls): def tearDown(self): self.mock_server.jobs = set() + self.mock_server.sas_enabled = False @staticmethod def run_patched_jobs(comp, mode, *jobs): @@ -98,10 +100,11 @@ def run_patched_jobs(comp, mode, *jobs): patch.object( comp.processor.acs, 'generate_account_sas', - return_value="[sas_url]"): + return_value=account_sas): return list(map(detection_func, jobs)) def test_audio_file(self): + self.mock_server.sas_enabled = True job = mpf.AudioJob( job_name='test_audio', data_uri=self._get_test_file('left.wav'), @@ -109,7 +112,8 @@ def test_audio_file(self): stop_time=-1, job_properties=get_test_properties( DIARIZE='FALSE', - LANGUAGE='en-US' + LANGUAGE='en-US', + USE_SAS_AUTH='TRUE' ), media_properties={}, feed_forward_track=None @@ -274,7 +278,19 @@ def __init__(self, base_local_path, base_url_path, server_address): self.base_local_path = base_local_path self.base_url_path = base_url_path self.jobs = set() - Thread(target=self.serve_forever).start() + self._sas_enabled = False + self.sas_lock = threading.Lock() + threading.Thread(target=self.serve_forever).start() + + @property + def sas_enabled(self): + with self.sas_lock: + return self._sas_enabled + + @sas_enabled.setter + def sas_enabled(self, value): + with self.sas_lock: + self._sas_enabled = value class MockRequestHandler(SimpleHTTPRequestHandler): @@ -382,7 +398,7 @@ def do_DELETE(self): 'The resource you are looking for has been removed, had its ' 'name changed, or is temporarily unavailable.'.encode() ) - return + def do_POST(self): self._validate_headers() @@ -423,6 +439,13 @@ def do_POST(self): ): raise Exception('Expected wordLevelTimestampsEnabled') + recording_url = data.get('contentUrls')[0] + if self.server.sas_enabled: + if not recording_url.endswith('?' + account_sas): + raise Exception('SAS enabled, but sas not found in recording URL.') + elif '?' in recording_url: + raise Exception('SAS disabled, but sas found in recording URL.') + self.send_response(202) location = os.path.join( @@ -433,7 +456,6 @@ def do_POST(self): self.server.jobs.add(data.get('displayName')) self.send_header('Location', origin + location) self.end_headers() - return if __name__ == '__main__':