Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid syntax error when unpacking *moe_losses in python-3.7 #24

Closed
adammoody opened this issue Dec 18, 2021 · 3 comments
Closed

Invalid syntax error when unpacking *moe_losses in python-3.7 #24

adammoody opened this issue Dec 18, 2021 · 3 comments

Comments

@adammoody
Copy link

I am trying to use the new MOE support from DeepSpeed 0.5.8 on a system with python 3.7.11. However, I get "invalid syntax" errors for all of the statements like:

75: 3:   File "/path/to/megatron/model/language_model.py", line 408
75: 3:     return encoder_output, pooled_output, *moe_losses
75: 3:                                           ^
75: 3: SyntaxError: invalid syntax

It seems like I can work around those errors with a tuple(...) statement. I hit this problems in at least the following places.

diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py
index 4eb983c..1efa1da 100644
--- a/megatron/model/gpt_model.py
+++ b/megatron/model/gpt_model.py
@@ -124,15 +124,23 @@ class GPTModel(MegatronModule):
             get_key_value=get_key_value)
 
         if self.post_process:
-            return post_language_model_processing(
+            #return post_language_model_processing(
+            #    lm_output, labels,
+            #    self.word_embeddings_weight(),
+            #    get_key_value,
+            #    self.parallel_output,
+            #    forward_method_parallel_output,
+            #    self.fp16_lm_cross_entropy), *moe_losses
+            return tuple(post_language_model_processing(
                 lm_output, labels,
                 self.word_embeddings_weight(),
                 get_key_value,
                 self.parallel_output,
                 forward_method_parallel_output,
-                self.fp16_lm_cross_entropy), *moe_losses
+                self.fp16_lm_cross_entropy), *moe_losses)
         else:
-            return lm_output, *moe_losses
+            #return lm_output, *moe_losses
+            return tuple(lm_output, *moe_losses)
 
     def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                        keep_vars=False):
diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py
index cb27498..e7853e6 100644
--- a/megatron/model/language_model.py
+++ b/megatron/model/language_model.py
@@ -405,9 +405,11 @@ class TransformerLanguageModel(MegatronModule):
         # similarity between two sequences by average pooling
         if not self.add_decoder or output_enc_hidden:
             if self.add_pooler and self.post_process:
-                return encoder_output, pooled_output, *moe_losses
+                #return encoder_output, pooled_output, *moe_losses
+                return tuple(encoder_output, pooled_output, *moe_losses)
             else:
-                return encoder_output, *moe_losses
+                #return encoder_output, *moe_losses
+                return tuple(encoder_output, *moe_losses)
 
         # Decoder Embedding
         dec_embedding_output = self.embedding(dec_input_ids,
@@ -421,9 +423,11 @@ class TransformerLanguageModel(MegatronModule):
                                       enc_dec_attn_mask=enc_dec_attn_mask)
 
         if self.add_pooler and self.post_process:
-            return decoder_output, encoder_output, pooled_output, *moe_losses
+            #return decoder_output, encoder_output, pooled_output, *moe_losses
+            return tuple(decoder_output, encoder_output, pooled_output, *moe_losses)
         else:
-            return decoder_output, encoder_output, *moe_losses
+            #return decoder_output, encoder_output, *moe_losses
+            return tuple(decoder_output, encoder_output, *moe_losses)
 
     def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                        keep_vars=False):
@awan-10
Copy link

awan-10 commented Feb 24, 2022

@adammoody if the merged PR fixes this issue on your end, please close this issue. And thank you for the PR :)

@conglongli
Copy link

I believe the merged PR fixes this issue, so closing this now but feel free to reopen if needed.

@adammoody
Copy link
Author

Yes, I can confirm the merged PR fixed this issue. Thanks!

hyoo pushed a commit to hyoo/Megatron-DeepSpeed that referenced this issue Apr 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants