diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 4c73b997e79..21d6f271cab 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -137,17 +137,9 @@ def sample_to_image(self,samples)->Image.Image: Given samples returned from a sampler, converts it into a PIL Image """ - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if len(x_samples) != 1: - raise Exception( - f'>> expected to get a single image, but got {len(x_samples)}') - x_sample = 255.0 * rearrange( - x_samples[0].cpu().numpy(), 'c h w -> h w c' - ) - return Image.fromarray(x_sample.astype(np.uint8)) - - # write an approximate RGB image from latent samples for a single step to PNG + with torch.inference_mode(): + image = self.model.decode_latents(samples) + return self.model.numpy_to_pil(image)[0] def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image: if init_image is None or init_mask is None: