Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in GPT2 Inference Example #364

Open
JianzheXiao opened this issue Mar 13, 2024 · 3 comments
Open

Bugs in GPT2 Inference Example #364

JianzheXiao opened this issue Mar 13, 2024 · 3 comments

Comments

@JianzheXiao
Copy link

JianzheXiao commented Mar 13, 2024

  1. There is no Moe Inference example in Example, even though the https://www.deepspeed.ai/tutorials/mixture-of-experts-inference/ blog provides the link to generate_text.sh, but it's a normal GPT2 model which num_expert=1
  2. There are two bugs, first one in line 82, in generate_samples_gpt.py
    # group.add_argument("--local_rank", type=int, default=0, help='local_rank')
    you have to comment these two lines to enable it to continue

the second is in text_generation_tuils.py, line 466

output_tensor = model(tokens, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output)

here the code will provide the layer_past and get_key_value to the model, but in the example you provided, you use GPTModel in gpt_model.py which does not contain any of the args above

def forward(self, input_ids, position_ids, attention_mask, retriever_input_ids=None, retriever_position_ids=None, retriever_attn_mask=None, labels=None, tokentype_ids=None, inference_params=None, curriculum_seqlen=None):

is there a quick way to fix this problem?

@nikit91
Copy link

nikit91 commented Mar 14, 2024

I am facing a similar problem. I have managed to bypass this issue by loading the model using the method implemented at https://github.com/microsoft/Megatron-DeepSpeed/blob/main/tasks/eval_harness/evaluate.py#L410. However, I am now stuck at the part where when I try to use the deepspeed.init_inference on the model. It seems to be unable to transpose weights somewhere in the code because it expects a normal MLP instead of the MoE layer. Which I think points to missing implementation. I am currently looking into this: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/moe_inference.py to see it makes sense to use this somehow.

@nikit91
Copy link

nikit91 commented Mar 14, 2024

Investigating a bit further, I find that during inference they pick from one of the predefined model implementations here:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/utils.py#L24-L37

However, the implementation that we need (DS_MegatronGPTMoEContainer -> https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/containers/megatron_gpt_moe.py) is not mapped to that dictionary (or anywhere else for that matter).

@haoranlll
Copy link

I'm facing a similar problem: I trained the gpt-125M-MoE64 model by the script ds_pretrain_gpt_125M_MoE64.sh https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples_deepspeed/MoE/ds_pretrain_gpt_125M_MoE64.sh.
But there is no usable script for the model inference. How to use the trained MoE-model for inference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants