Skip to content

Commit

Permalink
[Enhancement] .dev Python files updated to get better performance and…
Browse files Browse the repository at this point in the history
… syntax (#2020)

* logger hooks samples updated

* [Docs] Details for WandBLoggerHook Added

* [Docs] lint test pass

* [Enhancement] .dev Python files updated to get better performance and quality

* [Docs] Details for WandBLoggerHook Added

* [Docs] lint test pass

* [Enhancement] .dev Python files updated to get better performance and quality

* [Enhancement] lint test passed

* [Enhancement] Change Some Line from Previous to Support Python<3.9

* Update .dev/gather_models.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
  • Loading branch information
Nourollah and MeowZheng committed Sep 14, 2022
1 parent ecd1ecb commit 31395a8
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 91 deletions.
10 changes: 4 additions & 6 deletions .dev/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def parse_args():
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
return parser.parse_args()


def inference_model(config_name, checkpoint, args, logger=None):
Expand All @@ -66,11 +65,10 @@ def inference_model(config_name, checkpoint, args, logger=None):
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
elif logger is None:
print(f'{config_name}: unable to start aug test', flush=True)
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
logger.error(f'{config_name}: unable to start aug test')

model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
Expand Down
8 changes: 2 additions & 6 deletions .dev/check_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ def check_url(url):
Returns:
int, bool: status code and check flag.
"""
flag = True
r = requests.head(url)
status_code = r.status_code
if status_code == 403 or status_code == 404:
flag = False

flag = status_code not in [403, 404]
return status_code, flag


Expand All @@ -35,8 +32,7 @@ def parse_args():
type=str,
help='Select the model needed to check')

args = parser.parse_args()
return args
return parser.parse_args()


def main():
Expand Down
4 changes: 2 additions & 2 deletions .dev/gather_benchmark_evaluation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def parse_args():
continue

# Compare between new benchmark results and previous metrics
differential_results = dict()
new_metrics = dict()
differential_results = {}
new_metrics = {}
for record_metric_key in previous_metrics:
if record_metric_key not in metric['metric']:
raise KeyError('record_metric_key not exist, please '
Expand Down
6 changes: 3 additions & 3 deletions .dev/gather_benchmark_train_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def parse_args():
print(f'log file error: {log_json_path}')
continue

differential_results = dict()
old_results = dict()
new_results = dict()
differential_results = {}
old_results = {}
new_results = {}
for metric_key in model_performance:
if metric_key in ['mIoU']:
metric = round(model_performance[metric_key] * 100, 2)
Expand Down
21 changes: 10 additions & 11 deletions .dev/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def process_checkpoint(in_file, out_file):
# The hash code calculation and rename command differ on different system
# platform.
sha = calculate_file_sha256(out_file)
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
os.rename(out_file, final_file)

# Remove prefix and suffix
Expand All @@ -50,25 +50,23 @@ def get_final_iter(config):


def get_final_results(log_json_path, iter_num):
result_dict = dict()
result_dict = {}
last_iter = 0
with open(log_json_path, 'r') as f:
for line in f.readlines():
for line in f:
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue

# When evaluation, the 'iter' of new log json is the evaluation
# steps on single gpu.
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
flag1 = 'aAcc' in log_line or log_line['mode'] == 'val'
flag2 = last_iter in [iter_num - 50, iter_num]
if flag1 and flag2:
result_dict.update({
key: log_line[key]
for key in RESULTS_LUT if key in log_line
})
return result_dict

last_iter = log_line['iter']


Expand Down Expand Up @@ -123,7 +121,7 @@ def main():
exp_dir = osp.join(work_dir, config_name)
# check whether the exps is finished
final_iter = get_final_iter(used_config)
final_model = 'iter_{}.pth'.format(final_iter)
final_model = f'iter_{final_iter}.pth'
model_path = osp.join(exp_dir, final_model)

# skip if the model is still training
Expand All @@ -135,7 +133,7 @@ def main():
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
log_json_path = log_json_paths[0]
model_performance = None
for idx, _log_json_path in enumerate(log_json_paths):
for _log_json_path in log_json_paths:
model_performance = get_final_results(_log_json_path, final_iter)
if model_performance is not None:
log_json_path = _log_json_path
Expand All @@ -161,9 +159,10 @@ def main():
model_publish_dir = osp.join(collect_dir, config_name)

publish_model_path = osp.join(model_publish_dir,
config_name + '_' + model['model_time'])
f'{config_name}_' + model['model_time'])

trained_model_path = osp.join(work_dir, config_name,
'iter_{}.pth'.format(model['iters']))
f'iter_{model["iters"]}.pth')
if osp.exists(model_publish_dir):
for file in os.listdir(model_publish_dir):
if file.endswith('.pth'):
Expand Down
10 changes: 4 additions & 6 deletions .dev/generate_benchmark_evaluation_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def parse_args():
default='.dev/benchmark_evaluation.sh',
help='path to save model benchmark script')

args = parser.parse_args()
return args
return parser.parse_args()


def process_model_info(model_info, work_dir):
Expand All @@ -30,10 +29,9 @@ def process_model_info(model_info, work_dir):
job_name = fname
checkpoint = model_info['checkpoint'].strip()
work_dir = osp.join(work_dir, fname)
if not isinstance(model_info['eval'], list):
evals = [model_info['eval']]
else:
evals = model_info['eval']
evals = model_info['eval'] if isinstance(model_info['eval'],
list) else [model_info['eval']]

eval = ' '.join(evals)
return dict(
config=config,
Expand Down
7 changes: 2 additions & 5 deletions .dev/generate_benchmark_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ def main():
port = args.port
partition_name = 'PARTITION=$1'

commands = []
commands.append(partition_name)
commands.append('\n')
commands.append('\n')
commands = [partition_name, '\n', '\n']

with open(args.txt_path, 'r') as f:
model_cfgs = f.readlines()
for i, cfg in enumerate(model_cfgs):
for cfg in model_cfgs:
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
port)
port += 1
Expand Down
8 changes: 2 additions & 6 deletions .dev/log_collector/log_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
def parse_args():
parser = argparse.ArgumentParser(description='extract info from log.json')
parser.add_argument('config_dir')
args = parser.parse_args()
return args
return parser.parse_args()


def has_keyword(name: str, keywords: list):
for a_keyword in keywords:
if a_keyword in name:
return True
return False
return any(a_keyword in name for a_keyword in keywords)


def main():
Expand Down
3 changes: 1 addition & 2 deletions .dev/upload_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def parse_args():
type=str,
default='mmsegmentation/v0.5',
help='destination folder')
args = parser.parse_args()
return args
return parser.parse_args()


def main():
Expand Down
6 changes: 5 additions & 1 deletion docs/en/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,13 @@ log_config = dict( # config to register logger hook
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False),
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # The Wandb logger is also supported, It requires `wandb` to be installed.
dict(type='MMSegWandbHook', by_epoch=False, # The Wandb logger is also supported, It requires `wandb` to be installed.
init_kwargs={'entity': "OpenMMLab", # The entity used to log on Wandb
'project': "MMSeg", # Project name in WandB
'config': cfg_dict}), # Check https://docs.wandb.ai/ref/python/init for more init arguments.
# MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation.
])

dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set.
log_level = 'INFO' # The level of logging.
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
Expand Down
7 changes: 5 additions & 2 deletions docs/zh_cn/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,13 @@ data = dict(
]))
log_config = dict( # 注册日志钩 (register logger hook) 的配置文件。
interval=50, # 打印日志的间隔
hooks=[
hooks=[ # 训练期间执行的钩子
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False),
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # 同样支持 Wandb 日志
dict(type='MMSegWandbHook', by_epoch=False, # 还支持 Wandb 记录器,它需要安装 `wandb`。
init_kwargs={'entity': "OpenMMLab", # 用于登录wandb的实体
'project': "mmseg", # WandB中的项目名称
'config': cfg_dict}), # 检查 https://docs.wandb.ai/ref/python/init 以获取更多初始化参数
])

dist_params = dict(backend='nccl') # 用于设置分布式训练的参数,端口也同样可被设置。
Expand Down
71 changes: 30 additions & 41 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def parse_line(line):
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
yield from parse_require_file(target)
else:
info = {'line': line}
if line.startswith('-e '):
Expand All @@ -58,7 +57,6 @@ def parse_line(line):
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]

info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
Expand All @@ -69,31 +67,30 @@ def parse_line(line):
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
version = rest
info['version'] = op, version
yield info

def parse_require_file(fpath):
with open(fpath, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
yield from parse_line(line)

def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
if not exists(require_fpath):
return
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(f';{platform_deps}')
item = ''.join(parts)
yield item

packages = list(gen_packages_items())
return packages
Expand All @@ -110,35 +107,28 @@ def add_mim_extension():
# parse installment mode
if 'develop' in sys.argv:
# installed by `pip install -e .`
if platform.system() == 'Windows':
# set `copy` mode here since symlink fails on Windows.
mode = 'copy'
else:
mode = 'symlink'
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or \
platform.system() == 'Windows':
# set `copy` mode here since symlink fails on Windows.
mode = 'copy' if platform.system() == 'Windows' else 'symlink'
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or platform.system(
) == 'Windows':
# installed by `pip install .`
# or create source distribution by `python setup.py sdist`
# set `copy` mode here since symlink fails with WinError on Windows.
mode = 'copy'
else:
return

filenames = ['tools', 'configs', 'model-index.yml']
repo_path = osp.dirname(__file__)
mim_path = osp.join(repo_path, 'mmseg', '.mim')
os.makedirs(mim_path, exist_ok=True)

for filename in filenames:
if osp.exists(filename):
src_path = osp.join(repo_path, filename)
tar_path = osp.join(mim_path, filename)

if osp.isfile(tar_path) or osp.islink(tar_path):
os.remove(tar_path)
elif osp.isdir(tar_path):
shutil.rmtree(tar_path)

if mode == 'symlink':
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
try:
Expand All @@ -149,20 +139,19 @@ def add_mim_extension():
# the error happens, the src file will be copied
mode = 'copy'
warnings.warn(
f'Failed to create a symbolic link for {src_relpath}, '
f'and it will be copied to {tar_path}')
else:
continue
f'Failed to create a symbolic link for {src_relpath},'
f' and it will be copied to {tar_path}')

if mode == 'copy':
if osp.isfile(src_path):
shutil.copyfile(src_path, tar_path)
elif osp.isdir(src_path):
shutil.copytree(src_path, tar_path)
else:
warnings.warn(f'Cannot copy file {src_path}.')
else:
continue
if mode != 'copy':
raise ValueError(f'Invalid mode {mode}')
if osp.isfile(src_path):
shutil.copyfile(src_path, tar_path)
elif osp.isdir(src_path):
shutil.copytree(src_path, tar_path)
else:
warnings.warn(f'Cannot copy file {src_path}.')


if __name__ == '__main__':
Expand Down

0 comments on commit 31395a8

Please sign in to comment.