Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def process_audio(self, job_config: AzureJobConfig) -> List[mpf.AudioTrack]:
code = error_info.get('code')
msg = error_info.get('message')
raise mpf.DetectionError.DETECTION_FAILED.exception(
f'Transcripton failed with code "{code}" and message "{msg}".')
f'Transcription failed with code "{code}" and message "{msg}".')

transcription = self.acs.get_transcription(result)
logger.info('Speech-to-text processing complete')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class AcsServerInfo(NamedTuple):
blob_service_key: str
http_retry: mpf_util.HttpRetry
http_max_attempts: int
use_sas_auth: bool


class AzureConnection(object):
Expand All @@ -78,6 +79,7 @@ def __init__(self):
self.container_client = None
self.http_retry = None
self.http_max_attempts = None
self.use_sas_auth = False

def update_acs(self, server_info: AcsServerInfo):
self.url = server_info.url
Expand All @@ -88,6 +90,7 @@ def update_acs(self, server_info: AcsServerInfo):
'Content-Type': 'application/json',
}
self.http_retry = server_info.http_retry
self.use_sas_auth = server_info.use_sas_auth

logger.info('Retrieving valid transcription locales')
req = request.Request(
Expand Down Expand Up @@ -144,11 +147,10 @@ def upload_file_to_blob(self, audio_bytes, recording_id, blob_access_time):
'Uploading file to blob failed due to: ' + str(e)) from e

time_limit = timedelta(minutes=blob_access_time)
return '{url:s}/{recording_id:s}?{sas_url:s}'.format(
url=self.container_client.url,
recording_id=recording_id,
sas_url=self.generate_account_sas(time_limit)
)
blob_url = f'{self.container_client.url}/{recording_id}'
if self.use_sas_auth:
blob_url += '?' + self.generate_account_sas(time_limit)
return blob_url

def submit_batch_transcription(self, recording_url, job_name,
diarize, language, expiry):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _add_job_properties(self, job_properties: Mapping[str, str]):
acs_blob_container_url = self._get_job_property_or_env_value('ACS_BLOB_CONTAINER_URL', job_properties)
acs_blob_service_key = self._get_job_property_or_env_value('ACS_BLOB_SERVICE_KEY', job_properties)
http_retry = mpf_util.HttpRetry.from_properties(job_properties)
use_sas_auth = mpf_util.get_property(job_properties, 'USE_SAS_AUTH', False)

http_max_attempts = mpf_util.get_property(
properties=job_properties,
Expand All @@ -96,7 +97,8 @@ def _add_job_properties(self, job_properties: Mapping[str, str]):
blob_container_url=acs_blob_container_url,
blob_service_key=acs_blob_service_key,
http_retry=http_retry,
http_max_attempts=http_max_attempts
http_max_attempts=http_max_attempts,
use_sas_auth=use_sas_auth
)

self.blob_access_time = mpf_util.get_property(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
"description": "Trigger condition for this component when used as a downstream action, of the form `TRACK_PROPERTY=VALUE`. Example: 'SPEECH_DETECTOR=AZURESPEECH' will result in this action only being run if the feed-forward track has the property SPEECH_DETECTOR with the value AZURESPEECH.",
"type": "STRING",
"defaultValue": ""
},
{
"name": "USE_SAS_AUTH",
"description": "When true, a shared access signature (SAS) will be appended the recording URL included in the transcription request.",
"type": "BOOLEAN",
"defaultValue": "FALSE"
}
]
}
Expand Down
34 changes: 28 additions & 6 deletions python/AzureSpeechDetection/tests/test_acs_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import os
import json
import unittest
from threading import Thread
import threading
from typing import ClassVar
from http.server import HTTPServer, SimpleHTTPRequestHandler
from unittest.mock import patch
Expand All @@ -53,6 +53,7 @@
outputs_url = url_prefix + 'outputs'
models_url = url_prefix + 'models'
container_url = 'https://account_name.blob.core.endpoint.suffix/container_name'
account_sas = '[sas_url]'


def get_test_properties(**extra_properties):
Expand Down Expand Up @@ -83,6 +84,7 @@ def tearDownClass(cls):

def tearDown(self):
self.mock_server.jobs = set()
self.mock_server.sas_enabled = False

@staticmethod
def run_patched_jobs(comp, mode, *jobs):
Expand All @@ -98,18 +100,20 @@ def run_patched_jobs(comp, mode, *jobs):
patch.object(
comp.processor.acs,
'generate_account_sas',
return_value="[sas_url]"):
return_value=account_sas):
return list(map(detection_func, jobs))

def test_audio_file(self):
self.mock_server.sas_enabled = True
job = mpf.AudioJob(
job_name='test_audio',
data_uri=self._get_test_file('left.wav'),
start_time=0,
stop_time=-1,
job_properties=get_test_properties(
DIARIZE='FALSE',
LANGUAGE='en-US'
LANGUAGE='en-US',
USE_SAS_AUTH='TRUE'
),
media_properties={},
feed_forward_track=None
Expand Down Expand Up @@ -274,7 +278,19 @@ def __init__(self, base_local_path, base_url_path, server_address):
self.base_local_path = base_local_path
self.base_url_path = base_url_path
self.jobs = set()
Thread(target=self.serve_forever).start()
self._sas_enabled = False
self.sas_lock = threading.Lock()
threading.Thread(target=self.serve_forever).start()

@property
def sas_enabled(self):
with self.sas_lock:
return self._sas_enabled

@sas_enabled.setter
def sas_enabled(self, value):
with self.sas_lock:
self._sas_enabled = value


class MockRequestHandler(SimpleHTTPRequestHandler):
Expand Down Expand Up @@ -382,7 +398,7 @@ def do_DELETE(self):
'The resource you are looking for has been removed, had its '
'name changed, or is temporarily unavailable.'.encode()
)
return


def do_POST(self):
self._validate_headers()
Expand Down Expand Up @@ -423,6 +439,13 @@ def do_POST(self):
):
raise Exception('Expected wordLevelTimestampsEnabled')

recording_url = data.get('contentUrls')[0]
if self.server.sas_enabled:
if not recording_url.endswith('?' + account_sas):
raise Exception('SAS enabled, but sas not found in recording URL.')
elif '?' in recording_url:
raise Exception('SAS disabled, but sas found in recording URL.')

self.send_response(202)

location = os.path.join(
Expand All @@ -433,7 +456,6 @@ def do_POST(self):
self.server.jobs.add(data.get('displayName'))
self.send_header('Location', origin + location)
self.end_headers()
return


if __name__ == '__main__':
Expand Down