Skip to content

Commit

Permalink
Add: Add seed initializer.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Jul 31, 2023
1 parent 19bc4d9 commit 9189913
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from model.ddpm import Diffusion
from model.modules import EMA
from model.network import UNet
from utils.initializer import device_initializer
from utils.initializer import device_initializer, seed_initializer
from utils.utils import plot_images, save_images, get_dataset, setup_logging

logger = logging.getLogger(__name__)
Expand All @@ -40,6 +40,8 @@ def train(rank=None, args=None):
:return: None
"""
logger.info(msg=f"[{rank}]: Input params: {args}")
# 初始化种子
seed_initializer(seed_id=args.seed)
# 运行名称
run_name = args.run_name
# 输入图像大小
Expand Down Expand Up @@ -282,6 +284,8 @@ def main(args):
if __name__ == "__main__":
# 训练模型参数
parser = argparse.ArgumentParser()
# 设置初始化种子(必须设置)
parser.add_argument("--seed", type=int, default=0)
# 开启条件训练(必须设置)
# 若开启可修改自定义配置,详情参考最下面分界线
parser.add_argument("--conditional", type=bool, default=False)
Expand Down
17 changes: 17 additions & 0 deletions utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
@Author : chairc
@Site : https://github.com/chairc
"""
import random
import numpy as np
import torch
import logging
import coloredlogs
Expand Down Expand Up @@ -38,3 +40,18 @@ def device_initializer():
logger.warning(msg="The device is using cpu.")
device = torch.device(device="cpu")
return device


def seed_initializer(seed_id=0):
"""
初始化种子
:param seed_id: 种子id
:return:
"""
torch.manual_seed(seed_id)
torch.cuda.manual_seed_all(seed_id)
random.seed(seed_id)
np.random.seed(seed_id)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(msg=f"The seed is initialized, and the seed ID is {seed_id}.")

0 comments on commit 9189913

Please sign in to comment.