From 5f99330bfecebfa8726bdf776043778ffce84144 Mon Sep 17 00:00:00 2001 From: Xianpan Zhou <1239068645@qq.com> Date: Mon, 5 Dec 2022 19:53:52 +0800 Subject: [PATCH 1/4] [Feature] Add tools to convert distill ckpt to student-only ckpt. --- .../convert_kd_ckpt_to_student.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tools/model_converters/convert_kd_ckpt_to_student.py diff --git a/tools/model_converters/convert_kd_ckpt_to_student.py b/tools/model_converters/convert_kd_ckpt_to_student.py new file mode 100644 index 000000000..b6e15b790 --- /dev/null +++ b/tools/model_converters/convert_kd_ckpt_to_student.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from pathlib import Path + +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('checkpoint', help='input checkpoint filename') + parser.add_argument( + '--inplace', action='store_true', help='replace origin ckpt') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + checkpoint = torch.load(args.checkpoint, map_location='cpu') + new_state_dict = dict() + + for key, value in checkpoint['state_dict'].items(): + if key.startswith('architecture.'): + new_key = key.replace('architecture.', '') + new_state_dict[new_key] = value + + checkpoint['state_dict'] = new_state_dict + + if args.inplace: + torch.save(checkpoint, args.checkpoint) + else: + ckpt_path = Path(args.checkpoint) + ckpt_name = ckpt_path.stem + ckpt_dir = ckpt_path.parent + new_ckpt_path = ckpt_dir / f'{ckpt_name}_student.pth' + torch.save(checkpoint, new_ckpt_path) + + +if __name__ == '__main__': + main() From fbc2fe9df4ef01876804a69c01bb7bc00f4f2e6a Mon Sep 17 00:00:00 2001 From: Xianpan Zhou <1239068645@qq.com> Date: Mon, 5 Dec 2022 20:32:48 +0800 Subject: [PATCH 2/4] fix bug. --- tools/model_converters/convert_kd_ckpt_to_student.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/model_converters/convert_kd_ckpt_to_student.py b/tools/model_converters/convert_kd_ckpt_to_student.py index b6e15b790..ed7d14eea 100644 --- a/tools/model_converters/convert_kd_ckpt_to_student.py +++ b/tools/model_converters/convert_kd_ckpt_to_student.py @@ -23,7 +23,7 @@ def main(): for key, value in checkpoint['state_dict'].items(): if key.startswith('architecture.'): new_key = key.replace('architecture.', '') - new_state_dict[new_key] = value + new_state_dict[new_key] = value checkpoint['state_dict'] = new_state_dict From 6e44e5e959a7660a5bb9250b91501fd1a9229583 Mon Sep 17 00:00:00 2001 From: Xianpan Zhou <1239068645@qq.com> Date: Mon, 5 Dec 2022 20:44:31 +0800 Subject: [PATCH 3/4] add --model-only to only save model. --- tools/model_converters/convert_kd_ckpt_to_student.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/model_converters/convert_kd_ckpt_to_student.py b/tools/model_converters/convert_kd_ckpt_to_student.py index ed7d14eea..c591d7fc8 100644 --- a/tools/model_converters/convert_kd_ckpt_to_student.py +++ b/tools/model_converters/convert_kd_ckpt_to_student.py @@ -9,6 +9,8 @@ def parse_args(): parser = argparse.ArgumentParser( description='Process a checkpoint to be published') parser.add_argument('checkpoint', help='input checkpoint filename') + parser.add_argument( + '--model-only', action='store_true', help='only save model') parser.add_argument( '--inplace', action='store_true', help='replace origin ckpt') args = parser.parse_args() @@ -25,6 +27,9 @@ def main(): new_key = key.replace('architecture.', '') new_state_dict[new_key] = value + if args.model_only: + checkpoint = dict() + checkpoint['state_dict'] = new_state_dict if args.inplace: From 8e8db24c5f8fd0e8ff9b4de2a48ef77b0971f31a Mon Sep 17 00:00:00 2001 From: Xianpan Zhou <1239068645@qq.com> Date: Tue, 6 Dec 2022 12:37:37 +0800 Subject: [PATCH 4/4] Make changes accroding to PR review. --- .../convert_kd_ckpt_to_student.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tools/model_converters/convert_kd_ckpt_to_student.py b/tools/model_converters/convert_kd_ckpt_to_student.py index c591d7fc8..e44f66d02 100644 --- a/tools/model_converters/convert_kd_ckpt_to_student.py +++ b/tools/model_converters/convert_kd_ckpt_to_student.py @@ -7,10 +7,9 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Process a checkpoint to be published') + description='Convert KD checkpoint to student-only checkpoint') parser.add_argument('checkpoint', help='input checkpoint filename') - parser.add_argument( - '--model-only', action='store_true', help='only save model') + parser.add_argument('--out-path', help='save checkpoint path') parser.add_argument( '--inplace', action='store_true', help='replace origin ckpt') args = parser.parse_args() @@ -21,15 +20,15 @@ def main(): args = parse_args() checkpoint = torch.load(args.checkpoint, map_location='cpu') new_state_dict = dict() + new_meta = checkpoint['meta'] for key, value in checkpoint['state_dict'].items(): if key.startswith('architecture.'): new_key = key.replace('architecture.', '') new_state_dict[new_key] = value - if args.model_only: - checkpoint = dict() - + checkpoint = dict() + checkpoint['meta'] = new_meta checkpoint['state_dict'] = new_state_dict if args.inplace: @@ -37,7 +36,10 @@ def main(): else: ckpt_path = Path(args.checkpoint) ckpt_name = ckpt_path.stem - ckpt_dir = ckpt_path.parent + if args.out_path: + ckpt_dir = Path(args.out_path) + else: + ckpt_dir = ckpt_path.parent new_ckpt_path = ckpt_dir / f'{ckpt_name}_student.pth' torch.save(checkpoint, new_ckpt_path)