Skip to content

Commit

Permalink
Merge branch 'main' into add/ci/clean-caches
Browse files Browse the repository at this point in the history
  • Loading branch information
lstein committed Jan 31, 2023
2 parents 229514a + 053d11f commit 6ce75bd
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions ldm/invoke/CLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,8 @@ def _get_model_name(existing_names,completer,default_name:str='')->str:
model_name = input(f'Short name for this model [{default_name}]: ').strip()
if len(model_name)==0:
model_name = default_name
if not re.match('^[\w._+-]+$',model_name):
print('** model name must contain only words, digits and the characters "._+-" **')
if not re.match('^[\w._+:/-]+$',model_name):
print('** model name must contain only words, digits and the characters "._+:/-" **')
elif model_name != default_name and model_name in existing_names:
print(f'** the name {model_name} is already in use. Pick another.')
else:
Expand Down
12 changes: 9 additions & 3 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def inpaint_from_embeddings(
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))

init_image = init_image.to(device=device, dtype=latents_dtype)
mask = mask.to(device=device, dtype=latents_dtype)

if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
Expand All @@ -562,17 +563,22 @@ def inpaint_from_embeddings(

if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
.to(device=device, dtype=latents_dtype)

guidance: List[Callable] = []

if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)

# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
else:
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))

try:
result_latents, result_attention_maps = self.latents_from_embeddings(
Expand Down
23 changes: 9 additions & 14 deletions ldm/invoke/generator/txt2img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
'''

import math
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from typing import Callable, Optional

import torch
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error

from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
Expand Down Expand Up @@ -128,18 +128,13 @@ def get_noise(self,width,height,scale = True):
scaled_width = width
scaled_height = height

device = self.model.device
device = self.model.device
channels = self.latent_channels
if channels == 9:
channels = 4 # we don't really want noise for all the mask channels
shape = (1, channels,
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
if self.use_mps_noise or device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
2 changes: 1 addition & 1 deletion ldm/invoke/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def set_default_model(self,model_name:str) -> None:
Set the default model. The change will not take
effect until you call model_manager.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
assert model_name in self.model_names(), f"unknown model '{model_name}'"

config = self.config
for model in config:
Expand Down

0 comments on commit 6ce75bd

Please sign in to comment.