Skip to content

Commit

Permalink
Update: Edited to english in 'train.py' and 'generate.py'.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 8, 2023
1 parent 788138d commit f2e9d7c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
64 changes: 38 additions & 26 deletions tools/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,43 @@


def generate(args):
"""
Generating
:param args: Input parameters
:return: None
"""
logger.info(msg="Start generation.")
# 是否启用条件生成
# Enable conditional generation
conditional = args.conditional
# 采样器类别
# Sample type
sample = args.sample
# 生成名称
# Generation name
generate_name = args.generate_name
# 图片大小
# Image size
image_size = args.image_size
# 图片个数
# Number of images
num_images = args.num_images
# 权重路径
# Weight path
weight_path = args.weight_path
# 保存路径
# Saving path
result_path = args.result_path
# 设备初始化
# Run device initializer
device = device_initializer()
# 初始化扩散模型
# Initialize the diffusion model
if sample == "ddpm":
diffusion = DDPMDiffusion(img_size=image_size, device=device)
elif sample == "ddim":
diffusion = DDIMDiffusion(img_size=image_size, device=device)
else:
diffusion = DDPMDiffusion(img_size=image_size, device=device)
logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")
# 模型初始化
# Initialize model
if conditional:
# 类别个数
# Number of classes
num_classes = args.num_classes
# 生成的类别名称
# Generation class name
class_name = args.class_name
# classifier-free guidance插值权重
# classifier-free guidance interpolation weight
cfg_scale = args.cfg_scale
model = UNet(num_classes=num_classes, device=device, image_size=image_size).to(device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
Expand All @@ -66,7 +71,8 @@ def generate(args):
model = UNet(device=device, image_size=image_size).to(device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
x = diffusion.sample(model=model, n=num_images)
# 如果不存在路径信息则只展示;存在则保存到指定路径并展示
# If there is no path information, it will only be displayed
# If it exists, it will be saved to the specified path and displayed
if result_path == "" or result_path is None:
plot_images(images=x)
else:
Expand All @@ -76,30 +82,36 @@ def generate(args):


if __name__ == "__main__":
# 生成模型参数
# Generating model parameters
# required: Must be set
# needed: Set as needed
# recommend: Recommend to set
parser = argparse.ArgumentParser()
# 生成名称
# Generation name (required)
parser.add_argument("--generate_name", type=str, default="df")
# 输入图像大小
# Input image size (required)
parser.add_argument("--image_size", type=int, default=64)
# 生成图片个数
# Number of generation images (required)
parser.add_argument("--num_images", type=int, default=8)
# 模型路径
# Weight path (required)
parser.add_argument("--weight_path", type=str, default="/your/path/Defect-Diffusion-Model/weight/model.pt")
# 保存路径
# Saving path (required)
parser.add_argument("--result_path", type=str, default="/your/path/Defect-Diffusion-Model/results/vis")
# 开启条件生成,若使用False则不需要设置该参数之后的参数
# Enable conditional generation (required)
# If enabled, you can modify the custom configuration.
# For more details, please refer to the boundary line at the bottom.
parser.add_argument("--conditional", type=bool, default=True)
# 采样器类别(必须设置)
# 不设置是为ddpm,可设置ddpm,ddim
# Set the sample type (required)
# If not set, the default is for 'ddpm'. You can set it to either 'ddpm' or 'ddim'.
# Option: ddpm/ddim
parser.add_argument("--sample", type=str, default="ddpm")

# ==========================开启条件生成分界线==========================
# 类别个数
# Number of classes (required)
parser.add_argument("--num_classes", type=int, default=10)
# 类别名称
# Class name (required)
parser.add_argument("--class_name", type=int, default=0)
# classifier-free guidance插值权重,用户更好生成模型效果
# classifier-free guidance interpolation weight, users can better generate model effect (recommend)
parser.add_argument("--cfg_scale", type=int, default=3)

args = parser.parse_args()
Expand Down
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
def train(rank=None, args=None):
"""
Training
:param rank: Device id
:param args: Input parameters
:return: None
"""
Expand Down Expand Up @@ -294,6 +295,11 @@ def train(rank=None, args=None):


def main(args):
"""
Main function
:param args: Input parameters
:return: None
"""
if args.distributed:
gpus = torch.cuda.device_count()
mp.spawn(train, args=(args,), nprocs=gpus)
Expand Down

0 comments on commit f2e9d7c

Please sign in to comment.