Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BERT EXTRACTION: Unable to reproduce results on MNLI #52

Closed
ahmadrash opened this issue Mar 30, 2020 · 10 comments
Closed

BERT EXTRACTION: Unable to reproduce results on MNLI #52

ahmadrash opened this issue Mar 30, 2020 · 10 comments

Comments

@ahmadrash
Copy link

Following the scrips I trained a teacher model successfully, generated the extraction data and ran the knowledge distillation on teacher logits. On evaluation on the dev set I am getting 0.319 eval accuracy.

@martiansideofthemoon
Copy link

Hi Ahmad, thanks for your interest! An accuracy of 31.9% indicates worse than random guess performance. A few questions to help debug this,

  1. What is the class distribution of the extracted data?
  2. What scheme did you use, RANDOM / WIKI?
  3. What was the dev set accuracy of the teacher model?

@ahmadrash
Copy link
Author

Thanks a lot for the prompt response.

  1. The class distribution is [26.76%, 26.31%, 46.93%] respectively
  2. I used DATA_SCHEME="random_ed_k_uniform"
  3. Dev set accuracy of the teacher model is 0.851

@martiansideofthemoon
Copy link

Hi Ahmad,
1,2 and 3 look good to me. A few more follow-up questions,

  1. I guess you are using BERT-large?
  2. Are you using this file to train the student model? https://github.com/google-research/language/blob/master/language/bert_extraction/steal_bert_classifier/models/run_classifier_distillation.py
  3. Is the training loss decreasing? (just confirming if the weight updates are happening)
  4. Does the same script work for SST2 / SQuAD?

@ahmadrash
Copy link
Author

Thanks Kalpesh,

  1. Yes I am using BERT_large

  2. Yes I am using the file.

  3. I am adding the loss curve from Tensorboard. It shows oscillations.
    mnli_loss

  4. I am still running the other experiments.

@martiansideofthemoon
Copy link

martiansideofthemoon commented Mar 30, 2020

regarding your curve, how many epochs are you training it for / what's your batch size? A loss of 1.1 indicates nothing is being learnt, but I do see a strong decrease after the first few ~10k steps. Also, what is your learning rate, optimizer and learning rate schedule? Finally, what hardware are you using?

@ahmadrash
Copy link
Author

I am training it for 3 epochs. I have a batch size of 8 on an NVIDIA V100 GPU. The learning rate,optimizer and schedule are the default in the script.

--learning_rate=3e-5
--warmup_propotion=0.1

And optimizer is same as the default for BERT

@martiansideofthemoon
Copy link

martiansideofthemoon commented Mar 30, 2020

I think the batch size might be the issue, learning is less stable for RANDOM than the original MNLI, and smaller batch sizes (hence weaker gradient estimates) could put the model off the optimization path. I'd recommend trying batch size 32. If it doesn't fit on the GPU, you could try using BERT-base or gradient accumulation.

Another thing you could try is a learning rate decay. From your graph, it is clear that the training loss reduces during the warmup phase of training, but then the learning rate is too high and a bad gradient (from a small batch) can put off the optimization. You could also simply try smaller learning rates, maybe 1e-5

@ahmadrash
Copy link
Author

Thanks a lot for the suggestions. I will try these and report back.

@ahmadrash
Copy link
Author

Thanks Kalesh! I was able to to get 78 on MNLI dev and 90 on SST-2 reducing the learning rate to 1e-5. The loss curve still is not ideal but much better than what we were seeing before.

@Jimntu
Copy link

Jimntu commented Jul 25, 2022

Hi, I am a beginner in deep learning and have little experience in implementing the code. May I ask how can you draw the loss curve from tensorboard? I would really appreciate if you can help me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants