Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 61 additions & 50 deletions scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@
# Just want to get the formatting look right for now.
output_cntr = 0


def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()

if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1)

try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
Expand All @@ -58,18 +59,18 @@ def main():
# additional parameters will be added (or overriden) during
# the user input loop
t2i = Generate(
width = width,
height = height,
sampler_name = opt.sampler_name,
weights = weights,
full_precision = opt.full_precision,
config = config,
grid = opt.grid,
width=width,
height=height,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
grid=opt.grid,
# this is solely for recreating the prompt
seamless = opt.seamless,
embedding_path = opt.embedding_path,
device_type = opt.device,
ignore_ctrl_c = opt.infile is None,
seamless=opt.seamless,
embedding_path=opt.embedding_path,
device_type=opt.device,
ignore_ctrl_c=opt.infile is None,
)

# make sure the output directory exists
Expand Down Expand Up @@ -113,8 +114,8 @@ def main():

def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
"""prompt/read/execute loop"""
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()

# os.pathconf is not available on Windows
Expand All @@ -134,7 +135,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
except KeyboardInterrupt:
done = True
continue

# skip empty lines
if not command.strip():
continue
Expand Down Expand Up @@ -183,15 +184,17 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
if len(opt.prompt) == 0:
print('Try again with a prompt!')
continue
if opt.init_img is not None and re.match('^-\\d+$',opt.init_img): # retrieve previous value!
# retrieve previous value!
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}')
except IndexError:
print(f'>> No previous initial image at position {opt.init_img} found')
print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None
continue

if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try:
opt.seed = last_results[opt.seed][1]
Expand All @@ -201,12 +204,12 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
opt.seed = None
continue

do_grid = opt.grid or t2i.grid
do_grid = opt.grid or t2i.grid

if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
Expand Down Expand Up @@ -241,7 +244,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
current_outdir = os.path.join(outdir, subdir)

print ('Writing files to directory: "' + current_outdir + '"')
print('Writing files to directory: "' + current_outdir + '"')

# make sure the output directory exists
if not os.path.exists(current_outdir):
Expand All @@ -253,9 +256,10 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
last_results = []
try:
file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid`
prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid`

def image_writer(image, seed, upscaled=False):
if do_grid:
grid_images[seed] = image
Expand All @@ -265,35 +269,41 @@ def image_writer(image, seed, upscaled=False):
else:
filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy
iter_opt = argparse.Namespace(**vars(opt)) # copy
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename)
path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename)
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt])
last_results.append([path,seed])
last_results.append([path, seed])

t2i.prompt2image(image_callback=image_writer, **vars(opt))

if do_grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
grid_img = make_grid(list(grid_images.values()))
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
Expand All @@ -308,18 +318,16 @@ def image_writer(image, seed, upscaled=False):
print(e)
continue

print('\033[1mOutputs:\033[0m')
print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path)

print('goodbye!\033[0m')
print('goodbye!')


def get_next_command(infile=None) -> str: #command string
def get_next_command(infile=None) -> str: # command string
if infile is None:
print('\033[1m') # add some boldface
command = input('dream> ')
print('\033[0m',end='')
command = input('dream> ')
else:
command = infile.readline()
if not command:
Expand All @@ -329,6 +337,7 @@ def get_next_command(infile=None) -> str: #command string
print(f'#{command}')
return command


def dream_server_loop(t2i, host, port, outdir):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
Expand All @@ -342,7 +351,8 @@ def dream_server_loop(t2i, host, port, outdir):
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
print(
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
Expand All @@ -361,13 +371,13 @@ def write_log_message(results, log_path):
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')
print(output_cntr)

with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)


SAMPLER_CHOICES=[
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
Expand All @@ -378,6 +388,7 @@ def write_log_message(results, log_path):
'plms',
]


def create_argv_parser():
parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion.
Expand Down Expand Up @@ -518,8 +529,8 @@ def create_argv_parser():
)
parser.add_argument(
'--config',
default ='configs/models.yaml',
help ='Path to configuration file for alternate models.',
default='configs/models.yaml',
help='Path to configuration file for alternate models.',
)
return parser

Expand Down