Skip to content

Commit

Permalink
[Fix] fix the error of nonetype object for quant (PaddlePaddle#1854)
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Mar 14, 2022
1 parent 9732019 commit 86f51a3
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 151 deletions.
42 changes: 20 additions & 22 deletions slim/quant/qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,25 @@

from paddleslim import QAT


def parse_args():
parser = argparse.ArgumentParser(description='Model export.')
parser.add_argument(
"--config",
dest="cfg",
help="The config file.",
default=None,
type=str,
required=True)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the exported model',
type=str,
default='./output')
parser.add_argument(
'--model_path',
dest='model_path',
help='The path of model for export',
type=str,
default=None)
parser.add_argument("--config",
dest="cfg",
help="The config file.",
default=None,
type=str,
required=True)
parser.add_argument('--save_dir',
dest='save_dir',
help='The directory for saving the exported model',
type=str,
default='./output')
parser.add_argument('--model_path',
dest='model_path',
help='The path of model for export',
type=str,
default=None)
parser.add_argument(
'--without_argmax',
dest='without_argmax',
Expand All @@ -72,15 +70,15 @@ def main(args):

skip_quant(net)
quantizer = QAT(config=quant_config)
quant_net = quantizer.quantize(net)
quantizer.quantize(net)
logger.info('Quantize the model successfully')

if args.model_path:
utils.load_entire_model(quant_net, args.model_path)
utils.load_entire_model(net, args.model_path)
logger.info('Loaded trained params of model successfully')

if not args.without_argmax or args.with_softmax:
new_net = SavedSegmentationNet(quant_net, args.without_argmax,
new_net = SavedSegmentationNet(net, args.without_argmax,
args.with_softmax)
else:
new_net = net
Expand Down
168 changes: 79 additions & 89 deletions slim/quant/qat_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,85 +37,77 @@

def parse_args():
parser = argparse.ArgumentParser(description='Model training')
parser.add_argument(
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--iters',
dest='iters',
help='iters for training',
type=int,
default=None)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=None)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=None)
parser.add_argument("--config",
dest="cfg",
help="The config file.",
default=None,
type=str)
parser.add_argument('--iters',
dest='iters',
help='iters for training',
type=int,
default=None)
parser.add_argument('--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=None)
parser.add_argument('--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=None)
parser.add_argument(
'--save_interval',
dest='save_interval',
help='How many iters to save a model snapshot once during training.',
type=int,
default=1000)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--keep_checkpoint_max',
dest='keep_checkpoint_max',
help='Maximum number of checkpoints to save',
type=int,
default=5)
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument(
'--do_eval',
dest='do_eval',
help='Eval while training',
action='store_true')
parser.add_argument(
'--log_iters',
dest='log_iters',
help='Display logging information at every log_iters',
default=10,
type=int)
parser.add_argument('--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument('--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument('--keep_checkpoint_max',
dest='keep_checkpoint_max',
help='Maximum number of checkpoints to save',
type=int,
default=5)
parser.add_argument('--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument('--do_eval',
dest='do_eval',
help='Eval while training',
action='store_true')
parser.add_argument('--log_iters',
dest='log_iters',
help='Display logging information at every log_iters',
default=10,
type=int)
parser.add_argument(
'--use_vdl',
dest='use_vdl',
help='Whether to record the data to VisualDL during training',
action='store_true')
parser.add_argument(
'--seed',
dest='seed',
help='Set the random seed during training.',
default=None,
type=int)

parser.add_argument(
'--model_path',
dest='model_path',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument('--seed',
dest='seed',
help='Set the random seed during training.',
default=None,
type=int)

parser.add_argument('--model_path',
dest='model_path',
help='The path of pretrained model',
type=str,
default=None)

return parser.parse_args()

Expand Down Expand Up @@ -156,11 +148,10 @@ def main(args):
if not args.cfg:
raise RuntimeError('No configuration file specified.')

cfg = Config(
args.cfg,
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size)
cfg = Config(args.cfg,
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size)

train_dataset = cfg.train_dataset
if train_dataset is None:
Expand All @@ -187,24 +178,23 @@ def main(args):

skip_quant(model)
quantizer = QAT(config=quant_config)
quant_model = quantizer.quantize(model)
quantizer.quantize(model)
logger.info('Quantize the model successfully')

train(
quant_model,
train_dataset,
val_dataset=val_dataset,
optimizer=cfg.optimizer,
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
resume_model=None,
save_interval=args.save_interval,
log_iters=args.log_iters,
num_workers=args.num_workers,
use_vdl=args.use_vdl,
losses=losses,
keep_checkpoint_max=args.keep_checkpoint_max)
train(model,
train_dataset,
val_dataset=val_dataset,
optimizer=cfg.optimizer,
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
resume_model=None,
save_interval=args.save_interval,
log_iters=args.log_iters,
num_workers=args.num_workers,
use_vdl=args.use_vdl,
losses=losses,
keep_checkpoint_max=args.keep_checkpoint_max)


if __name__ == '__main__':
Expand Down
76 changes: 36 additions & 40 deletions slim/quant/qat_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,51 +55,48 @@ def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')

# params of evaluate
parser.add_argument(
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--model_path',
dest='model_path',
help='The path of model for evaluation',
type=str,
default=None)
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument("--config",
dest="cfg",
help="The config file.",
default=None,
type=str)
parser.add_argument('--model_path',
dest='model_path',
help='The path of model for evaluation',
type=str,
default=None)
parser.add_argument('--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)

# augment for evaluation
parser.add_argument(
'--aug_eval',
dest='aug_eval',
help='Whether to use mulit-scales and flip augment for evaluation',
action='store_true')
parser.add_argument(
'--scales',
dest='scales',
nargs='+',
help='Scales for augment',
type=float,
default=1.0)
parser.add_argument(
'--flip_horizontal',
dest='flip_horizontal',
help='Whether to use flip horizontally augment',
action='store_true')
parser.add_argument(
'--flip_vertical',
dest='flip_vertical',
help='Whether to use flip vertically augment',
action='store_true')
parser.add_argument('--scales',
dest='scales',
nargs='+',
help='Scales for augment',
type=float,
default=1.0)
parser.add_argument('--flip_horizontal',
dest='flip_horizontal',
help='Whether to use flip horizontally augment',
action='store_true')
parser.add_argument('--flip_vertical',
dest='flip_vertical',
help='Whether to use flip vertically augment',
action='store_true')

# sliding window evaluation
parser.add_argument(
'--is_slide',
dest='is_slide',
help='Whether to evaluate by sliding window',
action='store_true')
parser.add_argument('--is_slide',
dest='is_slide',
help='Whether to evaluate by sliding window',
action='store_true')
parser.add_argument(
'--crop_size',
dest='crop_size',
Expand Down Expand Up @@ -167,18 +164,17 @@ def main(args):

skip_quant(model)
quantizer = QAT(config=quant_config)
quant_model = quantizer.quantize(model)
quantizer.quantize(model)
logger.info('Quantize the model successfully')

if args.model_path:
utils.load_entire_model(quant_model, args.model_path)
utils.load_entire_model(model, args.model_path)
logger.info('Loaded trained params of model successfully')

test_config = get_test_config(cfg, args)
config_check(cfg, val_dataset=val_dataset)

evaluate(
quant_model, val_dataset, num_workers=args.num_workers, **test_config)
evaluate(model, val_dataset, num_workers=args.num_workers, **test_config)


if __name__ == '__main__':
Expand Down

0 comments on commit 86f51a3

Please sign in to comment.