Skip to content

Commit

Permalink
Add: When adding conditional generation, all class images can be gene…
Browse files Browse the repository at this point in the history
…rated at once. Multiple images of a single class can also be generated.
  • Loading branch information
chairc committed Sep 26, 2023
1 parent 5b15de6 commit a4c77b9
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tools/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def generate(args):
cfg_scale = args.cfg_scale
model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
y = torch.Tensor([class_name] * num_images).long().to(device)
if class_name == -1:
y = torch.arange(num_classes).long().to(device)
num_images = num_classes
else:
y = torch.Tensor([class_name] * num_images).long().to(device)
x = diffusion.sample(model=model, n=num_images, labels=y, cfg_scale=cfg_scale)
else:
model = UNet(device=device, image_size=image_size, act=act).to(device)
Expand All @@ -97,6 +101,7 @@ def generate(args):
# Input image size (required)
parser.add_argument("--image_size", type=int, default=64)
# Number of generation images (required)
# if class name is `-1` and conditional `is` True, the model would output one image per class.
parser.add_argument("--num_images", type=int, default=8)
# Weight path (required)
parser.add_argument("--weight_path", type=str, default="/your/path/Defect-Diffusion-Model/weight/model.pt")
Expand All @@ -119,6 +124,7 @@ def generate(args):
# Number of classes (required)
parser.add_argument("--num_classes", type=int, default=10)
# Class name (required)
# if class name is `-1`, the model would output one image per class.
parser.add_argument("--class_name", type=int, default=0)
# classifier-free guidance interpolation weight, users can better generate model effect (recommend)
parser.add_argument("--cfg_scale", type=int, default=3)
Expand Down

0 comments on commit a4c77b9

Please sign in to comment.