From 135c62f1a4edd332abc19491c78c89ee4949b9cf Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 27 Oct 2022 10:14:44 +0200 Subject: [PATCH 1/2] fix issue with hot-dog, improve () suppression --- ldm/invoke/prompt_parser.py | 74 ++++++++++++++++++++----------------- tests/test_prompt_parser.py | 9 +++-- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 6709f48066b..7613c9d93ec 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -358,10 +358,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) quoted_fragment = pp.Forward() parenthesized_fragment = pp.Forward() cross_attention_substitute = pp.Forward() - prompt_part = pp.Forward() def make_text_fragment(x): #print("### making fragment for", x) + if type(x[0]) is Fragment: + assert(False) if type(x) is str: return Fragment(x) elif type(x) is pp.ParseResults or type(x) is list: @@ -396,8 +397,10 @@ def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str): def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): - #print(f"parsing fragment string \"{x}\"") + #print(f"parsing fragment string for {x}") fragment_string = x[0] + #print(f"ppparsing fragment string \"{fragment_string}\"") + if len(fragment_string.strip()) == 0: return Fragment('') @@ -406,13 +409,16 @@ def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): fragment_string = fragment_string.replace('"', '\\"') #fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment)))) - result = pp.Group(pp.MatchFirst([ - pp.OneOrMore(prompt_part | quoted_fragment), - pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd() - ])).set_name('rr').set_debug(False).parse_string(fragment_string) - #result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0]) - #print("parsed to", result) - return result + try: + result = pp.Group(pp.MatchFirst([ + pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'), + pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd() + ])).set_name('blend-result').set_debug(False).parse_string(fragment_string) + #print("parsed to", result) + return result + except pp.ParseException as e: + print("parse_fragment_str couldn't parse prompt string:", e) + raise quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') @@ -422,14 +428,21 @@ def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') + + def not_ends_with_swap(x): #print("trying to match:", x) return not x[0].endswith('.swap') - unquoted_fragment = pp.Combine(pp.OneOrMore( + unquoted_word = pp.Combine(pp.OneOrMore( escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) - unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False) + (pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') + (pp.NotAny(pp.Word('+') | pp.Word('-')))) + )) + + unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False) #print(unquoted_fragment.parse_string("cat.swap(dog)")) parenthesized_fragment << pp.Or([ @@ -510,15 +523,16 @@ def make_attention(x): cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")]) cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) edited_fragment = pp.MatchFirst([ + (lparen + rparen).set_parse_action(lambda x: Fragment('')), lparen + (quoted_fragment | - pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))) + pp.Group(pp.ZeroOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))) ) + - pp.Dict(pp.OneOrMore(comma + cross_attention_option)) + + pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) + rparen, parenthesized_fragment ]) - cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment + cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) @@ -533,24 +547,18 @@ def make_cross_attention_substitute(x): cross_attention_substitute.set_parse_action(make_cross_attention_substitute) - # simple fragments of text - # use Or to match the longest - prompt_part << pp.MatchFirst([ - cross_attention_substitute, - attention, - unquoted_fragment, - lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + - ]) - prompt_part.set_debug(False) - prompt_part.set_name("prompt_part") - - empty = ( - (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | - (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') - # root prompt definition - prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \ - .set_parse_action(lambda x: Prompt(x)) + debug_root_prompt = False + prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt), + attention.set_debug(debug_root_prompt), + quoted_fragment.set_debug(debug_root_prompt), + (lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt), + unquoted_word.set_debug(debug_root_prompt), + empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) + ) + pp.StringEnd()) \ + .set_name('prompt') \ + .set_parse_action(lambda x: Prompt(x)) \ + .set_debug(debug_root_prompt) #print("parsing test:", prompt.parse_string("spaced eyes--")) #print("parsing test:", prompt.parse_string("eyes--")) @@ -567,7 +575,7 @@ def make_prompt_from_quoted_string(x): if len(x_unquoted.strip()) == 0: # print(' b : just an empty string') return Prompt([Fragment('')]) - # print(' b parsing ', c_unquoted) + #print(f' b parsing \'{x_unquoted}\'') x_parsed = prompt.parse_string(x_unquoted) #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) return x_parsed[0] diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 486265d2f50..4fd7616adef 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -32,6 +32,7 @@ def test_basic(self): self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) + self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating")) def test_attention(self): self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) @@ -106,10 +107,7 @@ def assert_if_prompt_string_not_untouched(prompt): #with self.assertRaises(pyparsing.ParseException): with self.assertRaises(pyparsing.ParseException): parse_prompt('a badly (formed +test prompt') - with self.assertRaises(pyparsing.ParseException): - parse_prompt('a badly (formed +test )prompt') - with self.assertRaises(pyparsing.ParseException): - parse_prompt('a badly (formed +test )prompt') + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt')) with self.assertRaises(pyparsing.ParseException): parse_prompt('(((a badly (formed +test )prompt') with self.assertRaises(pyparsing.ParseException): @@ -394,6 +392,9 @@ def test_single(self): # todo handle this #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), # parse_prompt('a badly formed +test prompt')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap()')) pass From 16e7cbdb3829dfc1cb52dde61ff1a32cef8ec77a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 27 Oct 2022 08:30:09 -0400 Subject: [PATCH 2/2] tweaks to documentation and call signature for advanced prompting --- docs/features/PROMPTS.md | 16 ++++++++++++---- ldm/invoke/conditioning.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 1a3dcb5e9d1..fd8148f622a 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -110,10 +110,15 @@ See the section below on "Prompt Blending" for more information about how this w ### Cross-Attention Control ('prompt2prompt') -Generate an image with a given prompt and then paint over the image -using the `prompt2prompt` syntax to substitute words in the original -prompt for words in a new prompt. Based off [bloc97's -colab](https://github.com/bloc97/CrossAttentionControl). +Sometimes an image you generate is almost right, and you just want to +change one detail without affecting the rest. You could use a photo editor and inpainting +to overpaint the area, but that's a pain. Here's where `prompt2prompt` +comes in handy. + +Generate an image with a given prompt, record the seed of the image, +and then use the `prompt2prompt` syntax to substitute words in the +original prompt for words in a new prompt. This works for `img2img` as well. + * `a ("fluffy cat").swap("smiling dog") eating a hotdog`. * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. @@ -125,6 +130,9 @@ colab](https://github.com/bloc97/CrossAttentionControl). * Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. +The `prompt2prompt` code is based off [bloc97's +colab](https://github.com/bloc97/CrossAttentionControl). + Note that `prompt2prompt` is not currently working with the runwayML inpainting model, and may never work due to the way this model is set up. If you attempt to use `prompt2prompt` you will get the original diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 7365bc9a872..33be8c342e4 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -114,7 +114,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n conditioning = original_embeddings edited_conditioning = edited_embeddings - print('got edit_opcodes', edit_opcodes, 'options', edit_options) + print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) cac_args = CrossAttentionControl.Arguments( edited_conditioning = edited_conditioning, edit_opcodes = edit_opcodes, diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index e985417b2bf..14f08578c33 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,5 +1,5 @@ from math import ceil -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch @@ -54,7 +54,7 @@ def remove_cross_attention_control(self): CrossAttentionControl.remove_cross_attention_control(self.model) def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, - unconditioning: torch.Tensor, conditioning: torch.Tensor, + unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict], unconditional_guidance_scale: float, step_index: int=None ):