Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConstantLengthDataset does not return the right length #943

Closed
edbeeching opened this issue Nov 2, 2023 · 7 comments
Closed

ConstantLengthDataset does not return the right length #943

edbeeching opened this issue Nov 2, 2023 · 7 comments

Comments

@edbeeching
Copy link
Collaborator

We noticed that when training with longer sequence lengths and packing=True that the estimated steps for an epoch can be far lower than expected. cc @lewtun
Example:
image
The root cause of this appears to be how the length of the ConstantLengthDataset is calculated, currently it returns the length of the unpacked dataset.

return len(self.dataset)

Minimal example:

from datasets import load_dataset
from transformers import AutoTokenizer
from trl.trainer.utils import ConstantLengthDataset

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

packed_dataset = ConstantLengthDataset(
    tokenizer=tokenizer,
    dataset=dataset,
    dataset_text_field="text",
    num_of_sequences=10000,
)

length = len(packed_dataset)

actual_length = 0

for example in packed_dataset:
    actual_length += 1
    
print(f"{length=}")
print(f"{actual_length=}")
print(f"{actual_length==length=}")
length=9846
actual_length=4283
actual_length==length=False

Potential problems this causes

This may lead to the warmup steps and other step related options (linear, cosine) to be calculated incorrectly.

Potential Solution

Perform the packing upfront in __init__ method, return the len of the packed examples. Modify the __iter__ method to return the precomputed packed sequences. This may cause issues with large datasets, small buffers and the infinite option.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Nov 8, 2023

Thanks @edbeeching for the report! I think this is really a seq_length problem:

from datasets import load_dataset
from transformers import AutoTokenizer
from trl.trainer.utils import ConstantLengthDataset

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
print(f"{len(dataset)=}")    

for seq_length in [512, 1024, 2048]:
    packed_dataset = ConstantLengthDataset(
        tokenizer=tokenizer,
        dataset=dataset,
        dataset_text_field="text",
        # num_of_sequences=10000,
        seq_length=seq_length,
    )
    actual_length = 0
    for example in packed_dataset:
        actual_length += 1
    print(f"{seq_length=}, {actual_length=}")

len(dataset)=9846
seq_length=512, actual_length=8563
seq_length=1024, actual_length=4281
seq_length=2048, actual_length=2141

The larger the seq_length the more examples we can "pack" into a single sequence, the less 'packed_dataset''s actual lengths.

@lvwerra
Copy link
Member

lvwerra commented Nov 9, 2023

Yes, since we pack the dataset on the fly, we don't know in advance how many samples it will yield. We could add a precompute flag and run the processing in advance. This should be fine for smaller datasets for sure.

@tcapelle
Copy link
Contributor

tcapelle commented Nov 14, 2023

I am also curious about how Zephyr fine-tune worked out as it uses Cosine Scheduler and packing, so the total training steps is not known in advance. In my experiments, when doing this, you end up with a wrong cosine scheduler that looks like this:
image

@maneandrea
Copy link
Contributor

Probably not terribly useful, but I have encountered the same problem in the case where the input dataset consists of strings longer than the desired seq_length (so that more examples are yielded by a single entry in self.dataset). In this case I get a lot of warnings of the form:

UserWarning: Length of IterableDataset <trl.trainer.utils.ConstantLengthDataset object at 0x7f873937a950> was reported to be 1 (when accessing len(dataloader)), but 3 samples have been fetched.
UserWarning: Length of IterableDataset <trl.trainer.utils.ConstantLengthDataset object at 0x7f873937a950> was reported to be 1 (when accessing len(dataloader)), but 4 samples have been fetched.
...

(the numbers are so low because this was a very small example for debugging)

@alvarobartt
Copy link
Member

alvarobartt commented Dec 21, 2023

Sharing this here in case it's useful for anyone 🤗 (forgot to share before)

It does calculate the number of steps that equal to an epoch when using ConstantLengthDataset, and it's feasible if the dataset is not huge, otherwise it can be non-efficient, but at least gets the job done on successfully calculating the required steps before triggering the training, as using epoch: 3 may mess things up as @tcapelle pointed in the main issue of the thread.

https://gist.github.com/alvarobartt/d08888dd2660b6763421dd6b1142127c

@lvwerra
Copy link
Member

lvwerra commented Dec 21, 2023

I think these issues should have been fixed in #979 - now the packed dataset is precomputed and the length/epoch should match what is provided at the cost of a small overhead at the beginning to process the dataset.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants