Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large audio chunking for the existing ASR pipeline #14896

Merged
merged 2 commits into from
Jan 3, 2022

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Dec 23, 2021

What does this PR do?

This adds audio chunking with fixed-sized chunks as a first step to enabling audio streaming in ASR pipelines (ref: #14250)

In this iteration there's no ffmpeg streaming or VAD, just simple slicing of inputs with padding, so that we can review the general pipeline from the modeling side.
Here's an illustration of the sliding window approach used for iterating over chunks:
chunking

To see how this will roughly look for real-time inference (when we implement it on top), check out this (admittedly old) demo: https://huggingface.co/spaces/anton-l/youtube-subs-wav2vec/

Comment on lines 82 to 84
chunk_len=128080,
chunk_start_padding_len=16160,
chunk_end_padding_len=16160,
Copy link
Member Author

@anton-l anton-l Dec 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This roughly corresponds to 8-second chunks with 1-sec padding on both sides.
The values are so "specific" to avoid border artifacts after convolutional feature extraction inside the Wav2Vec2 models. Basically, these input lengths can be evenly divided by conv_kernel size at each feature extractor layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to leave this as a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

@@ -136,6 +165,22 @@ def _sanitize_parameters(self, **kwargs):
# No parameters on this pipeline right now
return {}, {}, {}

def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to overload this method to allow for generator outputs from self.preprocess(). Maybe there's a better way.

@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Union

import numpy as np
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think pipelines are framework independent, so let's only add this if available

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]

n_repeats = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it's a slow test, can we maybe really stress-test it here and do n_repeats = 100 which would correspond to ~3 minutes I think

model_outputs.append(
chunk_output[self.chunk_output_padding_len : self.chunk_output_padding_len + self.chunk_output_len]
)
model_outputs = torch.cat(model_outputs)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me now since we only allow PyTorch for the moment. Maybe we'll have to do some kind of if is_torch_available() -else further down the road

@Narsil what do you think?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me in general and I think it's a good approach that allows for offline ASR for long sequences.

Think @Narsil knows the specifics better here

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a good intention PR, but IMHO we should probably refactor to pull everything out of the pipeline code for several reasons. I actually started work on #14250 in a similar fashion then scrapped everything. The reasons are

  1. it doesn't play nice with the auto batching / DataLoader framework explaining some failing tests most likely:
pipe = pipeline(..)
for out in pipe(dataset, batch_size=32):
    # do something with out

Using this allows users to adjust the batch_size relative to the hardware they have to maximize performance.

This unfortunately means:

  • No loop in _forward.
  • No loop/batch in preprocess.

This PR could enable it back again: #14225 (it was the main reason I started this PR in the first place).

  1. The arguments are not in _sanitize_parameters which means they will be only used in the initalization in the pipeline and not at call time. Since historically they were a lot of issues there, now all arguments are enabled in both which saves user sanity without figuring out where the arguments should be defined (look at other pipelines to see how they are implemented, its not a big change code wise, and the current doc is correct).

  2. chunking, chunking_len, chuking_start_padding_len and chunk_end_padding_len are not making sense independently.
    For arguments, it is always easier for the user if there's no interaction between arguments and you can independently modify any of them. Here we could imagine having a single chunking that be either None or a triplet (len, start_padding, stop_padding) for instance. That would reduce a lot of confusion IMHO.

  3. chunking_start_padding_len is expressed in number of samples which again depend on sampling_rate which might be different for different models, meaning quality might change if you use a different model. I would much prefer have start_padding_len_msfor instance since as a user it makes more sense to adjust in that space than in raw examples length space.

  4. If we want to add webrtc (which exists in Adding utilities to chunk large audio files and read directly from microphone #14250 ) then we have to add a whole bunch of new arguments, which would conflict with these added ones.
    IMO it would make a whole more sense to instead choose the following form for users:

pipe = pipeline(...)
dataset = load_dataset(...)
chunk_dataset = ChunkDataset(dataset, length_ms, start_pad_ms, end_pad_ms)

for chunk in pipe(dataset):
    print(chunk)

In this way we could add VADChunkDataset(dataset, threshold=2, frame_size_ms=20) for instance quite orthogonally without cluttering the pipeline with a ton of arguments.

It also enables more complex stuff like ffmpeg_microphone where the audio samples actually overlap and some results are "temporary" (replaced with other chunks later).

It also enables fast feedback because you get the results as soon as they come in (without waiting when it's an hour long audio sample for instance).

The main drawback from this approach is that you don't have any information within chunk to know from which file it comes from or which chunk it is. It might be important to recreate the end result as a user. But IMHO, it seems easier to start passing those information through the pipeline so that they are available within chunk to allow the user to assemble chunks as they see fit.

I am happy to discuss anything that I might have overlooked in this analysis and why this PR might still be the right solution. Again this was also my first idea.

@anton-l
Copy link
Member Author

anton-l commented Dec 23, 2021

@Narsil sorry, I only now remembered about the work in #14225!
Indeed, the ChunkPipeline API is much cleaner and I can adapt this PR to work with it.

Re: 2-4) totally agree, these points should be refactored as you suggest!
Re: 5) If I understand correctly, you suggest using ChunkDataset/VADChunkDataset only for offline datasets? Then we would still need to take parts of their chunking logic outside, to reuse them for streaming inputs

@Narsil
Copy link
Contributor

Narsil commented Dec 23, 2021

@Narsil sorry, I only now remembered about the work in #14225! Indeed, the ChunkPipeline API is much cleaner and I can adapt this PR to work with it.

It got autoclosed so even I struggled to find it again yesterday :)

Re: 2-4) totally agree, these points should be refactored as you suggest!

Re: 5) If I understand correctly, you suggest using ChunkDataset/VADChunkDataset only for offline datasets? Then we would still need to take parts of their chunking logic outside, to reuse them for streaming inputs

Actually, we might be able to make them interoperable too.

Well I am mentionning Dataset but actually the pipeline works with any generator so we can use that for streaming too ! This is actually what ffmpeg_microphone does.

I like mentioning Dataset since a dataset has a fixed number of elements, meaning tqdm and the like can infer a nice progress bar and time estimates. But everything works quite the same with a generator except:

  • num_workers cannot be used with values >1 (fetching from a generator from multiple threads is asking for trouble since you need to iterate on ALL objects on EVERY thread, even if you skip some on some threads, most likely the generator will already consume resources and time)
  • No nice progress bar and time estimate with tqdm.

So we could imagine something like:

dataset = datasets.load_dataset(...)
vad_dataset = vad_cut(dataset,  threshold=5, ..)
chunk_dataset = chunk_audio(dataset, chunk_len_ms=200)

or

microphone_generator = ffmpeg_microphone(...)
for chunk in pipe(chunk_audio(microphone, chunk_len_ms=200)):
    print(chunk)

I am unsure it makes total sense and that we should make ALL of them interoperable and such, but there definitely could be a nice way to make those chunking iterator composable (just like torchvision.transforms can be for instance).
It would be a very nice thing indeed.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 27, 2021

Thanks for the feedback @Narsil! I don't fully agree here - happy to discuss this a bit more (also in a call). Maybe I'm also not seeing something here.

I think it's a good intention PR, but IMHO we should probably refactor to pull everything out of the pipeline code for several reasons. I actually started work on #14250 in a similar fashion then scrapped everything. The reasons are

  1. it doesn't play nice with the auto batching / DataLoader framework explaining some failing tests most likely:
pipe = pipeline(..)
for out in pipe(dataset, batch_size=32):
    # do something with out

Using this allows users to adjust the batch_size relative to the hardware they have to maximize performance.

This unfortunately means:

  • No loop in _forward.
  • No loop/batch in preprocess.

This PR could enable it back again: #14225 (it was the main reason I started this PR in the first place).

Here, I don't really know whether it plays nicely with auto batching or not. I agree that it is very important to make it work nicely with auto batching since I can see companies being interesting in transcribing tons of audio files offline and it would be nice to have that working fast.

However, for me it's at least equally important to make sure it's very simple for the user to transcribe a single audio file. The main applications for this feature in my opinion are:

  • demo widget. Orgs on the hub will want to demo their models to clients, internally, etc... we have already seen demand for this
  • should be easy to build a space with this feature to transcribe long audio files of e.g. Videos (YouTube, TED) on the fly

IMO those features are more important than auto batching / Data Loader. A necessary requirement to make chunking easy to use is that one doesn't have to wrap her/his audio file in some kind of data loader, wrapper, datasets, etc. I feel pretty strongly about making it possible to just do:

transcription = asr("<path/to/audio/file>", chunking_length_in_s=2)
  1. The arguments are not in _sanitize_parameters which means they will be only used in the initalization in the pipeline and not at call time. Since historically they were a lot of issues there, now all arguments are enabled in both which saves user sanity without figuring out where the arguments should be defined (look at other pipelines to see how they are implemented, its not a big change code wise, and the current doc is correct).

Agree. We can however easily change that.

  1. chunking, chunking_len, chuking_start_padding_len and chunk_end_padding_len are not making sense independently.

For arguments, it is always easier for the user if there's no interaction between arguments and you can independently modify any of them. Here we could imagine having a single chunking that be either None or a triplet (len, start_padding, stop_padding) for instance. That would reduce a lot of confusion IMHO.

Disagree here. I really don't like tuple input arguments as one never knows which index of the tuple stands for what and this has to be looked up again in the docs every time someone uses the pipelines. We don't have many tuple args in transformers in general, but rather prefer "simple" args (in the config, function, etc...).

I also think they are quite independent from each other - changing one arg "chunking_len" doesn't mean that "chunk_end_padding_len" has to be changed either.

My 5 cents here are:

  • 1.) Remove the "chunking" input arg. I don't like boolean flags either and I think whether the input shoudl be chunked or not should be controlled by "chunking_len". If "chunking_len" is > 0, then no chunking else chunking.
  • 2.) (nit) I would write out len to length and IMO chunking_padding_left is easier to understand then "...start..."
  1. chunking_start_padding_len is expressed in number of samples which again depend on sampling_rate which might be different for different models, meaning quality might change if you use a different model. I would much prefer have start_padding_len_ms for instance since as a user it makes more sense to adjust in that space than in raw examples length space.

Agree. Nice observation!

  1. If we want to add webrtc (which exists in Adding utilities to chunk large audio files and read directly from microphone #14250 ) then we have to add a whole bunch of new arguments, which would conflict with these added ones.

Not sure I fully agree here. Why would those new arguments conflict with chunking_length? IMO, chunking_length should default to either 0 or None, i.e. be disabled. In the future we could image a vad="webrtc" argument or vad=WebRTCVAD("<all_necessary_args_here>") argument and I don't really see why conflicts with "chunking_length" or chunking_length_left/right . "chunking_length" can still be used (and refer to the maximum allowed chunk length) and the other two arguments could simply be set to 0.

IMO it would make a whole more sense to instead choose the following form for users:

pipe = pipeline(...)
dataset = load_dataset(...)
chunk_dataset = ChunkDataset(dataset, length_ms, start_pad_ms, end_pad_ms)

for chunk in pipe(dataset):
    print(chunk)

In this way we could add VADChunkDataset(dataset, threshold=2, frame_size_ms=20) for instance quite orthogonally without cluttering the pipeline with a ton of arguments.

This looks clean, but it's not easy to understand for the user. What if I just have a single long audio file that I want to transcribe? Do I first need to put it in some kind of dataset format? Users won't make that effort, they'll simple stop at this point. This relates to 1.) and here it's much more important to make this feature easy to try out instead of having something super extendable/general . For me pipelines always had the spirit of "2 lines of code is enough" and this starts to look much more complex for the user.

Also given that already have a bunch of input arguments for pipelines ("generate()") takes 50+ input arguments, I don't really see a problem with adding many new arguments + we can also mark something as experimental and change later.

If you are very strongly against just adding input arguments @Narsil, maybe we can find a compromise where we over some kind of object ChunkWithPadding that is required to have a chuck(...) method and can wrap all kinds of inputs, e.g.:

from transformers import pipeline, ChunkWithPadding

asr = pipeline("automatic_speech_recognition")

chunked_audio = ChunkWithPadding("path/to/audio/file", "<args>")

asr(chunked_audio)   # here inside we call `.chunk()` at some point

But I don't think that's very clean and I'd much rather prefer to do:

from transformers import pipeline, ChunkWithPadding

asr = pipeline("automatic_speech_recognition", chunking_length_in_s=10)

asr("/path/to/audio")

It also enables more complex stuff like ffmpeg_microphone where the audio samples actually overlap and some results are "temporary" (replaced with other chunks later).

It also enables fast feedback because you get the results as soon as they come in (without waiting when it's an hour long audio sample for instance).

The main drawback from this approach is that you don't have any information within chunk to know from which file it comes from or which chunk it is. It might be important to recreate the end result as a user. But IMHO, it seems easier to start passing those information through the pipeline so that they are available within chunk to allow the user to assemble chunks as they see fit.

I am happy to discuss anything that I might have overlooked in this analysis and why this PR might still be the right solution. Again this was also my first idea.

=> So to summarize, my by far biggest concern here is that the feature is to difficult to use for the user. I think we should think more about the user experience and not so much about making it super general, 100% clean from the inside. Pipelines are IMO the part of transformers we can and should absorb complexity so that the user has a very good user experience at the cost of maybe some ugly code inside pipelines.
It would be nice to focus here on a first solution that works well and has a nice user experience before thinking about how this feature could conflict with a feature that will potentially added in the future. If I understand correctly the "padding-chunk" approach works well in all kinds of settings and is also very light weight. There are no necessary imports of other libraries etc...Webrtc seems to work only equally well, but is more data dependent (does it work well with noise, different language?!), it has an important dependency and an additional model that needs to be loaded. So for me (if the above is correct), it's pretty clear that we should focus on the padding chunking approach in a first step. We don't even know if the users will use this a lot or not. It's very good to also think about how this would pan out with a future WebRTC integration, but I also don't really see a problem here with the argument "chunking_length_in_s".
Regarding the technical details, I'm not really sure how to solve this, but I also much rather prefer to add a small hack or a new design, etc... instead of forcing the user to load a single audio sample in some kind of dataset or generator.

Happy to jump on a call about this!

@Narsil
Copy link
Contributor

Narsil commented Dec 27, 2021

Pretty big but important clarification I didn't get when I read the PR:

The purple thing on the diagram is coming from the real audio, and the pipeline can cut on the green boundaries during decodings on ids tensor (before CTC). meaning we should pretty much exactly the exact decoding than when running the full audio (as long as the full chunk green + purple covers the whole word).

The name padding is slightly misleading to me, since it's usually called stride (at least in question-answering and vision). I imagined the purple part was supposed to be zeros, meaning it would help the edge of the green data but not nearly as well as with real data (and could be done outside of the pipeline, which would be much harder to do with real data, since once you do CTC you loose all those Green/purple information.

  • padding : Fill in with zeros
  • stride: Make overlapping data within samples

This approach could definitely work and be integrated in the pipeline (since CTC would make information be lost otherwise).

  • It needs to be CTC-only: it's unlikely to be sound for generative audio models, since there's no 1-1 mapping of audio samples to IDs.
  • There's actually probably sane defaults for the stride length since a word should rarely exceed 30s. And those samples do fit most user-grade GPU. We can also have something like qa where the stride is defined by default as 1/4 of max_length if my memory serves correctly.

All-in-all. Current PR approach should work and be done like it is right now (and having the utils externally for VAD and the like in the other PR). We just need to enabling ChunkPipeline here to make it work properly with batch_size (merged since)

@patrickvonplaten
Copy link
Contributor

Pretty big but important clarification I didn't get when I read the PR:

The purple thing on the diagram is coming from the real audio, and the pipeline can cut on the green boundaries during decodings on ids tensor (before CTC). meaning we should pretty much exactly the exact decoding than when running the full audio (as long as the full chunk green + purple covers the whole word).

The name padding is slightly misleading to me, since it's usually called stride (at least in question-answering and vision). I imagined the purple part was supposed to be zeros, meaning it would help the edge of the green data but not nearly as well as with real data (and could be done outside of the pipeline, which would be much harder to do with real data, since once you do CTC you loose all those Green/purple information.

  • padding : Fill in with zeros
  • stride: Make overlapping data within samples

This approach could definitely work and be integrated in the pipeline (since CTC would make information be lost otherwise).

  • It needs to be CTC-only: it's unlikely to be sound for generative audio models, since there's no 1-1 mapping of audio samples to IDs.
  • There's actually probably sane defaults for the stride length since a word should rarely exceed 30s. And those samples do fit most user-grade GPU. We can also have something like qa where the stride is defined by default as 1/4 of max_length if my memory serves correctly.

All-in-all. Current PR approach should work and be done like it is right now (and having the utils externally for VAD and the like in the other PR). We just need to enabling ChunkPipeline here to make it work properly with batch_size (merged since)

Thanks a lot for summarizing everything here! After the call I very much agree that stride_length_in_sec is a better name indeed

@Narsil Narsil force-pushed the add-naive-asr-chunking branch 5 times, most recently from 592536e to afc5e46 Compare December 28, 2021 15:41
chunk_length_ms (`int`, *optional*, defaults to 0):
The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer here stride_length_s because:

  • sampling rate is in the format samples/s
  • we mostly deal with seconds for other speech parameters:
    max_duration_in_seconds: Optional[float] = field(
  • I think the value should be of type float no matter what. It's just not an "integer" value for me logically and it might very well be that we have people that would like to have a stride of exactly 1024 samples. If this input would only be in integer it would make it difficult to do so. I would favor an input of float here and later throw an error if stride_length_s * sampling_rate is not an integer. Same for chunk_length_ms

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we mostly deal with seconds for other speech parameters:

Hammer argument, consistency wins it !

I would favor an input of float here and later throw an error if stride_length_s * sampling_rate is not an integer.

We cannot ever do that, there's no way floats can reliably fall onto integers. You have to purposefully round and cast to int to get something reliable. No biggy but still. Check the good old 0.1 + 0.2 = 0.300000004 in python console, similar issues arise all the time for integers.

Same for chunk_length_ms

chunk_length_ms (`int`, *optional*, defaults to 0):
The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also think it would be good to still have the option "left" and "right" for stride since later when we want to use pipelines also in "real" streaming mode (e.g. only buffers are sent to the pipeline) then ideally we would want to have a long left stride (since this is past speech), but a shorter right stride since there we have a trade-off between model performance and "real-time" display.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you ok to leave that for subsequent PR, when the need arises ?

For streaming buffers, like microphone, chunking would be done elsewhere not here for instance since you don't want to wait to get large chunks to start inference, 50-100ms tops in terms of latency to get that "instantenous" feel. We may want to expose the striding mecanism still to get those "clean" reconstituion, but the chunking is IMO not here. For files or any static data there's not real overhead in terms of "real-time".

Personally, I would follow torch way here, meaning simple argument means both sides are the same, and tuple would mean (left, right). It means no extra argument (simpler) yet still doable when needed.

However I would leave that for a subsequent PR since I don't think it's that important right now (should be quite simple anyway).

Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely happy to leave that for a future PR, I'm just a bit worried here about backward compatibility.
I agree that for the use case enabled by the PR (offline chunked ASR), it's slightly better to have just a single "stride_length_ms" argument because it's not necessary to have different strides and fewer arguments is better for the user.

However if we do need a different stride left and right then we would have to deprecate this argument in favor of stride_length_left_s and stride_length_right_s no? Which is what I'm a bit worried about.

We would need this however only once we implement the "online" streaming logic. Here is how I image it to work:

    1. audio -> microphone -> buffer until chunk_length_s + stride_length_right_s is reached, e.g. 0.2s + 0.05s -> forward array to pipelines e.g. buffer[0:4000] (with 16kHz) -> get output logits (or string) -> cut 0.05s of output logits (or string) -> store cut output logits or ids and display decoded text. -> emtpy buffer except right padding (0.05s)
    1. audio -> microphone buffer until chunk_length_s + stride_length_s is reached, e.g. 0.2s + 0.05s, -> concatenate leftover buffer with new buffer -> forward array to pipelines (now including stride_length_left_ms (e.g. 1s)) so this time e.g. buffer[0:7200] -> cut 50ms of output logits (or string) -> overwrite stored output ids -> display
    1. ....
    1. ....
    1. This process would be repeated (the forwarded buffer would be starting from [0: ...] until left stride + chunk size + right stride > received input so far. When this happens we would forward something like [1000:20,000] not starting from [0: ...]to the pipeline and then for the output we would have to retrieve the last stored output logits/ids and concatenate them with the currently model output logits/ids and display the decoded concatenated outputs.
    1. ....
    1. .... at some point the output logits or output ids are getting too large so that we'll need to completely throw away some stored logits, but I think this is only need at sizes of a minute or so which is probs too much to read anyways so it should be rather simple to just not display some words anymore.

Now the interesting question now is (and I'm open to both options here) whether the "cutting" logic of both the input (in case input > then left stride + chunk size + right stride) and the "cutting" logic of the output logits/ids (or strings if not possible or ugly to not decode in the pipeline) should or should not happen inside the pipeline.

a.) If it's not the responsibility of the pipeline then IMO the pipeline should just receive an input (the cut buffer) and output ideally the output logits or the output string and something else is responsible for correctly cutting the logits.
b.) If however we think it is the responsibility of the pipeline to cut both the input buffer and the output logits then I think we do need both left and right stride arguments.

If we want to go for b.), we would e.g. pass the arguments chunk_length_s=0.2 and stride_length_right_s=0.05 in step 1. above and pass the arguments chunk_length_ms=0.2, stride_length_right_s=0.05 and stride_length_left_s=1 for step 2.

Having written down all of this, I actually think a.) might make more sense here since there does seem to be a lot of logic which shouldn't necessarly be in transformers -> so I think I'm good with just having a single stride_length_s argument.

Curios to hear your thoughts here @Narsil @anton-l

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this PR - the logic looks good to me and I think this PR nicely enables to transcribe huge audio files in offline mode.

I left two comments above which I think we could iterate on :-)

Once this feature nicely works I think we could think in a nice step how we could extend this to the streaming case where we don't have the full audio file from the very beginning, but nevertheless would like to make use of stride to improve performance (In which case I think it would be nice to have more left stride as right stride)

@Narsil
Copy link
Contributor

Narsil commented Dec 29, 2021

@patrickvonplaten Could you review #14250 too ? (it needs a rebase but I think it would be nice if both PRs are available roughly at the same time.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge this as it is. @Narsil - wanna go ahead and then we look at #14250 ?

@Narsil
Copy link
Contributor

Narsil commented Jan 3, 2022

Let's do that.

@Narsil Narsil merged commit 38f95d1 into huggingface:master Jan 3, 2022
@patrickvonplaten
Copy link
Contributor

Would be nice if we could make a short blog post showcasing how this feature works! Think we'll just need an hour long audio clip or so (maybe a clean speech from someone, e.g. US president) and then a couple of lines of code :-)

@patrickvonplaten
Copy link
Contributor

Probably don't even need the blog post. Think an entry here: https://discuss.huggingface.co/ would be enough

token_n = tokens.shape[-1]
left_token = int(left / input_n * token_n)
right_token = int((input_n - right) / input_n * token_n) + 1
tokens = tokens[:, left_token:right_token]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the cleanest way of doing it I think @Narsil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can avoid it I would rather not copy the tensors.

stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Jan 6, 2022
* Naive ASR chunking

* Fixing batching for ASR.

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* Naive ASR chunking

* Fixing batching for ASR.

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants