Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text prompt to inpaint mask support #1133

Merged
merged 4 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/still-life-inpainted.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/still-life-scaled.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 27 additions & 1 deletion docs/features/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--from_file <path>` | | `None` | Read list of prompts from a file. Use `-` to read from standard input |
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
Expand Down Expand Up @@ -153,6 +154,7 @@ Here are the invoke> command that apply to txt2img:
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
| --individual | -i | True | Turn off grid mode (deprecated; leave off --grid instead) |
| --outdir <path> | -o<path> | outputs/img_samples | Temporarily change the location of these images |
Expand Down Expand Up @@ -210,11 +212,35 @@ accepts additional options:
[Inpainting](./INPAINTING.md) for details.

inpainting accepts all the arguments used for txt2img and img2img, as
well as the --mask (-M) argument:
well as the --mask (-M) and --text_mask (-tm) arguments:

| Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|--------------------|------------|---------------------|--------------|
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|

`--text_mask` (short form `-tm`) is a way to generate a mask using a
text description of the part of the image to replace. For example, if
you have an image of a breakfast plate with a bagel, toast and
scrambled eggs, you can selectively mask the bagel and replace it with
a piece of cake this way:

~~~
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel
~~~

The algorithm uses <a
href="https://github.com/timojl/clipseg">clipseg</a> to classify
different regions of the image. The classifier puts out a confidence
score for each region it identifies. Generally regions that score
above 0.5 are reliable, but if you are getting too much or too little
masking you can adjust the threshold down (to get more mask), or up
(to get less). In this example, by passing `-tm` a higher value, we
are insisting on a more stringent classification.

~~~
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel 0.6
~~~

# Other Commands

Expand Down
41 changes: 40 additions & 1 deletion docs/features/INPAINTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,46 @@ original unedited image and the masked (partially transparent) image:
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
```

We are hoping to get rid of the need for this workaround in an upcoming release.
## **Masking using Text**

You can also create a mask using a text prompt to select the part of
the image you want to alter, using the <a
href="https://github.com/timojl/clipseg">clipseg</a> algorithm. This
works on any image, not just ones generated by InvokeAI.

The `--text_mask` (short form `-tm`) option takes two arguments. The
first argument is a text description of the part of the image you wish
to mask (paint over). If the text description contains a space, you must
surround it with quotation marks. The optional second argument is the
minimum threshold for the mask classifier's confidence score, described
in more detail below.

To see how this works in practice, here's an image of a still life
painting that I got off the web.

<img src="../assets/still-life-scaled.jpg">

You can selectively mask out the
orange and replace it with a baseball in this way:

~~~
invoke> a baseball -I /path/to/still_life.png -tm orange
~~~

<img src="../assets/still-life-inpainted.png">

The clipseg classifier produces a confidence score for each region it
identifies. Generally regions that score above 0.5 are reliable, but
if you are getting too much or too little masking you can adjust the
threshold down (to get more mask), or up (to get less). In this
example, by passing `-tm` a higher value, we are insisting on a tigher
mask. However, if you make it too high, the orange may not be picked
up at all!

~~~
invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6
~~~


### Inpainting is not changing the masked region enough!

Expand Down
1 change: 1 addition & 0 deletions environment-mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies:
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg
- -e .
variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ dependencies:
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg
- -e .
46 changes: 37 additions & 9 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.model_cache import ModelCache

from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale

def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None

# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
Expand Down Expand Up @@ -266,6 +268,7 @@ def prompt2image(
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
text_mask = None,
fit = False,
strength = None,
init_color = None,
Expand Down Expand Up @@ -298,6 +301,8 @@ def prompt2image(
seamless // whether the generated image should tile
hires_fix // whether the Hires Fix should be applied during generation
init_img // path to an initial image
init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
Expand Down Expand Up @@ -405,6 +410,7 @@ def process_image(image,seed):
width,
height,
fit=fit,
text_mask=text_mask,
)

# TODO: Hacky selection of operation to perform. Needs to be refactored.
Expand Down Expand Up @@ -620,17 +626,14 @@ def _make_images(
width,
height,
fit=False,
text_mask=None,
):
init_image = None
init_mask = None
if not img:
return None, None

image = self._load_img(
img,
width,
height,
)
image = self._load_img(img)

if image.width < self.width and image.height < self.height:
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
Expand All @@ -648,10 +651,12 @@ def _make_images(
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor

if mask:
mask_image = self._load_img(
mask, width, height) # this returns an Image
mask_image = self._load_img(mask) # this returns an Image
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)

elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)

return init_image, init_mask

def _make_base(self):
Expand Down Expand Up @@ -830,7 +835,7 @@ def _set_sampler(self):

print(msg)

def _load_img(self, img, width, height)->Image:
def _load_img(self, img)->Image:
if isinstance(img, Image.Image):
image = img
print(
Expand Down Expand Up @@ -892,6 +897,29 @@ def _image_to_mask(self, mask_image, invert=False) -> Image:
mask = ImageOps.invert(mask)
return mask

# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
if self.txt2mask is None:
self.txt2mask = Txt2Mask(device = self.device)

segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB')
# now we adjust the size
if fit:
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)

def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:
return True
Expand Down
24 changes: 24 additions & 0 deletions ldm/invoke/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,14 @@ def _create_arg_parser(self):
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
)
model_group.add_argument(
'--png_compression','-z',
type=int,
default=6,
choices=range(0,9),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
)
model_group.add_argument(
'--sampler',
'-A',
Expand Down Expand Up @@ -650,6 +658,14 @@ def _create_dream_cmd_parser(self):
dest='save_intermediates',
help='Save every nth intermediate image into an "intermediates" directory within the output directory'
)
render_group.add_argument(
'--png_compression','-z',
type=int,
default=6,
choices=range(0,10),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
)
img2img_group.add_argument(
'-I',
'--init_img',
Expand All @@ -662,6 +678,14 @@ def _create_dream_cmd_parser(self):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'-tm',
'--text_mask',
nargs='+',
type=str,
help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).',
default=None,
)
img2img_group.add_argument(
'--init_color',
type=str,
Expand Down
1 change: 1 addition & 0 deletions ldm/invoke/generator/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ def get_noise(self,width,height):
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

4 changes: 2 additions & 2 deletions ldm/invoke/pngwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def unique_prefix(self):

# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)
if metadata:
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info)
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path

def retrieve_metadata(self,img_basename):
Expand Down
2 changes: 2 additions & 0 deletions ldm/invoke/readline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
'--log_tokenization','-t',
'--hires_fix',
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'!fix','!fetch','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model'
)
Expand Down
Loading