Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptation guidelines for Megatron v2.4 #61

Closed
ymjiang opened this issue Jul 12, 2021 · 6 comments
Closed

Adaptation guidelines for Megatron v2.4 #61

ymjiang opened this issue Jul 12, 2021 · 6 comments
Labels
good first issue Good for newcomers

Comments

@ymjiang
Copy link
Contributor

ymjiang commented Jul 12, 2021

Hi developers,

It seems that current patch for v2.2 no longer works directly for v2.4. I tried to migrate the code line by line, but here's the error log during runtime:

Traceback (most recent call last):
  File "/root/Megatron/pretrain_gpt.py", line 189, in <module>
    args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
  File "/root/Megatron/megatron/training.py", line 124, in pretrain
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  File "/root/Megatron/megatron/training.py", line 323, in setup_model_and_optimizer
    model = get_model(model_provider_func)
  File "/root/Megatron/megatron/training.py", line 269, in get_model
    for model_module in model]
  File "/root/Megatron/megatron/training.py", line 269, in <listcomp>
    for model_module in model]
TypeError: __init__() takes 2 positional arguments but 4 were given

Is there any guideline for me to fmoefy megatron-v2.4? Thanks.

@laekov
Copy link
Owner

laekov commented Jul 12, 2021

That error is caused by modification of the interface of forward_step function in Megatron's this commit. You can update patch_forward_step function in fmoe/megatron/balancing.py to fix this issue.

@laekov
Copy link
Owner

laekov commented Jul 12, 2021

I suppose the following modification should work.

diff --git a/fmoe/megatron/balance.py b/fmoe/megatron/balance.py
index 4a4f5db..a5769c7 100644
--- a/fmoe/megatron/balance.py
+++ b/fmoe/megatron/balance.py
@@ -84,9 +84,12 @@ def patch_forward_step(forward_step_func):
     if not get_args().balance_strategy:
         return forward_step_func
 
-    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
+    def forward_step_with_balance_loss(data_iterator, model, input_tensor='not given'):
         args = get_args()
-        output = forward_step_func(data_iterator, model, input_tensor)
+        if input_tensor == 'not given':
+            output = forward_step_func(data_iterator, model)
+        else:
+            output = forward_step_func(data_iterator, model, input_tensor)
 
         if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
             return output

@ymjiang can you try this and submit a PR if you make FastMoE work with Megatron 2.4?

@xptree xptree added the good first issue Good for newcomers label Jul 12, 2021
@ymjiang
Copy link
Contributor Author

ymjiang commented Jul 13, 2021

I suppose the following modification should work.

diff --git a/fmoe/megatron/balance.py b/fmoe/megatron/balance.py
index 4a4f5db..a5769c7 100644
--- a/fmoe/megatron/balance.py
+++ b/fmoe/megatron/balance.py
@@ -84,9 +84,12 @@ def patch_forward_step(forward_step_func):
     if not get_args().balance_strategy:
         return forward_step_func
 
-    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
+    def forward_step_with_balance_loss(data_iterator, model, input_tensor='not given'):
         args = get_args()
-        output = forward_step_func(data_iterator, model, input_tensor)
+        if input_tensor == 'not given':
+            output = forward_step_func(data_iterator, model)
+        else:
+            output = forward_step_func(data_iterator, model, input_tensor)
 
         if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
             return output

Still having the same issue after applying this.

@laekov
Copy link
Owner

laekov commented Jul 14, 2021

Try this? Seems like there are great many changes on its interface.

diff --git a/fmoe/megatron/distributed.py b/fmoe/megatron/distributed.py
index 9f8685e..ec1aa6d 100644
--- a/fmoe/megatron/distributed.py
+++ b/fmoe/megatron/distributed.py
@@ -23,7 +23,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
     Fast MoE.
     """
 
-    def __init__(self, module):
+    def __init__(self, module, _1, _2):
         from megatron import mpu
 
         super().__init__(

@ymjiang
Copy link
Contributor Author

ymjiang commented Jul 14, 2021

Thank @laekov. I am afraid I won't be able to test it shortly. Will come back and update once I do have time.

@laekov
Copy link
Owner

laekov commented May 30, 2023

Megatron-LM 2.5 is supported in v1.0.1

@laekov laekov closed this as completed May 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants