Skip to content

Commit

Permalink
Update: Modify the resume training model loading and cancel the broad…
Browse files Browse the repository at this point in the history
…cast method.
  • Loading branch information
chairc committed Jul 25, 2023
1 parent e58cd4e commit 19bc4d9
Showing 1 changed file with 11 additions and 29 deletions.
40 changes: 11 additions & 29 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def train(rank=None, args=None):
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo", rank=rank,
world_size=world_size)
# 设置设备ID
device = rank
torch.cuda.set_device(device=device)
device = torch.device("cuda", rank)
# 可能出现随机性错误,使用可减少cudnn随机性错误
# torch.backends.cudnn.deterministic = True
# 同步
Expand Down Expand Up @@ -122,33 +121,16 @@ def train(rank=None, args=None):
load_epoch = str(start_epoch - 1).zfill(3)
model_path = os.path.join(result_path, load_model_dir, f"model_{load_epoch}.pt")
optim_path = os.path.join(result_path, load_model_dir, f"optim_model_{load_epoch}.pt")
# 分布式恢复训练
if distributed:
# 使用主显卡
if dist.get_rank() == args.main_gpu:
model_weights_dict = torch.load(f=model_path)
model.load_state_dict(state_dict=model_weights_dict)
optim_weights_dict = torch.load(f=optim_path)
optimizer.load_state_dict(state_dict=optim_weights_dict)
logger.info(
msg=f"[{device}]: Successfully load model model_{load_epoch}.pt and optim_model_{load_epoch}.pt")
else:
NotImplementedError(
"Distributed computing loading model error, please check the main GPU configuration")
# 广播参数
dist.broadcast_object_list(object_list=[model, optimizer], src=args.main_gpu)
# 普通恢复训练
else:
model_dict = model.state_dict()
model_weights_dict = torch.load(f=model_path, map_location=device)
model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(model_weights_dict)
model.load_state_dict(state_dict=OrderedDict(model_dict))
logger.info(msg=f"[{device}]: Successfully load model model_{load_epoch}.pt")
# 加载优化器参数
optim_weights_dict = torch.load(f=optim_path, map_location=device)
optimizer.load_state_dict(state_dict=optim_weights_dict)
logger.info(msg=f"[{device}]: Successfully load optimizer optim_model_{load_epoch}.pt")
model_dict = model.state_dict()
model_weights_dict = torch.load(f=model_path, map_location=device)
model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(model_weights_dict)
model.load_state_dict(state_dict=OrderedDict(model_dict))
logger.info(msg=f"[{device}]: Successfully load model model_{load_epoch}.pt")
# 加载优化器参数
optim_weights_dict = torch.load(f=optim_path, map_location=device)
optimizer.load_state_dict(state_dict=optim_weights_dict)
logger.info(msg=f"[{device}]: Successfully load optimizer optim_model_{load_epoch}.pt")
else:
start_epoch = 0
if fp16:
Expand Down

0 comments on commit 19bc4d9

Please sign in to comment.