Skip to content

Commit

Permalink
Add Knowlege distillation to Tutorial.md (#4)
Browse files Browse the repository at this point in the history
update Tutorial.md with knowledge distillation tasks
  • Loading branch information
ljshou committed Apr 28, 2019
1 parent eaef926 commit b68048f
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
* [Task 4: Regression](#task-4)
* [Task 5: Sentiment Analysis](#task-5)
* [Task 6: Question Paraphrase](#task-6)
* [Task 7: Knowledge Distillation for Model Compression](#task-7)
1. [Compression for Query Binary Classifier](#task-7.1)
2. [Compression for Text Matching Model](#task-7.2)
3. [Compression for Slot Filling Model](#task-7.3)
4. [Compression for MRC Model](#task-7.4)
* [Advanced Usage](#advanced-usage)
* [Extra Feature Support](#extra-feature)
* [Learning Rate Decay](#lr-decay)
Expand Down Expand Up @@ -382,6 +387,61 @@ This task is to determine whether a pair of questions are semantically equivalen
*Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.*

### <span id="task-7">Task 7: Knowledge Distillation for Model Compression</span>

Knowledge Distillation is a common method to compress model in order to improve inference speed. Here are some reference papers:
- [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
- [Model Compression with Multi-Task Knowledge Distillation for Web-scale Question Answering System](https://arxiv.org/abs/1904.09636)

#### <span id="task-7.1">7.1: Compression for Query Binary Classifier</span>
This task is to train a query regression model to learn from a heavy teacher model such as BERT based query classifier model. The training process is to minimized the score difference between the student model output and teacher model output.
- ***Dataset***
*PROJECT_ROOT/dataset/knowledge_distillation/query_binary_classifier*:
* *train.tsv* and *valid.tsv*: two columns, namely **Query** and **Score**.
**Score** is the output score of a heavy teacher model (BERT base finetune model), which is the soft label to be learned by student model as knowledge.
* *test.tsv*: two columns, namely **Query** and **Label**.
**Label** is a binary value which 0 means negtive and 1 means positive.

In the meanwhile, you can also replace with your own dataset for compression task trainning.

- ***Usage***

1. Train student model
```bash
cd PROJECT_ROOT
python train.py --conf_path=model_zoo/nlp_tasks/knowledge_distillation/conf_kdqbc_bilstmattn_cnn.json
```

2. Test student model
```bash
cd PROJECT_ROOT
python test.py --conf_path=models/kdqbc_bilstmattn_cnn/train/conf_kdqbc_bilstmattn_cnn.json --previous_model_path models/kdqbc_bilstmattn_cnn/train/model.nb --predict_output_path models/kdqbc_bilstmattn_cnn/test/test.tsv --test_data_path dataset/knowledge_distillation/query_binary_classifier/test.tsv
```

3. Calculate AUC metric
```bash
cd PROJECT_ROOT
python tools/AUC.py --input_file models/kdqbc_bilstmattn_cnn/test/test.tsv --predict_index 2 --label_index 1
```

*Tips: you can try different models by running different JSON config files.*

- ***Result***

The AUC of student model is very close to that of teacher model and its inference speed is 3.5X~4X times faster.

|Model|AUC|
|-----|---|
|Teacher|0.9112|
|Student-BiLSTMAttn+TextCNN (NeuronBlocks)|0.8941|

*Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.*
#### <span id="task-7.2">7.2: Compression for Text Matching Model (ongoing)</span>
#### <span id="task-7.3">7.3: Compression for Slot Filling Model (ongoing)</span>
#### <span id="task-7.4">7.4: Compression for MRC (ongoing)</span>
## <span id="advanced-usage">Advanced Usage</span>
After building a model, the next goal is to train a model with good performance. It depends on a highly expressive model and tricks of the model training. NeuronBlocks provides some tricks of model training.
Expand Down

0 comments on commit b68048f

Please sign in to comment.