Skip to content

Commit

Permalink
feat: add wave data params check
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed May 26, 2022
1 parent 1508639 commit 80c2996
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 17 deletions.
89 changes: 73 additions & 16 deletions asrt_sdk/speech_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,36 +98,68 @@ def __init__(self, host:str, port:str, protocol:str, sub_path:str=''):
raise Exception('Unsupport netword protocol `' + protocol +'`')
self._url_ = protocol + '://' + host + ':' + port
self.sub_path = sub_path
self._wav_data_max_length = 16000 * 2 * 16

def recognite(self, wav_data, frame_rate:int, channels:int, byte_width:int) -> AsrtApiResponse:
'''
完整识别wav语音序列为文本
'''
if len(wav_data) > self._wav_data_max_length:
raise Exception('Too long wave sample byte length: `' + str(len(wav_data))
+ "`, the max length is `" + str(self._wav_data_max_length) + "`")

request_body = AsrtApiSpeechRequest(wav_data, frame_rate, channels, byte_width)
headers = {'Content-Type': 'application/json'}
http_request = get_http_session()
response_object = http_request.post(self._url_ + self.sub_path + '/all',
try:
response_object = http_request.post(self._url_ + self.sub_path + '/all',
headers=headers,
data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
if response_object.status_code != 200:
raise Exception("ASRT API server http statue code exception: " + str(response_object.status_code))
except Exception as exception_info:
raise Exception("Error to send speech recognition request to ASRT API server:" +
exception_info.__str__())

try:
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
except Exception as exception_info:
raise Exception("Unormal data format is responsed by ASRT API server with HTTP protocol: " +
exception_info.__str__())


def recognite_speech(self, wav_data, frame_rate, channels, byte_width):
'''
调用声学模型识别wav语音序列为拼音序列
'''
if len(wav_data) > self._wav_data_max_length:
raise Exception('Too long wave sample byte length: `' + str(len(wav_data))
+ "`, the max length is `" + str(self._wav_data_max_length) + "`")

request_body = AsrtApiSpeechRequest(wav_data, frame_rate, channels, byte_width)
headers = {'Content-Type': 'application/json'}
http_request = get_http_session()
response_object = http_request.post(self._url_ + self.sub_path + '/speech',
try:
response_object = http_request.post(self._url_ + self.sub_path + '/speech',
headers=headers,
data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
if response_object.status_code != 200:
raise Exception("ASRT API server http statue code exception: " + str(response_object.status_code))
except Exception as exception_info:
raise Exception("Error to send speech recognition request to ASRT API server:" +
exception_info.__str__())

try:
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
except Exception as exception_info:
raise Exception("Unormal data format is responsed by ASRT API server with HTTP protocol: " +
exception_info.__str__())

def recognite_language(self, sequence_pinyin):
'''
Expand All @@ -136,13 +168,24 @@ def recognite_language(self, sequence_pinyin):
request_body = AsrtApiLanguageRequest(sequence_pinyin)
headers = {'Content-Type': 'application/json'}
http_request = get_http_session()
response_object = http_request.post(self._url_ + self.sub_path + '/language',
try:
response_object = http_request.post(self._url_ + self.sub_path + '/language',
headers=headers,
data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
if response_object.status_code != 200:
raise Exception("ASRT API server http statue code exception: " + str(response_object.status_code))
except Exception as exception_info:
raise Exception("Error to send speech recognition request to ASRT API server:" +
exception_info.__str__())

try:
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body
except Exception as exception_info:
raise Exception("Unormal data format is responsed by ASRT API server with HTTP protocol: " +
exception_info.__str__())

def recognite_file(self, filename):
'''
Expand All @@ -151,10 +194,24 @@ def recognite_file(self, filename):
wave_data = read_wav_datas(filename)
str_data = wave_data.str_data
frame_rate = wave_data.sample_rate
if frame_rate != 16000:
raise Exception('Unsupport wave sample rate `' + str(frame_rate) +'`')

channels = wave_data.channels
if channels != 1:
raise Exception('Unsupport wave channels number `' + str(channels) +'`')

byte_width = wave_data.byte_width
return self.recognite(wav_data=str_data,
if byte_width != 2:
raise Exception('Unsupport wave byte width `' + str(byte_width) +'`')

asrt_result = list()
duration = 2*16000*10
for index in range(0, len(str_data)//duration+1):
rsp = self.recognite(wav_data=str_data[index*duration:min((index+1)*duration, len(str_data))],
frame_rate=frame_rate,
channels=channels,
byte_width=byte_width
)
asrt_result.append(rsp)
return asrt_result
4 changes: 3 additions & 1 deletion client_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
FILENAME = 'A11_0.wav'
result = speech_recognizer.recognite_file(FILENAME)
print(result)
print(result.result)
for index in range(0, len(result)):
item = result[index]
print("第", index, "段:", item.result)


wave_data = asrt_sdk.read_wav_datas(FILENAME)
Expand Down

0 comments on commit 80c2996

Please sign in to comment.