Skip to content

Distill megatron - test Draft WIP#352

Closed
younesbelkada wants to merge 175 commits intobigscience-workshop:mainfrom
younesbelkada:distill_megatron
Closed

Distill megatron - test Draft WIP#352
younesbelkada wants to merge 175 commits intobigscience-workshop:mainfrom
younesbelkada:distill_megatron

Conversation

@younesbelkada
Copy link
Copy Markdown

@younesbelkada younesbelkada commented Sep 28, 2022

An attempt to perform knowledge distillation using Megatron-DeepSpeed

disclaimer: this is a super ugly version of the code, the PR is here to compare the difference between the original code and this modified version - for now I don't plan to merge this PR

Updates on 28.09.2022

This version is very ugly, I had to add an argument student_ on all megatron modules since the arguments are directly retrieved from the global variable. The other solution could be to have each class re-written with the suffix Student - eg GPTModelStudentPipe. I preferred the first solution to have a quick working implementation.

The forward and backward pass seems to pass for the student model - for now I am not computing the teacher's logits..
Two solutions for that

1 - In distill_train_step - add a step where we retrieve the teacher's logits. In this case would we need to change the deepspeed.PipelineEngine internals?
2- Store the embedding layer of the teacher model inside the student model and gather the last hidden states of the teacher model. Once this is gathered apply the forward pass with the embedding layer of the teacher model to get the logits. (cc @thomasw21 as discussed offline)

main TODOs

  • load teacher model from a checkpoint - make sure to use a copy of the chkpt
  • load student model from a checkpoint - make sure to use a copy of the chkpt
  • make the broadcasting of the teacher logits (or last hidden states) work.
  • try on 176

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant