Skip to content

Commit

Permalink
Added cpu fallback for the rng gneerator and fixed #20.
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolya committed May 14, 2023
1 parent 2e37b87 commit dc08c1b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@

## v0.1.3
- **[2023.04.24]** Random perturbations now use a separate rng so it doesn't affect the rest of the diffusion process. Thanks @alihassanijr!
- **[2023.04.25]** Fixed an issue with the separate rng on mps devices. (Fixes #27)
- **[2023.04.25]** Fixed an issue with the separate rng on mps devices. (Fixes #27)
- **[2023.05.14]** Added fallback to CPU for non-supported devices for the separate rng generator.
- **[2023.05.14]** Defined `use_ada_layer_norm_zero` just in case for older diffuser versions. (Fixes #20)
9 changes: 6 additions & 3 deletions tomesd/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable,
if args["generator"] is None:
args["generator"] = init_generator(x.device)
elif args["generator"].device != x.device:
# MPS can use a cpu generator
if not (args["generator"].device.type == "cpu" and x.device.type == "mps"):
args["generator"] = init_generator(x.device)
args["generator"] = init_generator(x.device, fallback=args["generator"])

# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
# batch, which causes artifacts with use_rand, so force it to be off.
Expand Down Expand Up @@ -253,6 +251,11 @@ def apply_patch(
if not hasattr(module, "disable_self_attn") and not is_diffusers:
module.disable_self_attn = False

# Something needed for older versions of diffusers
if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers:
module.use_ada_layer_norm = False
module.use_ada_layer_norm_zero = False

return model


Expand Down
11 changes: 8 additions & 3 deletions tomesd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ def isinstance_str(x: object, cls_name: str):
return False


def init_generator(device: torch.device):
def init_generator(device: torch.device, fallback: torch.Generator=None):
"""
Forks the current default random generator given device.
"""
if device.type == "cpu" or device.type == "mps": # MPS can use a cpu generator
if device.type == "cpu":
return torch.Generator(device="cpu").set_state(torch.get_rng_state())
elif device.type == "cuda":
return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
raise NotImplementedError(f"Invalid/unsupported device. Expected `cpu`, `cuda`, or `mps`, got {device.type}.")
else:
if fallback is None:
return init_generator(torch.device("cpu"))
else:
return fallback

0 comments on commit dc08c1b

Please sign in to comment.