Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,18 @@ def inference_with_vad(self, input, input_len=None, **cfg):
spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
)
# del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
# postprocess expects np.ndarray embeddings (per its type hint).
spk_embedding_np = spk_embedding.detach().cpu().numpy()
if kwargs.get("return_spk_center", False):
sv_output, spk_center = postprocess(
all_segments, None, labels, spk_embedding_np, return_spk_center=True
)
# Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in
# sentence_info. Kept on the result for downstream voiceprint use
# (the per-chunk spk_embedding below is still deleted to keep output small).
result["spk_embedding_center"] = spk_center
else:
sv_output = postprocess(all_segments, None, labels, spk_embedding_np)
if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result:
logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.")
self.spk_mode = "vad_segment"
Expand Down
23 changes: 14 additions & 9 deletions funasr/models/campplus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,12 @@ def extract_feature(audio):


def postprocess(
segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
) -> list:
segments: list,
vad_segments: list,
labels: np.ndarray,
embeddings: np.ndarray,
return_spk_center: bool = False,
) -> Union[list, tuple]:
"""Postprocess.

Args:
Expand All @@ -156,13 +160,6 @@ def postprocess(
# merge the same speakers chronologically
distribute_res = merge_seque(distribute_res)

# accquire speaker center
spk_embs = []
for i in range(labels.max() + 1):
spk_emb = embeddings[labels == i].mean(0)
spk_embs.append(spk_emb)
spk_embs = np.stack(spk_embs)

def is_overlapped(t1, t2):
"""Is overlapped.

Expand All @@ -184,6 +181,14 @@ def is_overlapped(t1, t2):
# smooth the result
distribute_res = smooth(distribute_res)

if return_spk_center:
# spk_embs[i] is the centroid (mean of clustered chunk embeddings) for
# corrected speaker label i, aligned with the `spk` ids in sentence_info.
# Computed lazily: only when the caller requests speaker centers.
spk_embs = np.stack(
[embeddings[labels == i].mean(0) for i in range(labels.max() + 1)]
)
return distribute_res, spk_embs
return distribute_res
Comment thread
phoenixray2000 marked this conversation as resolved.


Expand Down