diff --git a/backend/python/qwen-asr/backend.py b/backend/python/qwen-asr/backend.py index 2d1940afc053..196f8f439fb4 100644 --- a/backend/python/qwen-asr/backend.py +++ b/backend/python/qwen-asr/backend.py @@ -134,6 +134,156 @@ def LoadModel(self, request, context): return backend_pb2.Result(message="Model loaded successfully", success=True) + @staticmethod + def _is_cjk(ch): + """Check if a character is CJK (Chinese/Japanese/Korean).""" + cp = ord(ch) + return ( + 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs + or 0x3400 <= cp <= 0x4DBF # Extension A + or 0x20000 <= cp <= 0x2A6DF # Extension B + or 0xF900 <= cp <= 0xFAFF # Compatibility Ideographs + or 0x3040 <= cp <= 0x309F # Hiragana + or 0x30A0 <= cp <= 0x30FF # Katakana + or 0xAC00 <= cp <= 0xD7AF # Hangul Syllables + ) + + @staticmethod + def _is_punct(ch): + """Check if a character is punctuation (no space before it).""" + import unicodedata + cat = unicodedata.category(ch) + return cat.startswith('P') + + @staticmethod + def _smart_join(tokens): + """Join tokens with spaces for non-CJK text, without spaces for CJK. + + Rules: + - Between two CJK chars: no space + - Between two non-CJK tokens: space + - Before punctuation: no space + - CJK adjacent to non-CJK: no space (smooth mixed-text transition) + """ + if not tokens: + return "" + result = [tokens[0]] + for token in tokens[1:]: + if not token: + continue + prev_ch = result[-1][-1] if result[-1] else '' + curr_ch = token[0] + # Punctuation never gets a space before it + if BackendServicer._is_punct(curr_ch): + result.append(token) + # CJK to CJK: no space + elif prev_ch and BackendServicer._is_cjk(prev_ch) and BackendServicer._is_cjk(curr_ch): + result.append(token) + # CJK adjacent to non-CJK or vice versa: no space + elif prev_ch and (BackendServicer._is_cjk(prev_ch) or BackendServicer._is_cjk(curr_ch)): + result.append(token) + # Both non-CJK (Latin, Cyrillic, etc.): add space + else: + result.append(' ' + token) + return "".join(result) + + @staticmethod + def _extract_word_info(ts): + """Return (start_sec, end_sec, text) from a ForcedAlignItem or tuple.""" + if hasattr(ts, 'start_time') and hasattr(ts, 'end_time') and hasattr(ts, 'text'): + return ( + float(ts.start_time) if ts.start_time is not None else 0.0, + float(ts.end_time) if ts.end_time is not None else 0.0, + str(ts.text) if ts.text else "", + ) + elif isinstance(ts, (list, tuple)) and len(ts) >= 3: + return ( + float(ts[0]) if ts[0] is not None else 0.0, + float(ts[1]) if ts[1] is not None else 0.0, + ts[2] if len(ts) > 2 and ts[2] is not None else "", + ) + return (0.0, 0.0, "") + + @staticmethod + def _compute_gap_threshold(time_stamps): + """Compute a gap threshold for sentence boundary detection. + + Uses the median inter-item gap multiplied by a factor, with a + minimum floor of 0.3s. Returns 0 if there are too few items. + """ + if len(time_stamps) < 2: + return 0.0 + gaps = [] + for i in range(1, len(time_stamps)): + prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1]) + curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i]) + gaps.append(curr_s - prev_e) + if not gaps: + return 0.0 + gaps.sort() + median = gaps[len(gaps) // 2] + # threshold = max(median * 4, 0.3s) + return max(median * 4, 0.3) + + def _build_segments(self, time_stamps, granularity): + """Build TranscriptSegment list from forced-aligner output. + + granularity: + - "word": one segment per aligned item (character / word) + - "segment" (default): merge consecutive items, splitting at + time gaps that exceed a dynamic threshold (sentence boundaries). + """ + if granularity == "word": + result = [] + for idx, ts in enumerate(time_stamps): + s, e, t = self._extract_word_info(ts) + result.append(backend_pb2.TranscriptSegment( + id=idx, + start=int(s * 1_000_000_000), + end=int(e * 1_000_000_000), + text=t, + )) + return result + + # segment mode — merge at time-gap boundaries + threshold = self._compute_gap_threshold(time_stamps) + result = [] + buf_text = [] + buf_start = None + buf_end = 0.0 + prev_end = None + + for ts in time_stamps: + s, e, t = self._extract_word_info(ts) + + # Detect sentence boundary via time gap + if prev_end is not None and (s - prev_end) >= threshold and buf_text: + result.append(backend_pb2.TranscriptSegment( + id=len(result), + start=int(buf_start * 1_000_000_000), + end=int(buf_end * 1_000_000_000), + text=self._smart_join(buf_text), + )) + buf_text = [] + buf_start = None + + if buf_start is None: + buf_start = s + buf_text.append(t) + buf_end = e + prev_end = e + + # flush remaining + if buf_text and buf_start is not None: + result.append(backend_pb2.TranscriptSegment( + id=len(result), + start=int(buf_start * 1_000_000_000), + end=int(buf_end * 1_000_000_000), + text=self._smart_join(buf_text), + )) + + return result + def AudioTranscription(self, request, context): result_segments = [] text = "" @@ -147,11 +297,22 @@ def AudioTranscription(self, request, context): if request.language and request.language.strip(): language = request.language.strip() - context = "" + ctx = "" if request.prompt and request.prompt.strip(): - context = request.prompt.strip() + ctx = request.prompt.strip() + + # Determine requested granularity (default: segment) + granularities = list(request.timestamp_granularities) if request.timestamp_granularities else [] + granularity = "word" if "word" in granularities else "segment" - results = self.model.transcribe(audio=audio_path, language=language, context=context) + has_aligner = getattr(self.model, 'forced_aligner', None) is not None + try: + results = self.model.transcribe( + audio=audio_path, language=language, context=ctx, + return_time_stamps=has_aligner, + ) + except TypeError: + results = self.model.transcribe(audio=audio_path, language=language, context=ctx) if not results: return backend_pb2.TranscriptResult(segments=[], text="") @@ -160,17 +321,7 @@ def AudioTranscription(self, request, context): text = r.text or "" if getattr(r, 'time_stamps', None) and len(r.time_stamps) > 0: - for idx, ts in enumerate(r.time_stamps): - start_ms = 0 - end_ms = 0 - seg_text = text - if isinstance(ts, (list, tuple)) and len(ts) >= 3: - start_ms = int(float(ts[0]) * 1000) if ts[0] is not None else 0 - end_ms = int(float(ts[1]) * 1000) if ts[1] is not None else 0 - seg_text = ts[2] if len(ts) > 2 and ts[2] is not None else "" - result_segments.append(backend_pb2.TranscriptSegment( - id=idx, start=start_ms, end=end_ms, text=seg_text - )) + result_segments = self._build_segments(r.time_stamps, granularity) else: if text: result_segments.append(backend_pb2.TranscriptSegment(