In [None]:
from init_notebook import *

import diffusers

from experiments.datasets import *

clip_device = "cpu"

In [None]:
SHAPE = (3, 32, 32)

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = diffusers.UNet2DModel(
            sample_size=SHAPE[-1],  # the target image resolution
            in_channels=SHAPE[0],  # the number of input channels, 3 for RGB images
            out_channels=SHAPE[0],  # the number of output channels
            class_embed_type="identity",
            act_fn="silu",
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(512 // 4, 128, 128, 128),
      
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
      
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x, condition):
        return self.model(x, 0, condition).sample.clamp(0, 1)

model = Module()
model.load_state_dict(torch.load(
    #"../checkpoints/super-res/unet_pix1k_srf-4_aa-True_act-silu/snapshot.pt"
    "../checkpoints/super-res/unet_pix60k_srf-4_aa-True_act-silu/snapshot.pt"
)["state_dict"])
model.eval()
print(f"params: {num_module_parameters(model):,}")

In [None]:
ds = PixelartDataset((3, 32, 32), with_clip_embedding=True)

In [None]:
input, condition = ds.shuffle(13).sample(8)
input = resize(input, 1/4)
display(VF.to_pil_image(make_grid(input)))

with torch.no_grad():
    output = model(resize(input, 4), condition)
    display(VF.to_pil_image(make_grid(output)))

    output = model(resize(output, 4), condition)
    display(VF.to_pil_image(make_grid(output)))


In [None]:
input = (torch.randn(1, SHAPE[0], 4, 4).repeat(4, 1, 1, 1) * .3 + .3).clamp(0, 1)
condition = ds.sample(4)[1]
#condition = ClipSingleton.encode_text([
#    "cobblestone", "brick wall", "fire", "water",
#], device=clip_device)
with torch.no_grad():
    for i in range(5):
        input = resize(input, 2)
        output = model(input, condition)
        display(VF.to_pil_image(make_grid(output)))
        input = output

In [None]:
image = VF.to_tensor(PIL.Image.open(
    #"/home/bergi/Pictures/bob/Bobdobbs_square.png"
    "/home/bergi/Pictures/photos/katjacam/101MSDCF/DSC00471.JPG",
).convert("RGB"))
image = image_maximum_size(image, 32)
condition = ClipSingleton.encode_text([
    "cobblestone", "brick wall", "fire", "water",
], device=clip_device)
display(VF.to_pil_image(image))
image = image[None, :].repeat(condition.shape[0], 1, 1, 1)
with torch.no_grad():
    output = model(resize(image, 4), condition)
    display(VF.to_pil_image(make_grid(output)))
