In [None]:
from init_notebook import *

In [None]:
img = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/__diverse/capitalism2.jpg"))

In [None]:
from src.functional import soft_histogram
images = torch.cat([img[None, :, :100, :100], img[None, :, 100:200, :100]])
images = images.view(2 * 3, 100, 100)
h = soft_histogram(images, 128, 0, 1, sigma=100)
h = h.view(2, 3, 128).mean(0)
px.line(h.T)

In [None]:
from experiments.datasets.classic import _dataset
from src.datasets import WrapDataset
import torchvision
from functools import partial

def flowers102_dataset(
        train: bool,
        shape: Tuple[int, int, int] = (3, 96, 96),
        interpolation: bool = True,
) -> Dataset:
    ds = torchvision.datasets.Flowers102(
        "~/prog/data/datasets/", split="train" if train else "test", #download=True,
    )
    def cropper(item):
        return image_resize_crop(
            item, 
            shape=shape[-2:], 
            interpolation=VF.InterpolationMode.BILINEAR if interpolation else VF.InterpolationMode.NEAREST,
        )
        
    return (
        WrapDataset(ds)
        .transform([
            VF.to_tensor,
            cropper,
        ])
    )

ds = flowers102_dataset(True)
VF.to_pil_image(ds[0][0])

In [None]:
VF.to_pil_image(make_grid(
    [ds[i][0] for i in range(8*8)]
))

In [None]:
trainer = load_experiment_trainer(
    #"../experiments/diffusion/blurdiffusion2.yml", 
    "../experiments/diffusion/noisediff-vit-class.yml", 
    device="cpu",
)
trainer

In [None]:
# print(trainer.checkpoint_path)
trainer.load_checkpoint()

In [None]:
from experiments.diffusion.trainer import DiffusionModelInput

def generate(steps: int = 10, shape=(3, 32, 32), seed=None):
    image_list = []
    with torch.no_grad():

        init_image = torch.ones((1, *shape)) * -1
        o = shape[-1] // 3
        init_image[:, :, o:-o, o:-o] = torch.rand((1, shape[0], shape[1]-2*o, shape[2]-2*o), generator=seed)
        
        classes = trainer.num_class_logits
        images, noise_amounts = trainer.diffusion_sampler.add_noise(
            init_image.repeat(classes, 1, 1, 1).to(trainer.device),
            torch.ones(classes, 1).to(trainer.device) * .5,
            seed=seed,
        )
        class_logits = torch.tensor([
            [10 if i == j else 0 for j in range(classes)]
            for i in range(classes)
        ], dtype=torch.int).to(trainer.device)
        
        for step in range(steps):
            image_list += list(images.clamp(-1, 1) * .5 + .5)
            
            target_noise_amounts = noise_amounts * .9
            predicted_noise = trainer.model(DiffusionModelInput(images, noise_amounts, class_logits)).noise
            a = 1
            images = images * (1-a) + a * trainer.diffusion_sampler.remove_noise(
                images, 
                predicted_noise,
                #target_noise_amounts,
            )
            noise_amounts = target_noise_amounts
            
        display(VF.to_pil_image(make_grid(image_list, nrow=classes)))

generate(shape=(3, 32, 32))

In [None]:
from experiments.datasets.classic import *
from experiments.diffusion.sampler import *
ds = ClassLogitsDataset(cifar10_dataset(False), tuple_position=1)
images = torch.cat([ds[i][0].unsqueeze(0) for i in range(8*8)])
classes = torch.cat([ds[i][1].unsqueeze(0) for i in range(8*8)])
display(VF.to_pil_image(make_grid(images)))

In [None]:
c = torchvision.datasets.CIFAR10("/home/bergi/prog/data/datasets")
c.classes

In [None]:
noisy_images, amounts = trainer.diffusion_sampler.add_noise(images * 2 - 1, .95 *torch.ones(images.shape[0], 1).to(images))
grids = [make_grid(noisy_images * .5 + .5)]
with torch.no_grad():
    noise = trainer.model(DiffusionModelInput(noisy_images, amounts, classes)).noise
    denoised = trainer.diffusion_sampler.remove_noise(noisy_images, noise)
    denoised = (denoised * .5 + .5).clamp(0, 1)
    grids.append(make_grid(denoised))

    classes2 = classes[1:2].repeat(classes.shape[0], 1)
    noise = trainer.model(DiffusionModelInput(noisy_images, amounts, classes2)).noise
    denoised = trainer.diffusion_sampler.remove_noise(noisy_images, noise)
    denoised = (denoised * .5 + .5).clamp(0, 1)
    grids.append(make_grid(denoised))
    
display(VF.to_pil_image(make_grid(grids)))

In [None]:
def generate_one(size=128, seed=None, steps=10):
    with torch.no_grad():
        shape = (3, size, size)
        classes = trainer.num_class_logits
        images, noise_amounts = trainer.diffusion_sampler.add_noise(
            torch.randn((1, *shape), generator=seed).repeat(classes, 1, 1, 1).to(trainer.device),
            torch.ones(1, 1).to(trainer.device) * .5,
            seed=seed,
        )
        images = images.clamp(-1, 1)
        class_logits = torch.tensor([
            [10 if i == j else 0 for j in range(classes)]
            for i in range(1)
        ], dtype=torch.int).to(trainer.device)

        image_list = []
        for step in range(steps):
            
            predicted_noise = trainer.model(DiffusionModelInput(images, noise_amounts, class_logits)).noise
            a = .9
            images = images * (1-a) + a * trainer.diffusion_sampler.remove_noise(
                images, 
                predicted_noise,
                noise_amounts,
            )
            if step % 4 == 0:
                image_list += list(images.clamp(-1, 1) * .5 + .5)
            #noise_amounts -= noise_amounts * .1
            images, _ = trainer.diffusion_sampler.add_noise(
                images,
                torch.ones(1, 1).to(trainer.device) * .5,
                seed=seed,
            )
            images += .1 * torch.randn_like(images)
            
        display(VF.to_pil_image(make_grid(image_list, nrow=4)))

generate_one(steps=16*3)

In [None]:
def process(filename, seed=None, steps=6):
    filename = Path(filename).expanduser()
    
    with torch.no_grad():
        image = VF.to_tensor(PIL.Image.open(filename).convert("RGB")) * 2. - 1.
        
        images, noise_amounts = trainer.diffusion_sampler.add_noise(
            image.unsqueeze(0).to(trainer.device),
            torch.ones(1, 1).to(trainer.device),
            seed=seed,
        )
        images = images.clamp(-1, 1)
        class_logits = torch.tensor([
            [10 if i == j else 0 for j in range(trainer.num_class_logits)]
            for i in range(1)
        ], dtype=torch.int).to(trainer.device)

        image_list = []
        for step in range(steps):
            image_list += list(images.clamp(-1, 1) * .5 + .5)
            
            predicted_noise = trainer.model(DiffusionModelInput(images, noise_amounts, class_logits)).noise
            a = 1
            images = images * (1-a) + a * trainer.diffusion_sampler.remove_noise(
                images, 
                predicted_noise,
                noise_amounts,
            )
            #noise_amounts -= noise_amounts * a

            images, _ = trainer.diffusion_sampler.add_noise(
                images,
                torch.ones(1, 1).to(trainer.device) * .5,
                seed=seed,
            )
            
        display(VF.to_pil_image(make_grid(image_list, nrow=1)))

process(
    "/home/bergi/Pictures/__diverse/_1983018_orson_150.jpg"
    #"/home/bergi/Pictures/__diverse/2NATO50thAnniversaryLogo01.jpg"
)