Skip to content
Merged
179 changes: 165 additions & 14 deletions backend/python/qwen-asr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,156 @@ def LoadModel(self, request, context):

return backend_pb2.Result(message="Model loaded successfully", success=True)

@staticmethod
def _is_cjk(ch):
"""Check if a character is CJK (Chinese/Japanese/Korean)."""
cp = ord(ch)
return (
0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
or 0x3400 <= cp <= 0x4DBF # Extension A
or 0x20000 <= cp <= 0x2A6DF # Extension B
or 0xF900 <= cp <= 0xFAFF # Compatibility Ideographs
or 0x3040 <= cp <= 0x309F # Hiragana
or 0x30A0 <= cp <= 0x30FF # Katakana
or 0xAC00 <= cp <= 0xD7AF # Hangul Syllables
)

@staticmethod
def _is_punct(ch):
"""Check if a character is punctuation (no space before it)."""
import unicodedata
cat = unicodedata.category(ch)
return cat.startswith('P')

@staticmethod
def _smart_join(tokens):
"""Join tokens with spaces for non-CJK text, without spaces for CJK.

Rules:
- Between two CJK chars: no space
- Between two non-CJK tokens: space
- Before punctuation: no space
- CJK adjacent to non-CJK: no space (smooth mixed-text transition)
"""
if not tokens:
return ""
result = [tokens[0]]
for token in tokens[1:]:
if not token:
continue
prev_ch = result[-1][-1] if result[-1] else ''
curr_ch = token[0]
# Punctuation never gets a space before it
if BackendServicer._is_punct(curr_ch):
result.append(token)
# CJK to CJK: no space
elif prev_ch and BackendServicer._is_cjk(prev_ch) and BackendServicer._is_cjk(curr_ch):
result.append(token)
# CJK adjacent to non-CJK or vice versa: no space
elif prev_ch and (BackendServicer._is_cjk(prev_ch) or BackendServicer._is_cjk(curr_ch)):
result.append(token)
# Both non-CJK (Latin, Cyrillic, etc.): add space
else:
result.append(' ' + token)
return "".join(result)

@staticmethod
def _extract_word_info(ts):
"""Return (start_sec, end_sec, text) from a ForcedAlignItem or tuple."""
if hasattr(ts, 'start_time') and hasattr(ts, 'end_time') and hasattr(ts, 'text'):
return (
float(ts.start_time) if ts.start_time is not None else 0.0,
float(ts.end_time) if ts.end_time is not None else 0.0,
str(ts.text) if ts.text else "",
)
elif isinstance(ts, (list, tuple)) and len(ts) >= 3:
return (
float(ts[0]) if ts[0] is not None else 0.0,
float(ts[1]) if ts[1] is not None else 0.0,
ts[2] if len(ts) > 2 and ts[2] is not None else "",
)
return (0.0, 0.0, "")

@staticmethod
def _compute_gap_threshold(time_stamps):
"""Compute a gap threshold for sentence boundary detection.

Uses the median inter-item gap multiplied by a factor, with a
minimum floor of 0.3s. Returns 0 if there are too few items.
"""
if len(time_stamps) < 2:
return 0.0
gaps = []
for i in range(1, len(time_stamps)):
prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1])
curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i])
gaps.append(curr_s - prev_e)
if not gaps:
return 0.0
gaps.sort()
median = gaps[len(gaps) // 2]
# threshold = max(median * 4, 0.3s)
return max(median * 4, 0.3)

def _build_segments(self, time_stamps, granularity):
"""Build TranscriptSegment list from forced-aligner output.

granularity:
- "word": one segment per aligned item (character / word)
- "segment" (default): merge consecutive items, splitting at
time gaps that exceed a dynamic threshold (sentence boundaries).
"""
if granularity == "word":
result = []
for idx, ts in enumerate(time_stamps):
s, e, t = self._extract_word_info(ts)
result.append(backend_pb2.TranscriptSegment(
id=idx,
start=int(s * 1_000_000_000),
end=int(e * 1_000_000_000),
text=t,
))
return result

# segment mode — merge at time-gap boundaries
threshold = self._compute_gap_threshold(time_stamps)
result = []
buf_text = []
buf_start = None
buf_end = 0.0
prev_end = None

for ts in time_stamps:
s, e, t = self._extract_word_info(ts)

# Detect sentence boundary via time gap
if prev_end is not None and (s - prev_end) >= threshold and buf_text:
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))
buf_text = []
buf_start = None

if buf_start is None:
buf_start = s
buf_text.append(t)
buf_end = e
prev_end = e

# flush remaining
if buf_text and buf_start is not None:
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))

return result

def AudioTranscription(self, request, context):
result_segments = []
text = ""
Expand All @@ -147,11 +297,22 @@ def AudioTranscription(self, request, context):
if request.language and request.language.strip():
language = request.language.strip()

context = ""
ctx = ""
if request.prompt and request.prompt.strip():
context = request.prompt.strip()
ctx = request.prompt.strip()

# Determine requested granularity (default: segment)
granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
granularity = "word" if "word" in granularities else "segment"

results = self.model.transcribe(audio=audio_path, language=language, context=context)
has_aligner = getattr(self.model, 'forced_aligner', None) is not None
try:
results = self.model.transcribe(
audio=audio_path, language=language, context=ctx,
return_time_stamps=has_aligner,
)
except TypeError:
results = self.model.transcribe(audio=audio_path, language=language, context=ctx)

if not results:
return backend_pb2.TranscriptResult(segments=[], text="")
Expand All @@ -160,17 +321,7 @@ def AudioTranscription(self, request, context):
text = r.text or ""

if getattr(r, 'time_stamps', None) and len(r.time_stamps) > 0:
for idx, ts in enumerate(r.time_stamps):
start_ms = 0
end_ms = 0
seg_text = text
if isinstance(ts, (list, tuple)) and len(ts) >= 3:
start_ms = int(float(ts[0]) * 1000) if ts[0] is not None else 0
end_ms = int(float(ts[1]) * 1000) if ts[1] is not None else 0
seg_text = ts[2] if len(ts) > 2 and ts[2] is not None else ""
result_segments.append(backend_pb2.TranscriptSegment(
id=idx, start=start_ms, end=end_ms, text=seg_text
))
result_segments = self._build_segments(r.time_stamps, granularity)
else:
if text:
result_segments.append(backend_pb2.TranscriptSegment(
Expand Down
Loading