From 36a9f6f04622df277ce91991016210c1479935f9 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 17 Sep 2021 16:07:22 +0800 Subject: [PATCH] update distill (#892) * polish distill --- demo/dygraph/dist/bert/README.md | 144 ++++++ demo/dygraph/dist/bert/distill_stage1.yaml | 20 + demo/dygraph/dist/bert/distill_stage2.yaml | 9 + demo/dygraph/dist/bert/run.sh | 20 + demo/dygraph/dist/bert/task_distill.py | 460 ++++++++++++++++++ paddleslim/dygraph/dist/__init__.py | 2 + paddleslim/dygraph/dist/distill.py | 319 ++++++------ paddleslim/dygraph/dist/distill_helpers.py | 41 ++ paddleslim/dygraph/dist/losses/__init__.py | 41 +- paddleslim/dygraph/dist/losses/basic_loss.py | 60 ++- .../dygraph/dist/losses/distillation_loss.py | 218 ++------- tests/dygraph/test_distill.py | 83 ++-- tests/dygraph/test_distillation_loss.py | 218 +++------ 13 files changed, 1057 insertions(+), 578 deletions(-) create mode 100644 demo/dygraph/dist/bert/README.md create mode 100644 demo/dygraph/dist/bert/distill_stage1.yaml create mode 100644 demo/dygraph/dist/bert/distill_stage2.yaml create mode 100644 demo/dygraph/dist/bert/run.sh create mode 100644 demo/dygraph/dist/bert/task_distill.py create mode 100644 paddleslim/dygraph/dist/distill_helpers.py diff --git a/demo/dygraph/dist/bert/README.md b/demo/dygraph/dist/bert/README.md new file mode 100644 index 0000000000000..d8f72b3792d5e --- /dev/null +++ b/demo/dygraph/dist/bert/README.md @@ -0,0 +1,144 @@ +# TinyBERT: Distilling BERT for Natural Language Understanding +以下是本例的简要目录结构及说明: +``` +. +├── task_distill.py # 在特定任务上下的蒸馏脚本 +└── README.md # 文档,本文件 +``` +## 简介 +本目录下的实验主要参考论文[《TinyBERT: Distilling BERT for Natural Language Understanding》](https://arxiv.org/abs/1909.10351)实现。 +TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型`tinybert-6l-768d-v2`进行初始化。 + +

+
+TinyBERT蒸馏流程图 +

+ + +在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。 +知识的蒸馏通常是通过让学生模型学习相关的蒸馏相损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。 + +由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。 + +实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。 + +需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2`、`tinybert-4l-312d-v2`这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d`、`tinybert-4l-312d`、`tinybert-6l-768d-zh`、`tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。 + +### 安装PaddleNLP和Paddle +本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。 + +```shell +pip install paddlenlp +pip install paddlepaddle_gpu +``` + +## 数据、预训练模型介绍及获取 + +本实验使用GLUE中数据集中的训练集作为训练语料,用数据集中的验证集评估模型的效果。 + +运行本目录下的实验,数据集会被自动下载到`paddlenlp.utils.env.DATA_HOME` 路径下,例如在linux系统下,对于GLUE中的QQP数据集,默认存储路径是`~/.paddlenlp/datasets/Glue/QQP`。 + +对于BERT的fine-tuning任务,本实验中使用了预训练模型`bert-base-uncased`。同样,这几个模型在训练时会被自动下载到`paddlenlp.utils.env.MODEL_HOME`路径下。例如,对于`bert-base-uncased`模型,在linux系统下,会被下载到`~/.paddlenlp/models/bert-base-uncased`下。 + +## 蒸馏实验过程 + +### 对BERT Fine-tuning得到教师模型 +首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考[Fine-tuning教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/bert/README.md) + +训练完成之后,可将训练效果最好的模型保存在本项目下的`pretrained_models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json`及`vocab.txt`这几个文件。 + +### 对TinyBERT在特定任务下蒸馏 + +先蒸馏中间层: + +```shell +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 +export TEACHER_DIR=./pretrained_models/SST-2/best_model_610 + +python task_distill.py \ + --model_type tinybert \ + --student_model_name_or_path tinybert-6l-768d-v2 \ + --task_name $TASK_NAME \ + --intermediate_distill \ + --max_seq_length 64 \ + --batch_size 32 \ + --T 1 \ + --teacher_model_type bert \ + --teacher_path $TEACHER_DIR \ + --learning_rate 5e-5 \ + --num_train_epochs 20 \ + --logging_steps 10 \ + --save_steps 10 \ + --output_dir ./tmp/$TASK_NAME/ \ + --distill_config ./distill_stage1.yaml \ + --device gpu + +``` + +其中参数释义如下: + +- `model_type` 学生模型类型,默认且目前仅支持tinybert。 +- `student_model_name_or_path` 中间层蒸馏后,学生模型存放的目录 +- `distill_config` 蒸馏配置文件 +- `max_seq_length` 表示最大句子长度,超过该长度将被截断。默认:128 +- `T` softmax的温度,用于对softmax做平滑,在训练中起到放大负标签效果的作用。默认:1 +- `teacher_model_type` 教师模型的类型,默认且目前仅支持bert +- `teacher_path` 教师Fine-tuned模型的目录 +- `output_dir` 学生模型存放的目录 +- `device` 表示运行该程序的设备,默认是gpu + +然后对预测层进行蒸馏: + +```shell + +export TEACHER_DIR=../pretrained_models/SST-2/best_model_610 + +python task_distill.py \ + --model_type tinybert \ + --student_model_name_or_path tmp/TASK_NAME best_inter_model \ + --task_name $TASK_NAME \ + --max_seq_length 64 \ + --batch_size 32 \ + --T 1 \ + --teacher_model_type bert \ + --teacher_path $TEACHER_DIR \ + --learning_rate 3e-5 \ + --num_train_epochs 3 \ + --logging_steps 10 \ + --save_steps 10 \ + --output_dir ./tmp/$TASK_NAME/ \ + --distill_config ./distill_stage2.yaml \ + --device gpu + +``` +其中参数释义如下: + +所有参数说明同上。 + +### 实验中使用的超参数 + +| | SST-2 | QQP | MRPC | CoLA | RTE | MNLI | QNLI | +| -------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | +| batch_size | 32 | 32 | 32 | 32 | 32 | 32 | 32 | +| max_seq_length | 64 | 128 | 128 | 64 | 128 | 128 | 128 | +| max_epochs_of_intermediate_layer | 20 | 10 | 20 | 50 | 20 | 10 | 10 | +| max_epochs_of_prediction_layer | 3 | 3 | 3 | 3 | 3 | 3 | 3 | +| learning_rate(inter/pred) | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | + + + +## 蒸馏实验结果 + +本文档的实验基于TinyBERT的6层、hidden_size为768的通用蒸馏得到的模型,用未使用数据增强的原始数据集训练,并基于验证集进行评价。得到以下实验结果: + + +| | SST-2 | QQP(acc/f1) | MRPC(acc/f1) | CoLA | RTE | MNLI-m | MNLI-mm | QNLI | +| ----------------- | ----- | ----------- | ------------ | ----- | ----- | ------ | ------- | ----- | +| BERT-base | 93.00 | 90.58/87.35 | 88.23/91.67 | 59.56 | 73.65 | 84.42 | 84.83 | 91.78 | +| TinyBERT(6l-768d) | 93.00 | 91.13/88.20 | 88.48/91.91 | 52.64 | 72.94 | 84.57 | 84.63 | 91.36 | + + +## 参考文献 + +Jiao X, Yin Y, Shang L, et al. [TinyBERT: Distilling BERT for Natural Language Understanding](https://arxiv.org/abs/1909.10351)[J]. arXiv preprint arXiv:1909.10351v5, 2020. diff --git a/demo/dygraph/dist/bert/distill_stage1.yaml b/demo/dygraph/dist/bert/distill_stage1.yaml new file mode 100644 index 0000000000000..1899721a50234 --- /dev/null +++ b/demo/dygraph/dist/bert/distill_stage1.yaml @@ -0,0 +1,20 @@ +- DistillConfig: + loss_function: MSELoss + model_name_pairs: + - - student_0 + - teacher_0 + weight: 1.0 + - layers: + - layers_name: ['tinybert.embeddings', 'bert.embeddings'] + - layers_name: ['tinybert.encoder.layers.0', 'bert.encoder.layers.1'] + - layers_name: ['tinybert.encoder.layers.1', 'bert.encoder.layers.3'] + - layers_name: ['tinybert.encoder.layers.2', 'bert.encoder.layers.5'] + - layers_name: ['tinybert.encoder.layers.3', 'bert.encoder.layers.7'] + - layers_name: ['tinybert.encoder.layers.4', 'bert.encoder.layers.9'] + - layers_name: ['tinybert.encoder.layers.5', 'bert.encoder.layers.11'] + - layers_name: ['tinybert.encoder.layers.0.self_attn', 'bert.encoder.layers.1.self_attn'] + - layers_name: ['tinybert.encoder.layers.1.self_attn', 'bert.encoder.layers.3.self_attn'] + - layers_name: ['tinybert.encoder.layers.2.self_attn', 'bert.encoder.layers.5.self_attn'] + - layers_name: ['tinybert.encoder.layers.3.self_attn', 'bert.encoder.layers.7.self_attn'] + - layers_name: ['tinybert.encoder.layers.4.self_attn', 'bert.encoder.layers.9.self_attn'] + - layers_name: ['tinybert.encoder.layers.5.self_attn', 'bert.encoder.layers.11.self_attn'] diff --git a/demo/dygraph/dist/bert/distill_stage2.yaml b/demo/dygraph/dist/bert/distill_stage2.yaml new file mode 100644 index 0000000000000..6d448a78f05a9 --- /dev/null +++ b/demo/dygraph/dist/bert/distill_stage2.yaml @@ -0,0 +1,9 @@ +- DistillConfig: + loss_function: CELoss + model_name_pairs: + - - student_0 + - teacher_0 + weight: 1.0 + - layers: + - layers_name: ['classifier', 'classifier'] + temperature: 1.0 diff --git a/demo/dygraph/dist/bert/run.sh b/demo/dygraph/dist/bert/run.sh new file mode 100644 index 0000000000000..58e07166765fd --- /dev/null +++ b/demo/dygraph/dist/bert/run.sh @@ -0,0 +1,20 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 +export TEACHER_DIR=/root/work/Distill_PaddleSlim/PaddleNLP/examples/model_compression/tinybert/best_model_610 + +python3.7 task_distill.py \ + --model_type tinybert \ + --student_model_name_or_path tinybert-6l-768d-v2 \ + --task_name $TASK_NAME \ + --intermediate_distill \ + --max_seq_length 64 \ + --batch_size 32 \ + --T 1 \ + --teacher_model_type bert \ + --teacher_path $TEACHER_DIR \ + --learning_rate 5e-5 \ + --num_train_epochs 20 \ + --logging_steps 10 \ + --save_steps 10 \ + --output_dir ./tmp/$TASK_NAME/ \ + --device gpu diff --git a/demo/dygraph/dist/bert/task_distill.py b/demo/dygraph/dist/bert/task_distill.py new file mode 100644 index 0000000000000..b76d453edbfe6 --- /dev/null +++ b/demo/dygraph/dist/bert/task_distill.py @@ -0,0 +1,460 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import random +import time +import math +from functools import partial + +import numpy as np +import paddle +from paddle.io import DataLoader +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.metric import Accuracy + +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.data.sampler import SamplerHelper +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +import paddlenlp.transformers as T +from paddleslim import Distill + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +METRIC_CLASSES = { + "cola": Mcc, + "sst-2": Accuracy, + "mrpc": AccuracyAndF1, + "sts-b": PearsonAndSpearman, + "qqp": AccuracyAndF1, + "mnli": Accuracy, + "qnli": Accuracy, + "rte": Accuracy, +} + +MODEL_CLASSES = { + "bert": (T.BertForSequenceClassification, T.BertTokenizer), + "tinybert": (T.TinyBertForSequenceClassification, T.TinyBertTokenizer), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default="tinybert", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--teacher_model_type", + default="bert", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--student_model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--distill_config", + default=None, + type=str, + help="distill config file path") + parser.add_argument( + "--teacher_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model.") + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--glue_dir", + default="/root/.paddlenlp/datasets/Glue/", + type=str, + required=False, + help="The Glue directory.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=1e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--T", + default=1, + type=int, + help="Temperature for softmax", ) + parser.add_argument( + "--use_aug", + action="store_true", + help="Whether to use augmentation data to train.", ) + parser.add_argument( + "--intermediate_distill", + action="store_true", + help="Whether distilling intermediate layers. If False, it means prediction layer distillation.", + ) + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion" + ) + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + help="The device to select to train the model, is must be cpu/gpu/xpu.") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, metric, data_loader): + model.eval() + metric.reset() + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + if isinstance(metric, AccuracyAndF1): + print( + "acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, " % ( + res[0], + res[1], + res[2], + res[3], + res[4], ), + end='') + elif isinstance(metric, Mcc): + print("mcc: %s, " % (res[0]), end='') + elif isinstance(metric, PearsonAndSpearman): + print( + "pearson: %s, spearman: %s, pearson and spearman: %s, " % + (res[0], res[1], res[2]), + end='') + else: + print("acc: %s, " % (res), end='') + model.train() + return res[0] if isinstance(metric, (AccuracyAndF1, Mcc, + PearsonAndSpearman)) else res + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['labels'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if (int(is_test) + len(example)) == 2: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + else: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + if args.use_aug: + aug_data_file = os.path.join( + os.path.join(args.glue_dir, args.task_name), "train_aug.tsv"), + train_ds = load_dataset( + 'glue', args.task_name, data_files=aug_data_file) + else: + train_ds = load_dataset('glue', args.task_name, splits='train') + tokenizer = tokenizer_class.from_pretrained(args.student_model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=args.max_seq_length) + train_ds = train_ds.map(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + if args.task_name == "mnli": + dev_ds_matched, dev_ds_mismatched = load_dataset( + 'glue', args.task_name, splits=["dev_matched", "dev_mismatched"]) + + dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True) + dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True) + dev_batch_sampler_matched = paddle.io.BatchSampler( + dev_ds_matched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_matched = DataLoader( + dataset=dev_ds_matched, + batch_sampler=dev_batch_sampler_matched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + dev_batch_sampler_mismatched = paddle.io.BatchSampler( + dev_ds_mismatched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_mismatched = DataLoader( + dataset=dev_ds_mismatched, + batch_sampler=dev_batch_sampler_mismatched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + else: + dev_ds = load_dataset('glue', args.task_name, splits='dev') + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list) + student = model_class.from_pretrained( + args.student_model_name_or_path, num_classes=num_classes) + teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type] + teacher = teacher_model_class.from_pretrained( + args.teacher_path, num_classes=num_classes) + teacher.eval() + + if paddle.distributed.get_world_size() > 1: + student = paddle.DataParallel(student, find_unused_parameters=True) + teacher = paddle.DataParallel(teacher, find_unused_parameters=True) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate, + num_training_steps, warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in student.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=student.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params) + + metric = metric_class() + + pad_token_id = 0 + global_step = 0 + tic_train = time.time() + best_res = 0.0 + + assert os.path.exists( + args.distill_config), "distill file {} not exist.".format( + args.distill_config) + distill_model = Distill( + args.distill_config, student_models=[student], + teacher_models=[teacher]) + + for epoch in range(num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + loss, _, _ = distill_model(input_ids, segment_ids) + + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + if global_step % args.logging_steps == 0: + print( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, optimizer.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + if args.task_name == "mnli": + res = evaluate(student, metric, dev_data_loader_matched) + evaluate(student, metric, dev_data_loader_mismatched) + print("eval done total : %s s" % (time.time() - tic_eval)) + else: + res = evaluate(student, metric, dev_data_loader) + print("eval done total : %s s" % (time.time() - tic_eval)) + if (best_res < res and global_step < num_training_steps or + global_step == num_training_steps + ) and paddle.distributed.get_rank() == 0: + if global_step < num_training_steps: + output_dir = os.path.join(args.output_dir, + "distill_model_%d.pdparams" % + (global_step)) + else: + output_dir = os.path.join( + args.output_dir, "distill_model_final.pdparams") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = student._layers if isinstance( + student, paddle.DataParallel) else student + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + best_res = res + + if global_step >= num_training_steps: + return + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/paddleslim/dygraph/dist/__init__.py b/paddleslim/dygraph/dist/__init__.py index e98e53df4d1ed..908380630023a 100644 --- a/paddleslim/dygraph/dist/__init__.py +++ b/paddleslim/dygraph/dist/__init__.py @@ -14,7 +14,9 @@ from . import distill from .distill import * +from .distill_helpers import * __all__ = [] __all__ += distill.__all__ +__all__ += distill_helpers.__all__ diff --git a/paddleslim/dygraph/dist/distill.py b/paddleslim/dygraph/dist/distill.py index 029230a785138..d067cd63a7cc6 100644 --- a/paddleslim/dygraph/dist/distill.py +++ b/paddleslim/dygraph/dist/distill.py @@ -17,207 +17,200 @@ from collections import namedtuple import paddle.nn as nn from . import losses +from .losses.basic_loss import BASIC_LOSS +from .distill_helpers import yaml2config -__all__ = ['Distill', 'AdaptorBase'] +__all__ = ['Distill'] class LayerConfig: + """ The key of config can be set""" + def __init__(self, - s_feature_idx, - t_feature_idx, - feature_type, + model_name_pairs, + layers_name, loss_function, weight=1.0, - align=False, - align_shape=None): - self.s_feature_idx = s_feature_idx - self.t_feature_idx = t_feature_idx - self.feature_type = feature_type - if loss_function in ['l1', 'l2', 'smooth_l1']: - self.loss_function = 'DistillationDistanceLoss' - elif loss_function in ['dml']: - self.loss_function = 'DistillationDMLLoss' - elif loss_function in ['rkl']: - self.loss_function = 'DistillationRKDLoss' - elif hasattr(losses, loss_function): - self.loss_function = loss_function - else: - raise NotImplementedError("loss function is not support!!!") + temperature=1.0, + align_params=None, + **loss_params): + self.model_name_pairs = model_name_pairs + self.layers_name = layers_name + if loss_function not in BASIC_LOSS.module_dict: + raise NotImplementedError("loss function {} is not support. " + "Support loss including {}".format( + loss_function, + BASIC_LOSS.module_dict.keys())) + self.loss_function = loss_function self.weight = weight - self.align = align - self.align_shape = align_shape - - -class AdaptorBase: - def __init__(self, model): - self.model = model - self.add_tensor = False - - def _get_activation(self, outs, name): + self.temperature = temperature + self.align_params = align_params + for k, v in loss_params.items(): + setattr(self, k, v) + + +def _add_hooks(model, outs, hook_layers_name): + """ + Get output by layer name. + models(nn.Layer): model need to be add hook. + outs(dict): save the middle outputs of model according to the name. + hook_layers_name(list): name of middle layers. + """ + + def _get_activation(outs, name): + ### TODO: need to support get input tensor + #outs[name] = {} def get_output_hook(layer, input, output): + #outs[name]["output"] = output + #outs[name]["input"] = input outs[name] = output return get_output_hook - def _add_distill_hook(self, outs, mapping_layers_name, layers_type): - """ - Get output by layer name. - outs(dict): save the middle outputs of model according to the name. - mapping_layers(list): name of middle layers. - layers_type(list): type of the middle layers to calculate distill loss. - """ - - ### TODO: support DP model - for idx, (n, m) in enumerate(self.model.named_sublayers()): - if n in mapping_layers_name: - midx = mapping_layers_name.index(n) - m.register_forward_post_hook( - self._get_activation(outs, layers_type[midx])) - - def mapping_layers(self): - raise NotImplementedError("function mapping_layers is not implemented") + ### TODO: support DP model + for idx, (n, m) in enumerate(model.named_sublayers()): + if n in hook_layers_name: + m.register_forward_post_hook(_get_activation(outs, n)) class Distill(nn.Layer): - ### TODO: support list of student model and teacher model - def __init__(self, distill_configs, student_models, teacher_models, - adaptors_S, adaptors_T): - super(Distill, self).__init__() - assert student_models.training, "The student model should be eval mode." + """ + Distill API. + distill_configs(list(dict) | path): the list of distill config. + student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode. + teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode. + return_model_outputs(bool): whether to return model output. Default: True. + """ - self._distill_configs = distill_configs + def __init__(self, + distill_configs, + student_models, + teacher_models, + return_model_outputs=True): + super(Distill, self).__init__() + if isinstance(student_models, nn.Layer): + student_models = [student_models] + if isinstance(teacher_models, nn.Layer): + teacher_models = [teacher_models] + for student_model in student_models: + assert student_model.training, "The student model should not be eval mode." + for teacher_model in teacher_models: + assert teacher_model.training is False, "The teacher model should be eval mode." + + if isinstance(distill_configs, list): + self._distill_configs = distill_configs + elif os.path.exists(distill_configs): + if distill_configs.endswith(".yaml"): + self._distill_configs = yaml2config(distill_configs) + else: + raise NotImplementedError("distill config file type error!") + else: + raise NotImplementedError("distill config error!") self._student_models = student_models self._teacher_models = teacher_models - self._adaptors_S = adaptors_S(self._student_models) - self._adaptors_T = adaptors_T(self._teacher_models) + self._return_model_outputs = return_model_outputs - self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs() - - self.configs = [] + self._loss_config_list = [] for c in self._distill_configs: - self.configs.append(LayerConfig(**c).__dict__) + self._transpose_config(c) - self.distill_idx = self._get_distill_idx() - - self._loss_config_list = [] - for c in self.configs: - loss_config = {} - loss_config[str(c['loss_function'])] = {} - loss_config[str(c['loss_function'])]['weight'] = c['weight'] - loss_config[str(c['loss_function'])]['key'] = c[ - 'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[ - 't_feature_idx']) - ### TODO: support list of student models and teacher_models - loss_config[str(c['loss_function'])][ - 'model_name_pairs'] = [['student', 'teacher']] - self._loss_config_list.append(loss_config) + self._hook_layers = self._extract_hook_position() # use self._loss_config_list to create all loss object self.distill_loss = losses.CombinedLoss(self._loss_config_list) + self._output_tensor_dict = self._prepare_outputs() + + def parameters(self): + params = [] + for s_model in self._student_models: + params.extend(s_model.parameters()) + return params + + def _extract_hook_position(self): + """ extrat hook position according to config""" + model_hook_layers = {} + for config in self._loss_config_list: + model_name_pairs = config['model_name_pairs'] + layers_name = config['layers_name'] + for model_name_pair in model_name_pairs: + for idx, model_name in enumerate(model_name_pair): + if model_name not in model_hook_layers: + model_hook_layers[model_name] = [layers_name[idx]] + else: + model_hook_layers[model_name].append(layers_name[idx]) + for model_name, hook_layers in model_hook_layers.items(): + model_hook_layers[model_name] = list(set(hook_layers)) + return model_hook_layers + + def _transpose_config(self, config): + """ Transpose config to loss needed """ + global_config = {} + if 'model_name_pairs' not in config: + global_config['model_name_pairs'] = [['student_0', 'teacher_0']] + else: + if isinstance(config['model_name_pairs'][0], str): + config['model_name_pairs'] = [config['model_name_pairs']] + global_config['model_name_pairs'] = config['model_name_pairs'] + config.pop('model_name_pairs') + + for key in config.keys(): + if key != 'layers': + global_config[key] = config[key] + + for per_layer_config in config['layers']: + per_layer_config.update(global_config) + self._loss_config_list.append( + LayerConfig(**per_layer_config).__dict__) + def _prepare_outputs(self): """ Add hook to get the output tensor of target layer. - Returns: - stu_outs_dict(dict): the name and tensor for the student model, - such as {'hidden_0': tensor_0, ..} - tea_outs_dict(dict): the name and tensor for the teather model, - such as {'hidden_0': tensor_0, ..} """ - stu_outs_dict = collections.OrderedDict() - tea_outs_dict = collections.OrderedDict() - stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict) - tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict) - return stu_outs_dict, tea_outs_dict - - def _prepare_hook(self, adaptors, outs_dict): + outputs_tensor = {} + for idx, m in enumerate(self._student_models): + hook_layers = self._hook_layers['student_{}'.format(idx)] + stu_outs = collections.OrderedDict() + outputs_tensor['student_{}'.format(idx)] = self._prepare_hook( + m, hook_layers, stu_outs) + for idx, m in enumerate(self._teacher_models): + hook_layers = self._hook_layers['teacher_{}'.format(idx)] + tea_outs = collections.OrderedDict() + outputs_tensor['teacher_{}'.format(idx)] = self._prepare_hook( + m, hook_layers, tea_outs) + return outputs_tensor + + def _prepare_hook(self, model, hook_layers, outs_dict): """ Add hook. """ - mapping_layers = adaptors.mapping_layers() - for layer_type, layer in mapping_layers.items(): + for layer in hook_layers: if isinstance(layer, str): - adaptors._add_distill_hook(outs_dict, [layer], [layer_type]) + _add_hooks(model, outs_dict, layer) return outs_dict - def _get_distill_idx(self): - """ - For each feature_type, get the feature index in the student and teacher models. - Returns: - distill_idx(dict): the feature index for each feature_type, - such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]} - """ - distill_idx = {} - for config in self._distill_configs: - if config['feature_type'] not in distill_idx: - distill_idx[config['feature_type']] = [[ - int(config['s_feature_idx']), int(config['t_feature_idx']) - ]] - else: - distill_idx[config['feature_type']].append([ - int(config['s_feature_idx']), int(config['t_feature_idx']) - ]) - return distill_idx - def forward(self, *inputs, **kwargs): - stu_batch_outs = self._student_models.forward(*inputs, **kwargs) - tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs) - if not self._teacher_models.training: - tea_batch_outs = [i.detach() for i in tea_batch_outs] - - # get all target tensor - if self._adaptors_S.add_tensor == False: - self._adaptors_S.add_tensor = True - if self._adaptors_T.add_tensor == False: - self._adaptors_T.add_tensor = True - self.stu_outs_dict = self._get_model_intermediate_output( - self._adaptors_S, self.stu_outs_dict) - self.tea_outs_dict = self._get_model_intermediate_output( - self._adaptors_T, self.tea_outs_dict) - - distill_inputs = self._process_outputs() + students_batch_outs = [] + teachers_batch_outs = [] + for idx, student_model in enumerate(self._student_models): + stu_batch_outs = student_model.forward(*inputs, **kwargs) + students_batch_outs.append(stu_batch_outs) + for idx, teacher_model in enumerate(self._teacher_models): + tea_batch_outs = teacher_model.forward(*inputs, **kwargs) + if not teacher_model.training: + tea_batch_outs = [i.detach() for i in tea_batch_outs] + teachers_batch_outs.extend(tea_batch_outs) + + if len(self._student_models) == 1: + students_batch_outs = students_batch_outs[0] + if len(self._teacher_models) == 1: + teachers_batch_outs = teachers_batch_outs[0] ### batch is None just for now - distill_outputs = self.distill_loss(distill_inputs, None) + distill_outputs = self.distill_loss(self._output_tensor_dict, None) distill_loss = distill_outputs['loss'] - return stu_batch_outs, tea_batch_outs, distill_loss - - def _get_model_intermediate_output(self, adaptors, outs_dict): - """ - Use the adaptor get the target tensor. - Returns: - outs_dict(dict): the name and tensor for the target model, - such as {'hidden_0': tensor_0, ..} - """ - mapping_layers = adaptors.mapping_layers() - for layer_type, layer in mapping_layers.items(): - if isinstance(layer, str): - continue - outs_dict[layer_type] = layer - return outs_dict - - def _process_outputs(self): - """ - Process the target tensor to adapt for loss. - """ - ### TODO: support list of student models and teacher_models - final_distill_dict = { - "student": collections.OrderedDict(), - "teacher": collections.OrderedDict() - } - - for feature_type, dist_idx in self.distill_idx.items(): - for idx, idx_list in enumerate(dist_idx): - sidx, tidx = idx_list[0], idx_list[1] - stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)] - tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)] - if not self._student_models.training: - stu_out = stu_out.detach() - if not self._teacher_models.training: - tea_out = tea_out.detach() - - name_str = feature_type + '_' + str(sidx) + '_' + str(tidx) - final_distill_dict['student'][name_str] = stu_out - final_distill_dict['teacher'][name_str] = tea_out - return final_distill_dict + if self._return_model_outputs: + return distill_loss, students_batch_outs, teachers_batch_outs + else: + return distill_loss diff --git a/paddleslim/dygraph/dist/distill_helpers.py b/paddleslim/dygraph/dist/distill_helpers.py new file mode 100644 index 0000000000000..07c3532e754a4 --- /dev/null +++ b/paddleslim/dygraph/dist/distill_helpers.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml + +__all__ = ['config2yaml'] + + +def yaml2config(yaml_path): + """ + convert yaml to dict config. + """ + final_configs = [] + f = open(yaml_path, 'r') + origin_configs = yaml.load(f, Loader=yaml.FullLoader) + f.close() + for configs in origin_configs: + configs = configs['DistillConfig'] + final_configs.extend(configs) + return final_configs + + +def config2yaml(configs, yaml_path): + """ + convert dict config to yaml. + """ + final_yaml = dict() + final_yaml['DistillConfig'] = configs + f = open(yaml_path, "w") + yaml.dump([final_yaml], f) + f.close() diff --git a/paddleslim/dygraph/dist/losses/__init__.py b/paddleslim/dygraph/dist/losses/__init__.py index ac2f1f365ad3a..d583a275b206f 100644 --- a/paddleslim/dygraph/dist/losses/__init__.py +++ b/paddleslim/dygraph/dist/losses/__init__.py @@ -19,18 +19,7 @@ from . import basic_loss from . import distillation_loss -from .basic_loss import L1Loss -from .basic_loss import L2Loss -from .basic_loss import SmoothL1Loss -from .basic_loss import CELoss -from .basic_loss import DMLLoss -from .basic_loss import DistanceLoss -from .basic_loss import RKdAngle, RkdDistance - -from .distillation_loss import DistillationDistanceLoss -from .distillation_loss import DistillationDMLLoss -from .distillation_loss import DistillationRKDLoss -from .distillation_loss import SegPairWiseLoss, SegChannelwiseLoss +from .distillation_loss import DistillationLoss class CombinedLoss(nn.Layer): @@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer): loss_config_list: a config list used to build loss function. A demo is as follows, which is used to calculate dml loss between Student output and Teacher output. Parameter weight is needed for the loss weight. - - DistillationDMLLoss: + { loss_function: DMLLoss weight: 1.0 act: "softmax" - model_name_pairs: - - ["Student", "Teacher"] - Another example is {'DistillationDistanceLoss': {'weight': 1.0, - 'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]} + model_name_pairs:["student_0", "teacher_0"]} + Another example is {loss_function: "MSELoss", 'weight': 1.0, + 'layers_name': ['conv0', 'conv0'], 'model_name_pairs': [['student', 'teacher']]} """ def __init__(self, loss_config_list=None): @@ -56,18 +44,14 @@ def __init__(self, loss_config_list=None): self.loss_weight = [] assert isinstance(loss_config_list, list), ( 'operator config should be a list') - supported_loss_list = basic_loss.__all__ + distillation_loss.__all__ for config in loss_config_list: - assert isinstance(config, - dict) and len(config) == 1, "yaml format error" - name = list(config)[0] - assert name in supported_loss_list, \ - "loss name must be in {} but got: {}".format(name, supported_loss_list) - param = config[name] - assert "weight" in param, "weight must be in param, but param just contains {}".format( - param.keys()) - self.loss_weight.append(param.pop("weight")) - self.loss_func.append(eval(name)(**param)) + assert isinstance( + config, dict), "config must be a dict, but now is {}".format( + type(config)) + assert "weight" in config, "weight must be in param, but param just contains {}".format( + config.keys()) + self.loss_weight.append(config.pop("weight")) + self.loss_func.append(DistillationLoss(**config)) def forward(self, input, batch, **kargs): loss_dict = {} @@ -82,6 +66,7 @@ def forward(self, input, batch, **kargs): for key in loss } loss_dict.update(loss) + if loss_dict == {}: loss_dict["loss"] = paddle.to_tensor(0.) else: diff --git a/paddleslim/dygraph/dist/losses/basic_loss.py b/paddleslim/dygraph/dist/losses/basic_loss.py index fda472ad8c922..a774fe6b3781a 100644 --- a/paddleslim/dygraph/dist/losses/basic_loss.py +++ b/paddleslim/dygraph/dist/losses/basic_loss.py @@ -20,11 +20,13 @@ from paddle.nn import MSELoss as L2Loss from paddle.nn import SmoothL1Loss -__all__ = [ - "CELoss", "DMLLoss", "DistanceLoss", "RKdAngle", "RkdDistance", "KLLoss" -] +from ....core import Registry +__all__ = ["BASIC_LOSS"] +BASIC_LOSS = Registry("basicloss") + +@BASIC_LOSS.register class CELoss(nn.Layer): """ CELoss: cross entropy loss @@ -78,6 +80,7 @@ def forward(self, x, label): return loss +@BASIC_LOSS.register class DMLLoss(nn.Layer): """ DMLLoss @@ -110,6 +113,7 @@ def forward(self, out1, out2): return loss +@BASIC_LOSS.register class KLLoss(nn.Layer): """ KLLoss. @@ -153,6 +157,7 @@ def forward(self, input, label): return loss +@BASIC_LOSS.register class DistanceLoss(nn.Layer): """ DistanceLoss @@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12): return res +@BASIC_LOSS.register class RKdAngle(nn.Layer): """ RKdAngle loss, see https://arxiv.org/abs/1904.05068 @@ -218,6 +224,7 @@ def forward(self, student, teacher): return loss +@BASIC_LOSS.register class RkdDistance(nn.Layer): """ RkdDistance loss, see https://arxiv.org/abs/1904.05068 @@ -244,3 +251,50 @@ def forward(self, student, teacher): loss = F.smooth_l1_loss(d, t_d, reduction="mean") return loss + + +@BASIC_LOSS.register +class MSELoss(DistanceLoss): + """ + MSELoss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss + """ + + def __init__(self, **kargs): + super().__init__(mode='l2', **kargs) + + +@BASIC_LOSS.register +class L1Loss(DistanceLoss): + """ + L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss + """ + + def __init__(self, **kargs): + super().__init__(mode='l1', **kargs) + + +@BASIC_LOSS.register +class SmoothL1Loss(DistanceLoss): + """ + SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss + """ + + def __init__(self, **kargs): + super().__init__(mode='smooth_l1', **kargs) + + +@BASIC_LOSS.register +class RKDLoss(nn.Layer): + """ + RKDLoss + """ + + def __init__(self, eps=1e-12): + super().__init__() + self.rkd_angle_loss_func = RKdAngle() + self.rkd_dist_func = RkdDistance(eps=eps) + + def forward(self, student, teacher): + angle_loss = self.rkd_angle_loss_func(student, teacher) + dist_loss = self.rkd_dist_func(student, teacher) + return angle_loss + dist_loss diff --git a/paddleslim/dygraph/dist/losses/distillation_loss.py b/paddleslim/dygraph/dist/losses/distillation_loss.py index 222ef0a5427be..87abce118bf15 100644 --- a/paddleslim/dygraph/dist/losses/distillation_loss.py +++ b/paddleslim/dygraph/dist/losses/distillation_loss.py @@ -15,210 +15,54 @@ import paddle import paddle.nn as nn -from .basic_loss import DMLLoss -from .basic_loss import DistanceLoss -from .basic_loss import RkdDistance -from .basic_loss import RKdAngle -from .basic_loss import KLLoss +from .basic_loss import BASIC_LOSS -__all__ = [ - "DistillationDMLLoss", - "DistillationDistanceLoss", - "DistillationRKDLoss", - "SegPairWiseLoss", - "SegChannelwiseLoss", -] +__all__ = ["DistillationLoss"] -class DistillationDMLLoss(DMLLoss): +class DistillationLoss(nn.Layer): """ - DistillationDMLLoss + DistillationLoss Args: model_name_pairs(list | tuple): model name pairs to extract submodel output. - act(string | None): activation function used to build dml loss. - axis(int): axis used to build activation function. - key(string | None): key of the tensor used to calculate loss if the submodel - output type is dict. - name(string): loss name. - """ - - def __init__(self, model_name_pairs=[], act=None, key=None, - name="loss_dml"): - super().__init__(act=act) - assert isinstance(model_name_pairs, list) - self.key = key - self.model_name_pairs = model_name_pairs - self.name = name - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, pair in enumerate(self.model_name_pairs): - out1 = predicts[pair[0]] - out2 = predicts[pair[1]] - if self.key is not None: - out1 = out1[self.key] - out2 = out2[self.key] - loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], - idx)] = super().forward(out1, out2) - return loss_dict - - -class DistillationDistanceLoss(DistanceLoss): - """ - DistillationDistanceLoss - Args: - mode: loss mode - model_name_pairs(list | tuple): model name pairs to extract submodel output. - such as [['student', 'teacher']] - key(string | None): key of the tensor used to calculate loss if the submodel. - such as 'hidden_0_0' - name(string): loss name. - kargs(dict): used to build corresponding loss function. + layers_name(list(string)): keys of the tensor used to calculate loss if the submodel. + loss_function(string): the name of loss function. + temperature(float): the temperature to compute distill loss. """ def __init__(self, - mode="l2", model_name_pairs=[], - key=None, - name="loss_distance", - **kargs): - super().__init__(mode=mode, **kargs) - assert isinstance(model_name_pairs, list) - self.key = key - self.model_name_pairs = model_name_pairs - self.name = name + "_" + mode - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, pair in enumerate(self.model_name_pairs): - out1 = predicts[pair[0]] - out2 = predicts[pair[1]] - if self.key is not None: - out1 = out1[self.key] - out2 = out2[self.key] - loss = super().forward(out1, out2) - loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], - idx)] = loss - return loss_dict - - -class DistillationRKDLoss(nn.Layer): - """ - DistillationRKDLoss - Args: - model_name_pairs(list | tuple): model name pairs to extract submodel output. - key(string | None): key of the tensor used to calculate loss if the submodel. - eps(float): epsilon for the pdist function for RkdDistance loss. - name(string): loss name. - """ - - def __init__(self, - model_name_pairs=[], - key=None, - eps=1e-12, - name="loss_rkd"): + layers_name=None, + loss_function=None, + temperature=1.0, + **params): super().__init__() self.model_name_pairs = model_name_pairs - self.key = key + self.layers_name = layers_name + self.loss_function = loss_function + self.temperature = temperature + self.align_params = params.pop( + 'align_params') if 'align_params' in params else None + if self.align_params is not None: + for attr, value in self.align_params.items(): + setattr(self, attr, value) - self.rkd_angle_loss_func = RKdAngle() - self.rkd_dist_func = RkdDistance(eps=eps) - self.name = name + self.loss_func = BASIC_LOSS.get(loss_function)(**params) def forward(self, predicts, batch): loss_dict = dict() for idx, pair in enumerate(self.model_name_pairs): out1 = predicts[pair[0]] out2 = predicts[pair[1]] - if self.key is not None: - out1 = out1[self.key] - out2 = out2[self.key] - loss_dict["{}_{}_{}_angle_{}".format(self.name, pair[0], pair[ - 1], idx)] = self.rkd_angle_loss_func(out1, out2) - - loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[ - 1], idx)] = self.rkd_dist_func(out1, out2) - return loss_dict - - -class SegPairWiseLoss(DistanceLoss): - """ - Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf - - Args: - model_name_pairs(list | tuple): model name pairs to extract submodel output. - key(string): key of the tensor used to calculate loss if the submodel - output type is dict. - mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2. - reduction(string, optional): the reduction params for F.kl_div. Default: mean. - name(string, optional): loss name. Default: seg_pair_wise_loss. - """ - - def __init__(self, - model_name_pairs=[], - key=None, - mode="l2", - reduction="mean", - name="seg_pair_wise_loss"): - super().__init__(mode=mode, reduction=reduction) - - assert isinstance(model_name_pairs, list) - assert key is not None - self.key = key - self.model_name_pairs = model_name_pairs - self.name = name - - self.pool1 = nn.AdaptiveAvgPool2D(output_size=[2, 2]) - self.pool2 = nn.AdaptiveAvgPool2D(output_size=[2, 2]) - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, pair in enumerate(self.model_name_pairs): - out1 = predicts[pair[0]][self.key] - out2 = predicts[pair[1]][self.key] - - pool1 = self.pool1(out1) - pool2 = self.pool2(out2) - - loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx) - loss_dict[loss_name] = super().forward(pool1, pool2) - return loss_dict - - -class SegChannelwiseLoss(KLLoss): - """ - Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`. - Args: - model_name_pairs(list | tuple): model name pairs to extract submodel output. - key(string): key of the tensor used to calculate loss if the submodel - output type is dict. - act(string, optional): activation function used for the input and label tensor. - Default: softmax. - axis(int, optional): the axis for the act. Default: -1. - reduction(str, optional): the reduction params for F.kl_div. Default: mean. - name(string, optional): loss name. Default: seg_ch_wise_loss. - """ - - def __init__(self, - model_name_pairs=[], - key=None, - act='softmax', - axis=-1, - reduction="mean", - name="seg_ch_wise_loss"): - super().__init__(act, axis, reduction) - - assert isinstance(model_name_pairs, list) - assert key is not None - self.model_name_pairs = model_name_pairs - self.key = key - self.name = name - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, pair in enumerate(self.model_name_pairs): - out1 = predicts[pair[0]][self.key] - out2 = predicts[pair[1]][self.key] - loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx) - loss_dict[loss_name] = super().forward(out1, out2) + if self.layers_name != None: + assert len(self.layers_name + ) == 2, "length of layers_name must be equal to 2." + out1 = out1[self.layers_name[0]] + out2 = out2[self.layers_name[1]] + if self.temperature != 1.0: + out1 = out1 / self.temperature + out2 = out2 / self.temperature + loss_dict["{}_{}_{}_{}_{}".format(self.loss_function, pair[0], pair[ + 1], self.layers_name[0] if self.layers_name != None else "0", \ + self.layers_name[1] if self.layers_name != None else "0")] = self.loss_func(out1, out2) return loss_dict diff --git a/tests/dygraph/test_distill.py b/tests/dygraph/test_distill.py index 84e270ba8c0ce..d3ebafafc1ea0 100644 --- a/tests/dygraph/test_distill.py +++ b/tests/dygraph/test_distill.py @@ -7,7 +7,7 @@ import paddle.nn as nn from paddle.vision.models import MobileNetV1 import paddle.vision.transforms as T -from paddleslim.dygraph.dist import Distill, AdaptorBase +from paddleslim.dygraph.dist import Distill, config2yaml from paddleslim.common.log_helper import get_logger _logger = get_logger( @@ -19,42 +19,30 @@ def setUp(self): self.s_model, self.t_model = self.prepare_model() self.t_model.eval() self.distill_configs = self.prepare_config() - self.adaptor = self.prepare_adaptor() def prepare_model(self): return MobileNetV1(), MobileNetV1() def prepare_config(self): distill_configs = [{ - 's_feature_idx': 0, - 't_feature_idx': 0, - 'feature_type': 'hidden', - 'loss_function': 'l2' + 'loss_function': 'MSELoss', + 'layers': [ + { + "layers_name": ["conv1", "conv1"] + }, + { + "layers_name": ["conv2_2", "conv2_2"] + }, + ] }, { - 's_feature_idx': 1, - 't_feature_idx': 1, - 'feature_type': 'hidden', - 'loss_function': 'l2' - }, { - 's_feature_idx': 0, - 't_feature_idx': 0, - 'feature_type': 'logits', - 'loss_function': 'l2' + 'loss_function': 'CELoss', + 'temperature': 1.0, + 'layers': [{ + "layers_name": ["fc", "fc"] + }, ] }] return distill_configs - def prepare_adaptor(self): - class Adaptor(AdaptorBase): - def mapping_layers(self): - mapping_layers = {} - mapping_layers['hidden_0'] = 'conv1' - mapping_layers['hidden_1'] = 'conv2_2' - mapping_layers['hidden_2'] = 'conv3_2' - mapping_layers['logits_0'] = 'fc' - return mapping_layers - - return Adaptor - def test_distill(self): transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) @@ -97,7 +85,7 @@ def train(model): for batch_id, data in enumerate(train_reader): img = paddle.to_tensor(data[0]) label = paddle.to_tensor(data[1]) - student_out, teacher_out, distill_loss = model(img) + distill_loss, student_out, teacher_out = model(img) loss = paddle.nn.functional.loss.cross_entropy(student_out, label) avg_loss = paddle.mean(loss) @@ -112,7 +100,7 @@ def train(model): self.s_model.train() distill_model = Distill(self.distill_configs, self.s_model, - self.t_model, self.adaptor, self.adaptor) + self.t_model) train(distill_model) @@ -136,31 +124,26 @@ def forward(self, x): return Model(), Model() - def prepare_adaptor(self): - class Adaptor(AdaptorBase): - def mapping_layers(self): - mapping_layers = {} - mapping_layers['hidden_1'] = 'conv2' - if self.add_tensor: - mapping_layers['hidden_0'] = self.model.conv1_out - mapping_layers['hidden_2'] = self.model.conv3_out - return mapping_layers - - return Adaptor - def prepare_config(self): distill_configs = [{ - 's_feature_idx': 0, - 't_feature_idx': 0, - 'feature_type': 'hidden', - 'loss_function': 'l2' + 'loss_function': 'MSELoss', + 'layers': [ + { + "layers_name": ["conv1", "conv1"] + }, + { + "layers_name": ["conv2", "conv3"] + }, + ] }, { - 's_feature_idx': 1, - 't_feature_idx': 2, - 'feature_type': 'hidden', - 'loss_function': 'l2' + 'loss_function': 'CELoss', + 'temperature': 1.0, + 'layers': [{ + "layers_name": ["fc", "fc"] + }, ] }] - return distill_configs + config2yaml(distill_configs, 'test.yaml') + return './test.yaml' if __name__ == '__main__': diff --git a/tests/dygraph/test_distillation_loss.py b/tests/dygraph/test_distillation_loss.py index 748cddba65e41..a32a34d2783a2 100644 --- a/tests/dygraph/test_distillation_loss.py +++ b/tests/dygraph/test_distillation_loss.py @@ -24,18 +24,14 @@ from paddleslim.dygraph.dist.losses import CombinedLoss # basic loss -from paddleslim.dygraph.dist.losses import DistanceLoss -from paddleslim.dygraph.dist.losses import CELoss -from paddleslim.dygraph.dist.losses import DMLLoss -from paddleslim.dygraph.dist.losses import RkdDistance -from paddleslim.dygraph.dist.losses import RKdAngle +from paddleslim.dygraph.dist.losses.basic_loss import DistanceLoss +from paddleslim.dygraph.dist.losses.basic_loss import CELoss +from paddleslim.dygraph.dist.losses.basic_loss import DMLLoss +from paddleslim.dygraph.dist.losses.basic_loss import RkdDistance +from paddleslim.dygraph.dist.losses.basic_loss import RKdAngle # distillation loss -from paddleslim.dygraph.dist.losses import DistillationDistanceLoss -from paddleslim.dygraph.dist.losses import DistillationRKDLoss -from paddleslim.dygraph.dist.losses import DistillationDMLLoss -from paddleslim.dygraph.dist.losses import SegPairWiseLoss -from paddleslim.dygraph.dist.losses import SegChannelwiseLoss +from paddleslim.dygraph.dist.losses import DistillationLoss import numpy as np @@ -70,14 +66,13 @@ def np_distance_loss(self, x, y, mode="l2", reduction="none"): out = np.sum(diff) return out - def dist_np_distance_loss( - self, - predicts, - mode="l2", - reduction="none", - model_name_pairs=(["", ""]), - key=None, - name="loss_distance", ): + def dist_np_distance_loss(self, + predicts, + loss_function=None, + mode="l2", + reduction="none", + model_name_pairs=(["", ""]), + key=None): loss_dict = dict() for idx, pair in enumerate(model_name_pairs): out1 = predicts[pair[0]] @@ -85,10 +80,12 @@ def dist_np_distance_loss( if key is not None: out1 = out1[key] out2 = out2[key] + else: + key = 0 loss = self.np_distance_loss( out1, out2, mode=mode, reduction=reduction) - loss_dict["{}_{}_{}_{}_{}".format(name, mode, pair[0], pair[1], - idx)] = loss + loss_dict["{}_{}_{}_{}_{}".format( + str(loss_function), pair[0], pair[1], key, key)] = loss return loss_dict @@ -120,7 +117,7 @@ def test_distillation_distance_loss(self, ): "student": paddle.rand(shape), "teacher": paddle.rand(shape), } - self.calc_distillation_distance_loss(predicts, pairs, key=None) + self.calc_distillation_distance_loss(predicts, pairs) predicts = { "student": { @@ -143,13 +140,15 @@ def calc_distillation_distance_loss(self, predicts, pairs, key=None): paddle.set_device(device) for reduction in reductions: for mode in modes: - loss_func = DistillationDistanceLoss( + loss_func = DistillationLoss( mode=mode, + loss_function='DistanceLoss', model_name_pairs=pairs, - key=key, + layers_name=[key, key] if key != None else None, reduction=reduction) np_result_dict = self.dist_np_distance_loss( predicts, + loss_function='DistanceLoss', mode=mode, reduction=reduction, model_name_pairs=pairs, @@ -358,12 +357,11 @@ def test_basic_dml_loss(self, ): np_loss = self.np_dml_loss(x, target) self.assertTrue(np.allclose(np_loss, pd_loss)) - def dist_np_dml_loss( - self, - predicts, - model_name_pairs=(["", ""]), - key=None, - name="loss_dml", ): + def dist_np_dml_loss(self, + predicts, + loss_function=None, + model_name_pairs=(["", ""]), + key=None): loss_dict = dict() for idx, pair in enumerate(model_name_pairs): out1 = predicts[pair[0]] @@ -371,8 +369,11 @@ def dist_np_dml_loss( if key is not None: out1 = out1[key] out2 = out2[key] - loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], - idx)] = self.np_dml_loss(out1, out2) + else: + key = 0 + loss_dict["{}_{}_{}_{}_{}".format( + str(loss_function), pair[0], pair[1], key, + key)] = self.np_dml_loss(out1, out2) return loss_dict def calc_distillation_dml_loss(self, predicts, pairs, key=None): @@ -382,11 +383,19 @@ def calc_distillation_dml_loss(self, predicts, pairs, key=None): for device in devices: paddle.set_device(device) - loss_func = DistillationDMLLoss( - act="softmax", model_name_pairs=pairs, key=key) + loss_func = DistillationLoss( + act="softmax", + model_name_pairs=pairs, + loss_function='DMLLoss', + layers_name=[key, key] if key != None else None) np_result_dict = self.dist_np_dml_loss( - predicts, model_name_pairs=pairs, key=key) + predicts, + model_name_pairs=pairs, + loss_function='DMLLoss', + key=key) pd_result_dict = loss_func(predicts, None) + print(pd_result_dict.keys()) + print(np_result_dict.keys()) for k in np_result_dict: pd_result = pd_result_dict[k].numpy() np_result = np_result_dict[k] @@ -526,7 +535,7 @@ def dist_np_rkd_loss( predicts, model_name_pairs=(["", ""]), key=None, - name="loss_rkd", ): + name="RKDLoss", ): loss_dict = dict() for idx, pair in enumerate(model_name_pairs): out1 = predicts[pair[0]] @@ -534,11 +543,12 @@ def dist_np_rkd_loss( if key is not None: out1 = out1[key] out2 = out2[key] - loss_dict["{}_{}_{}_angle_{}".format(name, pair[0], pair[ - 1], idx)] = self.np_rkd_angle(out1, out2) + else: + key = 0 + loss_dict["{}_{}_{}_{}_{}".format(name, pair[0], pair[ + 1], key, key)] = self.np_rkd_angle( + out1, out2) + self.np_rkd_distance(out1, out2) - loss_dict["{}_{}_{}_dist_{}".format(name, pair[0], pair[ - 1], idx)] = self.np_rkd_distance(out1, out2) return loss_dict def calc_distillation_rkd_loss(self, predicts, pairs, key=None): @@ -548,7 +558,10 @@ def calc_distillation_rkd_loss(self, predicts, pairs, key=None): for device in devices: paddle.set_device(device) - loss_func = DistillationRKDLoss(model_name_pairs=pairs, key=key) + loss_func = DistillationLoss( + model_name_pairs=pairs, + loss_function='RKDLoss', + layers_name=[key, key] if key != None else None) np_result_dict = self.dist_np_rkd_loss( predicts, model_name_pairs=pairs, key=key) pd_result_dict = loss_func(predicts, None) @@ -623,13 +636,12 @@ def np_dml_loss(self, x, target, act="softmax"): log_soft_target, soft_x)) / 2.0 return loss - def dist_np_dml_loss( - self, - predicts, - model_name_pairs=(["", ""]), - key=None, - act="softmax", - name="loss_dml", ): + def dist_np_dml_loss(self, + predicts, + model_name_pairs=(["", ""]), + loss_function=None, + key=None, + act="softmax"): loss_dict = dict() for idx, pair in enumerate(model_name_pairs): out1 = predicts[pair[0]] @@ -637,20 +649,24 @@ def dist_np_dml_loss( if key is not None: out1 = out1[key] out2 = out2[key] - loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], - idx)] = self.np_dml_loss(out1, out2) + loss_dict["{}_{}_{}_{}_0".format( + str(loss_function), pair[0], pair[1], idx)] = self.np_dml_loss( + out1, out2) return loss_dict def np_combined_loss(self, predicts, loss_cfg_list): # NOTE, dml is set as the list for combined loss loss_dict = dict() for idx, loss_func in enumerate(loss_cfg_list): - cfg = copy.deepcopy(loss_func["DistillationDMLLoss"]) + cfg = copy.deepcopy(loss_func) weight = cfg.pop("weight") loss = self.dist_np_dml_loss(predicts, **cfg) if isinstance(loss, np.ndarray): - loss = {"loss_{}_{}".format(str(loss), idx): loss} + loss = { + "{}_{}_{}".format(loss_func['loss_function'], + str(loss), idx): loss + } else: loss = { "{}_{}".format(key, idx): loss[key] * weight @@ -677,12 +693,10 @@ def test_combined_loss(self, ): devices.append("gpu") loss_cfg_list = [{ - "DistillationDMLLoss": { - "weight": 1.0, - "act": "softmax", - "model_name_pairs": pairs, - "key": None - } + "loss_function": "DMLLoss", + "weight": 1.0, + "act": "softmax", + "model_name_pairs": pairs }, ] for device in devices: @@ -696,95 +710,5 @@ def test_combined_loss(self, ): self.assertTrue(np.allclose(np_result, pd_result)) -class TestSegPairWiseLoss(unittest.TestCase): - def calculate_gt_loss(self, x, y): - pool_x = F.adaptive_avg_pool2d(x, [2, 2]) - pool_y = F.adaptive_avg_pool2d(y, [2, 2]) - loss = F.mse_loss(pool_x, pool_y) - return loss - - def test_seg_pair_wise_loss(self): - shape = [1, 3, 10, 10] - x = paddle.rand(shape) - y = paddle.rand(shape) - model_name_pairs = [['student', 'teacher']] - key = 'hidden_0_0' - - inputs = { - model_name_pairs[0][0]: { - key: x - }, - model_name_pairs[0][1]: { - key: y - } - } - devices = ["cpu"] - if paddle.is_compiled_with_cuda(): - devices.append("gpu") - - for device in devices: - paddle.set_device(device) - loss_func = SegPairWiseLoss(model_name_pairs, key) - pd_loss_dict = loss_func(inputs, None) - pd_loss = pd_loss_dict['seg_pair_wise_loss_student_teacher_0'] - gt_loss = self.calculate_gt_loss(x, y) - self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy())) - - -class TestSegChannelWiseLoss(unittest.TestCase): - def init(self): - self.act_name = None - self.act_func = None - - def calculate_gt_loss(self, x, y, act=None): - if act is not None: - x = act(x) - y = act(y) - x = paddle.log(x) - loss = F.kl_div(x, y) - return loss - - def test_seg_pair_wise_loss(self): - self.init() - - shape = [1, 3, 10, 10] - x = paddle.rand(shape) - y = paddle.rand(shape) - model_name_pairs = [['student', 'teacher']] - key = 'hidden_0_0' - - inputs = { - model_name_pairs[0][0]: { - key: x - }, - model_name_pairs[0][1]: { - key: y - } - } - devices = ["cpu"] - if paddle.is_compiled_with_cuda(): - devices.append("gpu") - - for device in devices: - paddle.set_device(device) - loss_func = SegChannelwiseLoss(model_name_pairs, key, self.act_name) - pd_loss_dict = loss_func(inputs, None) - pd_loss = pd_loss_dict['seg_ch_wise_loss_student_teacher_0'] - gt_loss = self.calculate_gt_loss(x, y, self.act_func) - self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy())) - - -class TestSegChannelWiseLoss1(TestSegChannelWiseLoss): - def init(self): - self.act_name = "softmax" - self.act_func = F.softmax - - -class TestSegChannelWiseLoss1(TestSegChannelWiseLoss): - def init(self): - self.act_name = "sigmoid" - self.act_func = F.sigmoid - - if __name__ == '__main__': unittest.main()