diff --git a/tools/generate.py b/tools/generate.py index e70df69..567c538 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -70,7 +70,11 @@ def generate(args): cfg_scale = args.cfg_scale model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device) load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False) - y = torch.Tensor([class_name] * num_images).long().to(device) + if class_name == -1: + y = torch.arange(num_classes).long().to(device) + num_images = num_classes + else: + y = torch.Tensor([class_name] * num_images).long().to(device) x = diffusion.sample(model=model, n=num_images, labels=y, cfg_scale=cfg_scale) else: model = UNet(device=device, image_size=image_size, act=act).to(device) @@ -97,6 +101,7 @@ def generate(args): # Input image size (required) parser.add_argument("--image_size", type=int, default=64) # Number of generation images (required) + # if class name is `-1` and conditional `is` True, the model would output one image per class. parser.add_argument("--num_images", type=int, default=8) # Weight path (required) parser.add_argument("--weight_path", type=str, default="/your/path/Defect-Diffusion-Model/weight/model.pt") @@ -119,6 +124,7 @@ def generate(args): # Number of classes (required) parser.add_argument("--num_classes", type=int, default=10) # Class name (required) + # if class name is `-1`, the model would output one image per class. parser.add_argument("--class_name", type=int, default=0) # classifier-free guidance interpolation weight, users can better generate model effect (recommend) parser.add_argument("--cfg_scale", type=int, default=3)