In [None]:
from init_notebook import *

import diffusers

from experiments.datasets import *

clip_device = "cpu"

In [None]:
SHAPE = (3, 32, 32)

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = diffusers.UNet2DModel(
            sample_size=SHAPE[-1],  # the target image resolution
            in_channels=SHAPE[0],  # the number of input channels, 3 for RGB images
            out_channels=SHAPE[0],  # the number of output channels
            # class_embed_type="identity",
            act_fn="silu",
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(512 // 4, 128, 128, 128),
      
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
      
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x, condition=None):
        output = self.model(x * 2. - 1., 0, condition).sample
        return (output * .5 + .5).clamp(0, 1)

model = Module()
model.load_state_dict(torch.load(
    "../checkpoints/super-res/unet_all_srf-4_aa-True_act-silu/snapshot.pt"
    #"../checkpoints/super-res/unet_all_srf-2_aa-True_act-silu/snapshot.pt"
)["state_dict"])
model.eval()
print(f"params: {num_module_parameters(model):,}")

In [None]:
ds = PixelartDataset((3, 32, 32), with_clip_embedding=True)

In [None]:
fac = 4

input = ds.shuffle(13).sample(8)[0]
input = resize(input, 1/fac)
display(VF.to_pil_image(make_grid(input)))

with torch.no_grad():
    output = model(resize(input, fac))
    display(VF.to_pil_image(make_grid(output)))

    output = model(resize(output, fac))
    display(VF.to_pil_image(make_grid(output)))


In [None]:
input = (torch.randn(4, SHAPE[0], 4, 4) * .3 + .3).clamp(0, 1)
with torch.no_grad():
    for i in range(5):
        input = resize(input, 2)
        output = model(input)
        display(VF.to_pil_image(make_grid(output)))
        input = output

In [None]:
image = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/bob/Bobdobbs_square.png"
    #"/home/bergi/Pictures/photos/katjacam/101MSDCF/DSC00471.JPG",
).convert("RGB"))
image = image_maximum_size(image, 32)
display(VF.to_pil_image(image))
image = image[None, :]
with torch.no_grad():
    for i in range(2):
        output = model(resize(image, 4))
        display(VF.to_pil_image(make_grid(output)))
        image = output


In [None]:
ds_ca = TotalCADataset((64, 64), seed=1, wrap=True, init_prob=.5, num_iterations=3)
input = ds_ca.offset(10).sample(4)[0].unsqueeze(1).repeat(1, 3, 1, 1).float()
input = resize(input, 4)
display(VF.to_pil_image(make_grid(input)))
#print(input)
with torch.no_grad():
    for i in range(1):
        output = model(input)
        display(VF.to_pil_image(make_grid(output)))
        input = resize(output, 2)


In [None]:
ds = InterleaveIterableDataset((
    PixelartDataset(shape=(SHAPE[0], 32, 32)).offset(3000).shuffle(23).limit(20_000)
        .transform([VT.RandomCrop(SHAPE[-1])]),
    WrapDataset(TensorDataset(torch.load(f"../datasets/photos-64x64-bcr03.pt")))
        .transform([VT.RandomCrop(SHAPE[-1])]), #, lambda x: x.float() / 255.])
    WrapDataset(TensorDataset(torch.load(f"../datasets/kali-uint8-64x64.pt")))
        .transform([VT.RandomCrop(SHAPE[-1]), lambda x: x.float() / 255.]),
    WrapDataset(TensorDataset(torch.load(f"../datasets/diverse-64x64-aug4.pt")))
        .transform([VT.RandomCrop(SHAPE[-1])]),
    #WrapDataset(TensorDataset(torch.load(f"../datasets/ifs-1x128x128-uint8-1000x16.pt")))
    #    .transform([VT.RandomCrop(SHAPE[-1]), lambda x: set_image_channels(x, SHAPE[0]).float() / 255., VT.RandomInvert(p=.5)])
))
VF.to_pil_image(make_grid(ds.sample(64)[0]))