Skip to content

Commit

Permalink
Add: Add the activation function in deploy.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Sep 25, 2023
1 parent 7f88e91 commit 6529f80
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_send_message(self):
Test local send message to deploy.py
:return: None
"""
test_json = {"conditional": True, "sample": "ddpm", "image_size": 64, "num_images": 2,
test_json = {"conditional": True, "sample": "ddpm", "image_size": 64, "num_images": 2, "act": "gelu",
"weight_path": "/your/test/model/path/test.pt",
"result_path": "/your/results/deploy",
"num_classes": 6, "class_name": 1, "cfg_scale": 3}
Expand Down
6 changes: 4 additions & 2 deletions tools/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def generate(parse_json_data):
image_size = parse_json_data["image_size"]
# Number of images
num_images = parse_json_data["num_images"] if parse_json_data["num_images"] >= 1 else 1
# Activation function
act = parse_json_data["act"]
# Weight path
weight_path = parse_json_data["weight_path"]
# Saving path
Expand All @@ -66,11 +68,11 @@ def generate(parse_json_data):
class_name = parse_json_data["class_name"]
# classifier-free guidance interpolation weight
cfg_scale = parse_json_data["cfg_scale"]
model = UNet(num_classes=num_classes, device=device, image_size=image_size).to(device)
model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
y = torch.Tensor([class_name]).long().to(device)
else:
model = UNet(device=device, image_size=image_size).to(device)
model = UNet(device=device, image_size=image_size, act=act).to(device)
load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False)
y = None
cfg_scale = None
Expand Down

0 comments on commit 6529f80

Please sign in to comment.