Skip to content

Commit

Permalink
exp with backend
Browse files Browse the repository at this point in the history
  • Loading branch information
nauman-daw committed Mar 29, 2021
1 parent 7dd9fa7 commit ea9f39d
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions speechbrain/processing/diarization.py
Expand Up @@ -1120,3 +1120,79 @@ def do_kmeans_clustering(

# logger.info("Completed diarizing " + rec_id)
write_rttm(lol, out_rttm_file)


def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3):
"""
Performs spectral clustering on embeddings. This function calls specific
clustering algorithms as per affinity.
Arguments
---------
diary_obj : StatObject_SB type
Contains embeddings in diary_obj.stat1 and segment IDs in diary_obj.segset.
out_rttm_file : str
Path of the output RTTM file.
rec_id : str
Recording ID for the recording under processing.
k : int
Number of speaker (None, if it has to be estimated).
pval : float
`pval` for prunning affinity matrix. Used only when number of speakers
are unknown. Note that this is just for experiment. Prefer Spectral clustering
for better clustering results.
"""

from sklearn.cluster import AgglomerativeClustering

if k_oracle is not None:
print("ORACLE SPKRs")
num_of_spk = k_oracle
clustering = AgglomerativeClustering(
n_clusters=num_of_spk, affinity="cosine", linkage="ward"
).fit(diary_obj.stat1)
labels = clustering.labels_
print("labels.shape (Ora.) = ", labels.shape)
else:
print("Using AHC threshold pval = ", p_val)
# Estimate num of using max eigen gap with `cos` affinity matrix.
# This is just for experimentation.
clustering = AgglomerativeClustering(
n_clusters=None,
affinity="cosine",
linkage="ward",
distance_threshold=p_val,
).fit(diary_obj.stat1)
labels = clustering.labels_
print("labels.shape (Est.) = ", labels.shape)

# Convert labels to speaker boundaries
subseg_ids = diary_obj.segset
lol = []

for i in range(labels.shape[0]):
spkr_id = rec_id + "_" + str(labels[i])

sub_seg = subseg_ids[i]

splitted = sub_seg.rsplit("_", 2)
rec_id = str(splitted[0])
sseg_start = float(splitted[1])
sseg_end = float(splitted[2])

a = [rec_id, sseg_start, sseg_end, spkr_id]
lol.append(a)

# Sorting based on start time of sub-segment
lol.sort(key=lambda x: float(x[1]))

# Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers
# Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster)
lol = merge_ssegs_same_speaker(lol)

# Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster)
# Taking mid-point as the splitting time location.
lol = distribute_overlap(lol)

# logger.info("Completed diarizing " + rec_id)
write_rttm(lol, out_rttm_file)

0 comments on commit ea9f39d

Please sign in to comment.