-
Notifications
You must be signed in to change notification settings - Fork 4
/
segment.py
337 lines (282 loc) · 9.97 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import argparse
from dataclasses import dataclass
from multiprocessing import cpu_count
from pathlib import Path
import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm
from constants import HIDDEN_SIZE, TARGET_SAMPLE_RATE
from data import FixedSegmentationDatasetNoTarget, segm_collate_fn
from eval import infer
from models import SegmentationFrameClassifer, prepare_wav2vec
@dataclass
class Segment:
start: float
end: float
probs: np.array
decimal: int = 4
@property
def duration(self):
return float(round((self.end - self.start) / TARGET_SAMPLE_RATE, self.decimal))
@property
def offset(self):
return float(round(self.start / TARGET_SAMPLE_RATE, self.decimal))
@property
def offset_plus_duration(self):
return round(self.offset + self.duration, self.decimal)
def trim(sgm: Segment, threshold: float) -> Segment:
"""reduces the segment to between the first and last points that are above the threshold
Args:
sgm (Segment): a segment
threshold (float): probability threshold
Returns:
Segment: new reduced segment
"""
included_indices = np.where(sgm.probs >= threshold)[0]
# return empty segment
if not len(included_indices):
return Segment(sgm.start, sgm.start, np.empty([0]))
i = included_indices[0]
j = included_indices[-1] + 1
sgm = Segment(sgm.start + i, sgm.start + j, sgm.probs[i:j])
return sgm
def split_and_trim(
sgm: Segment, split_idx: int, threshold: float
) -> tuple[Segment, Segment]:
"""splits the input segment at the split_idx and then trims and returns the two resulting segments
Args:
sgm (Segment): input segment
split_idx (int): index to split the input segment
threshold (float): probability threshold
Returns:
tuple[Segment, Segment]: the two resulting segments
"""
probs_a = sgm.probs[:split_idx]
sgm_a = Segment(sgm.start, sgm.start + len(probs_a), probs_a)
probs_b = sgm.probs[split_idx + 1 :]
sgm_b = Segment(sgm_a.end + 1, sgm.end, probs_b)
sgm_a = trim(sgm_a, threshold)
sgm_b = trim(sgm_b, threshold)
return sgm_a, sgm_b
def pdac(
probs: np.array,
max_segment_length: float,
min_segment_length: float,
threshold: float,
not_strict: bool
) -> list[Segment]:
"""applies the probabilistic Divide-and-Conquer algorithm to split an audio
into segments satisfying the max-segment-length and min-segment-length conditions
Args:
probs (np.array): the binary frame-level probabilities
output by the segmentation-frame-classifier
max_segment_length (float): the maximum length of a segment
min_segment_length (float): the minimum length of a segment
threshold (float): probability threshold
not_strict (bool): whether segments longer than max are allowed
Returns:
list[Segment]: resulting segmentation
"""
segments = []
sgm = Segment(0, len(probs), probs)
sgm = trim(sgm, threshold)
def recusrive_split(sgm):
if sgm.duration < max_segment_length:
segments.append(sgm)
else:
j = 0
sorted_indices = np.argsort(sgm.probs)
while j < len(sorted_indices):
split_idx = sorted_indices[j]
split_prob = sgm.probs[split_idx]
if not_strict and split_prob > threshold:
segments.append(sgm)
break
sgm_a, sgm_b = split_and_trim(sgm, split_idx, threshold)
if (
sgm_a.duration > min_segment_length
and sgm_b.duration > min_segment_length
):
recusrive_split(sgm_a)
recusrive_split(sgm_b)
break
j += 1
else:
if not_strict:
segments.append(sgm)
else:
if sgm_a.duration > min_segment_length:
recusrive_split(sgm_a)
if sgm_b.duration > min_segment_length:
recusrive_split(sgm_b)
recusrive_split(sgm)
return segments
def update_yaml_content(
yaml_content: list[dict], segments: list[Segment], wav_name: str
) -> list[dict]:
"""extends the yaml content with the segmentation of this wav file
Args:
yaml_content (list[dict]): segmentation in yaml format
segments (list[Segment]): resulting segmentation from pdac
wav_name (str): name of the wav file
Returns:
list[dict]: extended segmentation in yaml format
"""
for sgm in segments:
yaml_content.append(
{
"duration": sgm.duration,
"offset": sgm.offset,
"rW": 0,
"uW": 0,
"speaker_id": "NA",
"wav": wav_name,
}
)
return yaml_content
def segment(args):
device = (
torch.device(f"cuda:0")
if torch.cuda.device_count() > 0
else torch.device("cpu")
)
checkpoint = torch.load(args.path_to_checkpoint, map_location=device)
# init wav2vec 2.0
wav2vec_model = prepare_wav2vec(
checkpoint["args"].model_name,
checkpoint["args"].wav2vec_keep_layers,
device,
)
# init segmentation frame classifier
sfc_model = SegmentationFrameClassifer(
d_model=HIDDEN_SIZE,
n_transformer_layers=checkpoint["args"].classifier_n_transformer_layers,
).to(device)
sfc_model.load_state_dict(checkpoint["state_dict"])
sfc_model.eval()
yaml_content = []
for wav_path in tqdm(sorted(list(Path(args.path_to_wavs).glob("*.wav")))):
# initialize a dataset for the fixed segmentation
dataset = FixedSegmentationDatasetNoTarget(
wav_path, args.inference_segment_length, args.inference_times
)
sgm_frame_probs = None
for inference_iteration in range(args.inference_times):
# create a dataloader for this fixed-length segmentation of the wav file
dataset.fixed_length_segmentation(inference_iteration)
dataloader = DataLoader(
dataset,
batch_size=args.inference_batch_size,
num_workers=min(cpu_count() // 2, 4),
shuffle=False,
drop_last=False,
collate_fn=segm_collate_fn,
)
# get frame segmentation frame probabilities in the output space
probs, _ = infer(
wav2vec_model,
sfc_model,
dataloader,
device,
)
if sgm_frame_probs is None:
sgm_frame_probs = probs.copy()
else:
sgm_frame_probs += probs
sgm_frame_probs /= args.inference_times
segments = pdac(
sgm_frame_probs,
args.dac_max_segment_length,
args.dac_min_segment_length,
args.dac_threshold,
args.not_strict
)
yaml_content = update_yaml_content(yaml_content, segments, wav_path.name)
path_to_segmentation_yaml = Path(args.path_to_segmentation_yaml)
path_to_segmentation_yaml.parent.mkdir(parents=True, exist_ok=True)
with open(path_to_segmentation_yaml, "w") as f:
yaml.dump(yaml_content, f, default_flow_style=True)
print(
f"Saved SHAS segmentation with max={args.dac_max_segment_length} & "
f"min={args.dac_min_segment_length} at {path_to_segmentation_yaml}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--path_to_segmentation_yaml",
"-yaml",
type=str,
required=True,
help="absolute path to the yaml file to save the generated segmentation",
)
parser.add_argument(
"--path_to_checkpoint",
"-ckpt",
type=str,
required=True,
help="absolute path to the audio-frame-classifier checkpoint",
)
parser.add_argument(
"--path_to_wavs",
"-wavs",
type=str,
help="absolute path to the directory of the wav audios to be segmented",
)
parser.add_argument(
"--inference_batch_size",
"-bs",
type=int,
default=12,
help="batch size (in examples) of inference with the audio-frame-classifier",
)
parser.add_argument(
"--inference_segment_length",
"-len",
type=int,
default=20,
help="segment length (in seconds) of fixed-length segmentation during inference"
"with audio-frame-classifier",
)
parser.add_argument(
"--inference_times",
"-n",
type=int,
default=1,
help="how many times to apply inference on different fixed-length segmentations"
"of each wav",
)
parser.add_argument(
"--dac_max_segment_length",
"-max",
type=float,
default=18,
help="the segmentation algorithm splits until all segments are below this value"
"(in seconds)",
)
parser.add_argument(
"--dac_min_segment_length",
"-min",
type=float,
default=0.2,
help="a split by the algorithm is carried out only if the resulting two segments"
"are above this value (in seconds)",
)
parser.add_argument(
"--dac_threshold",
"-thr",
type=float,
default=0.5,
help="after each split by the algorithm, the resulting segments are trimmed to"
"the first and last points that corresponds to a probability above this value",
)
parser.add_argument(
"--not_strict",
action="store_true",
help="whether segments longer than max are allowed."
"If this argument is used, respecting the classification threshold conditions (p > thr)"
"is more important than the length conditions (len < max)."
)
args = parser.parse_args()
segment(args)