Skip to content

Commit

Permalink
fix crash in txt2img and img2img w/ inpainting models and perlin > 0 (#…
Browse files Browse the repository at this point in the history
…2544)

- get_perlin_noise() was returning 9 channels; fixed code to return
noise for just the 4 image channels and not the mask ones.

- Closes Issue #2541
  • Loading branch information
lstein committed Feb 6, 2023
2 parents 05bb9e4 + 0240656 commit 633f702
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 42 deletions.
31 changes: 30 additions & 1 deletion ldm/invoke/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def get_noise(self,width,height):

def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
noise = torch.stack([
rand_perlin_2d((height, width),
(8, 8),
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
return noise

def new_seed(self):
Expand Down Expand Up @@ -341,3 +346,27 @@ def save_sample(self, sample, filepath):

def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
x = (1-self.perlin)*x + self.perlin*perlin_noise
return x
19 changes: 0 additions & 19 deletions ldm/invoke/generator/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,3 @@ def get_noise_like(self, like: torch.Tensor):
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x

def get_noise(self,width,height):
# copy of the Txt2Img.get_noise
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x
22 changes: 0 additions & 22 deletions ldm/invoke/generator/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,4 @@ def make_image(x_T) -> PIL.Image.Image:
return make_image


# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

0 comments on commit 633f702

Please sign in to comment.