-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model Distillation #1758
Model Distillation #1758
Conversation
…ai/haystack into distillation-refactored
…ai/haystack into distillation-refactored
…ai/haystack into distillation-refactored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait for the new results from the benchmark before merging the branch into master. Further, could you please add a small test case, maybe with tiny models and a small dataset just to run the code as one of the tests and check for example that the weights of the student model change after training? We can talk about that in a call if you want.
haystack/modeling/training/base.py
Outdated
teacher_logits = [batch.pop(key) for key in keys] | ||
logits = self.model.forward(**batch) | ||
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) | ||
logit_difference_loss = self.distillation_loss_fn(logits[0], teacher_logits[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use named arguments in method calls so that this line becomes:
logit_difference_loss = self.distillation_loss_fn(student_logits=logits[0], teacher_logitsteacher_logits[0])
As a result, the code won't break when the order of arguments in the implementation of distillation_loss_fn
changes and it's easier to read the code. (In haystack's code base we use named arguments in almost every method that has multiple arguments for these reasons.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I changed that.
haystack/modeling/training/base.py
Outdated
def _kl_div(self, student_logits, teacher_logits): | ||
student_log_probs = F.log_softmax(student_logits, dim=-1) | ||
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) | ||
return self.kl(student_log_probs, teacher_log_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could use KLDivLoss(reduction="batchmean", log_target=True)
here directly instead of self.kl
and then there would be no need to define self.kl
as it isn't used anywhere else. If you make this change, we can get rid of line 694 self.kl = KLDivLoss(reduction="batchmean", log_target=True)
in the elif branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed it now so it now uses the functional api. I hope this also addresses this issue.
haystack/nodes/reader/farm.py
Outdated
|
||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them | ||
# and calculates a few descriptive statistics of our datasets | ||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes) | ||
if teacher_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are only very few new comments in the code. At this line it's definitely worth adding a comment on the if/else statement and its consequences. I would like it if there were more comments in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now added this comment and a few others.
@@ -0,0 +1,99 @@ | |||
from haystack.nodes import FARMReader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FARMReader import is a duplicate here and import os
is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now removed the imports.
…ai/haystack into distillation-refactored
I have added the test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found one unused import. Other than that, I think it's ready to be merged depending on the benchmark results. 👍
Maybe we could check the distilled model's QA predictions in the test case rather than only the change of weights but I guess the tiny dataset won't result in any bigger changes of the distilled model.
haystack/modeling/training/base.py
Outdated
@@ -9,6 +9,9 @@ | |||
from tqdm import tqdm | |||
from pathlib import Path | |||
|
|||
from torch.nn import KLDivLoss, MSELoss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KLDivLoss
is not used anymore so there is no need for this import.
I have now done a few tests and the best I could find so far was an increase of about 3 percentage points in EM when distilling deepset/bert-large-uncased-whole-word-masking-squad2 to prajjwal1/bert-medium compared to just training the student without distillation. |
Great to see the performance improvements now that the bug with the distillation loss calculation is fixed. The PR is ready to merge from my side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Great job! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Only left a few comments around documentation. Feel free to merge once those are adressed
@@ -722,3 +727,68 @@ def get_dict_checksum(payload_dict): | |||
""" | |||
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest() | |||
return checksum | |||
|
|||
class DistillationDataSilo(DataSilo): | |||
def __init__(self, teacher_model: "FARMReader", teacher_batch_size: int, device, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to add a short docstring here explaining the need for a special data silo and type hints for all params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the type hint for device. Do you also want me to add type hints for kwargs (i.e. write them out)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think it would be helpful to write them out as it will enable autocomplete in the IDE which can be helpful when users initialize a DistillationDataSilo
but are not sure which params are expected.
kwargs["max_processes"] = 1 # fix as long as multithreading is not working with teacher attribute | ||
super().__init__(**kwargs) | ||
|
||
def _run_teacher(self, batch, corresponding_chunks, teacher_outputs, tensor_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here type hints would be helpful :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the type hints.
@@ -596,3 +602,111 @@ def _all_ranks_have_data(self, has_data: bool, step: Optional[int] = None): | |||
return False | |||
else: | |||
return True | |||
|
|||
class DistillationTrainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort docstring explaining the purpose of this class and ideally a short code example would be helpful for the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the docstring.
haystack/modeling/training/base.py
Outdated
data_silo: DataSilo, | ||
epochs: int, | ||
n_gpu: int, | ||
device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add type hints :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the type hints. However, for lr_scheduler I needed to import a private class from pytorch. I'm not sure if that's desirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're referring to something like from torch.optim.lr_scheduler import _LRScheduler
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If mypy doesn't complain about the annotation with _LRScheduler
, which I think it doesn't, we can use it. Otherwise, type ignore is the next best solution here in my opinion.
@tholor What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the import looks okay to me - if that's the type we expect here it is what it is ;) .
If mypy complains we can drop it or add the ignore comment
haystack/modeling/training/base.py
Outdated
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable | ||
:param distillation_loss_weight: The weight of the distillation loss | ||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits) | ||
:param temperature: The temperature for distillation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please describe briefly the effects of the params (e.g. a higher temperature results in ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I have also added this description to the distil_from method.
""" | ||
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset | ||
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model | ||
using a more complex teacher. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a short code example here including a reasonable combination of models (to show usage and clarify that the student can be another pretrained model)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay
…ai/haystack into distillation-refactored
This adds model distillation as described in #1551.
A new method called distil_from is added to the FARMReader. This method takes mostly the same parameters as train and just has additional parameters for the teacher model and to configure distillation.
The classes DistillationTrainer and DistillationDataSilo are also introduced as they are used by distil_from. The DistillationDataSilo already computes the logits of the teacher meaning they won't have to be recomputed each epoch.
The implemented approach has the limitation that the character to token mappings of student and teacher tokenizer need to be the same as comparing logits would not make sense otherwise.
This PR also includes a benchmark allowing to compare teacher performance, student performance without distillation (baseline) and student performance with distillation. It can be configured similarly to the other benchmarks using a json file.