Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix crash in txt2img and img2img w/ inpainting models and perlin > 0 #2544

Merged
merged 1 commit into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion ldm/invoke/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def get_noise(self,width,height):

def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
noise = torch.stack([
rand_perlin_2d((height, width),
(8, 8),
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
return noise

def new_seed(self):
Expand Down Expand Up @@ -341,3 +346,27 @@ def save_sample(self, sample, filepath):

def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
x = (1-self.perlin)*x + self.perlin*perlin_noise
return x
19 changes: 0 additions & 19 deletions ldm/invoke/generator/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,3 @@ def get_noise_like(self, like: torch.Tensor):
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x

def get_noise(self,width,height):
# copy of the Txt2Img.get_noise
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x
22 changes: 0 additions & 22 deletions ldm/invoke/generator/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,4 @@ def make_image(x_T) -> PIL.Image.Image:
return make_image


# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x