Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7 from wanghaoshuang/distillation
Browse files Browse the repository at this point in the history
Add fsp distillatoin strategy.
  • Loading branch information
wanghaoshuang committed Jan 31, 2019
2 parents f7e24f3 + 14c9129 commit a836ec3
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 8 deletions.
18 changes: 12 additions & 6 deletions python/paddle/fluid/contrib/slim/core/compress_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from ....core import CPUPlace
from .... import io
from ....data_feeder import DataFeeder
from ..graph import get_executor, ImitationGraph
from config import ConfigFactory
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self,
train_reader=None,
eval_graph=None,
eval_reader=None,
teacher_graphs=None,
optimizer=None):
# The total number of epoches to be trained.
self.epoch = 0
Expand All @@ -57,6 +59,7 @@ def __init__(self,
self.eval_graph = eval_graph
self.eval_reader = eval_reader
self.executor = None
self.teacher_graphs = teacher_graphs
self.optimizer = optimizer

def run_eval_graph(self):
Expand Down Expand Up @@ -159,21 +162,23 @@ def config(self, config_file):

def _load_checkpoint(self, context):
if self.checkpoint:
exe = get_executor(context.train_graph, parallel=False)
fluid.io.load_persistables(
exe = get_executor(
context.train_graph, context.place, parallel=False)
io.load_persistables(
exe.exe,
self.checkpoint,
main_program=context.train_graph.program)
print("Loaded checkpoint from: {}".format(self.checkpoint))

def _save_checkpoint(self, context):
if context.epoch_id % 5 == 0 and self.model_save_dir:
model_path = os.path.join(self.model_save_dir,
str(context.epoch_id))
model_path = os.path.join(
self.model_save_dir,
str(context.epoch_id) + "_" + str(context.batch_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
exe = get_executor(context.train_graph, parallel=False)
fluid.io.save_persistables(
exe = get_executor(context.train_graph, context.place, False)
io.save_persistables(
exe.exe, model_path, main_program=context.train_graph.program)
print('Saved checkpoint to: {}'.format(model_path))

Expand Down Expand Up @@ -209,6 +214,7 @@ def run(self):
train_reader=self.train_reader,
eval_graph=self.eval_graph,
eval_reader=self.eval_reader,
teacher_graphs=self.teacher_graphs,
optimizer=self.optimizer)

self._load_checkpoint(context)
Expand Down
91 changes: 91 additions & 0 deletions python/paddle/fluid/contrib/slim/demo/distillation/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.fluid as fluid
import paddle
import os
import sys
from resnet import *
from paddle.fluid.contrib.slim import CompressPass
from paddle.fluid.contrib.slim import build_compressor
from paddle.fluid.contrib.slim import ImitationGraph


class Model(object):
def __init__(slef):
pass

def compress(self):

img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
resnet50 = ResNet50()
predict = resnet50.net(img, class_dim=10)
eval_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)

with fluid.program_guard(main_program=eval_program):
acc = fluid.layers.accuracy(input=predict, label=label)

optimizer = fluid.optimizer.SGD(0.001)
optimizer.minimize(avg_cost)

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=32)
eval_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=1)

train_feed_list = {'img': img.name, 'label': label.name}
train_fetch_list = {'cost': avg_cost.name}
eval_feed_list = {'img': img.name, 'label': label.name}
eval_fetch_list = {'acc': acc.name}

# define teacher program
teacher_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(teacher_program, startup_program):
img = fluid.layers.data(
name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
resnet101 = ResNet101()
predict = resnet101.net(img, class_dim=10)
exe.run(startup_program)

com_pass = CompressPass(
place,
fluid.global_scope(),
fluid.default_main_program(),
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=eval_program,
eval_reader=eval_reader,
eval_feed_list=eval_feed_list,
eval_fetch_list=eval_fetch_list,
teacher_programs=[teacher_program],
optimizer=optimizer)
com_pass.model_save_dir = './checkpoints'
com_pass.config('./config.yaml')
com_pass.run()


if __name__ == "__main__":
model = Model()
model.compress()
14 changes: 14 additions & 0 deletions python/paddle/fluid/contrib/slim/demo/distillation/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
version: 1.0
distillers:
fsp_distiller:
class: 'FSPDistiller'
strategies:
fsp_distillation_strategy:
class: 'FSPDistillationStrategy'
distiller: 'fsp_distiller'
start_epoch: 0
end_epoch: 10
compress_pass:
epoch: 10
strategies:
- fsp_distillation_strategy
137 changes: 137 additions & 0 deletions python/paddle/fluid/contrib/slim/demo/distillation/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math

__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]

train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}


class ResNet():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers

def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)

if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]

conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')

for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)

pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out

def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)

def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input

def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)

short = self.shortcut(input, num_filters * 4, stride)

return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')


def ResNet50():
model = ResNet(layers=50)
return model


def ResNet101():
model = ResNet(layers=101)
return model


def ResNet152():
model = ResNet(layers=152)
return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..core.strategy import Strategy
from ....framework import Program, program_guard, Parameter
from .... import layers
import numpy as np
import copy
import re

__all__ = ['FSPDistillationStrategy']


class FSPDistillationStrategy(Strategy):
def __init__(self, distiller=None, start_epoch=0, end_epoch=10):
super(FSPDistillationStrategy, self).__init__(start_epoch, end_epoch)
self.distiller = distiller
self.train_graph_backup = None

def on_epoch_begin(self, context):
if self.start_epoch == context.epoch_id:
self.train_graph_backup = context.train_graph
graph = self.distiller.distiller_graph(
context.eval_graph, context.teacher_graphs, context.optimizer,
context.place)
context.train_graph = graph

def on_epoch_end(self, context):
if context.epoch_id == (self.end_epoch - 1):
context.train_graph = self.train_graph_backup
Loading

0 comments on commit a836ec3

Please sign in to comment.