In [None]:
from init_notebook import *

import diffusers

from experiments.datasets import *

In [None]:
class DDPMPipelineWithClasses(diffusers.DDPMPipeline):

    @dataclass
    class CallbackArg:
        pipeline: "DDPMPipelineWithClasses"
        image: torch.Tensor
        iteration: int
        timestep: int

    @torch.no_grad()
    def __call__(
            self,
            classes: Iterable[int] = (0,),
            batch_size: int = 1,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            num_inference_steps: int = 1000,
            output_type: Optional[str] = "pt",
            return_dict: bool = True,
            callback: Optional[Callable[[CallbackArg], torch.Tensor]] = None,
    ) -> Union[diffusers.ImagePipelineOutput, Tuple]:
        from diffusers.utils.torch_utils import randn_tensor

        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        if self.device.type == "mps":
            # randn does not work reproducibly on mps
            image = randn_tensor(image_shape, generator=generator)
            image = image.to(self.device)
        else:
            image = randn_tensor(image_shape, generator=generator, device=self.device)

        class_labels = list(classes)
        while len(class_labels) < batch_size:
            class_labels.extend(class_labels)
        class_labels = torch.LongTensor(class_labels[:batch_size]).to(self.device)

        return self.run_inference_on(
            image=image / 2. + .5,
            class_labels=class_labels,
            num_inference_steps=num_inference_steps,
            output_type=output_type,
            return_dict=return_dict,
            generator=generator,
            callback=callback,
        )

    @torch.no_grad()
    def run_inference_on(
            self,
            image: torch.Tensor,
            class_labels: Union[int, Iterable, torch.LongTensor],
            num_inference_steps: int = 1000,
            timestep_offset: int = 0,
            timestep_count: Optional[int] = None,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            output_type: Optional[str] = "pt",
            return_dict: bool = True,
            callback: Optional[Callable[[CallbackArg], torch.Tensor]] = None,
    ):
        image = image.to(self.device)
        
        if image.ndim == 4:
            pass
        elif image.ndim == 3:
            image = image.unsqueeze(0)
        else:
            raise ValueError(f"`image` must have 4 (or 3) dimensions, got {image.shape}")

        image = image * 2. - 1.

        if isinstance(class_labels, int):
            class_labels = torch.LongTensor([class_labels]).to(self.device)
        elif not isinstance(class_labels, torch.Tensor):
            class_labels = torch.LongTensor(class_labels).to(self.device)
    
        if class_labels.shape[0] != image.shape[0]:
            raise ValueError(f"batch-size of `class_labels` must match `image`, expected {image.shape[0]}, got {class_labels.shape}")

        # set step values
        self.scheduler.set_timesteps(num_inference_steps)

        timesteps = timestepsX = self.scheduler.timesteps[timestep_offset:]
        if timestep_count is not None:
            timesteps = timesteps[:timestep_count]
        # print(timestep_count, timestep_offset, timesteps, timestepsX)
        for idx, t in enumerate(self.progress_bar(timesteps)):
            # 1. predict noise model_output
            model_output = self.unet(image, t, class_labels).sample

            # 2. compute previous image: x_t -> x_t-1
            image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

            if callback is not None:
                image = callback(self.CallbackArg(
                    pipeline=self, image=image, iteration=idx, timestep=t
                ))

        image = (image / 2 + 0.5).clamp(0, 1)
        if output_type == "pt":
            pass
        else:
            image = image.cpu().permute(0, 2, 3, 1).numpy()
            if output_type == "pil":
                image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return diffusers.ImagePipelineOutput(images=image)


In [None]:
#pipeline = DDPMPipelineWithClasses(model, diffusers.DDPMScheduler(num_train_timesteps=100))
pipe = DDPMPipelineWithClasses.from_pretrained("../ddpm-test-01/03").to("cuda")

In [None]:
grid = []
try:
    for klass in range(3):
        images = pipe(classes=[klass], batch_size=6*6, num_inference_steps=30).images
        grid.append(make_grid(images, nrow=6))
except KeyboardInterrupt:
    pass

if grid:
    display(VF.to_pil_image(make_grid(grid, nrow=6, padding=5)))

In [None]:
input = torch.randn(16*16, 4, 32, 32) + .5
label = PixelartDataset.LABELS.index("other")
output = pipe.run_inference_on(input, class_labels=[label] * input.shape[0], num_inference_steps=20).images
VF.to_pil_image(make_grid(output, nrow=16))

In [None]:
some_image = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/csv-turing.png").convert("RGBA"))
some_image = resize(some_image, 1/32)
print(some_image.shape)
VF.to_pil_image(some_image)

In [None]:
def gen_images(label: int, batch_size: int = 8, size: int = 32, steps: int = 100, offset: int = 0):
    image_widget = ImageWidget()
    display(image_widget)
    images = []
    
    def _callback(arg: DDPMPipelineWithClasses.CallbackArg):
        if arg.iteration % max(1, steps // 7) == 0:
            image = (arg.image * .5 + .5).clamp(0, 1)
            images.extend(image.cpu())
            image_widget.set_torch(resize(make_grid(images, nrow=batch_size), 2))
        
        return arg.image #* 1.01 - 0.001
    
    #input = torch.randn(batch_size, 4, size, size) * 1. + .5
    input = (some_image.unsqueeze(0).repeat(batch_size, 1, 1, 1)).clamp(0, 1)
    input[:, :3] = 1. - input[:, :3]

    if offset > 0:
        pipe.scheduler.set_timesteps(steps)
        #print("X", steps, pipe.scheduler.timesteps)
        #print("X", offset, pipe.scheduler.timesteps[offset])
        input = pipe.scheduler.add_noise(
            input, torch.randn_like(input), 
            #pipe.scheduler.timesteps[offset],
            torch.LongTensor([150])
        )
        
    images.extend(input)
    output = pipe.run_inference_on(
        input, class_labels=[label] * input.shape[0], num_inference_steps=steps,
        timestep_offset=offset,
        callback=_callback,
    ).images
    
    display(VF.to_pil_image(resize(make_grid(output), 1)))

gen_images(label=PixelartDataset.LABELS.index("wall"), size=32, batch_size=8, steps=20, offset=7)

In [None]:
VF.to_pil_image(resize(make_grid(output), 2))

In [None]:
input = some_image.unsqueeze(0).repeat(8, 1, 1, 1) - .3
output = pipe.run_inference_on(
    input, 
    class_labels=[0] * input.shape[0],
    num_inference_steps=20,
)
VF.to_pil_image(resize(make_grid(output.images), 2))

### process big image in patches

In [None]:
def process_image_patches(
    image: torch.Tensor,
    label: int, batch_size: int = 16, size: int = 32, steps: int = 30
):
    image_widget = ImageWidget()
    display(image_widget)
    image_widget.set_torch(image)
    
    def _callback(arg: DDPMPipelineWithClasses.CallbackArg):
        if arg.iteration % max(1, steps // 7) == 0:
            image = (arg.image * .5 + .5).clamp(0, 1)
            images.extend(image)
            image_widget.set_torch(resize(make_grid(images, nrow=batch_size), 2))
        
        return arg.image #* 1.01 - 0.001

    image = image.to(pipe.device)
    pipe.set_progress_bar_config(disable=True)
    for timestep_offset in tqdm(range(0, steps, 2)):
        for patches, positions in iter_image_patches(
                image,
                #count=50,
                shape=(size, size),
                stride=(size // 2, size // 2),
                batch_size=batch_size,
                with_pos=True,
                verbose=False,
        ):
            patches = pipe.run_inference_on(
                patches, class_labels=[label] * patches.shape[0], 
                num_inference_steps=steps,
                timestep_offset=timestep_offset,
                timestep_count=5,
                #callback=_callback,
            ).images

            for patch, pos in zip(patches, positions):
                s1, s2 = slice(pos[-2], pos[-2] + patch.shape[-2]), slice(pos[-1], pos[-1] + patch.shape[-1])
                image[:, s1, s2] += .2 * (patch - image[:, s1, s2])

        image_widget.set_torch(image.clamp(0, 1).cpu())
        
    return image.clamp(0, 1).cpu()
    #display(VF.to_pil_image(resize(make_grid(output), 1)))
    
output = process_image_patches(
    torch.randn(4, 256, 256) + .4,
    label=PixelartDataset.LABELS.index("wall"), 
    size=32, batch_size=16, steps=20,
)
VF.to_pil_image(resize(output, 2))

In [None]:
VF.to_pil_image(resize(output, 2))