diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 85524f6fa9b..281f8375780 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -153,6 +153,7 @@ Here are the invoke> command that apply to txt2img: | --cfg_scale | -C | 7.5 | How hard to try to match the prompt to the generated image; any number greater than 1.0 works, but the useful range is roughly 5.0 to 20.0 | | --seed | -S | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.| | --sampler | -A| k_lms | Sampler to use. Use -h to get list of available samplers. | +| --karras_max | | 29 | When using k_* samplers, set the maximum number of steps before shifting from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts) This value is sticky. [29] | | --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution | | --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt | diff --git a/ldm/generate.py b/ldm/generate.py index 135ec9ca545..18d62ca24a5 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -176,6 +176,7 @@ def __init__( self.free_gpu_mem = free_gpu_mem self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None + self.karras_max = None # Note that in previous versions, there was an option to pass the # device to Generate(). However the device was then ignored, so @@ -253,6 +254,7 @@ def prompt2image( variation_amount = 0.0, threshold = 0.0, perlin = 0.0, + karras_max = None, # these are specific to img2img and inpaint init_img = None, init_mask = None, @@ -331,7 +333,8 @@ def process_image(image,seed): strength = strength or self.strength self.seed = seed self.log_tokenization = log_tokenization - self.step_callback = step_callback + self.step_callback = step_callback + self.karras_max = karras_max with_variations = [] if with_variations is None else with_variations # will instantiate the model or return it from cache @@ -376,6 +379,11 @@ def process_image(image,seed): self.sampler_name = sampler_name self._set_sampler() + # bit of a hack to change the cached sampler's karras threshold to + # whatever the user asked for + if karras_max is not None and isinstance(self.sampler,KSampler): + self.sampler.adjust_settings(karras_max=karras_max) + tic = time.time() if self._has_cuda(): torch.cuda.reset_peak_memory_stats() @@ -815,26 +823,23 @@ def sample_to_image(self, samples): def _set_sampler(self): msg = f'>> Setting Sampler to {self.sampler_name}' + karras_max = self.karras_max # set in generate() call if self.sampler_name == 'plms': self.sampler = PLMSSampler(self.model, device=self.device) elif self.sampler_name == 'ddim': self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler( - self.model, 'dpm_2_ancestral', device=self.device - ) + self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max) elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model, 'dpm_2', device=self.device) + self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max) elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler( - self.model, 'euler_ancestral', device=self.device - ) + self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max) elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model, 'euler', device=self.device) + self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max) elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model, 'heun', device=self.device) + self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max) elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model, 'lms', device=self.device) + self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max) else: msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' self.sampler = PLMSSampler(self.model, device=self.device) diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index e2302e4452b..2b00b5a9ceb 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -216,6 +216,8 @@ def dream_prompt_str(self,**kwargs): switches.append(f'-W {a["width"]}') switches.append(f'-H {a["height"]}') switches.append(f'-C {a["cfg_scale"]}') + if a['karras_max'] is not None: + switches.append(f'--karras_max {a["karras_max"]}') if a['perlin'] > 0: switches.append(f'--perlin {a["perlin"]}') if a['threshold'] > 0: @@ -669,7 +671,13 @@ def _create_dream_cmd_parser(self): default=6, choices=range(0,10), dest='png_compression', - help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' + help='level of PNG compression, from 0 (none) to 9 (maximum). [6]' + ) + render_group.add_argument( + '--karras_max', + type=int, + default=None, + help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]." ) img2img_group.add_argument( '-I', diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30cc..1693baade5b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -12,6 +12,10 @@ extract_into_tensor, ) +# at this threshold, the scheduler will stop using the Karras +# noise schedule and start using the model's schedule +STEP_THRESHOLD = 29 + def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: return result @@ -60,6 +64,9 @@ def __init__(self, model, schedule='lms', device=None, **kwargs): self.sigmas = None self.ds = None self.s_in = None + self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD) + if self.karras_max is None: + self.karras_max = STEP_THRESHOLD def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) @@ -98,8 +105,13 @@ def make_schedule( rho=7., device=self.device, ) - self.sigmas = self.model_sigmas - #self.sigmas = self.karras_sigmas + + if ddim_num_steps >= self.karras_max: + print(f'>> Ksampler using model noise schedule (steps > {self.karras_max})') + self.sigmas = self.model_sigmas + else: + print(f'>> Ksampler using karras noise schedule (steps <= {self.karras_max})') + self.sigmas = self.karras_sigmas # ALERT: We are completely overriding the sample() method in the base class, which # means that inpainting will not work. To get this to work we need to be able to diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f87..01193ed5a55 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -2,8 +2,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc - ''' + import torch import numpy as np from tqdm import tqdm @@ -411,3 +411,15 @@ def q_sample(self,x0,ts): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + def adjust_settings(self,**kwargs): + ''' + This is a catch-all method for adjusting any instance variables + after the sampler is instantiated. No type-checking performed + here, so use with care! + ''' + for k in kwargs.keys(): + try: + setattr(self,k,kwargs[k]) + except AttributeError: + print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')