In [None]:
from init_notebook import *

import diffusers

from experiments.datasets import *

clip_device = "cpu"

In [None]:
class DDPMPipelineWithEmbedding(diffusers.DDPMPipeline):

    @dataclass
    class CallbackArg:
        pipeline: "DDPMPipelineWithEmbedding"
        image: torch.Tensor
        iteration: int
        timestep: int

    @torch.no_grad()
    def __call__(
            self,
            embedding: Optional[torch.Tensor] = None,
            batch_size: int = 1,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            num_inference_steps: int = 1000,
            size: Optional[int] = None,
            output_type: Optional[str] = "pt",
            return_dict: bool = True,
            callback: Optional[Callable[[CallbackArg], torch.Tensor]] = None,
    ) -> Union[diffusers.ImagePipelineOutput, Tuple]:
        from diffusers.utils.torch_utils import randn_tensor

        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                size or self.unet.config.sample_size,
                size or self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, size or self.unet.config.sample_size[0], self.unet.config.sample_size[1])

        if self.device.type == "mps":
            # randn does not work reproducibly on mps
            image = randn_tensor(image_shape, generator=generator)
            image = image.to(self.device)
        else:
            image = randn_tensor(image_shape, generator=generator, device=self.device)

        if embedding is None:
            embedding = randn_tensor((batch_size, 512), generator=generator, device=self.device)
        else:
            embedding = embedding.to(self.device)
            if embedding.ndim == 1:
                embedding = embedding.unsqueeze(0)
            if embedding.ndim != 2:
                raise ValueError(f"`embedding` must have 2 dimensions, got {embedding.shape}")
            if embedding.shape[0] < batch_size:
                embedding = embedding.repeat(batch_size // embedding.shape[0], 1)
            if embedding.shape[0] > batch_size:
                embedding = embedding[:batch_size]

        return self.run_inference_on(
            image=image / 2. + .5,
            embedding=embedding,
            num_inference_steps=num_inference_steps,
            output_type=output_type,
            return_dict=return_dict,
            generator=generator,
            callback=callback,
        )

    @torch.no_grad()
    def run_inference_on(
            self,
            image: torch.Tensor,
            embedding: torch.Tensor,
            num_inference_steps: int = 1000,
            timestep_offset: int = 0,
            timestep_count: Optional[int] = None,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            output_type: Optional[str] = "pt",
            return_dict: bool = True,
            callback: Optional[Callable[[CallbackArg], torch.Tensor]] = None,
    ):
        image = image.to(self.device)

        if image.ndim == 4:
            pass
        elif image.ndim == 3:
            image = image.unsqueeze(0)
        else:
            raise ValueError(f"`image` must have 4 (or 3) dimensions, got {image.shape}")

        image = image * 2. - 1.

        if embedding.shape[0] != image.shape[0]:
            raise ValueError(f"batch-size of `embedding` must match `image`, expected {image.shape[0]}, got {embedding.shape}")
        embedding = embedding.to(self.device)
        
        # set step values
        self.scheduler.set_timesteps(num_inference_steps)

        timesteps = self.scheduler.timesteps[timestep_offset:]
        if timestep_count is not None:
            timesteps = timesteps[:timestep_count]

        for idx, t in enumerate(self.progress_bar(timesteps)):
            # 1. predict noise model_output
            model_output = self.unet(image, t, embedding).sample

            # 2. compute previous image: x_t -> x_t-1
            image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

            if callback is not None:
                image = callback(self.CallbackArg(
                    pipeline=self, image=image, iteration=idx, timestep=t
                ))

        image = (image / 2 + 0.5).clamp(0, 1)
        if output_type == "pt":
            pass
        else:
            image = image.cpu().permute(0, 2, 3, 1).numpy()
            if output_type == "pil":
                image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return diffusers.ImagePipelineOutput(images=image)


In [None]:
#pipe = DDPMPipelineWithEmbedding.from_pretrained("../checkpoints/hug-diff/ddpm-07-clip-norm-mult5-clamp0.3/").to("cuda")  # the good one
pipe = DDPMPipelineWithEmbedding.from_pretrained("../checkpoints/hug-diff/ddpm-09-clip-norm-mult5-clamp0.3/").to("cuda")

In [None]:
pipe.scheduler

In [None]:
pipe.scheduler = diffusers.schedulers.DDPMScheduler(
   # timestep_spacing="trailing", 
    #thresholding=True,
)

In [None]:
ds = PixelartDataset((4, 32, 32))
some_images = [
    "/home/bergi/Pictures/csv-turing.png",
    "/home/bergi/Pictures/Abstract-Radiation-Hazard-Sign.jpg",
    "/home/bergi/Pictures/diffusion/cthulhu-11.jpeg",
]
some_images = [
    VF.resize(VF.to_tensor(PIL.Image.open(name).convert("RGBA")), (32, 32), VF.InterpolationMode.BILINEAR, antialias=True)
    for name in some_images
]

VF.to_pil_image(make_grid(some_images))

In [None]:
VF.to_pil_image(resize(ds[9000][0], 3))

In [None]:
def scale_embedding(embedding: torch.Tensor):
    return (embedding / torch.norm(embedding, dim=1, keepdim=True)).clamp(-3., 3.) * 5.

In [None]:
#prompt_embeds = ClipSingleton.encode_text([
#    "fire", "cobblestone", "brick wall", "wizard", "face",
#], device="cpu") * 2.
prompt_embeds = scale_embedding(ClipSingleton.encode_image([
    *some_images,
    ds[723][0], # green sword
    ds[823][0], # green dragon
    ds[3023][0], # yellowish rocks
    ds[3323][0], # dog?
    ds[8229][0], # marble block
    ds[8413][0], # cables before white
    ds[8523][0], # brick wall, grass
    ds[8840][0], # rocks, grass
    ds[9000][0], # cobblestone, grass
], device=clip_device))
#print(prompt_embeds.flatten(1).max(1))
#prompt_embeds /= prompt_embeds.abs().max()

grid = []
try:
    for idx in range(prompt_embeds.shape[0]):
        images = pipe(embedding=prompt_embeds[None, idx].repeat(6*6, 1), batch_size=6*6, num_inference_steps=10).images
        grid.append(make_grid(images, nrow=6))
except KeyboardInterrupt:
    pass

if grid:
    display(VF.to_pil_image(make_grid(grid, nrow=3, padding=5)))

In [None]:
prompt_embeds = scale_embedding(ClipSingleton.encode_text([
    "cobblestone", "brick wall", "steel", 
    "sand", "grass", "water",
    "fire", "ice", "air", 
    "sky", "earth", "mountains",

    "creature", "dragon", "dog",
    "wizard", "man", "woman",
    
    "face", "hand", "feet",
    "sword", "axe", "armour",
], device=clip_device))

grid = []
try:
    for idx in range(prompt_embeds.shape[0]):
        images = pipe(embedding=prompt_embeds[None, idx].repeat(6*6, 1), batch_size=6*6, num_inference_steps=10).images
        grid.append(make_grid(images, nrow=6))
except KeyboardInterrupt:
    pass

if grid:
    display(VF.to_pil_image(make_grid(grid, nrow=6, padding=5)))

In [None]:
prompt_embeds = scale_embedding(ClipSingleton.encode_text([
    "woman",
], device=clip_device))

grid = []
try:
    for _ in range(32):
        images = pipe(embedding=prompt_embeds.repeat(32, 1), batch_size=32, num_inference_steps=100).images
        grid.extend(images)
except KeyboardInterrupt:
    pass

if grid:
    display(VF.to_pil_image(make_grid(grid, nrow=32, padding=0)))

In [None]:
def get_wang_patches(corner: bool = False):
    
    tiles = wangtiles.WangTiles2C() if corner else wangtiles.WangTiles2E()
    template = tiles.create_template((32, 32), fade=5., padding=0)
    
    template.image = torch.concat([template.image, template.image[:1]], dim=0)
    template.image = torch.sigmoid((template.image - .985) * 400.)
    template.image[:, 32:64, 64:96] = 1
    # VF.to_pil_image(wang_template.image)
    patches = torch.concat([
        i.unsqueeze(0) for i in iter_image_patches(template.image, (32, 32))
    ])
    return tiles, template, patches

wang_tiles, wang_template, wang_patches = get_wang_patches(corner=True)
VF.to_pil_image(make_grid(wang_patches, nrow=4))

In [None]:
VF.to_pil_image(ds[602][0])

In [None]:
def wang_images_from_prompt(
    prompt: str, batch_size: int = 1,#wang_patches.shape[0], 
    size: int = wang_template.shape[-1], steps: int = 100, offset: int = 0,
    random_shift: int = 1,
    brightness: float = .0,
    contrast: float = 1.,
    wrap_pad: int = 0,
    fac: float = 0.,
    blur: float = 0.,
):
    demo_patch = make_grid([ds[602][0]] * 16, nrow=4, padding=0).to(pipe.device) * 2. - 1.
    #demo_patch = .5 * torch.ones(4, 32*4, 32*4).to(pipe.device)
    #wang_patches_ = wang_patches.to(pipe.device)
    wang_image = wang_template.image.to(pipe.device)
    
    last_shift = None
    def _callback(arg: DDPMPipelineWithEmbedding.CallbackArg):
        if arg.iteration < steps - 2:
            if blur:
                arg.image = arg.image + blur * (VF.gaussian_blur(arg.image, 3, 3) - arg.image)
            arg.image = arg.image.clone()
            arg.image[0] += fac * (.01 + .99 * wang_image) * (demo_patch - arg.image[0])
        
        nonlocal last_shift
        if random_shift > 0:
            if last_shift is None:
                last_shift = (
                    #(random.randrange(2) - 1) * random_shift, (random.randrange(2) - 1) * random_shift,
                    random.randrange(-random_shift, random_shift + 1),
                    random.randrange(-random_shift, random_shift + 1),
                )
                return image_shift(arg.image, *last_shift) 
            else:
                image = image_shift(arg.image, -last_shift[0], -last_shift[1])
                last_shift = None
                return image
        return arg.image

    prompt_embeds = scale_embedding(ClipSingleton.encode_text(
        [prompt] * batch_size, device=clip_device
    ))
    
    noise = (torch.randn(batch_size, 4, size, size) * contrast + brightness).clamp(0, 1)
            
    output = pipe.run_inference_on(
        noise, embedding=prompt_embeds, num_inference_steps=steps,
        timestep_offset=offset,
        callback=_callback,
    ).images

    # output = output.repeat(1, 1, 3, 3)
    display(VF.to_pil_image(resize(make_grid(output, nrow=4), 2)))

    template = wang_template.with_new_image(make_grid(output, nrow=4, padding=0))
    map = wangtiles.wang_map_stochastic_scanline(wang_tiles, (9, 12))
    display(VF.to_pil_image(resize(template.render_map(map), 2)))


wang_images_from_prompt(
    #"wooden texture",
    #"blood splattered "
    #"dark cobblestone", 
    #"black and white", 
    #"pixelart",
    #"alien",
    #"alien glittering magnified cobblestone background pattern",
    "hieroglyphs",
    steps=50, random_shift=0, brightness=0., contrast=4.,
    fac=.04,
    #blur=0.01,
)

In [None]:
#wangtiles.wang_map_stochastic_scanline?
map = wangtiles.wang_map_stochastic_scanline(wangtiles.WangTiles2C(), (9, 12))
VF.to_pil_image(resize(_340.render_map(map), 2))

## random

In [None]:
input = .5 + torch.randn(16*16, 4, 32, 32) #.repeat(16, 1, 1, 1) 
embeddings = torch.randn(16, 512).repeat(16, 1).to(pipe.device) * .5
output = pipe.run_inference_on(input, embedding=embeddings, num_inference_steps=20).images
VF.to_pil_image(make_grid(output, nrow=16))

In [None]:
some_image = VF.to_tensor(PIL.Image.open(
    #"/home/bergi/Pictures/Abstract-Radiation-Hazard-Sign.jpg"
    "/home/bergi/Pictures/kali2.png"
).convert("RGBA"))
some_image = VF.resize(some_image, (64, 64))

In [None]:
def gen_images_prompt(
    prompt: str, batch_size: int = 8, size: int = 32, steps: int = 100, offset: int = 0,
    random_shift: int = 0,
    brightness: float = 0.,
):
    image_widget = ImageWidget()
    display(image_widget)
    images = []
    
    def _callback(arg: DDPMPipelineWithEmbedding.CallbackArg):
        if arg.iteration % max(1, steps // 7) == 0:
            image = (arg.image * .5 + .5).clamp(0, 1)
            images.extend(image.cpu())
            image_widget.set_torch(resize(make_grid(images, nrow=batch_size), 2))

        if arg.iteration < steps - 1:
            arg.image = arg.image + 0
            arg.image[:, :3, :, :9] = ds[8523][0][:3, :, :9]
        
        if random_shift > 0:
            return image_shift(
                arg.image, 
                random.randrange(-random_shift, random_shift + 1),
                random.randrange(-random_shift, random_shift + 1),
            )
        return arg.image #* 1.01 - 0.001

    prompt_embeds = scale_embedding(ClipSingleton.encode_text(
        [prompt] * batch_size,
        device=clip_device
    ))
    
    input = noise = torch.randn(batch_size, 4, size, size) + brightness# * 1. + .5
    #input = (some_image.unsqueeze(0).repeat(batch_size, 1, 1, 1)).clamp(0, 1)
    #input[:, :3] = 1. - input[:, :3]
    #input = (input + .1 * (noise - input) + .3).clamp(0, 1)
    
    if offset > 0:
        pipe.scheduler.set_timesteps(steps)
        #print("X", steps, pipe.scheduler.timesteps)
        #print("X", offset, pipe.scheduler.timesteps[offset])
        input = pipe.scheduler.add_noise(
            input, torch.randn_like(input) + (.5 * (steps - offset) / steps), 
            pipe.scheduler.timesteps[offset],
            #torch.LongTensor([950])
        )
        
    images.extend(input)
    output = pipe.run_inference_on(
        input, embedding=prompt_embeds, num_inference_steps=steps,
        timestep_offset=offset,
        callback=_callback,
    ).images
    
    display(VF.to_pil_image(resize(make_grid(output, nrow=batch_size), 2)))

gen_images_prompt(
    "blood stained brick wall", 
    size=32, batch_size=12, steps=10, offset=0, random_shift=0, brightness=.2,
)

In [None]:
def gen_prompts(
    *prompts: str,
    steps: int = 20,
    count: int = 6 * 6,
    size: int = 32,
):
    prompt_embeds = scale_embedding(ClipSingleton.encode_text(prompts, device=clip_device))

    grid = []
    try:
        for idx in range(len(prompts)):
            images = pipe(
                embedding=prompt_embeds[None, idx].repeat(count, 1), 
                batch_size=count, 
                num_inference_steps=steps,
                size=size,
            ).images
            grid.append(make_grid(images, nrow=int(math.ceil(math.sqrt(count)))))
    except KeyboardInterrupt:
        pass

    if grid:
        display(VF.to_pil_image(make_grid(grid, nrow=int(math.ceil(math.sqrt(len(prompts)))), padding=5)))

gen_prompts(
    "depiction of a sword",
    "plants",
    size=32,
    steps=100,
)

In [None]:
def tilable_images_from_prompt(
    prompt: str, batch_size: int = 8, size: int = 32, steps: int = 100, offset: int = 0,
    random_shift: int = 1,
    brightness: float = .0,
    contrast: float = 1.,
):
    last_shift = None
    def _callback(arg: DDPMPipelineWithEmbedding.CallbackArg):
        nonlocal last_shift
        if random_shift > 0:
            if last_shift is None:
                last_shift = (
                    #(random.randrange(2) - 1) * random_shift, (random.randrange(2) - 1) * random_shift,
                    random.randrange(-random_shift, random_shift + 1),
                    random.randrange(-random_shift, random_shift + 1),
                )
                return image_shift(arg.image, *last_shift) 
            else:
                image = image_shift(arg.image, -last_shift[0], -last_shift[1])
                last_shift = None
                return image
        return arg.image

    prompt_embeds = scale_embedding(ClipSingleton.encode_text(
        [prompt] * batch_size, device=clip_device
    ))
    
    noise = (torch.randn(batch_size, 4, size, size) * contrast + brightness).clamp(0, 1)
            
    output = pipe.run_inference_on(
        noise, embedding=prompt_embeds, num_inference_steps=steps,
        timestep_offset=offset,
        callback=_callback,
    ).images

    output = output.repeat(1, 1, 3, 3)
    
    display(VF.to_pil_image(resize(make_grid(output, nrow=4), 1)))


tilable_images_from_prompt(
    #"blood splattered cobblestone", 
    "cobblestone pattern",
    #"pixelart",
    size=64, batch_size=16, steps=50, random_shift=1, brightness=.4, contrast=1.,
)

In [None]:
def images_from_prompt2(
    prompt: str, batch_size: int = 8, size: int = 32, steps: int = 100, offset: int = 0,
    random_shift: int = 1,
    brightness: float = .0,
    contrast: float = 1.,
):
    last_shift = None
    def _callback(arg: DDPMPipelineWithEmbedding.CallbackArg):
        arg.image = arg.image + 0
        arg.image[:1, 10:20, 10:20] = .5
        nonlocal last_shift
        if random_shift > 0:
            if last_shift is None:
                last_shift = (
                    #(random.randrange(2) - 1) * random_shift, (random.randrange(2) - 1) * random_shift,
                    random.randrange(-random_shift, random_shift + 1),
                    random.randrange(-random_shift, random_shift + 1),
                )
                return image_shift(arg.image, *last_shift) 
            else:
                image = image_shift(arg.image, -last_shift[0], -last_shift[1])
                last_shift = None
                return image
        return arg.image

    prompt_embeds = scale_embedding(ClipSingleton.encode_text(
        [prompt] * batch_size, device=clip_device
    ))
    
    noise = (torch.randn(batch_size, 4, size, size) * contrast + brightness).clamp(0, 1)
            
    output = pipe.run_inference_on(
        noise, embedding=prompt_embeds, num_inference_steps=steps,
        timestep_offset=offset,
        callback=_callback,
    ).images

    # output = output.repeat(1, 1, 3, 3)
    display(VF.to_pil_image(resize(make_grid(output, nrow=4), 1)))


images_from_prompt2(
    #"blood splattered cobblestone", 
    "cobblestone pattern, high contrast",
    size=128, batch_size=16, steps=50, random_shift=0, brightness=.4, contrast=1.,
)

In [None]:
norm = True
embeds = torch.concat([
    ClipSingleton.encode_image(VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/bob/Bobdobbs.jpg")), normalize=norm),
    ClipSingleton.encode_text(["bob dobbs"], normalize=norm),
]).cpu()
embeds.shape

In [None]:
px.line(embeds.T)

In [None]:
from diffusers.models.embeddings import Timesteps
ts = Timesteps(512, True, 0)

In [None]:
with torch.no_grad():
    te = ts(torch.Tensor([5, 100]))
    te = te + embeds.cpu().clamp(-.2, .2) * 10.
    display(px.line(te.T))

In [None]:
feature = scale_embedding(ClipSingleton.encode_text(["cobblestone", "sword", "grass", "sand"], device=clip_device))
input = torch.randn(1, 4, 32, 32, generator=torch.manual_seed(23)).clamp(-1, 1).repeat(feature.shape[0], 1, 1, 1)
with torch.no_grad():
    output = pipe.unet(input.cuda(), torch.LongTensor([500] * feature.shape[0]).cuda(), feature.cuda()).sample.cpu()
    restored1 = input - output 
    output2 = pipe.unet(restored1.cuda(), torch.LongTensor([250] * feature.shape[0]).cuda(), feature.cuda()).sample.cpu()
    restored2 = restored1 - output2
    output3 = pipe.unet(restored2.cuda(), torch.LongTensor([100] * feature.shape[0]).cuda(), feature.cuda()).sample.cpu()
    restored3 = restored2 - output3/2
    output4 = pipe.unet(restored3.cuda(), torch.LongTensor([100] * feature.shape[0]).cuda(), feature.cuda()).sample.cpu()
    restored4 = restored3 - output4/2

grid = (torch.concat([input, output, restored1, output2, restored2, output3, restored3, output4, restored4]) * .5 + .5).clamp(0, 1)
VF.to_pil_image(resize(make_grid(grid, nrow=input.shape[0]), 2))