Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance TextDataset for LM training #2202

Closed
r0mainK opened this issue Apr 8, 2021 · 3 comments
Closed

Enhance TextDataset for LM training #2202

r0mainK opened this issue Apr 8, 2021 · 3 comments
Labels
wontfix This will not be worked on

Comments

@r0mainK
Copy link
Contributor

r0mainK commented Apr 8, 2021

Please add the appropriate label to this ticket: enhancement.

Is your feature/enhancement request related to a problem? Please describe.

Currently the TextDataset class in the language_model_trainer.py file is suboptimal, and quite slow. This is due to the usage of for loops instead of comprehension. Furthermore, the way random case flipping is done is error prone for weird character, something which was swept under the rug by the code. Specifically, if random casing changes the length of a sentence, then it is not caught due to to the if token >= tokens: break statement, and results in tokens being lost. Since we do it after expanding the vocabulary, it can also lead to token being UNKed for no reason, or not added.

Describe the solution you'd like

I would like to refactor the code to make it faster and remove the errors related to casing mentionned above. Specifically, it would entail applying the case changes directly when reading the text, and then doing everything with comprehension. Unless there is reason to do otherwise, I would also like to remove the unused (to my knowledge) tokenize function, and merge the __getitem__ and charsplit methods, as I don't quite get why they are split.

Additional context

Although I did not directy use your trainer as it was a bit too much for my needs, I initially was using a close replica of the TextDataset. Modifying it as I described reduced loading time by an order of magnitude, simplified the code, and removed all errors. For reference, here is a snippet of what my __getitem__ currently looks like:

        with gzip.open(self.files[split_id]) as fin:
            lines = list(map(self.apply_random_case_change, jsonlines.Reader(fin)))
        if self.shuffle_lines:
            random.shuffle(lines)
        lines = (
            [self.char_to_ids.get(char.encode("utf-8"), self.unk_id) for char in line]
            + [self.delimiter_id]
            for line in lines
        )
        ids = torch.tensor([char_id for line in lines for char_id in line], dtype=torch.uint8)
        if not self.is_forward_lm:
            ids = ids.flip(0)
@alanakbik
Copy link
Collaborator

Hello @r0mainK looks very interesting. If you'd like to prepare a PR for this we'd appreciate it. Would be great to speed up training the LM!

@r0mainK
Copy link
Contributor Author

r0mainK commented Apr 8, 2021

Hey @alanakbik I just prepared the PR for this, feel free to review whenever :)

@stale
Copy link

stale bot commented Aug 6, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix This will not be worked on label Aug 6, 2021
@stale stale bot closed this as completed Aug 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants