Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Distillation #1758

Merged
merged 40 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7448d1b
initial commit
MichelBartels Nov 12, 2021
23c6542
Add latest docstring and tutorial changes
github-actions[bot] Nov 12, 2021
3f6495e
added comments and fixed bug
MichelBartels Nov 12, 2021
5a1785f
Merge branch 'distillation-refactored' of https://github.com/deepset-…
MichelBartels Nov 12, 2021
96bb479
fixed bugs, added benchmark and added documentation
MichelBartels Nov 15, 2021
f218a61
Add latest docstring and tutorial changes
github-actions[bot] Nov 15, 2021
d993a26
fix type: ignore comment
MichelBartels Nov 15, 2021
03b3aeb
Merge branch 'distillation-refactored' of https://github.com/deepset-…
MichelBartels Nov 15, 2021
ed3c4f2
fix logging in benchmark
MichelBartels Nov 15, 2021
80ea5aa
fixed distillation config
MichelBartels Nov 16, 2021
75cea33
Add latest docstring and tutorial changes
github-actions[bot] Nov 16, 2021
113f92f
added type annotations
MichelBartels Nov 16, 2021
4647562
Merge branch 'distillation-refactored' of https://github.com/deepset-…
MichelBartels Nov 16, 2021
3b2a021
fixed distillation loss calculation
MichelBartels Nov 16, 2021
3124ad6
added type annotations
MichelBartels Nov 17, 2021
2f46477
fixed distillation mse loss
MichelBartels Nov 17, 2021
5b2333b
improved model distillation benchmark config loading
MichelBartels Nov 17, 2021
3bb1b20
added temperature for model distillation
MichelBartels Nov 17, 2021
c045783
removed uncessary imports, added comments, added named parameter calls
MichelBartels Nov 17, 2021
143fafd
Add latest docstring and tutorial changes
github-actions[bot] Nov 17, 2021
c6439ee
added some more comments
MichelBartels Nov 17, 2021
1ae1937
Merge branch 'distillation-refactored' of https://github.com/deepset-…
MichelBartels Nov 17, 2021
86bc797
added distillation test
MichelBartels Nov 17, 2021
d3a3c16
fixed distillation test
MichelBartels Nov 18, 2021
98727b1
removed unnecessary import
MichelBartels Nov 18, 2021
e68d7a2
fix softmax dimension
MichelBartels Nov 23, 2021
42b8160
add grid search
MichelBartels Nov 24, 2021
acb755a
fix merge
MichelBartels Nov 24, 2021
57964dc
improved model distillation benchmark config
MichelBartels Nov 25, 2021
6bb3fcb
fixed model distillation hyperparameter search
MichelBartels Nov 25, 2021
6728aff
added doc strings and type hints for model distillation
MichelBartels Nov 26, 2021
212450b
Add latest docstring and tutorial changes
github-actions[bot] Nov 26, 2021
0167426
fixed type hints
MichelBartels Nov 26, 2021
a7f72d7
Merge branch 'distillation-refactored' of https://github.com/deepset-…
MichelBartels Nov 26, 2021
645c9b0
fixed type hints
MichelBartels Nov 26, 2021
b1fbaf3
fixed type hints
MichelBartels Nov 26, 2021
7f61f73
wrote out params instead of kwargs in DistillationDataSilo initializer
MichelBartels Nov 26, 2021
d1643b0
fixed type hints
MichelBartels Nov 26, 2021
786b1ae
fixed typo
MichelBartels Nov 26, 2021
3c90210
fixed typo
MichelBartels Nov 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ and that FARM includes no_answer in the sorted list of predictions.
#### train

```python
| train(data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, batch_size: int = 10, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3)
| train(data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, batch_size: int = 10, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"))
```

Fine-tune a model on a QA dataset. Options:
Expand Down Expand Up @@ -150,6 +150,76 @@ If any checkpoints are stored, a subsequent run of train() will resume training
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: save a train checkpoint after this many steps of training.
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset
- `cache_path`: Path to cache the preprocessed dataset

**Returns**:

None

<a name="farm.FARMReader.distil_from"></a>
#### distil\_from

```python
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0)
```

Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model
using a more complex teacher.

**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")

student.distil_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```

Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps.
If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.

**Arguments**:

- `teacher_model`: Model whose logits will be used to improve accuracy
- `data_dir`: Path to directory containing your training data in SQuAD style
- `train_filename`: Filename of training data
- `dev_filename`: Filename of dev / eval data
- `test_filename`: Filename of test data
- `dev_split`: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here
that gets split off from training data for eval.
- `use_gpu`: Whether to use GPU (if available)
- `student_batch_size`: Number of samples the student model receives in one batch for training
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
- `n_epochs`: Number of iterations on the whole training data set
- `learning_rate`: Learning rate of the optimizer
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
- `warmup_proportion`: Proportion of training steps until maximum learning rate is reached.
Until that point LR is increasing linearly. After that it's decreasing again linearly.
Options for different schedules are available in FARM.
- `evaluate_every`: Evaluate the model every X steps on the hold-out eval dataset
- `save_dir`: Path to store the final model
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: save a train checkpoint after this many steps of training.
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset and teacher logits
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.

**Returns**:

Expand Down
80 changes: 79 additions & 1 deletion haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import random
import math
from contextlib import ExitStack
from functools import partial
from itertools import groupby
Expand All @@ -16,6 +17,9 @@
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from haystack.nodes import FARMReader
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.logger import MLFlowLogger as MlLogger
Expand Down Expand Up @@ -153,7 +157,8 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis
f"dictionaries to pytorch datasets."
)

results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, num_dicts)) # type: ignore
# temporary fix
results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, 1)) # type: ignore

datasets = []
problematic_ids_all = set()
Expand Down Expand Up @@ -722,3 +727,76 @@ def get_dict_checksum(payload_dict):
"""
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest()
return checksum

class DistillationDataSilo(DataSilo):
"""
This data silo does a forward pass on the full data set on a teacher model for model distillation.
As its done in preprocessing, it does not need to be repeated in each epoch and can be cached.
"""
def __init__(self, teacher_model: "FARMReader", teacher_batch_size: int, device: str, processor: Processor,
batch_size: int, eval_batch_size: Optional[int] = None, distributed: bool = False,
automatic_loading: bool = True, max_processes: int = 128, caching: bool = False, cache_path: Path = Path("cache/data_silo")):
self.teacher = teacher_model
self.teacher_batch_size = teacher_batch_size
self.device = device
max_processes = 1 # fix as long as multithreading is not working with teacher attribute
super().__init__(max_processes=max_processes, processor=processor, batch_size=batch_size, eval_batch_size=eval_batch_size,
distributed=distributed, automatic_loading=automatic_loading, caching=caching, cache_path=cache_path)

def _run_teacher(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str]):
with torch.no_grad():
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
batch_dict = {key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list)} # create input dict
y = self.teacher.inferencer.model(**batch_dict)
y = [y.cpu() for y in y]

# grouping by chunk
for i, data in zip(corresponding_chunks, zip(*y)): # transpose back
teacher_outputs[i].append(data)
return

def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
concat_datasets, tensor_names = super()._get_dataset(filename, dicts)

batch = []
corresponding_chunks = [] # to be able to associate elements of batches with chunks (elements could be from multiple chunks)

teacher_outputs: List[List[Tuple[torch.Tensor, ...]]] = [] # list of teacher outputs group in list by chunk

# creating batches from chunks
for i, dataset in enumerate(tqdm(concat_datasets.datasets, desc="Doing forward pass on teacher model")):
teacher_outputs.append([])
for x in zip(*dataset.tensors): # loop through chunks
batch.append(x)
corresponding_chunks.append(i)
if len(batch) == self.teacher_batch_size:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
batch = []
corresponding_chunks = []
if batch:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names)

# appending teacher outputs to original dataset
for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs):
dataset.tensors += tuple(torch.stack(tensors) for tensors in zip(*teacher_output))
tensor_names.extend(["teacher_output_" + str(i) for i, _ in enumerate(zip(*teacher_output))])
concat_datasets = ConcatDataset(concat_datasets.datasets) # making sure metrics are updated
return concat_datasets, tensor_names

def _get_checksum(self):
"""
Get checksum based on a dict to ensure validity of cached DataSilo
"""
# keys in the dict identifies uniqueness for a given DataSilo.
payload_dict = {
"train_filename": str(Path(self.processor.train_filename).absolute()),
"data_dir": str(self.processor.data_dir.absolute()),
"max_seq_len": self.processor.max_seq_len,
"dev_split": self.processor.dev_split,
"tasks": self.processor.tasks,
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"]
}
checksum = get_dict_checksum(payload_dict)
return checksum
2 changes: 1 addition & 1 deletion haystack/modeling/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from haystack.modeling.training.base import Trainer
from haystack.modeling.training.base import Trainer, DistillationTrainer
Loading