Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions docs/features/PROMPTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ See the section below on "Prompt Blending" for more information about how this w

### Cross-Attention Control ('prompt2prompt')

Generate an image with a given prompt and then paint over the image
using the `prompt2prompt` syntax to substitute words in the original
prompt for words in a new prompt. Based off [bloc97's
colab](https://github.com/bloc97/CrossAttentionControl).
Sometimes an image you generate is almost right, and you just want to
change one detail without affecting the rest. You could use a photo editor and inpainting
to overpaint the area, but that's a pain. Here's where `prompt2prompt`
comes in handy.

Generate an image with a given prompt, record the seed of the image,
and then use the `prompt2prompt` syntax to substitute words in the
original prompt for words in a new prompt. This works for `img2img` as well.


* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
Expand All @@ -125,6 +130,9 @@ colab](https://github.com/bloc97/CrossAttentionControl).
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.

The `prompt2prompt` code is based off [bloc97's
colab](https://github.com/bloc97/CrossAttentionControl).

Note that `prompt2prompt` is not currently working with the runwayML
inpainting model, and may never work due to the way this model is set
up. If you attempt to use `prompt2prompt` you will get the original
Expand Down
2 changes: 1 addition & 1 deletion ldm/invoke/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n

conditioning = original_embeddings
edited_conditioning = edited_embeddings
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
Expand Down
74 changes: 41 additions & 33 deletions ldm/invoke/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
prompt_part = pp.Forward()

def make_text_fragment(x):
#print("### making fragment for", x)
if type(x[0]) is Fragment:
assert(False)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
Expand Down Expand Up @@ -396,8 +397,10 @@ def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):


def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string \"{x}\"")
#print(f"parsing fragment string for {x}")
fragment_string = x[0]
#print(f"ppparsing fragment string \"{fragment_string}\"")

if len(fragment_string.strip()) == 0:
return Fragment('')

Expand All @@ -406,13 +409,16 @@ def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
fragment_string = fragment_string.replace('"', '\\"')

#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
result = pp.Group(pp.MatchFirst([
pp.OneOrMore(prompt_part | quoted_fragment),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('rr').set_debug(False).parse_string(fragment_string)
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
#print("parsed to", result)
return result
try:
result = pp.Group(pp.MatchFirst([
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
#print("parsed to", result)
return result
except pp.ParseException as e:
print("parse_fragment_str couldn't parse prompt string:", e)
raise

quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
Expand All @@ -422,14 +428,21 @@ def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')

empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')


def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')

unquoted_fragment = pp.Combine(pp.OneOrMore(
unquoted_word = pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
(pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') + (pp.NotAny(pp.Word('+') | pp.Word('-'))))
))

unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))

parenthesized_fragment << pp.Or([
Expand Down Expand Up @@ -510,15 +523,16 @@ def make_attention(x):
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
lparen +
(quoted_fragment |
pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
pp.Group(pp.ZeroOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.OneOrMore(comma + cross_attention_option)) +
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment

original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
Expand All @@ -533,24 +547,18 @@ def make_cross_attention_substitute(x):
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)


# simple fragments of text
# use Or to match the longest
prompt_part << pp.MatchFirst([
cross_attention_substitute,
attention,
unquoted_fragment,
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
])
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")

empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')

# root prompt definition
prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x))
debug_root_prompt = False
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
attention.set_debug(debug_root_prompt),
quoted_fragment.set_debug(debug_root_prompt),
(lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt),
unquoted_word.set_debug(debug_root_prompt),
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
) + pp.StringEnd()) \
.set_name('prompt') \
.set_parse_action(lambda x: Prompt(x)) \
.set_debug(debug_root_prompt)

#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
Expand All @@ -567,7 +575,7 @@ def make_prompt_from_quoted_string(x):
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
#print(f' b parsing \'{x_unquoted}\'')
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
Expand Down
4 changes: 2 additions & 2 deletions ldm/models/diffusion/shared_invokeai_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import ceil
from typing import Callable, Optional
from typing import Callable, Optional, Union

import torch

Expand Down Expand Up @@ -54,7 +54,7 @@ def remove_cross_attention_control(self):
CrossAttentionControl.remove_cross_attention_control(self.model)

def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: int=None
):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_basic(self):
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))

def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
Expand Down Expand Up @@ -106,10 +107,7 @@ def assert_if_prompt_string_not_untouched(prompt):
#with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test )prompt')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
Expand Down Expand Up @@ -394,6 +392,9 @@ def test_single(self):
# todo handle this
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
# parse_prompt('a badly formed +test prompt'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
pass


Expand Down