Skip to content

fix(cpu): enable robust CPU support and address PR feedback#2958

Open
Manamama-Gemini-Cloud-AI-01 wants to merge 10 commits into
modelscope:mainfrom
Manamama-Gemini-Cloud-AI-01:cpu-support-patch-v3
Open

fix(cpu): enable robust CPU support and address PR feedback#2958
Manamama-Gemini-Cloud-AI-01 wants to merge 10 commits into
modelscope:mainfrom
Manamama-Gemini-Cloud-AI-01:cpu-support-patch-v3

Conversation

@Manamama-Gemini-Cloud-AI-01

Copy link
Copy Markdown
Contributor

This PR enables robust CPU support for the FunASR Nano real-time server.

Key changes:

  • Dynamically bypasses vLLM when the device is set to a non-CUDA device (e.g., 'cpu').
  • Implements robust inference result extraction to handle different return types between AutoModel (PyTorch) and AutoModelVLLM.
  • Improves tokenizer access safety and input buffer handling (AI review feedback).
  • Optimizes hallucination detection.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a streaming WebSocket server for Fun-ASR-Nano, featuring VAD segmentation, ASR decoding, speaker diarization, and hallucination detection. Key feedback focuses on optimizing the audio buffer concatenation to prevent O(N^2) copy overhead, correcting the parameter name hotwords to hotword for standard AutoModel compatibility, fixing the hallucination truncation logic to keep exactly one occurrence of repeated patterns, and improving robustness through safer attribute access and exception handling around tokenizer operations.

Comment on lines +249 to +271
def add_audio(self, pcm_bytes):
if len(pcm_bytes) % 2 != 0:
pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)]
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
audio_float = audio_int16.astype(np.float32) / 32768.0
self.audio_buffer = np.concatenate([self.audio_buffer, audio_float])

new_audio = self.audio_buffer[self.vad_fed_samples:]
if len(new_audio) > 0:
new_confirmed = self.vad.feed(torch.from_numpy(new_audio).float(), is_final=False)
self.vad_fed_samples = len(self.audio_buffer)

for seg in new_confirmed:
seg_text = self._decode_segment(seg)
self.prev_text = ""
if not seg_text.strip():
continue
self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])})
if self.spk_tracker:
s0 = int(seg[0] * self.sample_rate / 1000)
s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer))
self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1])
logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Appending and concatenating the audio buffer on every single incoming WebSocket packet (which can be as frequent as every 10-20ms) results in an O(N^2) copy overhead. Over a long streaming session, this will cause severe CPU spikes, massive latency, and potential server freezes.

We can optimize this by keeping a list of chunks (self.audio_chunks) and only concatenating them on-demand when the contiguous audio_buffer is actually accessed (e.g., during decoding or speaker diarization). This is implemented cleanly using a property and setter.

    @property
    def audio_buffer(self):
        if getattr(self, '_audio_buffer_dirty', False):
            if self.audio_chunks:
                self._audio_buffer = np.concatenate([self._audio_buffer, *self.audio_chunks])
                self.audio_chunks = []
            self._audio_buffer_dirty = False
        return self._audio_buffer

    @audio_buffer.setter
    def audio_buffer(self, value):
        self._audio_buffer = value
        self.audio_chunks = []
        self._audio_buffer_dirty = False

    def add_audio(self, pcm_bytes):
        if len(pcm_bytes) % 2 != 0:
            pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)]
        if not pcm_bytes:
            return
        audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
        audio_float = audio_int16.astype(np.float32) / 32768.0
        
        self.audio_chunks.append(audio_float)
        self._audio_buffer_dirty = True

        new_confirmed = self.vad.feed(torch.from_numpy(audio_float).float(), is_final=False)
        self.vad_fed_samples += len(audio_float)

        for seg in new_confirmed:
            seg_text = self._decode_segment(seg)
            self.prev_text = ""
            if not seg_text.strip():
                continue
            self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])})
            if self.spk_tracker:
                s0 = int(seg[0] * self.sample_rate / 1000)
                s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer))
                self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1])
            logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"")

Comment on lines +320 to +324
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Standard AutoModel (used on non-CUDA devices like CPU) expects the parameter name hotword (singular) instead of hotwords (plural) as documented in auto_model.py. Passing hotwords will be ignored, causing hotword customization to fail on CPU.

Suggested change
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
results = self.vllm_engine.generate(
input=audio_tensor,
hotword=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Comment on lines +385 to +389
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, standard AutoModel expects hotword instead of hotwords here.

Suggested change
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
results = self.vllm_engine.generate(
input=audio_tensor,
hotword=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Comment on lines +41 to +44
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of hallucination truncation keeps two occurrences of the repeated pattern instead of one, because it slices up to end_pos + len(repeated). Slicing up to end_pos (the start of the second occurrence) will correctly keep exactly one occurrence of the pattern as intended by the docstring.

Suggested change
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos], True

Comment on lines +53 to +57
pos = text.find(repeated)
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, this truncation keeps two occurrences of the repeated pattern instead of one. Slicing up to end_pos will correctly keep exactly one occurrence.

Suggested change
pos = text.find(repeated)
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos], True

Comment on lines +354 to +358
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = self.vllm_engine.kwargs.get("tokenizer")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr with a default value is safer to prevent potential AttributeError if self.vllm_engine does not have the kwargs attribute.

Suggested change
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = self.vllm_engine.kwargs.get("tokenizer")
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = getattr(self.vllm_engine, 'kwargs', {}).get("tokenizer")

Comment on lines +359 to +369
if tokenizer is not None:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
else:
self.prev_text = ""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wrapping the tokenizer's encode/decode operations in a try...except block prevents the entire WebSocket session from crashing if the tokenizer fails or has an unexpected interface.

Suggested change
if tokenizer is not None:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
else:
self.prev_text = ""
if tokenizer is not None:
try:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
except Exception as e:
logger.warning(f"Failed to encode/decode text with tokenizer: {e}")
self.prev_text = ""
else:
self.prev_text = ""

@LauraGPT LauraGPT left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! A few concerns:

  1. The example files (client_python.py, demo_vllm.py, serve_realtime_ws.py) appear to be full replacements (+1015 additions, 0 deletions). This means the diff shows the entire file as new, making it very hard to review what actually changed. Could you rebase on the latest main so the diff shows only your actual modifications?

  2. The encoder_conf None check in fsmn_vad_streaming/model.py is a good fix — this part looks clean.

  3. Scope: The PR title says "CPU support" but it also includes hallucination detection optimization, tokenizer changes, and buffer handling. Could you split these into separate PRs? It makes reviewing much easier.

Please rebase on latest main so we can see the incremental changes. The CPU fallback (bypassing vLLM on non-CUDA) is a useful feature worth merging, but we need to review it as a clean diff.

LauraGPT added a commit that referenced this pull request Jun 9, 2026
Prevents TypeError when encoder_conf is not provided.
Cherry-picked from #2958.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants