Skip to content

Support skip iteration flag#177

Merged
stas00 merged 42 commits intomainfrom
skip-iterations
Nov 17, 2021
Merged

Support skip iteration flag#177
stas00 merged 42 commits intomainfrom
skip-iterations

Conversation

@jaketae
Copy link
Copy Markdown
Member

@jaketae jaketae commented Nov 3, 2021

This PR resolves #175.

  • Support relevant argument flag
  • Add continue logic to skip iterations
  • Update counter and number of consumed samples/tokens
  • Add tests
  • Logging

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 4, 2021

@stas00 I have a few questions I'd like to discuss.

  1. I realized that Meg-DS keeps track of skipped_iter, which is also reported in stdout. Should we increment this when we skip iterations in the specified range? To me, skipped_iter should be distinct from what we're doing here, since we are deliberately skipping according to user demand, whereas skipped_iter is incremented if something goes wrong (i.e. if update_successful is False).

  2. I created (copied and pasted) code from other tests to check that providing arguments to the skip flag doesn't break. But this is a very minimal test. What are your thoughts on having something be logged out to stdout (e.g. skipped iterations from x to y), then call self.assertIn("skipped iterations from", cs.out)? Or perhaps would there be a better approach? I want to make sure that appropriate counters are updated as needed.

Thank you!


P.S. Very happy that CI is working (though it took 10+ minutes to finish)!

@jaketae jaketae marked this pull request as ready for review November 4, 2021 19:48
@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 4, 2021

P.S. Very happy that CI is working (though it took 10+ minutes to finish)!

Yes! Other then really slow instance booting, the main overhead is this: it takes like 5min to compile the cuda kernels, which happens on every CI run. If you want to try to speed it up see:
#174

And CI only works on non-forked branches.

Copy link
Copy Markdown
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually let's look at what it does for the logic of skipped iterations when it doesn't step, does it update any counters?

that would probably be the best guide for us.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 4, 2021

@stas00 I have a few questions I'd like to discuss.

  1. I realized that Meg-DS keeps track of skipped_iter, which is also reported in stdout. Should we increment this when we skip iterations in the specified range? To me, skipped_iter should be distinct from what we're doing here, since we are deliberately skipping according to user demand, whereas skipped_iter is incremented if something goes wrong (i.e. if update_successful is False).

This is a good consideration. Here is my take on it:

this counter has a different logical purpose. if you were to log every 100 iterations it then tells you if any iterations were skipped due to internal logic, so you know something was off - e.g. loss scale was too big. i.e. this is the framework reporting to the user that it did something that user needs to know about.

with skip_train_iteration_range that is prescribed by the operator and thus this is a different feature, albeit it too skips iterations.

  1. I created (copied and pasted) code from other tests to check that providing arguments to the skip flag doesn't break. But this is a very minimal test. What are your thoughts on having something be logged out to stdout (e.g. skipped iterations from x to y), then call self.assertIn("skipped iterations from", cs.out)? Or perhaps would there be a better approach? I want to make sure that appropriate counters are updated as needed.

The logging is part of the spec. So absolutely yes.

But this is not sufficient for testing that it actually works. That will only test that the right branch of code was invoked.

As I wrote in the spec you actually want to test that the correct iteration x/ y is logged and iteration z/ y is not logged because z is on the skipped list. Does it make sense? This is the real test.

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 5, 2021

@stas00 Some updates on the changes I've made since yesterday:

  1. Merge intervals: Although this probably an edge case, I wrote some code that merges overlapping intervals (e.g., 1-5 2-7 produces [[1, 7]] instead of two intervals. I assume this will be useful when we want get lazy and combine bad data intervals found in different experiments.
  2. Flush intervals: The edge case you mentioned. I used built-in binary search to find the earliest relevant interval, and "flush" everything before it as they are irrelevant. This happens before the while loop.
  3. Check condition: Instead of iteration <= end, I added start <= iteration <= end for more robust checking.

I will spend more time looking at logging and testing moving forward. Thank you!

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 5, 2021

As I wrote in the spec you actually want to test that the correct iteration x/ y is logged and iteration z/ y is not logged because z is on the skipped list. Does it make sense? This is the real test.

@stas00 I've added logging and testing based on the specs and your earlier comment. I've also removed some repeats by using @parameterized.expand. Some considerations:

  1. In the test, skip iterations are hard-coded, and so are the asserts. Is this okay?
  2. Should validation occur within skip intervals? Say we validate every 5 iterations, and we skip 7-11. On the 10th iteration (which will be skipped), should we run validation as scheduled?
  3. The intricacy, as you noted in the spec, is to make sure that consumed_train_samples (and consumed_valid_samples) are updated properly. I don't think the current test accounts for this. Is there something we can check to make sure that these values have been updated as expected?

Thank you!

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 6, 2021

As I wrote in the spec you actually want to test that the correct iteration x/ y is logged and iteration z/ y is not logged because z is on the skipped list. Does it make sense? This is the real test.

@stas00 I've added logging and testing based on the specs and your earlier comment. I've also removed some repeats by using @parameterized.expand. Some considerations:

  1. In the test, skip iterations are hard-coded, and so are the asserts. Is this okay?

I'm not 100% sure what you are asking about - the exact format match rather than some sort of regex? No, that's fine.

  1. Should validation occur within skip intervals? Say we validate every 5 iterations, and we skip 7-11. On the 10th iteration (which will be skipped), should we run validation as scheduled?

I'd say let's not worry about it. This is not a normal functionality.

  1. The intricacy, as you noted in the spec, is to make sure that consumed_train_samples (and consumed_valid_samples) are updated properly. I don't think the current test accounts for this. Is there something we can check to make sure that these values have been updated as expected?

of course, it's in the logs, e.g. example: consumed samples == consumed_train_samples - don't worry about valid.

 iteration     5285/  159576 | consumed samples:       135552 | elapsed time per iteration (ms): 21426.5 | learning rate: 3.752E-05 | global batch size:    48 | lm loss: 4.109760E+00 | loss scale: 131072.0 | grad norm: 76870.639 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 iteration     5286/  159576 | consumed samples:       135600 | elapsed time per iteration (ms): 20871.4 | learning rate: 3.754E-05 | global batch size:    48 | lm loss: 4.163858E+00 | loss scale: 131072.0 | grad norm: 71007.567 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 iteration     5287/  159576 | consumed samples:       135648 | elapsed time per iteration (ms): 20712.1 | learning rate: 3.755E-05 | global batch size:    48 | lm loss: 4.236997E+00 | loss scale: 131072.0 | grad norm: 69662.875 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 11, 2021

Oh, I know what the problem is - args.curriculum_seqlen doesn't change in the skip code. ok figured it out.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 11, 2021

OK, succeeded.

I left the debug code for now if we choose to change the logic before we merge this.

Please think about this one:

#177 (comment)

As it's critical we get it right.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 11, 2021

I'm trying your PR on a live server to try to salvage the CL training.

We have another bug here, it fails if we skip the first iteration, since args.curriculum_seqlen is not yet set!

I fixed it with:

                if args.curriculum_learning and hasattr(args, "curriculum_seqlen"):
                    args.consumed_train_tokens += new_samples * args.curriculum_seqlen

new_samples = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += new_samples
Copy link
Copy Markdown
Contributor

@conglongli conglongli Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaketae @stas00 I think we should not accumulate the args.consumed_train_samples, args.consumed_train_tokens, iteration and args.iteration at here for two reasons: 1) If we just skipped the data but still count the steps, samples and tokens, it could lead to undesirable behavior for those techniques that reply on these stats. For example curriculum learning replies on step to calculate current seqlen. 2) DeepSpeed engine itself will keep counting global step when the train step is called. So if we only increment the step on user side without calling train step to ds engine, it will generate a global step mismatch which is also an issue.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the incarnation that I'm running at the moment on the CL experiment on JZ:

            start, end = args.skip_train_iteration_range.popleft()
            print_rank_0(f"RANGE: {start} {end}")
            print_rank_0(f"iteration {args.iteration}")
            print_rank_0(f"Skipped iterations {start} {end} due to --skip-iterations flag.")
            iteration_for_skipping = args.iteration
            while iteration_for_skipping + 1 <= end:
                try:
                    _ = next(train_data_iterator)
                except TypeError:
                    pass
                iteration_for_skipping += 1
            continue

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 12, 2021

Reposting some potentially important conversations from Slack for documentation/future reference.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 12, 2021

I backported this feature into the tr8-104B branch. 7a0158e

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 14, 2021

@stas00 I see in the backported commit that you implemented the additional counter discussed in Slack. I'm wondering how it plays with checkpointing, etc. Let me know if there's something I can contribute (i.e. perhaps bringing in the changes in your backported commit into this branch before the final merge)!

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 14, 2021

Since the new version silently skips the data, that's all there is to it. No need to do anything else.

But the problem is that we don't know when the data was skipped. So if down the road we want to extract a sample, it will be incorrect.

So we either want to keep a different counter of real number of samples drawn (which would require a lot of extra work) or adjust the sample_idxs_to_text.py script to support --skip-train-iteration-range flag. Which is probably simpler.

So perhaps let's do that the latter?

Of course our elaborate test will have to be cut down to a much simpler now.

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 15, 2021

@stas00 I agree that the easiest way to examine samples would be to modify sample_idxs_to_text.py. I have some questions.

  1. Does CL affect anything at all? I know that seq_len progressively changes in the CL setup and was wondering if the current script accounts for this.
  2. Let's say the user wants to see a dump of data from iterations 10 to 12. But we had skipped iterations 2 to 4 in training. Does that mean the actual index that we want to retrieve is from 13 to 15?
  3. Can we assume that the skip iterations flag will only contain relevant intervals? That is, say the user wants iterations from 8 to 10. Is it possible that the flag contains irrelevant skips, such as 100 to 110?
  4. Would it make sense to work on modifying this script in a separate PR?

Thank you!

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 15, 2021

@stas00 I agree that the easiest way to examine samples would be to modify sample_idxs_to_text.py. I have some questions.

  1. Does CL affect anything at all? I know that seq_len progressively changes in the CL setup and was wondering if the current script accounts for this.

It has its own counters, so it doesn't affect anything, not even the general seq_len counter.

  1. Let's say the user wants to see a dump of data from iterations 10 to 12. But we had skipped iterations 2 to 4 in training. Does that mean the actual index that we want to retrieve is from 13 to 15?

Exactly

  1. Can we assume that the skip iterations flag will only contain relevant intervals? That is, say the user wants iterations from 8 to 10. Is it possible that the flag contains irrelevant skips, such as 100 to 110?

Let's start with making simple assumptions and not worry about edge cases here.

  1. Would it make sense to work on modifying this script in a separate PR?

Yes please.

Bottom line let's finish this PR with:

  1. Let's overwrite the contents of the original while loop with the silent skip as in 7a0158e

  2. Adjust the tests:

  • remove parameterized - as now we have just the base
  • remove the last part of the test - we can't really test anything qualitatively anymore other that it just runs. Well, perhaps we can test that we still get the full normal range of iterations logged - i.e., 1, 2, 3, ..., 10. and none are missing.
  1. create an issue to adjust sample_idxs_to_text.py to support the skipping using the exact same API (and code)

  2. Then one of us can implement that issue.

@jaketae
Copy link
Copy Markdown
Member Author

jaketae commented Nov 16, 2021

@stas00 Pushed the updates + opened a new issue #189 . Please feel free to edit or comment as you see fit. Thank you!


There seems to be an issue with deepspeed dependency. Traceback from GitHub Actions:

fatal: remote error: upload-pack: not our ref a105ae87388acc20630ea4c32930eb0d2b20d06e
fatal: the remote end hung up unexpectedly
Fetched in submodule path 'DeepSpeedExamples', but it did not contain a105ae87388acc20630ea4c32930eb0d2b20d06e. Direct fetching of that commit failed.

...
truncated
...

ERROR: Could not find a version that satisfies the requirement deepspeed (unavailable) (from versions: 0.3.1.dev1, 0.3.1.dev2, 0.3.1.dev3, 0.3.1.dev4, 0.3.1.dev5, 0.3.1.dev6, 0.3.1.dev7, 0.3.1.dev8, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.5.0, 0.5.1, 0.5.2, 0.5.3, 0.5.4, 0.5.5, 0.5.6)
28
ERROR: No matching distribution found for deepspeed (unavailable)
29
Error: Process completed with exit code 1.

Is this something that just happens from time to time?

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 17, 2021

This is strange indeed. I will ask Jeff.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Nov 17, 2021

Jeff fixed the DSE issue, but we now discovered a new issue in deepspeed@master - should be fixed soon.

Waiting for the merge of deepspeedai/DeepSpeed#1569

@stas00 stas00 merged commit 106a9a6 into main Nov 17, 2021
@stas00 stas00 deleted the skip-iterations branch November 17, 2021 17:29
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Aug 16, 2023
* Llama 2 GQA

* llama2 pretrain demo

* GQA minor fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[new feature] skip iterations X-Y

3 participants