From a6fa3a8ad017195c6540bad4c2f10d46e54a7c89 Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Sun, 7 Jun 2026 02:33:59 +0800 Subject: [PATCH 1/2] feat(spk): optionally return per-speaker embedding centroids Add a return_spk_center option so AutoModel.generate surfaces the per-speaker centroid embeddings (mean of clustered chunk embeddings) that diarization already computes in postprocess() but currently discards. Lets downstream speaker voiceprint / identity reuse them without re-embedding. Backward compatible: default off; postprocess return shape is unchanged unless return_spk_center=True. Co-Authored-By: Claude Opus 4.8 --- funasr/auto/auto_model.py | 11 ++++++++++- funasr/models/campplus/utils.py | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 3820782db..9789a04f0 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -818,7 +818,16 @@ def inference_with_vad(self, input, input_len=None, **cfg): spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None) ) # del result['spk_embedding'] - sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) + if kwargs.get("return_spk_center", False): + sv_output, spk_center = postprocess( + all_segments, None, labels, spk_embedding.cpu(), return_spk_center=True + ) + # Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in + # sentence_info. Kept on the result for downstream voiceprint use + # (the per-chunk spk_embedding below is still deleted to keep output small). + result["spk_embedding_center"] = spk_center + else: + sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result: logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.") self.spk_mode = "vad_segment" diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py index e9f5eb4a1..a3314f745 100644 --- a/funasr/models/campplus/utils.py +++ b/funasr/models/campplus/utils.py @@ -138,7 +138,11 @@ def extract_feature(audio): def postprocess( - segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray + segments: list, + vad_segments: list, + labels: np.ndarray, + embeddings: np.ndarray, + return_spk_center: bool = False, ) -> list: """Postprocess. @@ -184,6 +188,10 @@ def is_overlapped(t1, t2): # smooth the result distribute_res = smooth(distribute_res) + if return_spk_center: + # spk_embs[i] is the centroid (mean of clustered chunk embeddings) for + # corrected speaker label i, aligned with the `spk` ids in sentence_info. + return distribute_res, spk_embs return distribute_res From efc2347345a3422644cc5ca7fb9dc195765c919d Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Sun, 7 Jun 2026 02:49:16 +0800 Subject: [PATCH 2/2] refactor(spk): address review on spk_embedding_center - pass np.ndarray (not torch.Tensor) to postprocess to match its type hint - update postprocess return hint to Union[list, tuple] - compute spk_embs lazily, only when return_spk_center=True Co-Authored-By: Claude Opus 4.8 --- funasr/auto/auto_model.py | 6 ++++-- funasr/models/campplus/utils.py | 13 +++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 9789a04f0..f7bcae26e 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -818,16 +818,18 @@ def inference_with_vad(self, input, input_len=None, **cfg): spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None) ) # del result['spk_embedding'] + # postprocess expects np.ndarray embeddings (per its type hint). + spk_embedding_np = spk_embedding.detach().cpu().numpy() if kwargs.get("return_spk_center", False): sv_output, spk_center = postprocess( - all_segments, None, labels, spk_embedding.cpu(), return_spk_center=True + all_segments, None, labels, spk_embedding_np, return_spk_center=True ) # Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in # sentence_info. Kept on the result for downstream voiceprint use # (the per-chunk spk_embedding below is still deleted to keep output small). result["spk_embedding_center"] = spk_center else: - sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) + sv_output = postprocess(all_segments, None, labels, spk_embedding_np) if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result: logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.") self.spk_mode = "vad_segment" diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py index a3314f745..d1f6972ac 100644 --- a/funasr/models/campplus/utils.py +++ b/funasr/models/campplus/utils.py @@ -143,7 +143,7 @@ def postprocess( labels: np.ndarray, embeddings: np.ndarray, return_spk_center: bool = False, -) -> list: +) -> Union[list, tuple]: """Postprocess. Args: @@ -160,13 +160,6 @@ def postprocess( # merge the same speakers chronologically distribute_res = merge_seque(distribute_res) - # accquire speaker center - spk_embs = [] - for i in range(labels.max() + 1): - spk_emb = embeddings[labels == i].mean(0) - spk_embs.append(spk_emb) - spk_embs = np.stack(spk_embs) - def is_overlapped(t1, t2): """Is overlapped. @@ -191,6 +184,10 @@ def is_overlapped(t1, t2): if return_spk_center: # spk_embs[i] is the centroid (mean of clustered chunk embeddings) for # corrected speaker label i, aligned with the `spk` ids in sentence_info. + # Computed lazily: only when the caller requests speaker centers. + spk_embs = np.stack( + [embeddings[labels == i].mean(0) for i in range(labels.max() + 1)] + ) return distribute_res, spk_embs return distribute_res