Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Distillation #1758

Merged
merged 40 commits into from
Nov 26, 2021
Merged

Model Distillation #1758

merged 40 commits into from
Nov 26, 2021

Conversation

MichelBartels
Copy link
Contributor

This adds model distillation as described in #1551.
A new method called distil_from is added to the FARMReader. This method takes mostly the same parameters as train and just has additional parameters for the teacher model and to configure distillation.
The classes DistillationTrainer and DistillationDataSilo are also introduced as they are used by distil_from. The DistillationDataSilo already computes the logits of the teacher meaning they won't have to be recomputed each epoch.

The implemented approach has the limitation that the character to token mappings of student and teacher tokenizer need to be the same as comparing logits would not make sense otherwise.

This PR also includes a benchmark allowing to compare teacher performance, student performance without distillation (baseline) and student performance with distillation. It can be configured similarly to the other benchmarks using a json file.

Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait for the new results from the benchmark before merging the branch into master. Further, could you please add a small test case, maybe with tiny models and a small dataset just to run the code as one of the tests and check for example that the weights of the student model change after training? We can talk about that in a call if you want.

teacher_logits = [batch.pop(key) for key in keys]
logits = self.model.forward(**batch)
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
logit_difference_loss = self.distillation_loss_fn(logits[0], teacher_logits[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use named arguments in method calls so that this line becomes:
logit_difference_loss = self.distillation_loss_fn(student_logits=logits[0], teacher_logitsteacher_logits[0])
As a result, the code won't break when the order of arguments in the implementation of distillation_loss_fn changes and it's easier to read the code. (In haystack's code base we use named arguments in almost every method that has multiple arguments for these reasons.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I changed that.

def _kl_div(self, student_logits, teacher_logits):
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
return self.kl(student_log_probs, teacher_log_probs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use KLDivLoss(reduction="batchmean", log_target=True) here directly instead of self.kl and then there would be no need to define self.kl as it isn't used anywhere else. If you make this change, we can get rid of line 694 self.kl = KLDivLoss(reduction="batchmean", log_target=True) in the elif branch.

Copy link
Contributor Author

@MichelBartels MichelBartels Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it now so it now uses the functional api. I hope this also addresses this issue.


# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
# and calculates a few descriptive statistics of our datasets
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes)
if teacher_model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are only very few new comments in the code. At this line it's definitely worth adding a comment on the if/else statement and its consequences. I would like it if there were more comments in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added this comment and a few others.

@@ -0,0 +1,99 @@
from haystack.nodes import FARMReader
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FARMReader import is a duplicate here and import osis not needed.

Copy link
Contributor Author

@MichelBartels MichelBartels Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now removed the imports.

@MichelBartels
Copy link
Contributor Author

I have added the test.

@deepset-ai deepset-ai deleted a comment from CLAassistant Nov 17, 2021
Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found one unused import. Other than that, I think it's ready to be merged depending on the benchmark results. 👍
Maybe we could check the distilled model's QA predictions in the test case rather than only the change of weights but I guess the tiny dataset won't result in any bigger changes of the distilled model.

@@ -9,6 +9,9 @@
from tqdm import tqdm
from pathlib import Path

from torch.nn import KLDivLoss, MSELoss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KLDivLoss is not used anymore so there is no need for this import.

@MichelBartels
Copy link
Contributor Author

I have now done a few tests and the best I could find so far was an increase of about 3 percentage points in EM when distilling deepset/bert-large-uncased-whole-word-masking-squad2 to prajjwal1/bert-medium compared to just training the student without distillation.
Results of teacher
EM: 0.7890170976164407
F1: 0.8320508749659892
Top n accuracy: 0.9773435525983324
Results of student without distillation (baseline)
EM: 0.655689379263876
F1: 0.6964405123508122
Top n accuracy: 0.9535079592352396
Results of student with distillation (temperature: 5 distillation loss weight: 1
EM: 0.6864313989724585
F1: 0.7276370837908053
Top n accuracy: 0.9530868356775878

@julian-risch
Copy link
Member

Great to see the performance improvements now that the bug with the distillation loss calculation is fixed. The PR is ready to merge from my side.
@tholor could you please briefly have a look at this PR and maybe give some general feedback (not detailed) before we merge it? Thank you!

Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Great job! 👍

Copy link
Member

@tholor tholor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Only left a few comments around documentation. Feel free to merge once those are adressed

@@ -722,3 +727,68 @@ def get_dict_checksum(payload_dict):
"""
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest()
return checksum

class DistillationDataSilo(DataSilo):
def __init__(self, teacher_model: "FARMReader", teacher_batch_size: int, device, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add a short docstring here explaining the need for a special data silo and type hints for all params

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the type hint for device. Do you also want me to add type hints for kwargs (i.e. write them out)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it would be helpful to write them out as it will enable autocomplete in the IDE which can be helpful when users initialize a DistillationDataSilo but are not sure which params are expected.

kwargs["max_processes"] = 1 # fix as long as multithreading is not working with teacher attribute
super().__init__(**kwargs)

def _run_teacher(self, batch, corresponding_chunks, teacher_outputs, tensor_names):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here type hints would be helpful :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the type hints.

@@ -596,3 +602,111 @@ def _all_ranks_have_data(self, has_data: bool, step: Optional[int] = None):
return False
else:
return True

class DistillationTrainer(Trainer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort docstring explaining the purpose of this class and ideally a short code example would be helpful for the docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the docstring.

data_silo: DataSilo,
epochs: int,
n_gpu: int,
device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add type hints :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the type hints. However, for lr_scheduler I needed to import a private class from pytorch. I'm not sure if that's desirable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're referring to something like from torch.optim.lr_scheduler import _LRScheduler ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If mypy doesn't complain about the annotation with _LRScheduler, which I think it doesn't, we can use it. Otherwise, type ignore is the next best solution here in my opinion.
@tholor What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the import looks okay to me - if that's the type we expect here it is what it is ;) .
If mypy complains we can drop it or add the ignore comment

:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
:param distillation_loss_weight: The weight of the distillation loss
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
:param temperature: The temperature for distillation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please describe briefly the effects of the params (e.g. a higher temperature results in ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I have also added this description to the distil_from method.

"""
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model
using a more complex teacher.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a short code example here including a reasonable combination of models (to show usage and clarify that the student can be another pretrained model)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants