diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 3820782db..f7bcae26e 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -818,7 +818,18 @@ def inference_with_vad(self, input, input_len=None, **cfg): spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None) ) # del result['spk_embedding'] - sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) + # postprocess expects np.ndarray embeddings (per its type hint). + spk_embedding_np = spk_embedding.detach().cpu().numpy() + if kwargs.get("return_spk_center", False): + sv_output, spk_center = postprocess( + all_segments, None, labels, spk_embedding_np, return_spk_center=True + ) + # Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in + # sentence_info. Kept on the result for downstream voiceprint use + # (the per-chunk spk_embedding below is still deleted to keep output small). + result["spk_embedding_center"] = spk_center + else: + sv_output = postprocess(all_segments, None, labels, spk_embedding_np) if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result: logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.") self.spk_mode = "vad_segment" diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py index e9f5eb4a1..d1f6972ac 100644 --- a/funasr/models/campplus/utils.py +++ b/funasr/models/campplus/utils.py @@ -138,8 +138,12 @@ def extract_feature(audio): def postprocess( - segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray -) -> list: + segments: list, + vad_segments: list, + labels: np.ndarray, + embeddings: np.ndarray, + return_spk_center: bool = False, +) -> Union[list, tuple]: """Postprocess. Args: @@ -156,13 +160,6 @@ def postprocess( # merge the same speakers chronologically distribute_res = merge_seque(distribute_res) - # accquire speaker center - spk_embs = [] - for i in range(labels.max() + 1): - spk_emb = embeddings[labels == i].mean(0) - spk_embs.append(spk_emb) - spk_embs = np.stack(spk_embs) - def is_overlapped(t1, t2): """Is overlapped. @@ -184,6 +181,14 @@ def is_overlapped(t1, t2): # smooth the result distribute_res = smooth(distribute_res) + if return_spk_center: + # spk_embs[i] is the centroid (mean of clustered chunk embeddings) for + # corrected speaker label i, aligned with the `spk` ids in sentence_info. + # Computed lazily: only when the caller requests speaker centers. + spk_embs = np.stack( + [embeddings[labels == i].mean(0) for i in range(labels.max() + 1)] + ) + return distribute_res, spk_embs return distribute_res