Skip to content

Commit

Permalink
unify dist download (PaddlePaddle#3867)
Browse files Browse the repository at this point in the history
  • Loading branch information
heavengate committed Aug 3, 2021
1 parent b63fe62 commit 73bbc91
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 42 deletions.
41 changes: 3 additions & 38 deletions ppdet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints):
return unique_endpoints


def get_weights_path_dist(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)

return path


def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
Expand All @@ -99,7 +64,7 @@ def _strip_postfix(path):

def load_weight(model, weight, optimizer=None):
if is_url(weight):
weight = get_weights_path_dist(weight)
weight = get_weights_path(weight)

path = _strip_postfix(weight)
pdparam_path = path + '.pdparams'
Expand Down Expand Up @@ -205,7 +170,7 @@ def match(a, b):

def load_pretrain_weight(model, pretrain_weight):
if is_url(pretrain_weight):
pretrain_weight = get_weights_path_dist(pretrain_weight)
pretrain_weight = get_weights_path(pretrain_weight)

path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or
Expand Down Expand Up @@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir))
logger.info("Save checkpoint: {}".format(save_dir))
66 changes: 62 additions & 4 deletions ppdet/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os.path as osp
import sys
import yaml
import time
import shutil
import requests
import tqdm
Expand All @@ -29,6 +30,7 @@
import tarfile
import zipfile

from paddle.utils.download import _get_unique_endpoints
from ppdet.core.workspace import BASE_KEY
from .logger import setup_logger
from .voc_utils import create_list
Expand Down Expand Up @@ -144,8 +146,8 @@ def get_config_path(url):
cfg_url = parse_url(cfg_url)

# 3. download and decompress
cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME))
_decompress(cfg_fullname)
cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME))
_decompress_dist(cfg_fullname)

# 4. check config file existing
if os.path.isfile(path):
Expand Down Expand Up @@ -281,12 +283,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
else:
os.remove(fullpath)

fullname = _download(url, root_dir, md5sum)
fullname = _download_dist(url, root_dir, md5sum)

# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
_decompress_dist(fullname)

return fullpath, False

Expand Down Expand Up @@ -381,6 +383,38 @@ def _download(url, path, md5sum=None):
return fullname


def _download_dist(url, path, md5sum=None):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
return _download(url, path, md5sum)
else:
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
lock_path = fullname + '.download.lock'

if not osp.isdir(path):
os.makedirs(path)

if not osp.exists(fullname):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_download(url, path, md5sum)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
return fullname
else:
return _download(url, path, md5sum)


def _check_exist_file_md5(filename, md5sum, url):
# if md5sum is None, and file to check is weights file,
# read md5um from url and check, else check md5sum directly
Expand Down Expand Up @@ -458,6 +492,30 @@ def _decompress(fname):
os.remove(fname)


def _decompress_dist(fname):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
_decompress(fname)
else:
lock_path = fname + '.decompress.lock'
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_decompress(fname)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
else:
_decompress(fname)


def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
Expand Down

0 comments on commit 73bbc91

Please sign in to comment.