Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ldm/dream/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,13 @@ def _create_dream_cmd_parser(self):
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
render_group.add_argument(
'--use_mps_noise',
action='store_true',
dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results'
)

return parser

def format_metadata(**kwargs):
Expand Down
1 change: 1 addition & 0 deletions ldm/dream/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, model, precision):
self.downsampling_factor = downsampling # BUG: should come from model or config
self.variation_amount = 0
self.with_variations = []
self.use_mps_noise = False

# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion ldm/dream/generator/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def make_image(x_T):
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
Expand Down
2 changes: 1 addition & 1 deletion ldm/dream/generator/txt2img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_noise(self,width,height,scale = True):
scaled_height = height

device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
Expand Down
2 changes: 2 additions & 0 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def prompt2image(
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
use_mps_noise = False,
**args,
): # eat up additional cruft
"""
Expand Down Expand Up @@ -386,6 +387,7 @@ def process_image(image,seed):

generator.set_variation(
self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise
results = generator.generate(
prompt,
iterations=iterations,
Expand Down