Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make seg_iter and hyp behave as expected in kws search (fixes #122) #306

Merged
merged 4 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cython/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_config__iter(self):
self.assertEqual(default_len, len(config))
config["hmm"] = os.path.join(MODELDIR, "en-us", "en-us")
config["fsg"] = os.path.join(DATADIR, "goforward.fsg")
config["lm"] = None # It won't do this automatically
config["dict"] = os.path.join(DATADIR, "turtle.dic")
self.assertEqual(default_len, len(config))
for key in config:
Expand Down
28 changes: 22 additions & 6 deletions cython/test/kws_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_kws(self):
stream = open(os.path.join(DATADIR, "goforward.raw"), "rb")

# Process audio chunk by chunk. On keyphrase detected perform action and restart search
decoder = Decoder(keyphrase="forward",
decoder = Decoder(kws=os.path.join(DATADIR, "goforward.kws"),
loglevel="INFO",
lm=None,
kws_threshold=1e20)
lm=None)
decoder.start_utt()
while True:
keywords = ["forward", "meters"]
while keywords:
buf = stream.read(1024)
if buf:
decoder.process_raw(buf, False, False)
decoder.process_raw(buf)
else:
break
if decoder.hyp() != None:
Expand All @@ -34,10 +34,26 @@ def test_kws(self):
print("Detected keyphrase, restarting search")
for seg in decoder.seg():
self.assertTrue(seg.end_frame > seg.start_frame)
self.assertEqual(seg.word, "forward")
self.assertEqual(seg.word, keywords.pop(0))
decoder.end_utt()
decoder.start_utt()
stream.close()
decoder.end_utt()

# Detect keywords in a batch utt, make sure they show up in the right order
stream = open(os.path.join(DATADIR, "goforward.raw"), "rb")
decoder.start_utt()
decoder.process_raw(stream.read(), full_utt=True)
decoder.end_utt()
print(
[
(seg.word, seg.prob, seg.start_frame, seg.end_frame)
for seg in decoder.seg()
]
)
self.assertEqual(decoder.hyp().hypstr, "forward meters")
self.assertEqual(["forward", "meters"],
[seg.word for seg in decoder.seg()])


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions src/kws_detections.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ kws_detections_hyp_str(kws_detections_t *detections, int frame, int delay)

hyp_str = (char *)ckd_calloc(len, sizeof(char));
c = hyp_str;
detections->detect_list = glist_reverse(detections->detect_list);
for (gn = detections->detect_list; gn; gn = gnode_next(gn)) {
kws_detection_t *det = (kws_detection_t *)gnode_ptr(gn);
if (det->ef < frame - delay) {
Expand All @@ -115,6 +116,7 @@ kws_detections_hyp_str(kws_detections_t *detections, int frame, int delay)
c--;
*c = '\0';
}
detections->detect_list = glist_reverse(detections->detect_list);
return hyp_str;
}

27 changes: 18 additions & 9 deletions src/kws_search.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ static void
kws_seg_free(ps_seg_t *seg)
{
kws_seg_t *itor = (kws_seg_t *)seg;
ckd_free(itor->detections);
ckd_free(itor);
}

static void
kws_seg_fill(kws_seg_t *itor)
{
kws_detection_t* detection = (kws_detection_t*)gnode_ptr(itor->detection);
kws_detection_t* detection = itor->detections[itor->pos];

itor->base.text = detection->keyphrase;
itor->base.wid = BAD_S3WID;
Expand All @@ -98,12 +99,7 @@ kws_seg_next(ps_seg_t *seg)
{
kws_seg_t *itor = (kws_seg_t *)seg;

gnode_t *detect_head = gnode_next(itor->detection);
while (detect_head != NULL && ((kws_detection_t*)gnode_ptr(detect_head))->ef > itor->last_frame)
detect_head = gnode_next(detect_head);
itor->detection = detect_head;

if (!itor->detection) {
if (++itor->pos == itor->n_detections) {
kws_seg_free(seg);
return NULL;
}
Expand All @@ -124,18 +120,31 @@ kws_search_seg_iter(ps_search_t * search)
kws_search_t *kwss = (kws_search_t *)search;
kws_seg_t *itor;
gnode_t *detect_head = kwss->detections->detect_list;
gnode_t *gn;

while (detect_head != NULL && ((kws_detection_t*)gnode_ptr(detect_head))->ef > kwss->frame - kwss->delay)
detect_head = gnode_next(detect_head);

if (!detect_head)
return NULL;

itor = (kws_seg_t *)ckd_calloc(1, sizeof(*itor));
/* Count and reverse them into the vector. */
itor = ckd_calloc(1, sizeof(*itor));
for (gn = detect_head; gn; gn = gnode_next(gn))
itor->n_detections++;
itor->detections = ckd_calloc(itor->n_detections, sizeof(*itor->detections));
itor->pos = itor->n_detections - 1;
for (gn = detect_head; gn; gn = gnode_next(gn))
itor->detections[itor->pos--] = (kws_detection_t *)gnode_ptr(gn);
itor->pos = 0;

/* Copy the detections here to get them in the right order and
* avoid undefined behaviour if recognition continues during
* iteration. */

itor->base.vt = &kws_segfuncs;
itor->base.search = search;
itor->base.lwf = 1.0;
itor->detection = detect_head;
itor->last_frame = kwss->frame - kwss->delay;
kws_seg_fill(itor);
return (ps_seg_t *)itor;
Expand Down
8 changes: 5 additions & 3 deletions src/kws_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ extern "C" {
* Segmentation "iterator" for KWS history.
*/
typedef struct kws_seg_s {
ps_seg_t base; /**< Base structure. */
gnode_t *detection; /**< Keyphrase detection correspondent to segment. */
frame_idx_t last_frame; /**< Last frame to raise the detection */
ps_seg_t base; /**< Base structure. */
kws_detection_t **detections; /**< Vector of current detections. */
frame_idx_t last_frame; /**< Last frame to raise the detection. */
int n_detections; /**< Size of vector. */
int pos; /**< Position of iterator. */
} kws_seg_t;

typedef struct kws_keyphrase_s {
Expand Down
1 change: 1 addition & 0 deletions test/data/goforward.kws
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ something
bad line / here
just bad line /
forward
meters

non_existign_word

2 changes: 1 addition & 1 deletion test/unit/test_keyphrase.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ main(int argc, char *argv[])
"hmm: \"" MODELDIR "/en-us/en-us\","
"kws: \"" DATADIR "/goforward.kws\","
"dict: \"" MODELDIR "/en-us/cmudict-en-us.dict\""));
return ps_decoder_test(config, "KEYPHRASE", "forward");
return ps_decoder_test(config, "KEYPHRASE", "forward meters");
}