Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 103 additions & 103 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything
from pytorch_lightning import seed_everything, logging

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
Expand All @@ -35,7 +35,7 @@
from ldm.generate import Generate

# Create an object with default values
gr = Generate()
gr = Generate('stable-diffusion-1.4')

# do the slow model initialization
gr.load_model()
Expand Down Expand Up @@ -79,16 +79,17 @@

The full list of arguments to Generate() are:
gr = Generate(
# these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file
full_precision = False

# this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms

# these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
iterations = <integer> // how many times to run the sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // condition-free guidance scale (7.5)
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
)

"""
Expand All @@ -101,66 +102,62 @@ class Generate:

def __init__(
self,
iterations = 1,
steps = 50,
cfg_scale = 7.5,
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
config = 'configs/stable-diffusion/v1-inference.yaml',
grid = False,
width = 512,
height = 512,
model = 'stable-diffusion-1.4',
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
strength = 0.75, # default in scripts/img2img.py
seamless = False,
embedding_path = None,
device_type = 'cuda',
ignore_ctrl_c = False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
self.iterations = iterations
self.width = width
self.height = height
self.steps = steps
self.cfg_scale = cfg_scale
self.weights = weights
self.config = config
self.sampler_name = sampler_name
self.grid = grid
self.ddim_eta = ddim_eta
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.strength = strength
self.seamless = seamless
self.embedding_path = embedding_path
self.device_type = device_type
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
self.model = None # empty for now
self.sampler = None
self.device = None
self.generators = {}
self.base_generator = None
self.seed = None

if device_type == 'cuda' and not torch.cuda.is_available():
device_type = choose_torch_device()
print(">> cuda not available, using device", device_type)
models = OmegaConf.load(conf)
mconfig = models[model]
self.weights = mconfig.weights if weights is None else weights
self.config = mconfig.config if config is None else config
self.height = mconfig.height
self.width = mconfig.width
self.iterations = 1
self.steps = 50
self.cfg_scale = 7.5
self.sampler_name = sampler_name
self.ddim_eta = 0.0 # same seed always produces same image
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.strength = 0.75
self.seamless = False
self.embedding_path = embedding_path
self.model = None # empty for now
self.sampler = None
self.device = None
self.session_peakmem = None
self.generators = {}
self.base_generator = None
self.seed = None

# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device()
self.device = torch.device(device_type)

# for VRAM usage statistics
device_type = choose_torch_device()
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
transformers.logging.set_verbosity_error()

# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)

def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to Generate and prompt2image()
"""
results = self.prompt2image(prompt, **kwargs)
results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(outdir)
prefix = pngwriter.unique_prefix()
outputs = []
prefix = pngwriter.unique_prefix()
outputs = []
for image, seed in results:
name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(
Expand All @@ -183,33 +180,35 @@ def prompt2image(
self,
# these are common
prompt,
iterations = None,
steps = None,
seed = None,
cfg_scale = None,
ddim_eta = None,
skip_normalize = False,
image_callback = None,
step_callback = None,
width = None,
height = None,
sampler_name = None,
seamless = False,
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
iterations = None,
steps = None,
seed = None,
cfg_scale = None,
ddim_eta = None,
skip_normalize = False,
image_callback = None,
step_callback = None,
width = None,
height = None,
sampler_name = None,
seamless = False,
log_tokenization = False,
with_variations = None,
variation_amount = 0.0,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
fit = False,
strength = None,
init_img = None,
init_mask = None,
fit = False,
strength = None,
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
upscale = None,
gfpgan_strength = 0,
save_original = False,
upscale = None,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
**args,
): # eat up additional cruft
"""
Expand Down Expand Up @@ -262,10 +261,9 @@ def process_image(image,seed):
self.log_tokenization = log_tokenization
with_variations = [] if with_variations is None else with_variations

model = (
self.load_model()
) # will instantiate the model or return it from cache

# will instantiate the model or return it from cache
model = self.load_model()

for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
Expand All @@ -281,7 +279,6 @@ def process_image(image,seed):
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
), 'Embiggen requires an init/input image to be specified'

# check this logic - doesn't look right
if len(with_variations) > 0 or variation_amount > 1.0:
assert seed is not None,\
'seed must be specified when using with_variations'
Expand All @@ -298,7 +295,7 @@ def process_image(image,seed):
self._set_sampler()

tic = time.time()
if torch.cuda.is_available():
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()

results = list()
Expand All @@ -307,9 +304,9 @@ def process_image(image,seed):

try:
uc, c = get_uc_and_c(
prompt, model=self.model,
prompt, model =self.model,
skip_normalize=skip_normalize,
log_tokens=self.log_tokenization
log_tokens =self.log_tokenization
)

(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
Expand Down Expand Up @@ -352,27 +349,25 @@ def process_image(image,seed):
save_original = save_original,
image_callback = image_callback)

except KeyboardInterrupt:
print('*interrupted*')
if not self.ignore_ctrl_c:
raise KeyboardInterrupt
print(
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt

toc = time.time()
print('>> Usage stats:')
print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
)
if torch.cuda.is_available() and self.device.type == 'cuda':
if self._has_cuda():
print(
f'>> Max VRAM used for this generation:',
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
'Current VRAM utilization:'
'Current VRAM utilization:',
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)

Expand Down Expand Up @@ -439,8 +434,7 @@ def load_model(self):
if self.model is None:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
try:
config = OmegaConf.load(self.config)
model = self._load_model_from_config(config, self.weights)
model = self._load_model_from_config(self.config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(
self.embedding_path, self.full_precision
Expand Down Expand Up @@ -541,8 +535,11 @@ def _set_sampler(self):

print(msg)

def _load_model_from_config(self, config, ckpt):
print(f'>> Loading model from {ckpt}')
# Be warned: config is the path to the model config file, not the dream conf file!
# Also note that we can get config and weights from self, so why do we need to
# pass them as args?
def _load_model_from_config(self, config, weights):
print(f'>> Loading model from {weights}')

# for usage statistics
device_type = choose_torch_device()
Expand All @@ -551,10 +548,11 @@ def _load_model_from_config(self, config, ckpt):
tic = time.time()

# this does the work
pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
c = OmegaConf.load(config)
pl_sd = torch.load(weights, map_location='cpu')
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)

if self.full_precision:
print(
Expand All @@ -573,7 +571,7 @@ def _load_model_from_config(self, config, ckpt):
print(
f'>> Model loaded in', '%4.2fs' % (toc - tic)
)
if device_type == 'cuda':
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
Expand Down Expand Up @@ -710,3 +708,5 @@ def _resolution_check(self, width, height, log=False):
return width, height, resize_needed


def _has_cuda(self):
return self.device.type == 'cuda'
2 changes: 1 addition & 1 deletion ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def ddim_sampling(
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')

iterator = tqdm(
time_range,
Expand Down
Loading