Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timestamps for Wav2Vec 2.0 models and/or ASR pipelines #15502

Closed
iskaj opened this issue Feb 3, 2022 · 10 comments 路 Fixed by #15687
Closed

Timestamps for Wav2Vec 2.0 models and/or ASR pipelines #15502

iskaj opened this issue Feb 3, 2022 · 10 comments 路 Fixed by #15687

Comments

@iskaj
Copy link

iskaj commented Feb 3, 2022

馃殌 Feature request

So the ASR pipeline (https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/pipelines/automatic_speech_recognition.py#L122) is great for leveraging Wav2Vec 2.0 for longer files. However it does not allow us to get the timestamps of the words, so when each word was spoken out.

Motivation

This is relevant for many applications of ASR, such as automatic subtitles or anything else requiring this timing information. Since this information should be available somewhere "under the hood", it might be beneficial to many to include this in the output. This might not be specific to the pipelines, but also to the general output of Wav2Vec 2.0 models.

Your contribution

I'm not yet that familiar with HF + Wav2Vec 2.0, but this is https://github.com/lumaku/ctc-segmentation is a useful github page. Would be willing to help out though!

@LysandreJik
Copy link
Member

@patrickvonplaten
Copy link
Contributor

I very much agree - this would be a very welcoming feature to add and it actually shouldn't be too difficult for CTC.

For CTC we know exactly what characters are predicted at what time because we know the sampling_rate of the model and we know the context window of each outputted token id. E.g. the first token id corresponds to the first 320 input samples which corresponds to 320 samles / 16_000 samples / sec -> 0.02 seconds. The second token id vector then corresponds more or less to the window 0.02 - 0.04 seconds and so on. We could then easily map each token id to a time window. Now knowing the id of the word delimiter token we can also easily retrieve the time window in which a word was spoken.

In terms of the implementation details, I think Wav2Vec2CTCTokenzier should be responsible for returning time stamps for words. We could do the following - add a method def retrieve_time_stamps(...) that takes the token_ids, stride (the tuple of the config) and the feature extractor's sampling_rate as an input and retrieves a list of time stamps (one for each word) from it.

We could then also integrate this into the tokenizer's decode(...) and batch_decode(...) eventually.

@iskaj would you be interested in opening a PR for this? I think we could start by adding the following method to Wav2Vec2CTCTokenizer:

def retrieve_time_stamps(token_ids, stride, sampling_rate): 
     # 1. compute total stride: `total_stride = reduce(stride, multiply)`
     # 2. time_frame_per_logit_in_s = total_stride / sampling_rate
     # 3. now we need to find the first non- `pad_token_id` in token_ids which represents the start_id. Then the first `word_delimeter_token` represents the first end_id. The next non-pad_token_id then represents the next start_id, the next word_delimiter_token` after the next end_id and so on. This can be done in a simple for loop
     # 4. that's pretty much it -> then we can return a list of tuples which correspond to the time stamps of the returned words.

Also interested in feedback from @anton-l @Narsil

@Narsil
Copy link
Contributor

Narsil commented Feb 14, 2022

I very much agree - this would be a very welcoming feature to add and it actually shouldn't be too difficult for CTC.

I agree too, it's a very welcome feature.

The main concern I have is the actual implementation of this. Ideally, it would be finely manageable by users, because I can see (at least) a usage for video, where you want to add subtitles, and you need to put timestamps at definite boundaries (most likely sentence boundary and/or length of text).

The ideal way would be to be highly transparent and maybe quite noisy:

pipe = pipeline(..., add_timestamps=True) # crashes on non CTC

out = pipe(...)
# out = {"text": "ABCD", "timestamps": [0.01, 0.03, 0.03, 0.04]} 

Here I propose 1 float per character of the output. it's very noisy, but seems still quite simple to use and give everything needed for someone wanting fine control over timestamps. I imagine this float would correspond the the first TOKEN using that letter in CTC context.

As an implementation, it might be tedious to properly add with chunking and striding. (Or not I am not sure)

@iskaj
Copy link
Author

iskaj commented Feb 14, 2022

I think what @Narsil proposed will work fine for what most people want indeed, so I agree with that sentiment. For me the interest lies in automatic subtitling and the noise in this solution would be fine. It is also nicely interpretable. I think it should work with the pipeline approach (chunking and striding), otherwise the purpose would be kind of lost right? I'm also not sure how that would work though...

In the future I might be interesting in doing a pull request for this, but currently my priorities lay elsewhere. Hope I can help with this in the near feature.

@patrickvonplaten
Copy link
Contributor

@anton-l - thoughts on this?

@anton-l
Copy link
Member

anton-l commented Feb 16, 2022

I think we can add an alternative to Wav2Vec2CTCTokenizer.decode() to add timestamps to each character pretty easily. Basically implement Wav2Vec2CTCTokenizer.decode_with_timestamps() that returns a structure like this:

{
    "text": "I AM HERE",
    "tokens": [
        {
            "token": "I",
            "time_start": <first_ctc_logit_index> * 0.02,
            "time_end": <last_ctc_logit_index> * 0.02 + 0.025,
            "probability": 0.7
        },
        {
            "token": " ",
            "time_start": <first_ctc_logit_index> * 0.02,
            "time_end": <last_ctc_logit_index> * 0.02 + 0.025,
            "probability": 0.4
        },
        {
            "token": "A",
            "time_start": <first_ctc_logit_index> * 0.02,
            "time_end": <last_ctc_logit_index> * 0.02 + 0.025,
            "probability": 0.6
        },
        ....
    ]
}

where 0.02 is the frame stride, and 0.025 is the frame width in seconds (could be calculated from stride and sample_rate like @patrickvonplaten suggested above).

Returning the word boundaries' offsets (whitespaces in this example) is also important for consistency IMO.

Probabilities are optional, but they would be pretty handy for downstream applications like forced alignment to filter out low-confidence segments, so we can add them as a bonus while we're at it:
image
(image taken from https://github.com/lumaku/ctc-segmentation)

Since the whole step can be contained inside the tokenizer, it shouldn't be a problem to add it inside AutomaticSpeechRecognitionPipeline.postprocess() and support all modes of streaming as well 馃檪

@Narsil
Copy link
Contributor

Narsil commented Feb 16, 2022

@anton-l ,

Fully agree having both start and stop timings are better (and fine with throwing probabilities in there)

I just realised though, will that approach be possible with ctc_with_lm ? Since we added a bunch recently, it would be nice if we could.

@anton-l
Copy link
Member

anton-l commented Feb 16, 2022

@Narsil it's possible for ctc_with_lm, but we might need to create a wrapper for the pyctcdecode's decode_beams() function to create a common API.
It can use pretty much the same logic as the ordinary CTC tokenizer, just with full words instead of granular characters, because it doesn't return per-character logit indices:
image
(image from their tutorial)

@patrickvonplaten
Copy link
Contributor

Interesting idea @anton-l!

I thought about it a bit - couple of remarks:

1.) I don't think we should entangle the probabilities with the time stamps too much to be honest and rather treat them as separate feature additions because:

  • Requiring time stamps doesn't mean that the user also needs probabilities and vice versa
  • In order to extract time stamps, for the normal Wav2Vec2Processor we need to work with the predicted ids IMO and not the logits. We group the ids and then know the boundaries from this. We cannot work directly with the logits for the normal Wav2Vec2Processor (for the ...WithLM it's a different story)
  • => so I'd prefer if we keep this PR just for time stamps for now. It's nevertheless a good idea to think about how the design would work for the probs though.
    2.) Instead of creating a new function decode_with_timestamps I think it's more in line with the library to add a output_time_stamps_flag to the decoding function
    3.) Not sure actually if the tokenizer or directly the processor should be responsible for the function actually. I'm tending more and more towards adding the function to the processor.

Will do a draft PR and share it here.

@patrickvonplaten patrickvonplaten linked a pull request Feb 16, 2022 that will close this issue
5 tasks
@Narsil
Copy link
Contributor

Narsil commented Feb 16, 2022

Also this could definitely be directly in the pipeline if it's going to be the main/sole user.

Might make implementation easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants